Skip to main content

ml_kem/
param.rs

1//! This module encapsulates all of the compile-time logic related to parameter-set dependent sizes
2//! of objects.  `ParameterSet` captures the parameters in the form described by the ML-KEM
3//! specification.  `EncodingSize`, `VectorEncodingSize`, and `CbdSamplingSize` are "upstream" of
4//! `ParameterSet`; they provide basic logic about the size of encoded objects.  `PkeParams` and
5//! `KemParams` are "downstream" of `ParameterSet`; they define derived parameters relevant to
6//! K-PKE and ML-KEM.
7//!
8//! While the primary purpose of these traits is to describe the sizes of objects, in order to
9//! avoid leakage of complicated trait bounds, they also need to provide any logic that needs to
10//! know any details about object sizes.  For example, `VectorEncodingSize::flatten` needs to know
11//! that the size of an encoded vector is `K` times the size of an encoded polynomial.
12
13use crate::{
14    B32, Ciphertext, Kem,
15    algebra::{BaseField, Elem, NttVector},
16};
17use array::{
18    Array,
19    typenum::{
20        Const, ToUInt, U0, U2, U3, U4, U6, U12, U16, U32, U64, U384,
21        operator_aliases::{Prod, Sum},
22    },
23};
24use core::{
25    fmt::Debug,
26    ops::{Add, Div, Mul, Rem, Sub},
27};
28use module_lattice::{
29    ArraySize, Encode, EncodedPolynomialSize, EncodedVectorSize, EncodingSize, Field,
30    VectorEncodingSize,
31};
32
33#[cfg(doc)]
34use crate::Seed;
35
36/// To speed up CBD sampling, we pre-compute all the bit-manipulations:
37///
38/// * Splitting a sampled integer into two parts
39/// * Counting the ones in each part
40/// * Taking the difference between the two counts mod q
41#[allow(clippy::integer_division_remainder_used, reason = "constant")]
42const fn ones_array<const B: usize, const N: usize, U>() -> Array<Elem, U>
43where
44    U: ArraySize<ArrayType<Elem> = [Elem; N]>,
45    Const<N>: ToUInt<Output = U>,
46{
47    let max = 1 << B;
48    let mut out = [Elem::new(0); N];
49    let mut x = 0usize;
50    while x < max {
51        let mut y = 0usize;
52        while y < max {
53            let x_ones = (x.count_ones() & 0xFFFF) as u16;
54            let y_ones = (y.count_ones() & 0xFFFF) as u16;
55            let i = x + (y << B);
56            out[i] = Elem::new((x_ones + BaseField::Q - y_ones) % BaseField::Q);
57
58            y += 1;
59        }
60        x += 1;
61    }
62    Array(out)
63}
64
65/// An integer that describes a bit length to be used in CBD sampling
66#[allow(unreachable_pub)]
67pub trait CbdSamplingSize: ArraySize {
68    type SampleSize: EncodingSize;
69    type OnesSize: ArraySize;
70    const ONES: Array<Elem, Self::OnesSize>;
71}
72
73impl CbdSamplingSize for U2 {
74    type SampleSize = U4;
75    type OnesSize = U16;
76    const ONES: Array<Elem, U16> = ones_array::<2, 16, U16>();
77}
78
79impl CbdSamplingSize for U3 {
80    type SampleSize = U6;
81    type OnesSize = U64;
82    const ONES: Array<Elem, U64> = ones_array::<3, 64, U64>();
83}
84
85/// A `ParameterSet` captures the parameters that describe a particular instance of ML-KEM.
86///
87/// There are three variants, corresponding to three different security levels.
88pub trait ParameterSet: Default + Clone + Debug + PartialEq {
89    /// The dimensionality of vectors and arrays
90    type K: ArraySize;
91
92    /// The bit width of the centered binary distribution used when sampling random polynomials in
93    /// key generation and encryption.
94    type Eta1: CbdSamplingSize;
95
96    /// The bit width of the centered binary distribution used when sampling error vectors during
97    /// encryption.
98    type Eta2: CbdSamplingSize;
99
100    /// The bit width of encoded integers in the `u` vector in a ciphertext
101    type Du: VectorEncodingSize<Self::K>;
102
103    /// The bit width of encoded integers in the `v` polynomial in a ciphertext
104    type Dv: EncodingSize;
105}
106
107pub(crate) type EncodedUSize<P> =
108    EncodedVectorSize<<P as ParameterSet>::Du, <P as ParameterSet>::K>;
109pub(crate) type EncodedVSize<P> = EncodedPolynomialSize<<P as ParameterSet>::Dv>;
110
111type EncodedU<P> = Array<u8, EncodedUSize<P>>;
112type EncodedV<P> = Array<u8, EncodedVSize<P>>;
113
114/// Derived parameter relevant to K-PKE
115pub trait PkeParams: Kem<SharedKeySize = U32> + ParameterSet {
116    type NttVectorSize: ArraySize;
117    type EncryptionKeySize: ArraySize;
118
119    fn encode_u12(p: &NttVector<Self::K>) -> EncodedNttVector<Self>;
120    fn decode_u12(v: &EncodedNttVector<Self>) -> NttVector<Self::K>;
121
122    fn concat_ct(u: EncodedU<Self>, v: EncodedV<Self>) -> Ciphertext<Self>;
123    fn split_ct(ct: &Ciphertext<Self>) -> (&EncodedU<Self>, &EncodedV<Self>);
124
125    fn concat_ek(t_hat: EncodedNttVector<Self>, rho: B32) -> EncodedEncryptionKey<Self>;
126    fn split_ek(ek: &EncodedEncryptionKey<Self>) -> (&EncodedNttVector<Self>, &B32);
127}
128
129pub(crate) type EncodedNttVector<P> = Array<u8, <P as PkeParams>::NttVectorSize>;
130pub(crate) type EncodedDecryptionKey<P> = Array<u8, <P as PkeParams>::NttVectorSize>;
131pub(crate) type EncodedEncryptionKey<P> = Array<u8, <P as PkeParams>::EncryptionKeySize>;
132
133impl<P> PkeParams for P
134where
135    P: Kem<CiphertextSize = Sum<EncodedUSize<P>, EncodedVSize<P>>, SharedKeySize = U32>
136        + ParameterSet,
137    U384: Mul<P::K>,
138    Prod<U384, P::K>: ArraySize + Add<U32> + Div<P::K, Output = U384> + Rem<P::K, Output = U0>,
139    EncodedUSize<P>: Add<EncodedVSize<P>>,
140    Sum<EncodedUSize<P>, EncodedVSize<P>>:
141        ArraySize + Sub<EncodedUSize<P>, Output = EncodedVSize<P>>,
142    EncodedVectorSize<U12, P::K>: Add<U32>,
143    Sum<EncodedVectorSize<U12, P::K>, U32>:
144        ArraySize + Sub<EncodedVectorSize<U12, P::K>, Output = U32>,
145{
146    type NttVectorSize = EncodedVectorSize<U12, P::K>;
147    type EncryptionKeySize = Sum<Self::NttVectorSize, U32>;
148
149    fn encode_u12(p: &NttVector<Self::K>) -> EncodedNttVector<Self> {
150        Encode::<U12>::encode(p)
151    }
152
153    fn decode_u12(v: &EncodedNttVector<Self>) -> NttVector<Self::K> {
154        Encode::<U12>::decode(v)
155    }
156
157    fn concat_ct(u: EncodedU<Self>, v: EncodedV<Self>) -> Ciphertext<Self> {
158        u.concat(v)
159    }
160
161    fn split_ct(ct: &Ciphertext<Self>) -> (&EncodedU<Self>, &EncodedV<Self>) {
162        ct.split_ref()
163    }
164
165    fn concat_ek(t_hat: EncodedNttVector<Self>, rho: B32) -> EncodedEncryptionKey<Self> {
166        t_hat.concat(rho)
167    }
168
169    fn split_ek(ek: &EncodedEncryptionKey<Self>) -> (&EncodedNttVector<Self>, &B32) {
170        ek.split_ref()
171    }
172}
173
174/// Derived parameters relevant to ML-KEM
175pub trait KemParams: PkeParams {
176    type DecapsulationKeySize: ArraySize;
177
178    fn concat_dk(
179        dk: EncodedDecryptionKey<Self>,
180        ek: EncodedEncryptionKey<Self>,
181        h: B32,
182        z: B32,
183    ) -> ExpandedDecapsulationKey<Self>;
184
185    fn split_dk(
186        enc: &ExpandedDecapsulationKey<Self>,
187    ) -> (
188        &EncodedDecryptionKey<Self>,
189        &EncodedEncryptionKey<Self>,
190        &B32,
191        &B32,
192    );
193}
194
195pub(crate) type DecapsulationKeySize<P> = <P as KemParams>::DecapsulationKeySize;
196pub(crate) type EncapsulationKeySize<P> = <P as PkeParams>::EncryptionKeySize;
197
198/// Serialized decapsulation key after having been expanded from a [`Seed`].
199pub type ExpandedDecapsulationKey<P> = Array<u8, <P as KemParams>::DecapsulationKeySize>;
200
201impl<P> KemParams for P
202where
203    P: PkeParams,
204    P::NttVectorSize: Add<P::EncryptionKeySize>,
205    Sum<P::NttVectorSize, P::EncryptionKeySize>:
206        ArraySize + Add<U32> + Sub<P::NttVectorSize, Output = P::EncryptionKeySize>,
207    Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>:
208        ArraySize + Add<U32> + Sub<Sum<P::NttVectorSize, P::EncryptionKeySize>, Output = U32>,
209    Sum<Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>, U32>:
210        ArraySize + Sub<Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>, Output = U32>,
211{
212    type DecapsulationKeySize = Sum<Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>, U32>;
213
214    fn concat_dk(
215        dk: EncodedDecryptionKey<Self>,
216        ek: EncodedEncryptionKey<Self>,
217        h: B32,
218        z: B32,
219    ) -> ExpandedDecapsulationKey<Self> {
220        dk.concat(ek).concat(h).concat(z)
221    }
222
223    #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
224    fn split_dk(
225        enc: &ExpandedDecapsulationKey<Self>,
226    ) -> (
227        &EncodedDecryptionKey<Self>,
228        &EncodedEncryptionKey<Self>,
229        &B32,
230        &B32,
231    ) {
232        // We parse from right to left to make it easier to write the trait bounds above
233        let (enc, z) = enc.split_ref();
234        let (enc, h) = enc.split_ref();
235        let (dk_pke, ek_pke) = enc.split_ref();
236        (dk_pke, ek_pke, h, z)
237    }
238}