ml_dsa/
param.rs

1//! This module encapsulates all of the compile-time logic related to parameter-set dependent sizes
2//! of objects.  `ParameterSet` captures the parameters in the form described by the ML-KEM
3//! specification.  `EncodingSize`, `VectorEncodingSize`, and `CbdSamplingSize` are "upstream" of
4//! `ParameterSet`; they provide basic logic about the size of encoded objects.  `PkeParams` and
5//! `KemParams` are "downstream" of `ParameterSet`; they define derived parameters relevant to
6//! K-PKE and ML-KEM.
7//!
8//! While the primary purpose of these traits is to describe the sizes of objects, in order to
9//! avoid leakage of complicated trait bounds, they also need to provide any logic that needs to
10//! know any details about object sizes.  For example, `VectorEncodingSize::flatten` needs to know
11//! that the size of an encoded vector is `K` times the size of an encoded polynomial.
12
13use core::fmt::Debug;
14use core::ops::{Add, Div, Mul, Rem, Sub};
15
16use crate::module_lattice::encode::{
17    ArraySize, Encode, EncodedPolynomialSize, EncodedVectorSize, EncodingSize,
18};
19use hybrid_array::{
20    Array,
21    typenum::{
22        Diff, Len, Length, Prod, Shleft, Sum, U0, U1, U2, U4, U13, U23, U32, U64, U128, U320, U416,
23        Unsigned,
24    },
25};
26
27use crate::algebra::{Polynomial, Vector};
28use crate::encode::{
29    BitPack, RangeEncodedPolynomialSize, RangeEncodedVectorSize, RangeEncodingSize,
30};
31use crate::util::{B32, B64};
32
33/// Some useful compile-time constants
34pub(crate) type SpecQ = Sum<Diff<Shleft<U1, U23>, Shleft<U1, U13>>, U1>;
35pub(crate) type SpecD = U13;
36pub(crate) type QMinus1 = Diff<SpecQ, U1>;
37pub(crate) type BitlenQMinusD = Diff<Length<SpecQ>, SpecD>;
38pub(crate) type Pow2DMinus1 = Shleft<U1, Diff<SpecD, U1>>;
39pub(crate) type Pow2DMinus1Minus1 = Diff<Pow2DMinus1, U1>;
40
41/// An integer that describes a bit length to be used in sampling
42#[expect(unreachable_pub)]
43pub trait SamplingSize: ArraySize + Len {
44    const ETA: Eta;
45}
46
47#[derive(Copy, Clone)]
48pub(crate) enum Eta {
49    Two,
50    Four,
51}
52
53impl SamplingSize for U2 {
54    const ETA: Eta = Eta::Two;
55}
56
57impl SamplingSize for U4 {
58    const ETA: Eta = Eta::Four;
59}
60
61/// An integer that describes a mask sampling size
62#[expect(unreachable_pub)]
63pub trait MaskSamplingSize: Unsigned {
64    type SampleSize: ArraySize;
65
66    fn unpack(v: &Array<u8, Self::SampleSize>) -> Polynomial;
67}
68
69impl<G> MaskSamplingSize for G
70where
71    G: Unsigned + Sub<U1>,
72    (Diff<G, U1>, G): RangeEncodingSize,
73{
74    type SampleSize = RangeEncodedPolynomialSize<Diff<G, U1>, G>;
75
76    fn unpack(v: &Array<u8, Self::SampleSize>) -> Polynomial {
77        BitPack::<Diff<G, U1>, G>::unpack(v)
78    }
79}
80
81/// A `ParameterSet` captures the parameters that describe a particular instance of ML-DSA.  There
82/// are three variants, corresponding to three different security levels.
83pub trait ParameterSet {
84    /// Number of rows in the A matrix
85    type K: ArraySize;
86
87    /// Number of columns in the A matrix
88    type L: ArraySize;
89
90    /// Private key range
91    type Eta: SamplingSize;
92
93    /// Error size bound for y
94    type Gamma1: MaskSamplingSize;
95
96    /// Low-order rounding range
97    type Gamma2: Unsigned;
98
99    /// Low-order rounding range (2 * gamma2 in terms of the spec)
100    type TwoGamma2: Unsigned;
101
102    /// Encoding width of the W1 polynomial, namely bitlen((q - 1) / (2 * gamma2) - 1)
103    type W1Bits: EncodingSize;
104
105    /// Collision strength of `c_tilde`, in bytes (lambda / 4 in the spec)
106    type Lambda: ArraySize;
107
108    /// Max number of true values in the hint
109    type Omega: ArraySize;
110
111    /// Number of nonzero values in the polynomial c
112    const TAU: usize;
113
114    /// Beta = Tau * Eta
115    #[allow(clippy::as_conversions)]
116    #[allow(clippy::cast_possible_truncation)]
117    const BETA: u32 = (Self::TAU as u32) * Self::Eta::U32;
118}
119
120pub trait SigningKeyParams: ParameterSet {
121    type S1Size: ArraySize;
122    type S2Size: ArraySize;
123    type T0Size: ArraySize;
124    type SigningKeySize: ArraySize;
125
126    fn encode_s1(s1: &Vector<Self::L>) -> EncodedS1<Self>;
127    fn decode_s1(enc: &EncodedS1<Self>) -> Vector<Self::L>;
128
129    fn encode_s2(s2: &Vector<Self::K>) -> EncodedS2<Self>;
130    fn decode_s2(enc: &EncodedS2<Self>) -> Vector<Self::K>;
131
132    fn encode_t0(t0: &Vector<Self::K>) -> EncodedT0<Self>;
133    fn decode_t0(enc: &EncodedT0<Self>) -> Vector<Self::K>;
134
135    fn concat_sk(
136        rho: B32,
137        K: B32,
138        tr: B64,
139        s1: EncodedS1<Self>,
140        s2: EncodedS2<Self>,
141        t0: EncodedT0<Self>,
142    ) -> EncodedSigningKey<Self>;
143    fn split_sk(
144        enc: &EncodedSigningKey<Self>,
145    ) -> (
146        &B32,
147        &B32,
148        &B64,
149        &EncodedS1<Self>,
150        &EncodedS2<Self>,
151        &EncodedT0<Self>,
152    );
153}
154
155pub(crate) type EncodedS1<P> = Array<u8, <P as SigningKeyParams>::S1Size>;
156pub(crate) type EncodedS2<P> = Array<u8, <P as SigningKeyParams>::S2Size>;
157pub(crate) type EncodedT0<P> = Array<u8, <P as SigningKeyParams>::T0Size>;
158
159pub(crate) type SigningKeySize<P> = <P as SigningKeyParams>::SigningKeySize;
160
161/// A signing key encoded as a byte array
162pub type EncodedSigningKey<P> = Array<u8, SigningKeySize<P>>;
163
164impl<P> SigningKeyParams for P
165where
166    P: ParameterSet,
167    // General rules about Eta
168    P::Eta: Add<P::Eta>,
169    Sum<P::Eta, P::Eta>: Len,
170    Length<Sum<P::Eta, P::Eta>>: EncodingSize,
171    // S1 encoding with Eta (L-size)
172    EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>: Mul<P::L>,
173    Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>: ArraySize
174        + Div<P::L, Output = EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>>
175        + Rem<P::L, Output = U0>,
176    // S2 encoding with Eta (K-size)
177    EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>: Mul<P::K>,
178    Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>: ArraySize
179        + Div<P::K, Output = EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>>
180        + Rem<P::K, Output = U0>,
181    // T0 encoding in -2^{d-1}-1 .. 2^{d-1} (D bits) (416 = 32 * D)
182    U416: Mul<P::K>,
183    Prod<U416, P::K>: ArraySize + Div<P::K, Output = U416> + Rem<P::K, Output = U0>,
184    // Signing key encoding rules
185    U128: Add<Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
186    Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>: ArraySize
187        + Add<Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>>
188        + Sub<U128, Output = Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
189    Sum<
190        Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
191        Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
192    >: ArraySize
193        + Add<Prod<U416, P::K>>
194        + Sub<
195            Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
196            Output = Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
197        >,
198    Sum<
199        Sum<
200            Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
201            Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
202        >,
203        Prod<U416, P::K>,
204    >: ArraySize
205        + Sub<
206            Sum<
207                Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
208                Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
209            >,
210            Output = Prod<U416, P::K>,
211        >,
212{
213    type S1Size = RangeEncodedVectorSize<P::Eta, P::Eta, P::L>;
214    type S2Size = RangeEncodedVectorSize<P::Eta, P::Eta, P::K>;
215    type T0Size = RangeEncodedVectorSize<Pow2DMinus1Minus1, Pow2DMinus1, P::K>;
216    type SigningKeySize = Sum<
217        Sum<
218            Sum<U128, RangeEncodedVectorSize<P::Eta, P::Eta, P::L>>,
219            RangeEncodedVectorSize<P::Eta, P::Eta, P::K>,
220        >,
221        RangeEncodedVectorSize<Pow2DMinus1Minus1, Pow2DMinus1, P::K>,
222    >;
223
224    fn encode_s1(s1: &Vector<Self::L>) -> EncodedS1<Self> {
225        BitPack::<P::Eta, P::Eta>::pack(s1)
226    }
227
228    fn decode_s1(enc: &EncodedS1<Self>) -> Vector<Self::L> {
229        BitPack::<P::Eta, P::Eta>::unpack(enc)
230    }
231
232    fn encode_s2(s2: &Vector<Self::K>) -> EncodedS2<Self> {
233        BitPack::<P::Eta, P::Eta>::pack(s2)
234    }
235
236    fn decode_s2(enc: &EncodedS2<Self>) -> Vector<Self::K> {
237        BitPack::<P::Eta, P::Eta>::unpack(enc)
238    }
239
240    fn encode_t0(t0: &Vector<Self::K>) -> EncodedT0<Self> {
241        BitPack::<Pow2DMinus1Minus1, Pow2DMinus1>::pack(t0)
242    }
243
244    fn decode_t0(enc: &EncodedT0<Self>) -> Vector<Self::K> {
245        BitPack::<Pow2DMinus1Minus1, Pow2DMinus1>::unpack(enc)
246    }
247
248    fn concat_sk(
249        rho: B32,
250        K: B32,
251        tr: B64,
252        s1: EncodedS1<Self>,
253        s2: EncodedS2<Self>,
254        t0: EncodedT0<Self>,
255    ) -> EncodedSigningKey<Self> {
256        rho.concat(K).concat(tr).concat(s1).concat(s2).concat(t0)
257    }
258
259    fn split_sk(
260        enc: &EncodedSigningKey<Self>,
261    ) -> (
262        &B32,
263        &B32,
264        &B64,
265        &EncodedS1<Self>,
266        &EncodedS2<Self>,
267        &EncodedT0<Self>,
268    ) {
269        let (enc, t0) = enc.split_ref();
270        let (enc, s2) = enc.split_ref();
271        let (enc, s1) = enc.split_ref();
272        let (enc, tr) = enc.split_ref::<U64>();
273        let (rho, K) = enc.split_ref();
274        (rho, K, tr, s1, s2, t0)
275    }
276}
277
278pub trait VerifyingKeyParams: ParameterSet {
279    type T1Size: ArraySize;
280    type VerifyingKeySize: ArraySize;
281
282    fn encode_t1(t1: &Vector<Self::K>) -> EncodedT1<Self>;
283    fn decode_t1(enc: &EncodedT1<Self>) -> Vector<Self::K>;
284
285    fn concat_vk(rho: B32, t1: EncodedT1<Self>) -> EncodedVerifyingKey<Self>;
286    fn split_vk(enc: &EncodedVerifyingKey<Self>) -> (&B32, &EncodedT1<Self>);
287}
288
289pub(crate) type VerifyingKeySize<P> = <P as VerifyingKeyParams>::VerifyingKeySize;
290
291pub(crate) type EncodedT1<P> = Array<u8, <P as VerifyingKeyParams>::T1Size>;
292
293/// A verifying key encoded as a byte array
294pub type EncodedVerifyingKey<P> = Array<u8, VerifyingKeySize<P>>;
295
296impl<P> VerifyingKeyParams for P
297where
298    P: ParameterSet,
299    // T1 encoding rules
300    U320: Mul<P::K>,
301    Prod<U320, P::K>: ArraySize + Div<P::K, Output = U320> + Rem<P::K, Output = U0>,
302    // Verifying key encoding rules
303    U32: Add<Prod<U320, P::K>>,
304    Sum<U32, U32>: ArraySize,
305    Sum<U32, Prod<U320, P::K>>: ArraySize + Sub<U32, Output = Prod<U320, P::K>>,
306{
307    type T1Size = EncodedVectorSize<BitlenQMinusD, P::K>;
308    type VerifyingKeySize = Sum<U32, Self::T1Size>;
309
310    fn encode_t1(t1: &Vector<P::K>) -> EncodedT1<Self> {
311        Encode::<BitlenQMinusD>::encode(t1)
312    }
313
314    fn decode_t1(enc: &EncodedT1<Self>) -> Vector<Self::K> {
315        Encode::<BitlenQMinusD>::decode(enc)
316    }
317
318    fn concat_vk(rho: B32, t1: EncodedT1<Self>) -> EncodedVerifyingKey<Self> {
319        rho.concat(t1)
320    }
321
322    fn split_vk(enc: &EncodedVerifyingKey<Self>) -> (&B32, &EncodedT1<Self>) {
323        enc.split_ref()
324    }
325}
326
327pub trait SignatureParams: ParameterSet {
328    type W1Size: ArraySize;
329    type ZSize: ArraySize;
330    type HintSize: ArraySize;
331    type SignatureSize: ArraySize;
332
333    const GAMMA1_MINUS_BETA: u32;
334    const GAMMA2_MINUS_BETA: u32;
335
336    fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>);
337
338    fn encode_w1(t1: &Vector<Self::K>) -> EncodedW1<Self>;
339    fn decode_w1(enc: &EncodedW1<Self>) -> Vector<Self::K>;
340
341    fn encode_z(z: &Vector<Self::L>) -> EncodedZ<Self>;
342    fn decode_z(enc: &EncodedZ<Self>) -> Vector<Self::L>;
343
344    fn concat_sig(
345        c_tilde: EncodedCTilde<Self>,
346        z: EncodedZ<Self>,
347        h: EncodedHint<Self>,
348    ) -> EncodedSignature<Self>;
349    fn split_sig(
350        enc: &EncodedSignature<Self>,
351    ) -> (&EncodedCTilde<Self>, &EncodedZ<Self>, &EncodedHint<Self>);
352}
353
354pub(crate) type SignatureSize<P> = <P as SignatureParams>::SignatureSize;
355
356pub(crate) type EncodedCTilde<P> = Array<u8, <P as ParameterSet>::Lambda>;
357pub(crate) type EncodedW1<P> = Array<u8, <P as SignatureParams>::W1Size>;
358pub(crate) type EncodedZ<P> = Array<u8, <P as SignatureParams>::ZSize>;
359pub(crate) type EncodedHintIndices<P> = Array<u8, <P as ParameterSet>::Omega>;
360pub(crate) type EncodedHintCuts<P> = Array<u8, <P as ParameterSet>::K>;
361pub(crate) type EncodedHint<P> = Array<u8, <P as SignatureParams>::HintSize>;
362
363/// A signature encoded as a byte array
364pub type EncodedSignature<P> = Array<u8, SignatureSize<P>>;
365
366impl<P> SignatureParams for P
367where
368    P: ParameterSet,
369    // W1
370    U32: Mul<P::W1Bits>,
371    EncodedPolynomialSize<P::W1Bits>: Mul<P::K>,
372    Prod<EncodedPolynomialSize<P::W1Bits>, P::K>:
373        ArraySize + Div<P::K, Output = EncodedPolynomialSize<P::W1Bits>> + Rem<P::K, Output = U0>,
374    // Z
375    P::Gamma1: Sub<U1>,
376    (Diff<P::Gamma1, U1>, P::Gamma1): RangeEncodingSize,
377    RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>: Mul<P::L>,
378    Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>: ArraySize
379        + Div<P::L, Output = RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>>
380        + Rem<P::L, Output = U0>,
381    // Hint
382    P::Omega: Add<P::K>,
383    Sum<P::Omega, P::K>: ArraySize + Sub<P::Omega, Output = P::K>,
384    // Signature
385    P::Lambda: Add<Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
386    Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>:
387        ArraySize
388            + Add<Sum<P::Omega, P::K>>
389            + Sub<
390                P::Lambda,
391                Output = Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>,
392            >,
393    Sum<
394        Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
395        Sum<P::Omega, P::K>,
396    >: ArraySize
397        + Sub<
398            Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
399            Output = Sum<P::Omega, P::K>,
400        >,
401{
402    type W1Size = EncodedVectorSize<Self::W1Bits, P::K>;
403    type ZSize = RangeEncodedVectorSize<Diff<P::Gamma1, U1>, P::Gamma1, P::L>;
404    type HintSize = Sum<P::Omega, P::K>;
405    type SignatureSize = Sum<Sum<P::Lambda, Self::ZSize>, Self::HintSize>;
406
407    const GAMMA1_MINUS_BETA: u32 = P::Gamma1::U32 - P::BETA;
408    const GAMMA2_MINUS_BETA: u32 = P::Gamma2::U32 - P::BETA;
409
410    fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>) {
411        y.split_ref()
412    }
413
414    fn encode_w1(w1: &Vector<Self::K>) -> EncodedW1<Self> {
415        Encode::<Self::W1Bits>::encode(w1)
416    }
417
418    fn decode_w1(enc: &EncodedW1<Self>) -> Vector<Self::K> {
419        Encode::<Self::W1Bits>::decode(enc)
420    }
421
422    fn encode_z(z: &Vector<Self::L>) -> EncodedZ<Self> {
423        BitPack::<Diff<P::Gamma1, U1>, P::Gamma1>::pack(z)
424    }
425
426    fn decode_z(enc: &EncodedZ<Self>) -> Vector<Self::L> {
427        BitPack::<Diff<P::Gamma1, U1>, P::Gamma1>::unpack(enc)
428    }
429
430    fn concat_sig(
431        c_tilde: EncodedCTilde<P>,
432        z: EncodedZ<P>,
433        h: EncodedHint<P>,
434    ) -> EncodedSignature<P> {
435        c_tilde.concat(z).concat(h)
436    }
437
438    fn split_sig(enc: &EncodedSignature<P>) -> (&EncodedCTilde<P>, &EncodedZ<P>, &EncodedHint<P>) {
439        let (enc, h) = enc.split_ref();
440        let (c_tilde, z) = enc.split_ref();
441        (c_tilde, z, h)
442    }
443}
444
445/// An instance of `MlDsaParams` defines all of the parameters necessary for ML-DSA operations.
446/// Typically this is done by implementing `ParameterSet` with values that will fit into the
447/// blanket implementations of `SigningKeyParams`, `VerifyingKeyParams`, and `SignatureParams`.
448pub trait MlDsaParams:
449    SigningKeyParams + VerifyingKeyParams + SignatureParams + Debug + Default + PartialEq + Clone
450{
451}
452
453impl<T> MlDsaParams for T where
454    T: SigningKeyParams
455        + VerifyingKeyParams
456        + SignatureParams
457        + Debug
458        + Default
459        + PartialEq
460        + Clone
461{
462}