kyber_rs/dh/
dh_impl.rs

1use core::marker::PhantomData;
2
3use aes_gcm::{
4    aead::{Aead, Payload},
5    Aes256Gcm, KeyInit,
6};
7use digest::{
8    block_buffer::Eager,
9    consts::U256,
10    core_api::{BufferKindUser, CoreProxy, FixedOutputCore, UpdateCore},
11    generic_array::GenericArray,
12    typenum::{IsLess, Le, NonZero},
13    HashMarker, OutputSizeUser,
14};
15use hkdf::Hkdf;
16use thiserror::Error;
17
18use crate::{encoding::MarshallingError, group::HashFactory, share::vss::suite::Suite, Point};
19
20pub(crate) const NONCE_SIZE: usize = 12;
21
22pub trait HmacCompatible: OutputSizeUser + CoreProxy<Core = Self::C> {
23    type C: HmacCompatibleCore;
24}
25
26impl<T: CoreProxy + OutputSizeUser> HmacCompatible for T
27where
28    <T as CoreProxy>::Core: HmacCompatibleCore,
29{
30    type C = T::Core;
31}
32
33pub trait HmacCompatibleCore:
34    FixedOutputCore<BlockSize = Self::B>
35    + HashMarker
36    + UpdateCore
37    + BufferKindUser<BufferKind = Eager>
38    + Default
39    + Clone
40{
41    type B: HmacBlockSize;
42}
43
44impl<
45        T: HashMarker
46            + UpdateCore
47            + BufferKindUser<BufferKind = Eager>
48            + Default
49            + Clone
50            + FixedOutputCore,
51    > HmacCompatibleCore for T
52where
53    Self::BlockSize: IsLess<U256>,
54    Le<Self::BlockSize, U256>: NonZero,
55{
56    type B = Self::BlockSize;
57}
58
59pub trait HmacBlockSize: IsLess<U256, Output = Self::O> {
60    type O: NonZero;
61}
62
63impl<T: IsLess<U256>> HmacBlockSize for T
64where
65    Self::Output: NonZero,
66{
67    type O = Self::Output;
68}
69
70pub trait Dh {
71    type H: HmacCompatible;
72
73    /// [`dh_exchange()`] computes the shared key from a private key and a public key
74    fn dh_exchange<SUITE: Suite>(
75        suite: SUITE,
76        own_private: <SUITE::POINT as Point>::SCALAR,
77        remote_public: SUITE::POINT,
78    ) -> SUITE::POINT {
79        suite.point().mul(&own_private, Some(&remote_public))
80    }
81
82    fn hkdf(ikm: &[u8], info: &[u8], output_size: Option<usize>) -> Result<Vec<u8>, DhError> {
83        let size = output_size.unwrap_or(32);
84        let h = Hkdf::<Self::H>::new(None, ikm);
85        let mut out = vec![0; size];
86        h.expand(info, &mut out)
87            .map_err(|e| DhError::HkdfFailure(e.to_string()))?;
88
89        Ok(out)
90    }
91
92    fn aes_encrypt(
93        key: &[u8],
94        nonce: &[u8; NONCE_SIZE],
95        data: &[u8],
96        additional_data: Option<&[u8]>,
97    ) -> Result<Vec<u8>, DhError> {
98        let key_len = key.len();
99        if key_len != 32 {
100            return Err(DhError::WrongKeyLength(format!(
101                "expected 32, got {key_len}"
102            )));
103        }
104        let key = GenericArray::from_slice(key);
105        let aes_gcm = Aes256Gcm::new(key);
106        let nonce = GenericArray::from_slice(nonce);
107
108        let payload: Payload = match additional_data {
109            None => Payload::from(data),
110            Some(add_data) => Payload {
111                aad: add_data,
112                msg: data,
113            },
114        };
115
116        let ciphertext = aes_gcm
117            .encrypt(nonce, payload)
118            .map_err(|e| DhError::DecryptionFailed(e.to_string()))?;
119
120        Ok(ciphertext)
121    }
122
123    fn aes_decrypt(
124        key: &[u8],
125        nonce: &[u8; NONCE_SIZE],
126        ciphertext: &[u8],
127        additional_data: Option<&[u8]>,
128    ) -> Result<Vec<u8>, DhError> {
129        let key_len = key.len();
130        if key_len != 32 {
131            return Err(DhError::WrongKeyLength(format!(
132                "expected 32, got {key_len}"
133            )));
134        }
135        let key = GenericArray::from_slice(key);
136        let aes_gcm = Aes256Gcm::new(key);
137        let nonce = GenericArray::from_slice(nonce);
138
139        let payload: Payload = match additional_data {
140            None => Payload::from(ciphertext),
141            Some(add_data) => Payload {
142                aad: add_data,
143                msg: ciphertext,
144            },
145        };
146
147        let decrypted = aes_gcm
148            .decrypt(nonce, payload)
149            .map_err(|e| DhError::DecryptionFailed(e.to_string()))?;
150
151        Ok(decrypted)
152    }
153
154    fn encrypt<POINT: Point>(
155        pre_key: &POINT,
156        info: &[u8],
157        nonce: &[u8; NONCE_SIZE],
158        data: &[u8],
159    ) -> Result<Vec<u8>, DhError> {
160        let pre_buff = pre_key.marshal_binary()?;
161        let key = Self::hkdf(&pre_buff, info, None)?;
162        let encrypted = Self::aes_encrypt(&key, nonce, data, Some(info))?;
163
164        Ok(encrypted)
165    }
166
167    fn decrypt<POINT: Point>(
168        pre_key: &POINT,
169        info: &[u8],
170        nonce: &[u8; NONCE_SIZE],
171        cipher: &[u8],
172    ) -> Result<Vec<u8>, DhError> {
173        let pre_buff = pre_key.marshal_binary()?;
174        let key = Self::hkdf(&pre_buff, info, None)?;
175        let decrypted = Self::aes_decrypt(&key, nonce, cipher, Some(info))?;
176
177        Ok(decrypted)
178    }
179}
180
181impl<T: HashFactory> Dh for T {
182    type H = T::T;
183}
184
185pub struct AEAD<T: Dh> {
186    key: Vec<u8>,
187    phantom: PhantomData<T>,
188}
189
190impl<DH: Dh> AEAD<DH> {
191    pub fn new<POINT: Point>(pre: POINT, hkfd_context: &[u8]) -> Result<Self, DhError> {
192        let pre_buff = pre.marshal_binary()?;
193        let key = DH::hkdf(&pre_buff, hkfd_context, None)?;
194        let key_len = key.len();
195        if key_len != 32 {
196            return Err(DhError::WrongKeyLength(format!(
197                "expected 32, got {key_len}"
198            )));
199        }
200        Ok(AEAD {
201            key,
202            phantom: PhantomData,
203        })
204    }
205
206    /// [`seal()`] encrypts and authenticates `plaintext`, authenticates the
207    /// `additional_data` and appends the result to `dst`, returning the updated
208    /// slice. The nonce must be [`NONCE_SIZE`] bytes long and unique for all
209    /// time, for a given key.
210    ///
211    /// To reuse `plaintext`'s storage for the encrypted output, use `plaintext[..0]`
212    /// as `dst`. Otherwise, the remaining capacity of dst must not overlap plaintext.
213    pub fn seal(
214        &self,
215        dst: Option<&mut [u8]>,
216        nonce: &[u8; NONCE_SIZE],
217        plaintext: &[u8],
218        additional_data: Option<&[u8]>,
219    ) -> Result<Vec<u8>, DhError> {
220        let encrypted = DH::aes_encrypt(&self.key, nonce, plaintext, additional_data)?;
221        if let Some(d) = dst {
222            d.copy_from_slice(&encrypted);
223        }
224        Ok(encrypted)
225    }
226
227    /// [`open()`] decrypts and authenticates `ciphertext`, authenticates the
228    /// `additional_data` and, if successful, appends the resulting `plaintext`
229    /// to `dst`, returning the updated slice. The `nonce` must be [`NONCE_SIZE`]
230    /// bytes long and both it and the additional data must match the
231    /// value passed to [`seal()`].
232    ///
233    /// To reuse ciphertext's storage for the decrypted output, use `ciphertext[..0]`
234    /// as `dst`. Otherwise, the remaining capacity of `dst` must not overlap `plaintext`.
235    ///
236    /// Even if the function fails, the contents of `dst`, up to its capacity,
237    /// may be overwritten.
238    pub fn open(
239        &self,
240        dst: Option<&mut [u8]>,
241        nonce: &[u8; NONCE_SIZE],
242        ciphertext: &[u8],
243        additional_data: Option<&[u8]>,
244    ) -> Result<Vec<u8>, DhError> {
245        let decrypted = DH::aes_decrypt(&self.key, nonce, ciphertext, additional_data)?;
246        if let Some(d) = dst {
247            d.copy_from_slice(&decrypted);
248        }
249        Ok(decrypted)
250    }
251
252    pub const fn nonce_size() -> usize {
253        NONCE_SIZE
254    }
255}
256
257#[derive(Debug, Error)]
258pub enum DhError {
259    #[error("marshalling error")]
260    MarshalingError(#[from] MarshallingError),
261    #[error("wrong key length")]
262    WrongKeyLength(String),
263    #[error("aes decryption failed")]
264    DecryptionFailed(String),
265    #[error("aes encryption failed")]
266    EncryptionFailed(String),
267    #[error("unexpected error in hkdf_sha256")]
268    HkdfFailure(String),
269}