1use chacha20poly1305::{
32 aead::{Aead, Error as AError, NewAead},
33 ChaCha20Poly1305, Key,
34};
35pub use chacha20poly1305::aead::Error as EciesError;
37use hkdf::Hkdf;
38use rand_core::RngCore;
39use serde::{Deserialize, Serialize};
40use sha2::Sha256;
41
42use crate::group::{Curve, Element};
43
44const NONCE_LEN: usize = 12;
46
47const KEY_LEN: usize = 32;
49
50const DOMAIN: [u8; 4] = [1, 9, 6, 9];
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct EciesCipher<C: Curve> {
57 aead: Vec<u8>,
59 ephemeral: C::Point,
62 nonce: [u8; NONCE_LEN],
64}
65
66pub fn encrypt<C: Curve, R: RngCore>(to: &C::Point, msg: &[u8], rng: &mut R) -> EciesCipher<C> {
68 let eph_secret = C::Scalar::rand(rng);
69
70 let mut ephemeral = C::Point::one();
71 ephemeral.mul(&eph_secret);
72
73 let mut dh = to.clone();
75 dh.mul(&eph_secret);
76
77 let ephemeral_key = derive::<C>(&dh);
79
80 let aead = ChaCha20Poly1305::new(Key::from_slice(ephemeral_key.as_slice()));
82
83 let mut nonce: [u8; NONCE_LEN] = [0u8; NONCE_LEN];
85 rng.fill_bytes(&mut nonce);
86
87 let aead = aead
89 .encrypt(&nonce.into(), msg)
90 .expect("aead should not fail");
91
92 EciesCipher {
93 aead,
94 nonce,
95 ephemeral,
96 }
97}
98
99pub fn decrypt<C: Curve>(private: &C::Scalar, cipher: &EciesCipher<C>) -> Result<Vec<u8>, AError> {
101 let mut dh = cipher.ephemeral.clone();
103 dh.mul(private);
104
105 let ephemeral_key = derive::<C>(&dh);
106
107 let aead = ChaCha20Poly1305::new(Key::from_slice(ephemeral_key.as_slice()));
108
109 aead.decrypt(&cipher.nonce.into(), &cipher.aead[..])
110}
111
112fn derive<C: Curve>(dh: &C::Point) -> [u8; KEY_LEN] {
114 let serialized = bincode::serialize(dh).expect("could not serialize element");
115
116 let h = Hkdf::<Sha256>::new(None, &serialized);
118 let mut ephemeral_key = [0u8; KEY_LEN];
119 h.expand(&DOMAIN, &mut ephemeral_key)
120 .expect("hkdf should not fail");
121
122 debug_assert!(ephemeral_key.len() == KEY_LEN);
123
124 ephemeral_key
125}
126
127#[cfg(feature = "bls12_381")]
128#[cfg(test)]
129mod tests {
130 use rand::thread_rng;
131
132 use crate::curve::bls12381::{G1Curve as Curve, Scalar, G1};
133
134 use super::*;
135
136 fn kp() -> (Scalar, G1) {
137 let secret = Scalar::rand(&mut thread_rng());
138 let mut public = G1::one();
139 public.mul(&secret);
140 (secret, public)
141 }
142
143 #[test]
144 fn test_decryption() {
145 let (s1, _) = kp();
146 let (s2, p2) = kp();
147 let data = vec![1, 2, 3, 4];
148
149 let mut cipher = encrypt::<Curve, _>(&p2, &data, &mut thread_rng());
151 let deciphered = decrypt::<Curve>(&s2, &cipher).unwrap();
152 assert_eq!(data, deciphered);
153
154 decrypt::<Curve>(&s1, &cipher).unwrap_err();
156
157 cipher.aead = vec![0; 32];
159 decrypt::<Curve>(&s2, &cipher).unwrap_err();
160 }
161}