1use core::fmt::{Debug, Formatter};
17
18use curve25519_dalek::constants::RISTRETTO_BASEPOINT_TABLE;
19use curve25519_dalek::ristretto::RistrettoPoint;
20use curve25519_dalek::scalar::Scalar;
21use rand_core::{CryptoRng, RngCore};
22
23#[cfg(feature = "enable-serde")]
24use serde::{Deserialize, Deserializer, Serialize, de::Visitor};
25
26use crate::{Ciphertext, EncryptionKey};
27
28#[derive(Copy, Clone, Eq, PartialEq)]
30#[cfg_attr(feature = "enable-serde", derive(Serialize))]
31pub struct DecryptionKey {
32 pub(crate) secret: Scalar,
33 #[cfg_attr(feature = "enable-serde", serde(skip_serializing))]
34 pub(crate) ek: EncryptionKey,
35}
36
37impl DecryptionKey {
38 pub fn new<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
52 let secret = Scalar::random(rng);
53 let ek = EncryptionKey(&secret * &RISTRETTO_BASEPOINT_TABLE);
54 Self { secret, ek }
55 }
56
57 pub fn decrypt(&self, ct: Ciphertext) -> RistrettoPoint {
75 ct.1 - ct.0 * &self.secret
76 }
77
78 pub fn encryption_key(&self) -> &EncryptionKey {
80 &self.ek
81 }
82}
83
84impl Debug for DecryptionKey {
85 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
86 write!(f, "DecryptionKey({:?})", self.secret)
87 }
88}
89
90impl From<Scalar> for DecryptionKey {
93 fn from(secret: Scalar) -> Self {
94 let ek = EncryptionKey(&secret * &RISTRETTO_BASEPOINT_TABLE);
95 Self { secret, ek }
96 }
97}
98
99impl AsRef<Scalar> for DecryptionKey {
100 fn as_ref(&self) -> &Scalar {
101 &self.secret
102 }
103}
104
105#[cfg(feature = "enable-serde")]
109impl<'de> Deserialize<'de> for DecryptionKey {
110 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
111 where D: Deserializer<'de>
112 {
113 struct DecryptionKeyVisitor;
114
115 impl<'de> Visitor<'de> for DecryptionKeyVisitor {
116 type Value = DecryptionKey;
117
118 fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
119 formatter.write_str("a valid ElGamal decryption key")
120 }
121
122 fn visit_seq<A>(self, mut seq: A) -> Result<DecryptionKey, A::Error>
123 where A: serde::de::SeqAccess<'de>
124 {
125 let secret = seq.next_element()?
126 .ok_or(serde::de::Error::invalid_length(0, &"expected decryption key (32 bytes)"))?;
127 let ek = EncryptionKey(&secret * &RISTRETTO_BASEPOINT_TABLE);
128 Ok(DecryptionKey { secret, ek })
129 }
130 }
131
132 deserializer.deserialize_tuple(32, DecryptionKeyVisitor)
133 }
134}
135
136#[cfg(feature = "enable-serde")]
137#[cfg(test)]
138mod tests {
139 use rand::prelude::StdRng;
140 use rand_core::SeedableRng;
141
142 use crate::DecryptionKey;
143
144 #[test]
146 fn serde_decryption_key() {
147 const N: usize = 100;
148
149 let mut rng = StdRng::from_entropy();
150
151 for _ in 0..N {
152 let dk = DecryptionKey::new(&mut rng);
153 let encoded = bincode::serialize(&dk).unwrap();
154
155 assert_eq!(encoded.len(), 32);
157
158 let decoded = bincode::deserialize(&encoded).unwrap();
159 assert_eq!(dk, decoded);
160 }
161 }
162}