Skip to main content

ml_dsa/
param.rs

1//! This module encapsulates all of the compile-time logic related to parameter-set dependent sizes
2//! of objects.  `ParameterSet` captures the parameters in the form described by the ML-KEM
3//! specification.  `EncodingSize`, `VectorEncodingSize`, and `CbdSamplingSize` are "upstream" of
4//! `ParameterSet`; they provide basic logic about the size of encoded objects.
5//!
6//! While the primary purpose of these traits is to describe the sizes of objects, in order to
7//! avoid leakage of complicated trait bounds, they also need to provide any logic that needs to
8//! know any details about object sizes.  For example, `VectorEncodingSize::flatten` needs to know
9//! that the size of an encoded vector is `K` times the size of an encoded polynomial.
10
11use crate::{
12    B32, B64,
13    algebra::{Polynomial, Vector},
14    encode::{BitPack, RangeEncodedPolynomialSize, RangeEncodedVectorSize, RangeEncodingSize},
15};
16use core::{
17    fmt::Debug,
18    ops::{Add, Div, Mul, Rem, Sub},
19};
20use hybrid_array::{
21    Array,
22    typenum::{
23        Diff, Len, Length, Prod, Shleft, Sum, U0, U1, U2, U4, U13, U23, U32, U64, U128, U320, U416,
24        Unsigned,
25    },
26};
27use module_lattice::{ArraySize, Encode, EncodedPolynomialSize, EncodedVectorSize, EncodingSize};
28
29/// Some useful compile-time constants
30pub(crate) type SpecQ = Sum<Diff<Shleft<U1, U23>, Shleft<U1, U13>>, U1>;
31pub(crate) type SpecD = U13;
32pub(crate) type QMinus1 = Diff<SpecQ, U1>;
33pub(crate) type BitlenQMinusD = Diff<Length<SpecQ>, SpecD>;
34pub(crate) type Pow2DMinus1 = Shleft<U1, Diff<SpecD, U1>>;
35pub(crate) type Pow2DMinus1Minus1 = Diff<Pow2DMinus1, U1>;
36
37/// An integer that describes a bit length to be used in sampling
38#[expect(unreachable_pub)]
39pub trait SamplingSize: ArraySize + Len {
40    const ETA: Eta;
41}
42
43#[derive(Copy, Clone)]
44pub(crate) enum Eta {
45    Two,
46    Four,
47}
48
49impl SamplingSize for U2 {
50    const ETA: Eta = Eta::Two;
51}
52
53impl SamplingSize for U4 {
54    const ETA: Eta = Eta::Four;
55}
56
57/// An integer that describes a mask sampling size
58#[expect(unreachable_pub)]
59pub trait MaskSamplingSize: Unsigned {
60    type SampleSize: ArraySize;
61
62    fn unpack(v: &Array<u8, Self::SampleSize>) -> Polynomial;
63}
64
65impl<G> MaskSamplingSize for G
66where
67    G: Unsigned + Sub<U1>,
68    (Diff<G, U1>, G): RangeEncodingSize,
69{
70    type SampleSize = RangeEncodedPolynomialSize<Diff<G, U1>, G>;
71
72    fn unpack(v: &Array<u8, Self::SampleSize>) -> Polynomial {
73        BitPack::<Diff<G, U1>, G>::unpack(v)
74    }
75}
76
77/// A `ParameterSet` captures the parameters that describe a particular instance of ML-DSA.  There
78/// are three variants, corresponding to three different security levels.
79pub trait ParameterSet {
80    /// Number of rows in the A matrix
81    type K: ArraySize;
82
83    /// Number of columns in the A matrix
84    type L: ArraySize;
85
86    /// Private key range
87    type Eta: SamplingSize;
88
89    /// Error size bound for y
90    type Gamma1: MaskSamplingSize;
91
92    /// Low-order rounding range
93    type Gamma2: Unsigned;
94
95    /// Low-order rounding range (2 * gamma2 in terms of the spec)
96    type TwoGamma2: Unsigned;
97
98    /// Encoding width of the W1 polynomial, namely bitlen((q - 1) / (2 * gamma2) - 1)
99    type W1Bits: EncodingSize;
100
101    /// Collision strength of `c_tilde`, in bytes (lambda / 4 in the spec)
102    type Lambda: ArraySize;
103
104    /// Max number of true values in the hint
105    type Omega: ArraySize;
106
107    /// Number of nonzero values in the polynomial c
108    const TAU: usize;
109
110    /// Beta = Tau * Eta
111    #[allow(clippy::as_conversions)]
112    #[allow(clippy::cast_possible_truncation)]
113    const BETA: u32 = (Self::TAU as u32) * Self::Eta::U32;
114}
115
116pub trait SigningKeyParams: ParameterSet {
117    type S1Size: ArraySize;
118    type S2Size: ArraySize;
119    type T0Size: ArraySize;
120    type SigningKeySize: ArraySize;
121
122    fn encode_s1(s1: &Vector<Self::L>) -> EncodedS1<Self>;
123    fn decode_s1(enc: &EncodedS1<Self>) -> Vector<Self::L>;
124
125    fn encode_s2(s2: &Vector<Self::K>) -> EncodedS2<Self>;
126    fn decode_s2(enc: &EncodedS2<Self>) -> Vector<Self::K>;
127
128    fn encode_t0(t0: &Vector<Self::K>) -> EncodedT0<Self>;
129    fn decode_t0(enc: &EncodedT0<Self>) -> Vector<Self::K>;
130
131    fn concat_sk(
132        rho: B32,
133        K: B32,
134        tr: B64,
135        s1: EncodedS1<Self>,
136        s2: EncodedS2<Self>,
137        t0: EncodedT0<Self>,
138    ) -> ExpandedSigningKey<Self>;
139    fn split_sk(
140        enc: &ExpandedSigningKey<Self>,
141    ) -> (
142        &B32,
143        &B32,
144        &B64,
145        &EncodedS1<Self>,
146        &EncodedS2<Self>,
147        &EncodedT0<Self>,
148    );
149}
150
151pub(crate) type EncodedS1<P> = Array<u8, <P as SigningKeyParams>::S1Size>;
152pub(crate) type EncodedS2<P> = Array<u8, <P as SigningKeyParams>::S2Size>;
153pub(crate) type EncodedT0<P> = Array<u8, <P as SigningKeyParams>::T0Size>;
154
155pub(crate) type SigningKeySize<P> = <P as SigningKeyParams>::SigningKeySize;
156
157/// A signing key encoded as a byte array
158pub type ExpandedSigningKey<P> = Array<u8, SigningKeySize<P>>;
159
160impl<P> SigningKeyParams for P
161where
162    P: ParameterSet,
163    // General rules about Eta
164    P::Eta: Add<P::Eta>,
165    Sum<P::Eta, P::Eta>: Len,
166    Length<Sum<P::Eta, P::Eta>>: EncodingSize,
167    // S1 encoding with Eta (L-size)
168    EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>: Mul<P::L>,
169    Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>: ArraySize
170        + Div<P::L, Output = EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>>
171        + Rem<P::L, Output = U0>,
172    // S2 encoding with Eta (K-size)
173    EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>: Mul<P::K>,
174    Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>: ArraySize
175        + Div<P::K, Output = EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>>
176        + Rem<P::K, Output = U0>,
177    // T0 encoding in -2^{d-1}-1 .. 2^{d-1} (D bits) (416 = 32 * D)
178    U416: Mul<P::K>,
179    Prod<U416, P::K>: ArraySize + Div<P::K, Output = U416> + Rem<P::K, Output = U0>,
180    // Signing key encoding rules
181    U128: Add<Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
182    Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>: ArraySize
183        + Add<Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>>
184        + Sub<U128, Output = Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
185    Sum<
186        Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
187        Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
188    >: ArraySize
189        + Add<Prod<U416, P::K>>
190        + Sub<
191            Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
192            Output = Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
193        >,
194    Sum<
195        Sum<
196            Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
197            Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
198        >,
199        Prod<U416, P::K>,
200    >: ArraySize
201        + Sub<
202            Sum<
203                Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
204                Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
205            >,
206            Output = Prod<U416, P::K>,
207        >,
208{
209    type S1Size = RangeEncodedVectorSize<P::Eta, P::Eta, P::L>;
210    type S2Size = RangeEncodedVectorSize<P::Eta, P::Eta, P::K>;
211    type T0Size = RangeEncodedVectorSize<Pow2DMinus1Minus1, Pow2DMinus1, P::K>;
212    type SigningKeySize = Sum<
213        Sum<
214            Sum<U128, RangeEncodedVectorSize<P::Eta, P::Eta, P::L>>,
215            RangeEncodedVectorSize<P::Eta, P::Eta, P::K>,
216        >,
217        RangeEncodedVectorSize<Pow2DMinus1Minus1, Pow2DMinus1, P::K>,
218    >;
219
220    fn encode_s1(s1: &Vector<Self::L>) -> EncodedS1<Self> {
221        BitPack::<P::Eta, P::Eta>::pack(s1)
222    }
223
224    fn decode_s1(enc: &EncodedS1<Self>) -> Vector<Self::L> {
225        BitPack::<P::Eta, P::Eta>::unpack(enc)
226    }
227
228    fn encode_s2(s2: &Vector<Self::K>) -> EncodedS2<Self> {
229        BitPack::<P::Eta, P::Eta>::pack(s2)
230    }
231
232    fn decode_s2(enc: &EncodedS2<Self>) -> Vector<Self::K> {
233        BitPack::<P::Eta, P::Eta>::unpack(enc)
234    }
235
236    fn encode_t0(t0: &Vector<Self::K>) -> EncodedT0<Self> {
237        BitPack::<Pow2DMinus1Minus1, Pow2DMinus1>::pack(t0)
238    }
239
240    fn decode_t0(enc: &EncodedT0<Self>) -> Vector<Self::K> {
241        BitPack::<Pow2DMinus1Minus1, Pow2DMinus1>::unpack(enc)
242    }
243
244    fn concat_sk(
245        rho: B32,
246        K: B32,
247        tr: B64,
248        s1: EncodedS1<Self>,
249        s2: EncodedS2<Self>,
250        t0: EncodedT0<Self>,
251    ) -> ExpandedSigningKey<Self> {
252        rho.concat(K).concat(tr).concat(s1).concat(s2).concat(t0)
253    }
254
255    fn split_sk(
256        enc: &ExpandedSigningKey<Self>,
257    ) -> (
258        &B32,
259        &B32,
260        &B64,
261        &EncodedS1<Self>,
262        &EncodedS2<Self>,
263        &EncodedT0<Self>,
264    ) {
265        let (enc, t0) = enc.split_ref();
266        let (enc, s2) = enc.split_ref();
267        let (enc, s1) = enc.split_ref();
268        let (enc, tr) = enc.split_ref::<U64>();
269        let (rho, K) = enc.split_ref();
270        (rho, K, tr, s1, s2, t0)
271    }
272}
273
274pub trait VerifyingKeyParams: ParameterSet {
275    type T1Size: ArraySize;
276    type VerifyingKeySize: ArraySize;
277
278    fn encode_t1(t1: &Vector<Self::K>) -> EncodedT1<Self>;
279    fn decode_t1(enc: &EncodedT1<Self>) -> Vector<Self::K>;
280
281    fn concat_vk(rho: B32, t1: EncodedT1<Self>) -> EncodedVerifyingKey<Self>;
282    fn split_vk(enc: &EncodedVerifyingKey<Self>) -> (&B32, &EncodedT1<Self>);
283}
284
285pub(crate) type VerifyingKeySize<P> = <P as VerifyingKeyParams>::VerifyingKeySize;
286
287pub(crate) type EncodedT1<P> = Array<u8, <P as VerifyingKeyParams>::T1Size>;
288
289/// A verifying key encoded as a byte array
290pub type EncodedVerifyingKey<P> = Array<u8, VerifyingKeySize<P>>;
291
292impl<P> VerifyingKeyParams for P
293where
294    P: ParameterSet,
295    // T1 encoding rules
296    U320: Mul<P::K>,
297    Prod<U320, P::K>: ArraySize + Div<P::K, Output = U320> + Rem<P::K, Output = U0>,
298    // Verifying key encoding rules
299    U32: Add<Prod<U320, P::K>>,
300    Sum<U32, U32>: ArraySize,
301    Sum<U32, Prod<U320, P::K>>: ArraySize + Sub<U32, Output = Prod<U320, P::K>>,
302{
303    type T1Size = EncodedVectorSize<BitlenQMinusD, P::K>;
304    type VerifyingKeySize = Sum<U32, Self::T1Size>;
305
306    fn encode_t1(t1: &Vector<P::K>) -> EncodedT1<Self> {
307        Encode::<BitlenQMinusD>::encode(t1)
308    }
309
310    fn decode_t1(enc: &EncodedT1<Self>) -> Vector<Self::K> {
311        Encode::<BitlenQMinusD>::decode(enc)
312    }
313
314    fn concat_vk(rho: B32, t1: EncodedT1<Self>) -> EncodedVerifyingKey<Self> {
315        rho.concat(t1)
316    }
317
318    fn split_vk(enc: &EncodedVerifyingKey<Self>) -> (&B32, &EncodedT1<Self>) {
319        enc.split_ref()
320    }
321}
322
323pub trait SignatureParams: ParameterSet {
324    type W1Size: ArraySize;
325    type ZSize: ArraySize;
326    type HintSize: ArraySize;
327    type SignatureSize: ArraySize;
328
329    const GAMMA1_MINUS_BETA: u32;
330    const GAMMA2_MINUS_BETA: u32;
331
332    fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>);
333
334    fn encode_w1(t1: &Vector<Self::K>) -> EncodedW1<Self>;
335    fn decode_w1(enc: &EncodedW1<Self>) -> Vector<Self::K>;
336
337    fn encode_z(z: &Vector<Self::L>) -> EncodedZ<Self>;
338    fn decode_z(enc: &EncodedZ<Self>) -> Vector<Self::L>;
339
340    fn concat_sig(
341        c_tilde: EncodedCTilde<Self>,
342        z: EncodedZ<Self>,
343        h: EncodedHint<Self>,
344    ) -> EncodedSignature<Self>;
345    fn split_sig(
346        enc: &EncodedSignature<Self>,
347    ) -> (&EncodedCTilde<Self>, &EncodedZ<Self>, &EncodedHint<Self>);
348}
349
350pub(crate) type SignatureSize<P> = <P as SignatureParams>::SignatureSize;
351
352pub(crate) type EncodedCTilde<P> = Array<u8, <P as ParameterSet>::Lambda>;
353pub(crate) type EncodedW1<P> = Array<u8, <P as SignatureParams>::W1Size>;
354pub(crate) type EncodedZ<P> = Array<u8, <P as SignatureParams>::ZSize>;
355pub(crate) type EncodedHintIndices<P> = Array<u8, <P as ParameterSet>::Omega>;
356pub(crate) type EncodedHintCuts<P> = Array<u8, <P as ParameterSet>::K>;
357pub(crate) type EncodedHint<P> = Array<u8, <P as SignatureParams>::HintSize>;
358
359/// A signature encoded as a byte array
360pub type EncodedSignature<P> = Array<u8, SignatureSize<P>>;
361
362impl<P> SignatureParams for P
363where
364    P: ParameterSet,
365    // W1
366    U32: Mul<P::W1Bits>,
367    EncodedPolynomialSize<P::W1Bits>: Mul<P::K>,
368    Prod<EncodedPolynomialSize<P::W1Bits>, P::K>:
369        ArraySize + Div<P::K, Output = EncodedPolynomialSize<P::W1Bits>> + Rem<P::K, Output = U0>,
370    // Z
371    P::Gamma1: Sub<U1>,
372    (Diff<P::Gamma1, U1>, P::Gamma1): RangeEncodingSize,
373    RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>: Mul<P::L>,
374    Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>: ArraySize
375        + Div<P::L, Output = RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>>
376        + Rem<P::L, Output = U0>,
377    // Hint
378    P::Omega: Add<P::K>,
379    Sum<P::Omega, P::K>: ArraySize + Sub<P::Omega, Output = P::K>,
380    // Signature
381    P::Lambda: Add<Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
382    Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>:
383        ArraySize
384            + Add<Sum<P::Omega, P::K>>
385            + Sub<
386                P::Lambda,
387                Output = Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>,
388            >,
389    Sum<
390        Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
391        Sum<P::Omega, P::K>,
392    >: ArraySize
393        + Sub<
394            Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
395            Output = Sum<P::Omega, P::K>,
396        >,
397{
398    type W1Size = EncodedVectorSize<Self::W1Bits, P::K>;
399    type ZSize = RangeEncodedVectorSize<Diff<P::Gamma1, U1>, P::Gamma1, P::L>;
400    type HintSize = Sum<P::Omega, P::K>;
401    type SignatureSize = Sum<Sum<P::Lambda, Self::ZSize>, Self::HintSize>;
402
403    const GAMMA1_MINUS_BETA: u32 = P::Gamma1::U32 - P::BETA;
404    const GAMMA2_MINUS_BETA: u32 = P::Gamma2::U32 - P::BETA;
405
406    fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>) {
407        y.split_ref()
408    }
409
410    fn encode_w1(w1: &Vector<Self::K>) -> EncodedW1<Self> {
411        Encode::<Self::W1Bits>::encode(w1)
412    }
413
414    fn decode_w1(enc: &EncodedW1<Self>) -> Vector<Self::K> {
415        Encode::<Self::W1Bits>::decode(enc)
416    }
417
418    fn encode_z(z: &Vector<Self::L>) -> EncodedZ<Self> {
419        BitPack::<Diff<P::Gamma1, U1>, P::Gamma1>::pack(z)
420    }
421
422    fn decode_z(enc: &EncodedZ<Self>) -> Vector<Self::L> {
423        BitPack::<Diff<P::Gamma1, U1>, P::Gamma1>::unpack(enc)
424    }
425
426    fn concat_sig(
427        c_tilde: EncodedCTilde<P>,
428        z: EncodedZ<P>,
429        h: EncodedHint<P>,
430    ) -> EncodedSignature<P> {
431        c_tilde.concat(z).concat(h)
432    }
433
434    fn split_sig(enc: &EncodedSignature<P>) -> (&EncodedCTilde<P>, &EncodedZ<P>, &EncodedHint<P>) {
435        let (enc, h) = enc.split_ref();
436        let (c_tilde, z) = enc.split_ref();
437        (c_tilde, z, h)
438    }
439}
440
441/// An instance of `MlDsaParams` defines all of the parameters necessary for ML-DSA operations.
442/// Typically this is done by implementing `ParameterSet` with values that will fit into the
443/// blanket implementations of `SigningKeyParams`, `VerifyingKeyParams`, and `SignatureParams`.
444pub trait MlDsaParams:
445    SigningKeyParams + VerifyingKeyParams + SignatureParams + Debug + Default + PartialEq + Clone
446{
447}
448
449impl<T> MlDsaParams for T where
450    T: SigningKeyParams
451        + VerifyingKeyParams
452        + SignatureParams
453        + Debug
454        + Default
455        + PartialEq
456        + Clone
457{
458}