ml_dsa/
param.rs

1//! This module encapsulates all of the compile-time logic related to parameter-set dependent sizes
2//! of objects.  `ParameterSet` captures the parameters in the form described by the ML-KEM
3//! specification.  `EncodingSize`, `VectorEncodingSize`, and `CbdSamplingSize` are "upstream" of
4//! `ParameterSet`; they provide basic logic about the size of encoded objects.  `PkeParams` and
5//! `KemParams` are "downstream" of `ParameterSet`; they define derived parameters relevant to
6//! K-PKE and ML-KEM.
7//!
8//! While the primary purpose of these traits is to describe the sizes of objects, in order to
9//! avoid leakage of complicated trait bounds, they also need to provide any logic that needs to
10//! know any details about object sizes.  For example, `VectorEncodingSize::flatten` needs to know
11//! that the size of an encoded vector is `K` times the size of an encoded polynomial.
12
13use core::fmt::Debug;
14use core::ops::{Add, Div, Mul, Rem, Sub};
15
16use crate::module_lattice::encode::{
17    ArraySize, Encode, EncodedPolynomialSize, EncodedVectorSize, EncodingSize,
18};
19use hybrid_array::{
20    Array,
21    typenum::{
22        Diff, Len, Length, Prod, Shleft, Sum, U0, U1, U2, U4, U13, U23, U32, U64, U128, U320, U416,
23        Unsigned,
24    },
25};
26
27use crate::algebra::{Polynomial, Vector};
28use crate::encode::{
29    BitPack, RangeEncodedPolynomialSize, RangeEncodedVectorSize, RangeEncodingSize,
30};
31use crate::util::{B32, B64};
32
33/// Some useful compile-time constants
34pub type SpecQ = Sum<Diff<Shleft<U1, U23>, Shleft<U1, U13>>, U1>;
35pub type SpecD = U13;
36pub type QMinus1 = Diff<SpecQ, U1>;
37pub type BitlenQMinusD = Diff<Length<SpecQ>, SpecD>;
38pub type Pow2DMinus1 = Shleft<U1, Diff<SpecD, U1>>;
39pub type Pow2DMinus1Minus1 = Diff<Pow2DMinus1, U1>;
40
41/// An integer that describes a bit length to be used in sampling
42pub trait SamplingSize: ArraySize + Len {
43    const ETA: Eta;
44}
45
46#[derive(Copy, Clone)]
47pub enum Eta {
48    Two,
49    Four,
50}
51
52impl SamplingSize for U2 {
53    const ETA: Eta = Eta::Two;
54}
55
56impl SamplingSize for U4 {
57    const ETA: Eta = Eta::Four;
58}
59
60/// An integer that describes a mask sampling size
61pub trait MaskSamplingSize: Unsigned {
62    type SampleSize: ArraySize;
63
64    fn unpack(v: &Array<u8, Self::SampleSize>) -> Polynomial;
65}
66
67impl<G> MaskSamplingSize for G
68where
69    G: Unsigned + Sub<U1>,
70    (Diff<G, U1>, G): RangeEncodingSize,
71{
72    type SampleSize = RangeEncodedPolynomialSize<Diff<G, U1>, G>;
73
74    fn unpack(v: &Array<u8, Self::SampleSize>) -> Polynomial {
75        BitPack::<Diff<G, U1>, G>::unpack(v)
76    }
77}
78
79/// A `ParameterSet` captures the parameters that describe a particular instance of ML-DSA.  There
80/// are three variants, corresponding to three different security levels.
81pub trait ParameterSet {
82    /// Number of rows in the A matrix
83    type K: ArraySize;
84
85    /// Number of columns in the A matrix
86    type L: ArraySize;
87
88    /// Private key range
89    type Eta: SamplingSize;
90
91    /// Error size bound for y
92    type Gamma1: MaskSamplingSize;
93
94    /// Low-order rounding range
95    type Gamma2: Unsigned;
96
97    /// Low-order rounding range (2 * gamma2 in terms of the spec)
98    type TwoGamma2: Unsigned;
99
100    /// Encoding width of the W1 polynomial, namely bitlen((q - 1) / (2 * gamma2) - 1)
101    type W1Bits: EncodingSize;
102
103    /// Collision strength of `c_tilde`, in bytes (lambda / 4 in the spec)
104    type Lambda: ArraySize;
105
106    /// Max number of true values in the hint
107    type Omega: ArraySize;
108
109    /// Number of nonzero values in the polynomial c
110    const TAU: usize;
111
112    /// Beta = Tau * Eta
113    #[allow(clippy::as_conversions)]
114    #[allow(clippy::cast_possible_truncation)]
115    const BETA: u32 = (Self::TAU as u32) * Self::Eta::U32;
116}
117
118pub trait SigningKeyParams: ParameterSet {
119    type S1Size: ArraySize;
120    type S2Size: ArraySize;
121    type T0Size: ArraySize;
122    type SigningKeySize: ArraySize;
123
124    fn encode_s1(s1: &Vector<Self::L>) -> EncodedS1<Self>;
125    fn decode_s1(enc: &EncodedS1<Self>) -> Vector<Self::L>;
126
127    fn encode_s2(s2: &Vector<Self::K>) -> EncodedS2<Self>;
128    fn decode_s2(enc: &EncodedS2<Self>) -> Vector<Self::K>;
129
130    fn encode_t0(t0: &Vector<Self::K>) -> EncodedT0<Self>;
131    fn decode_t0(enc: &EncodedT0<Self>) -> Vector<Self::K>;
132
133    fn concat_sk(
134        rho: B32,
135        K: B32,
136        tr: B64,
137        s1: EncodedS1<Self>,
138        s2: EncodedS2<Self>,
139        t0: EncodedT0<Self>,
140    ) -> EncodedSigningKey<Self>;
141    fn split_sk(
142        enc: &EncodedSigningKey<Self>,
143    ) -> (
144        &B32,
145        &B32,
146        &B64,
147        &EncodedS1<Self>,
148        &EncodedS2<Self>,
149        &EncodedT0<Self>,
150    );
151}
152
153pub type EncodedS1<P> = Array<u8, <P as SigningKeyParams>::S1Size>;
154pub type EncodedS2<P> = Array<u8, <P as SigningKeyParams>::S2Size>;
155pub type EncodedT0<P> = Array<u8, <P as SigningKeyParams>::T0Size>;
156
157pub type SigningKeySize<P> = <P as SigningKeyParams>::SigningKeySize;
158
159/// A signing key encoded as a byte array
160pub type EncodedSigningKey<P> = Array<u8, SigningKeySize<P>>;
161
162impl<P> SigningKeyParams for P
163where
164    P: ParameterSet,
165    // General rules about Eta
166    P::Eta: Add<P::Eta>,
167    Sum<P::Eta, P::Eta>: Len,
168    Length<Sum<P::Eta, P::Eta>>: EncodingSize,
169    // S1 encoding with Eta (L-size)
170    EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>: Mul<P::L>,
171    Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>: ArraySize
172        + Div<P::L, Output = EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>>
173        + Rem<P::L, Output = U0>,
174    // S2 encoding with Eta (K-size)
175    EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>: Mul<P::K>,
176    Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>: ArraySize
177        + Div<P::K, Output = EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>>
178        + Rem<P::K, Output = U0>,
179    // T0 encoding in -2^{d-1}-1 .. 2^{d-1} (D bits) (416 = 32 * D)
180    U416: Mul<P::K>,
181    Prod<U416, P::K>: ArraySize + Div<P::K, Output = U416> + Rem<P::K, Output = U0>,
182    // Signing key encoding rules
183    U128: Add<Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
184    Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>: ArraySize
185        + Add<Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>>
186        + Sub<U128, Output = Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
187    Sum<
188        Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
189        Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
190    >: ArraySize
191        + Add<Prod<U416, P::K>>
192        + Sub<
193            Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
194            Output = Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
195        >,
196    Sum<
197        Sum<
198            Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
199            Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
200        >,
201        Prod<U416, P::K>,
202    >: ArraySize
203        + Sub<
204            Sum<
205                Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
206                Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
207            >,
208            Output = Prod<U416, P::K>,
209        >,
210{
211    type S1Size = RangeEncodedVectorSize<P::Eta, P::Eta, P::L>;
212    type S2Size = RangeEncodedVectorSize<P::Eta, P::Eta, P::K>;
213    type T0Size = RangeEncodedVectorSize<Pow2DMinus1Minus1, Pow2DMinus1, P::K>;
214    type SigningKeySize = Sum<
215        Sum<
216            Sum<U128, RangeEncodedVectorSize<P::Eta, P::Eta, P::L>>,
217            RangeEncodedVectorSize<P::Eta, P::Eta, P::K>,
218        >,
219        RangeEncodedVectorSize<Pow2DMinus1Minus1, Pow2DMinus1, P::K>,
220    >;
221
222    fn encode_s1(s1: &Vector<Self::L>) -> EncodedS1<Self> {
223        BitPack::<P::Eta, P::Eta>::pack(s1)
224    }
225
226    fn decode_s1(enc: &EncodedS1<Self>) -> Vector<Self::L> {
227        BitPack::<P::Eta, P::Eta>::unpack(enc)
228    }
229
230    fn encode_s2(s2: &Vector<Self::K>) -> EncodedS2<Self> {
231        BitPack::<P::Eta, P::Eta>::pack(s2)
232    }
233
234    fn decode_s2(enc: &EncodedS2<Self>) -> Vector<Self::K> {
235        BitPack::<P::Eta, P::Eta>::unpack(enc)
236    }
237
238    fn encode_t0(t0: &Vector<Self::K>) -> EncodedT0<Self> {
239        BitPack::<Pow2DMinus1Minus1, Pow2DMinus1>::pack(t0)
240    }
241
242    fn decode_t0(enc: &EncodedT0<Self>) -> Vector<Self::K> {
243        BitPack::<Pow2DMinus1Minus1, Pow2DMinus1>::unpack(enc)
244    }
245
246    fn concat_sk(
247        rho: B32,
248        K: B32,
249        tr: B64,
250        s1: EncodedS1<Self>,
251        s2: EncodedS2<Self>,
252        t0: EncodedT0<Self>,
253    ) -> EncodedSigningKey<Self> {
254        rho.concat(K).concat(tr).concat(s1).concat(s2).concat(t0)
255    }
256
257    fn split_sk(
258        enc: &EncodedSigningKey<Self>,
259    ) -> (
260        &B32,
261        &B32,
262        &B64,
263        &EncodedS1<Self>,
264        &EncodedS2<Self>,
265        &EncodedT0<Self>,
266    ) {
267        let (enc, t0) = enc.split_ref();
268        let (enc, s2) = enc.split_ref();
269        let (enc, s1) = enc.split_ref();
270        let (enc, tr) = enc.split_ref::<U64>();
271        let (rho, K) = enc.split_ref();
272        (rho, K, tr, s1, s2, t0)
273    }
274}
275
276pub trait VerifyingKeyParams: ParameterSet {
277    type T1Size: ArraySize;
278    type VerifyingKeySize: ArraySize;
279
280    fn encode_t1(t1: &Vector<Self::K>) -> EncodedT1<Self>;
281    fn decode_t1(enc: &EncodedT1<Self>) -> Vector<Self::K>;
282
283    fn concat_vk(rho: B32, t1: EncodedT1<Self>) -> EncodedVerifyingKey<Self>;
284    fn split_vk(enc: &EncodedVerifyingKey<Self>) -> (&B32, &EncodedT1<Self>);
285}
286
287pub type VerifyingKeySize<P> = <P as VerifyingKeyParams>::VerifyingKeySize;
288
289pub type EncodedT1<P> = Array<u8, <P as VerifyingKeyParams>::T1Size>;
290
291/// A verifying key encoded as a byte array
292pub type EncodedVerifyingKey<P> = Array<u8, VerifyingKeySize<P>>;
293
294impl<P> VerifyingKeyParams for P
295where
296    P: ParameterSet,
297    // T1 encoding rules
298    U320: Mul<P::K>,
299    Prod<U320, P::K>: ArraySize + Div<P::K, Output = U320> + Rem<P::K, Output = U0>,
300    // Verifying key encoding rules
301    U32: Add<Prod<U320, P::K>>,
302    Sum<U32, U32>: ArraySize,
303    Sum<U32, Prod<U320, P::K>>: ArraySize + Sub<U32, Output = Prod<U320, P::K>>,
304{
305    type T1Size = EncodedVectorSize<BitlenQMinusD, P::K>;
306    type VerifyingKeySize = Sum<U32, Self::T1Size>;
307
308    fn encode_t1(t1: &Vector<P::K>) -> EncodedT1<Self> {
309        Encode::<BitlenQMinusD>::encode(t1)
310    }
311
312    fn decode_t1(enc: &EncodedT1<Self>) -> Vector<Self::K> {
313        Encode::<BitlenQMinusD>::decode(enc)
314    }
315
316    fn concat_vk(rho: B32, t1: EncodedT1<Self>) -> EncodedVerifyingKey<Self> {
317        rho.concat(t1)
318    }
319
320    fn split_vk(enc: &EncodedVerifyingKey<Self>) -> (&B32, &EncodedT1<Self>) {
321        enc.split_ref()
322    }
323}
324
325pub trait SignatureParams: ParameterSet {
326    type W1Size: ArraySize;
327    type ZSize: ArraySize;
328    type HintSize: ArraySize;
329    type SignatureSize: ArraySize;
330
331    const GAMMA1_MINUS_BETA: u32;
332    const GAMMA2_MINUS_BETA: u32;
333
334    fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>);
335
336    fn encode_w1(t1: &Vector<Self::K>) -> EncodedW1<Self>;
337    fn decode_w1(enc: &EncodedW1<Self>) -> Vector<Self::K>;
338
339    fn encode_z(z: &Vector<Self::L>) -> EncodedZ<Self>;
340    fn decode_z(enc: &EncodedZ<Self>) -> Vector<Self::L>;
341
342    fn concat_sig(
343        c_tilde: EncodedCTilde<Self>,
344        z: EncodedZ<Self>,
345        h: EncodedHint<Self>,
346    ) -> EncodedSignature<Self>;
347    fn split_sig(
348        enc: &EncodedSignature<Self>,
349    ) -> (&EncodedCTilde<Self>, &EncodedZ<Self>, &EncodedHint<Self>);
350}
351
352pub type SignatureSize<P> = <P as SignatureParams>::SignatureSize;
353
354pub type EncodedCTilde<P> = Array<u8, <P as ParameterSet>::Lambda>;
355pub type EncodedW1<P> = Array<u8, <P as SignatureParams>::W1Size>;
356pub type EncodedZ<P> = Array<u8, <P as SignatureParams>::ZSize>;
357pub type EncodedHintIndices<P> = Array<u8, <P as ParameterSet>::Omega>;
358pub type EncodedHintCuts<P> = Array<u8, <P as ParameterSet>::K>;
359pub type EncodedHint<P> = Array<u8, <P as SignatureParams>::HintSize>;
360
361/// A signature encoded as a byte array
362pub type EncodedSignature<P> = Array<u8, SignatureSize<P>>;
363
364impl<P> SignatureParams for P
365where
366    P: ParameterSet,
367    // W1
368    U32: Mul<P::W1Bits>,
369    EncodedPolynomialSize<P::W1Bits>: Mul<P::K>,
370    Prod<EncodedPolynomialSize<P::W1Bits>, P::K>:
371        ArraySize + Div<P::K, Output = EncodedPolynomialSize<P::W1Bits>> + Rem<P::K, Output = U0>,
372    // Z
373    P::Gamma1: Sub<U1>,
374    (Diff<P::Gamma1, U1>, P::Gamma1): RangeEncodingSize,
375    RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>: Mul<P::L>,
376    Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>: ArraySize
377        + Div<P::L, Output = RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>>
378        + Rem<P::L, Output = U0>,
379    // Hint
380    P::Omega: Add<P::K>,
381    Sum<P::Omega, P::K>: ArraySize + Sub<P::Omega, Output = P::K>,
382    // Signature
383    P::Lambda: Add<Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
384    Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>:
385        ArraySize
386            + Add<Sum<P::Omega, P::K>>
387            + Sub<
388                P::Lambda,
389                Output = Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>,
390            >,
391    Sum<
392        Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
393        Sum<P::Omega, P::K>,
394    >: ArraySize
395        + Sub<
396            Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
397            Output = Sum<P::Omega, P::K>,
398        >,
399{
400    type W1Size = EncodedVectorSize<Self::W1Bits, P::K>;
401    type ZSize = RangeEncodedVectorSize<Diff<P::Gamma1, U1>, P::Gamma1, P::L>;
402    type HintSize = Sum<P::Omega, P::K>;
403    type SignatureSize = Sum<Sum<P::Lambda, Self::ZSize>, Self::HintSize>;
404
405    const GAMMA1_MINUS_BETA: u32 = P::Gamma1::U32 - P::BETA;
406    const GAMMA2_MINUS_BETA: u32 = P::Gamma2::U32 - P::BETA;
407
408    fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>) {
409        y.split_ref()
410    }
411
412    fn encode_w1(w1: &Vector<Self::K>) -> EncodedW1<Self> {
413        Encode::<Self::W1Bits>::encode(w1)
414    }
415
416    fn decode_w1(enc: &EncodedW1<Self>) -> Vector<Self::K> {
417        Encode::<Self::W1Bits>::decode(enc)
418    }
419
420    fn encode_z(z: &Vector<Self::L>) -> EncodedZ<Self> {
421        BitPack::<Diff<P::Gamma1, U1>, P::Gamma1>::pack(z)
422    }
423
424    fn decode_z(enc: &EncodedZ<Self>) -> Vector<Self::L> {
425        BitPack::<Diff<P::Gamma1, U1>, P::Gamma1>::unpack(enc)
426    }
427
428    fn concat_sig(
429        c_tilde: EncodedCTilde<P>,
430        z: EncodedZ<P>,
431        h: EncodedHint<P>,
432    ) -> EncodedSignature<P> {
433        c_tilde.concat(z).concat(h)
434    }
435
436    fn split_sig(enc: &EncodedSignature<P>) -> (&EncodedCTilde<P>, &EncodedZ<P>, &EncodedHint<P>) {
437        let (enc, h) = enc.split_ref();
438        let (c_tilde, z) = enc.split_ref();
439        (c_tilde, z, h)
440    }
441}
442
443/// An instance of `MlDsaParams` defines all of the parameters necessary for ML-DSA operations.
444/// Typically this is done by implementing `ParameterSet` with values that will fit into the
445/// blanket implementations of `SigningKeyParams`, `VerifyingKeyParams`, and `SignatureParams`.
446pub trait MlDsaParams:
447    SigningKeyParams + VerifyingKeyParams + SignatureParams + Debug + Default + PartialEq + Clone
448{
449}
450
451impl<T> MlDsaParams for T where
452    T: SigningKeyParams
453        + VerifyingKeyParams
454        + SignatureParams
455        + Debug
456        + Default
457        + PartialEq
458        + Clone
459{
460}