libcrux_ml_kem/
types.rs

1macro_rules! impl_generic_struct {
2    ($name:ident, $doc:expr) => {
3        #[doc = $doc]
4        pub struct $name<const SIZE: usize> {
5            pub(crate) value: [u8; SIZE],
6        }
7
8        impl<const SIZE: usize> Default for $name<SIZE> {
9            fn default() -> Self {
10                Self { value: [0u8; SIZE] }
11            }
12        }
13
14        #[hax_lib::attributes]
15        impl<const SIZE: usize> AsRef<[u8]> for $name<SIZE> {
16            #[ensures(|result| fstar!(r#"$result = ${self_}.f_value"#))]
17            fn as_ref(&self) -> &[u8] {
18                &self.value
19            }
20        }
21
22        #[hax_lib::attributes]
23        impl<const SIZE: usize> From<[u8; SIZE]> for $name<SIZE> {
24            #[ensures(|result| fstar!(r#"${result}.f_value = $value"#))]
25            fn from(value: [u8; SIZE]) -> Self {
26                Self { value }
27            }
28        }
29
30        impl<const SIZE: usize> From<&[u8; SIZE]> for $name<SIZE> {
31            fn from(value: &[u8; SIZE]) -> Self {
32                Self {
33                    value: value.clone(),
34                }
35            }
36        }
37
38        impl<const SIZE: usize> From<$name<SIZE>> for [u8; SIZE] {
39            fn from(value: $name<SIZE>) -> Self {
40                value.value
41            }
42        }
43
44        impl<const SIZE: usize> TryFrom<&[u8]> for $name<SIZE> {
45            type Error = core::array::TryFromSliceError;
46
47            fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
48                match value.try_into() {
49                    Ok(value) => Ok(Self { value }),
50                    Err(e) => Err(e),
51                }
52            }
53        }
54
55        #[hax_lib::attributes]
56        impl<const SIZE: usize> $name<SIZE> {
57            /// A reference to the raw byte slice.
58            #[ensures(|result| fstar!(r#"$result == self.f_value"#))]
59            pub fn as_slice(&self) -> &[u8; SIZE] {
60                &self.value
61            }
62
63            // This is only used for some of the macro callers.
64            // #[allow(dead_code)]
65            // /// Split this value and return the raw byte slices.
66            // pub(crate) fn split_at(&self, mid: usize) -> (&[u8], &[u8]) {
67            //     self.value.split_at(mid)
68            // }
69
70            /// The number of bytes
71            pub const fn len() -> usize {
72                SIZE
73            }
74        }
75    };
76}
77macro_rules! impl_index_impls_for_generic_struct {
78    ($name:ident) => {
79        impl<const SIZE: usize> core::ops::Index<usize> for $name<SIZE> {
80            type Output = u8;
81
82            fn index(&self, index: usize) -> &Self::Output {
83                &self.value[index]
84            }
85        }
86
87        impl<const SIZE: usize> core::ops::IndexMut<usize> for $name<SIZE> {
88            fn index_mut(&mut self, range: usize) -> &mut Self::Output {
89                &mut self.value[range]
90            }
91        }
92
93        impl<const SIZE: usize> core::ops::Index<core::ops::Range<usize>> for $name<SIZE> {
94            type Output = [u8];
95
96            fn index(&self, range: core::ops::Range<usize>) -> &Self::Output {
97                &self.value[range]
98            }
99        }
100
101        impl<const SIZE: usize> core::ops::IndexMut<core::ops::Range<usize>> for $name<SIZE> {
102            fn index_mut(&mut self, range: core::ops::Range<usize>) -> &mut Self::Output {
103                &mut self.value[range]
104            }
105        }
106
107        impl<const SIZE: usize> core::ops::Index<core::ops::RangeTo<usize>> for $name<SIZE> {
108            type Output = [u8];
109
110            fn index(&self, range: core::ops::RangeTo<usize>) -> &Self::Output {
111                &self.value[range]
112            }
113        }
114
115        impl<const SIZE: usize> core::ops::IndexMut<core::ops::RangeTo<usize>> for $name<SIZE> {
116            fn index_mut(&mut self, range: core::ops::RangeTo<usize>) -> &mut Self::Output {
117                &mut self.value[range]
118            }
119        }
120
121        impl<const SIZE: usize> core::ops::Index<core::ops::RangeFrom<usize>> for $name<SIZE> {
122            type Output = [u8];
123
124            fn index(&self, range: core::ops::RangeFrom<usize>) -> &Self::Output {
125                &self.value[range]
126            }
127        }
128
129        impl<const SIZE: usize> core::ops::IndexMut<core::ops::RangeFrom<usize>> for $name<SIZE> {
130            fn index_mut(&mut self, range: core::ops::RangeFrom<usize>) -> &mut Self::Output {
131                &mut self.value[range]
132            }
133        }
134    };
135}
136
137impl_generic_struct!(MlKemCiphertext, "An ML-KEM Ciphertext");
138impl_generic_struct!(MlKemPrivateKey, "An ML-KEM Private key");
139impl_generic_struct!(MlKemPublicKey, "An ML-KEM Public key");
140
141// These traits are used only in `ind_cpa` for kyber cipher text.
142mod index_impls {
143    use super::*;
144    impl_index_impls_for_generic_struct!(MlKemCiphertext);
145    impl_index_impls_for_generic_struct!(MlKemPrivateKey);
146    impl_index_impls_for_generic_struct!(MlKemPublicKey);
147}
148
149/// An ML-KEM key pair
150pub struct MlKemKeyPair<const PRIVATE_KEY_SIZE: usize, const PUBLIC_KEY_SIZE: usize> {
151    pub(crate) sk: MlKemPrivateKey<PRIVATE_KEY_SIZE>,
152    pub(crate) pk: MlKemPublicKey<PUBLIC_KEY_SIZE>,
153}
154
155#[hax_lib::attributes]
156impl<const PRIVATE_KEY_SIZE: usize, const PUBLIC_KEY_SIZE: usize>
157    MlKemKeyPair<PRIVATE_KEY_SIZE, PUBLIC_KEY_SIZE>
158{
159    /// Creates a new [`MlKemKeyPair`].
160    pub fn new(sk: [u8; PRIVATE_KEY_SIZE], pk: [u8; PUBLIC_KEY_SIZE]) -> Self {
161        Self {
162            sk: sk.into(),
163            pk: pk.into(),
164        }
165    }
166
167    /// Create a new [`MlKemKeyPair`] from the secret and public key.
168    #[ensures(|result| fstar!(r#"${result}.f_sk == $sk /\ ${result}.f_pk == $pk"#))]
169    pub fn from(
170        sk: MlKemPrivateKey<PRIVATE_KEY_SIZE>,
171        pk: MlKemPublicKey<PUBLIC_KEY_SIZE>,
172    ) -> Self {
173        Self { sk, pk }
174    }
175
176    /// Get a reference to the [`MlKemPublicKey<PUBLIC_KEY_SIZE>`].
177    pub fn public_key(&self) -> &MlKemPublicKey<PUBLIC_KEY_SIZE> {
178        &self.pk
179    }
180
181    /// Get a reference to the [`MlKemPrivateKey<PRIVATE_KEY_SIZE>`].
182    pub fn private_key(&self) -> &MlKemPrivateKey<PRIVATE_KEY_SIZE> {
183        &self.sk
184    }
185
186    /// Get a reference to the raw public key bytes.
187    pub fn pk(&self) -> &[u8; PUBLIC_KEY_SIZE] {
188        self.pk.as_slice()
189    }
190
191    /// Get a reference to the raw private key bytes.
192    pub fn sk(&self) -> &[u8; PRIVATE_KEY_SIZE] {
193        self.sk.as_slice()
194    }
195
196    /// Separate this key into the public and private key.
197    pub fn into_parts(
198        self,
199    ) -> (
200        MlKemPrivateKey<PRIVATE_KEY_SIZE>,
201        MlKemPublicKey<PUBLIC_KEY_SIZE>,
202    ) {
203        (self.sk, self.pk)
204    }
205}
206
207/// Unpack an incoming private key into it's different parts.
208///
209/// We have this here in types to extract into a common core for C.
210#[hax_lib::requires(fstar!(r#"Seq.length private_key >= 
211                            v v_CPA_SECRET_KEY_SIZE + v v_PUBLIC_KEY_SIZE + 
212                            v Libcrux_ml_kem.Constants.v_H_DIGEST_SIZE"#))]
213#[hax_lib::ensures(|result| fstar!(r#"
214           let (ind_cpa_secret_key_s,rest) = split $private_key $CPA_SECRET_KEY_SIZE in
215           let (ind_cpa_public_key_s,rest) = split rest $PUBLIC_KEY_SIZE in
216           let (ind_cpa_public_key_hash_s,implicit_rejection_value_s) = split rest Libcrux_ml_kem.Constants.v_H_DIGEST_SIZE in
217           let (ind_cpa_secret_key,ind_cpa_public_key,ind_cpa_public_key_hash,implicit_rejection_value)
218               = result in
219           ind_cpa_secret_key_s == ind_cpa_secret_key /\
220           ind_cpa_public_key_s == ind_cpa_public_key /\
221           ind_cpa_public_key_hash_s == ind_cpa_public_key_hash /\
222           implicit_rejection_value_s == implicit_rejection_value /\
223           Seq.length ind_cpa_secret_key == v v_CPA_SECRET_KEY_SIZE /\
224           Seq.length ind_cpa_public_key == v v_PUBLIC_KEY_SIZE /\
225           Seq.length ind_cpa_public_key_hash == v Libcrux_ml_kem.Constants.v_H_DIGEST_SIZE /\
226           Seq.length implicit_rejection_value == 
227           Seq.length private_key - 
228             (v v_CPA_SECRET_KEY_SIZE + v v_PUBLIC_KEY_SIZE + v Libcrux_ml_kem.Constants.v_H_DIGEST_SIZE)
229           "#))]
230pub(crate) fn unpack_private_key<const CPA_SECRET_KEY_SIZE: usize, const PUBLIC_KEY_SIZE: usize>(
231    private_key: &[u8], // len: SECRET_KEY_SIZE
232) -> (&[u8], &[u8], &[u8], &[u8]) {
233    let (ind_cpa_secret_key, secret_key) = private_key.split_at(CPA_SECRET_KEY_SIZE);
234    let (ind_cpa_public_key, secret_key) = secret_key.split_at(PUBLIC_KEY_SIZE);
235    let (ind_cpa_public_key_hash, implicit_rejection_value) =
236        secret_key.split_at(crate::constants::H_DIGEST_SIZE);
237    (
238        ind_cpa_secret_key,
239        ind_cpa_public_key,
240        ind_cpa_public_key_hash,
241        implicit_rejection_value,
242    )
243}