1use crate::arithmetic::group_elements::{GroupElement, G};
4use crate::arithmetic::scalars::ScalarNonZero;
5use base64::engine::general_purpose;
6use base64::Engine;
7use rand_core::{CryptoRng, RngCore};
8#[cfg(feature = "serde")]
9use serde::de::{Error, Visitor};
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12#[cfg(feature = "serde")]
13use std::fmt::Formatter;
14
15#[cfg(not(feature = "elgamal3"))]
18pub const ELGAMAL_LENGTH: usize = 64;
19#[cfg(feature = "elgamal3")]
20pub const ELGAMAL_LENGTH: usize = 96;
21
22#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
24pub struct ElGamal {
25 pub gb: GroupElement,
26 pub gc: GroupElement,
27 #[cfg(feature = "elgamal3")]
28 pub gy: GroupElement,
29}
30
31impl ElGamal {
32 pub fn from_bytes(v: &[u8; ELGAMAL_LENGTH]) -> Option<Self> {
34 Some(Self {
35 gb: GroupElement::from_slice(&v[0..32])?,
36 gc: GroupElement::from_slice(&v[32..64])?,
37 #[cfg(feature = "elgamal3")]
38 gy: GroupElement::from_slice(&v[64..96])?,
39 })
40 }
41
42 pub fn from_slice(v: &[u8]) -> Option<Self> {
44 if v.len() != ELGAMAL_LENGTH {
45 None
46 } else {
47 let mut arr = [0u8; ELGAMAL_LENGTH];
48 arr.copy_from_slice(v);
49 Self::from_bytes(&arr)
50 }
51 }
52
53 pub fn to_bytes(&self) -> [u8; ELGAMAL_LENGTH] {
55 let mut retval = [0u8; ELGAMAL_LENGTH];
56 retval[0..32].clone_from_slice(self.gb.to_bytes().as_ref());
57 retval[32..64].clone_from_slice(self.gc.to_bytes().as_ref());
58 #[cfg(feature = "elgamal3")]
59 retval[64..96].clone_from_slice(self.gy.to_bytes().as_ref());
60 retval
61 }
62
63 pub fn into_bytes(self) -> [u8; ELGAMAL_LENGTH] {
66 self.to_bytes()
67 }
68
69 pub fn to_base64(&self) -> String {
71 general_purpose::URL_SAFE.encode(self.to_bytes())
72 }
73
74 pub fn from_base64(s: &str) -> Option<Self> {
76 general_purpose::URL_SAFE
77 .decode(s)
78 .ok()
79 .and_then(|v| Self::from_slice(&v))
80 }
81}
82
83#[cfg(feature = "serde")]
84impl Serialize for ElGamal {
85 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
86 where
87 S: Serializer,
88 {
89 serializer.serialize_str(self.to_base64().as_str())
90 }
91}
92
93#[cfg(feature = "serde")]
94impl<'de> Deserialize<'de> for ElGamal {
95 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
96 where
97 D: Deserializer<'de>,
98 {
99 struct ElGamalVisitor;
100 impl Visitor<'_> for ElGamalVisitor {
101 type Value = ElGamal;
102 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
103 formatter.write_str("a base64 encoded string representing an ElGamal ciphertext")
104 }
105
106 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
107 where
108 E: Error,
109 {
110 ElGamal::from_base64(v)
111 .ok_or(E::custom(format!("invalid base64 encoded string: {v}")))
112 }
113 }
114
115 deserializer.deserialize_str(ElGamalVisitor)
116 }
117}
118
119pub fn encrypt<R: RngCore + CryptoRng>(
125 gm: &GroupElement,
126 gy: &GroupElement,
127 rng: &mut R,
128) -> ElGamal {
129 assert_ne!(gy, &GroupElement::identity()); let r = ScalarNonZero::random(rng); ElGamal {
132 gb: r * G,
133 gc: gm + r * gy,
134 #[cfg(feature = "elgamal3")]
135 gy: *gy,
136 }
137}
138
139#[cfg(feature = "elgamal3")]
142pub fn decrypt(encrypted: &ElGamal, y: &ScalarNonZero) -> Option<GroupElement> {
143 if y * G != encrypted.gy {
144 return None;
145 }
146 Some(encrypted.gc - y * encrypted.gb)
147}
148
149#[cfg(not(feature = "elgamal3"))]
151pub fn decrypt(encrypted: &ElGamal, y: &ScalarNonZero) -> GroupElement {
152 encrypted.gc - y * encrypted.gb
153}
154
155#[cfg(test)]
156#[allow(clippy::unwrap_used, clippy::expect_used)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn encrypt_decrypt_roundtrip() {
162 let mut rng = rand::rng();
163 let secret_key = ScalarNonZero::random(&mut rng);
164 let public_key = secret_key * G;
165 let message = GroupElement::random(&mut rng);
166
167 let encrypted = encrypt(&message, &public_key, &mut rng);
168 #[cfg(feature = "elgamal3")]
169 let decrypted = decrypt(&encrypted, &secret_key).expect("decryption should succeed");
170 #[cfg(not(feature = "elgamal3"))]
171 let decrypted = decrypt(&encrypted, &secret_key);
172
173 assert_eq!(message, decrypted);
174 }
175
176 #[test]
177 fn base64_roundtrip() {
178 let mut rng = rand::rng();
179 let message = GroupElement::random(&mut rng);
180 let public_key = GroupElement::random(&mut rng);
181 let encrypted = encrypt(&message, &public_key, &mut rng);
182
183 let encoded = encrypted.to_base64();
184 let decoded = ElGamal::from_base64(&encoded).expect("base64 decoding should succeed");
185
186 assert_eq!(encrypted, decoded);
187 }
188
189 #[test]
190 fn known_base64_decoding() {
191 #[cfg(feature = "elgamal3")]
192 let base64 = "NESP1FCKkF7nWbqM9cvuUEUPgHaF8qnLeW9RLe_5FCMs-daoTGSyJKa5HRKxk0jFMHVuZ77pJMacNLmtRnlkZEpkKEPWnLzh_s8ievM3gTqeBYm20E23K6hExSxMOw8D";
193 #[cfg(not(feature = "elgamal3"))]
194 let base64 =
195 "xGOnBZzbSrvKUQYBtww0vi8jZWzN9qkrm5OnI2pnEFJu4DkZP2jLLGT-yWa_qnkC_ScCwQwcQtZk_z_z7s_gVQ==";
196
197 let decoded = ElGamal::from_base64(base64).expect("decoding should succeed");
198 let re_encoded = decoded.to_base64();
199
200 assert_eq!(base64, re_encoded);
201 }
202}