1use core::marker::PhantomData;
2
3use aes_gcm::{
4 aead::{Aead, Payload},
5 Aes256Gcm, KeyInit,
6};
7use digest::{
8 block_buffer::Eager,
9 consts::U256,
10 core_api::{BufferKindUser, CoreProxy, FixedOutputCore, UpdateCore},
11 generic_array::GenericArray,
12 typenum::{IsLess, Le, NonZero},
13 HashMarker, OutputSizeUser,
14};
15use hkdf::Hkdf;
16use thiserror::Error;
17
18use crate::{encoding::MarshallingError, group::HashFactory, share::vss::suite::Suite, Point};
19
20pub(crate) const NONCE_SIZE: usize = 12;
21
22pub trait HmacCompatible: OutputSizeUser + CoreProxy<Core = Self::C> {
23 type C: HmacCompatibleCore;
24}
25
26impl<T: CoreProxy + OutputSizeUser> HmacCompatible for T
27where
28 <T as CoreProxy>::Core: HmacCompatibleCore,
29{
30 type C = T::Core;
31}
32
33pub trait HmacCompatibleCore:
34 FixedOutputCore<BlockSize = Self::B>
35 + HashMarker
36 + UpdateCore
37 + BufferKindUser<BufferKind = Eager>
38 + Default
39 + Clone
40{
41 type B: HmacBlockSize;
42}
43
44impl<
45 T: HashMarker
46 + UpdateCore
47 + BufferKindUser<BufferKind = Eager>
48 + Default
49 + Clone
50 + FixedOutputCore,
51 > HmacCompatibleCore for T
52where
53 Self::BlockSize: IsLess<U256>,
54 Le<Self::BlockSize, U256>: NonZero,
55{
56 type B = Self::BlockSize;
57}
58
59pub trait HmacBlockSize: IsLess<U256, Output = Self::O> {
60 type O: NonZero;
61}
62
63impl<T: IsLess<U256>> HmacBlockSize for T
64where
65 Self::Output: NonZero,
66{
67 type O = Self::Output;
68}
69
70pub trait Dh {
71 type H: HmacCompatible;
72
73 fn dh_exchange<SUITE: Suite>(
75 suite: SUITE,
76 own_private: <SUITE::POINT as Point>::SCALAR,
77 remote_public: SUITE::POINT,
78 ) -> SUITE::POINT {
79 suite.point().mul(&own_private, Some(&remote_public))
80 }
81
82 fn hkdf(ikm: &[u8], info: &[u8], output_size: Option<usize>) -> Result<Vec<u8>, DhError> {
83 let size = output_size.unwrap_or(32);
84 let h = Hkdf::<Self::H>::new(None, ikm);
85 let mut out = vec![0; size];
86 h.expand(info, &mut out)
87 .map_err(|e| DhError::HkdfFailure(e.to_string()))?;
88
89 Ok(out)
90 }
91
92 fn aes_encrypt(
93 key: &[u8],
94 nonce: &[u8; NONCE_SIZE],
95 data: &[u8],
96 additional_data: Option<&[u8]>,
97 ) -> Result<Vec<u8>, DhError> {
98 let key_len = key.len();
99 if key_len != 32 {
100 return Err(DhError::WrongKeyLength(format!(
101 "expected 32, got {key_len}"
102 )));
103 }
104 let key = GenericArray::from_slice(key);
105 let aes_gcm = Aes256Gcm::new(key);
106 let nonce = GenericArray::from_slice(nonce);
107
108 let payload: Payload = match additional_data {
109 None => Payload::from(data),
110 Some(add_data) => Payload {
111 aad: add_data,
112 msg: data,
113 },
114 };
115
116 let ciphertext = aes_gcm
117 .encrypt(nonce, payload)
118 .map_err(|e| DhError::DecryptionFailed(e.to_string()))?;
119
120 Ok(ciphertext)
121 }
122
123 fn aes_decrypt(
124 key: &[u8],
125 nonce: &[u8; NONCE_SIZE],
126 ciphertext: &[u8],
127 additional_data: Option<&[u8]>,
128 ) -> Result<Vec<u8>, DhError> {
129 let key_len = key.len();
130 if key_len != 32 {
131 return Err(DhError::WrongKeyLength(format!(
132 "expected 32, got {key_len}"
133 )));
134 }
135 let key = GenericArray::from_slice(key);
136 let aes_gcm = Aes256Gcm::new(key);
137 let nonce = GenericArray::from_slice(nonce);
138
139 let payload: Payload = match additional_data {
140 None => Payload::from(ciphertext),
141 Some(add_data) => Payload {
142 aad: add_data,
143 msg: ciphertext,
144 },
145 };
146
147 let decrypted = aes_gcm
148 .decrypt(nonce, payload)
149 .map_err(|e| DhError::DecryptionFailed(e.to_string()))?;
150
151 Ok(decrypted)
152 }
153
154 fn encrypt<POINT: Point>(
155 pre_key: &POINT,
156 info: &[u8],
157 nonce: &[u8; NONCE_SIZE],
158 data: &[u8],
159 ) -> Result<Vec<u8>, DhError> {
160 let pre_buff = pre_key.marshal_binary()?;
161 let key = Self::hkdf(&pre_buff, info, None)?;
162 let encrypted = Self::aes_encrypt(&key, nonce, data, Some(info))?;
163
164 Ok(encrypted)
165 }
166
167 fn decrypt<POINT: Point>(
168 pre_key: &POINT,
169 info: &[u8],
170 nonce: &[u8; NONCE_SIZE],
171 cipher: &[u8],
172 ) -> Result<Vec<u8>, DhError> {
173 let pre_buff = pre_key.marshal_binary()?;
174 let key = Self::hkdf(&pre_buff, info, None)?;
175 let decrypted = Self::aes_decrypt(&key, nonce, cipher, Some(info))?;
176
177 Ok(decrypted)
178 }
179}
180
181impl<T: HashFactory> Dh for T {
182 type H = T::T;
183}
184
185pub struct AEAD<T: Dh> {
186 key: Vec<u8>,
187 phantom: PhantomData<T>,
188}
189
190impl<DH: Dh> AEAD<DH> {
191 pub fn new<POINT: Point>(pre: POINT, hkfd_context: &[u8]) -> Result<Self, DhError> {
192 let pre_buff = pre.marshal_binary()?;
193 let key = DH::hkdf(&pre_buff, hkfd_context, None)?;
194 let key_len = key.len();
195 if key_len != 32 {
196 return Err(DhError::WrongKeyLength(format!(
197 "expected 32, got {key_len}"
198 )));
199 }
200 Ok(AEAD {
201 key,
202 phantom: PhantomData,
203 })
204 }
205
206 pub fn seal(
214 &self,
215 dst: Option<&mut [u8]>,
216 nonce: &[u8; NONCE_SIZE],
217 plaintext: &[u8],
218 additional_data: Option<&[u8]>,
219 ) -> Result<Vec<u8>, DhError> {
220 let encrypted = DH::aes_encrypt(&self.key, nonce, plaintext, additional_data)?;
221 if let Some(d) = dst {
222 d.copy_from_slice(&encrypted);
223 }
224 Ok(encrypted)
225 }
226
227 pub fn open(
239 &self,
240 dst: Option<&mut [u8]>,
241 nonce: &[u8; NONCE_SIZE],
242 ciphertext: &[u8],
243 additional_data: Option<&[u8]>,
244 ) -> Result<Vec<u8>, DhError> {
245 let decrypted = DH::aes_decrypt(&self.key, nonce, ciphertext, additional_data)?;
246 if let Some(d) = dst {
247 d.copy_from_slice(&decrypted);
248 }
249 Ok(decrypted)
250 }
251
252 pub const fn nonce_size() -> usize {
253 NONCE_SIZE
254 }
255}
256
257#[derive(Debug, Error)]
258pub enum DhError {
259 #[error("marshalling error")]
260 MarshalingError(#[from] MarshallingError),
261 #[error("wrong key length")]
262 WrongKeyLength(String),
263 #[error("aes decryption failed")]
264 DecryptionFailed(String),
265 #[error("aes encryption failed")]
266 EncryptionFailed(String),
267 #[error("unexpected error in hkdf_sha256")]
268 HkdfFailure(String),
269}