1use core::fmt::Debug;
14use core::ops::{Add, Div, Mul, Rem, Sub};
15
16use crate::module_lattice::encode::{
17 ArraySize, Encode, EncodedPolynomialSize, EncodedVectorSize, EncodingSize,
18};
19use hybrid_array::{
20 Array,
21 typenum::{
22 Diff, Len, Length, Prod, Shleft, Sum, U0, U1, U2, U4, U13, U23, U32, U64, U128, U320, U416,
23 Unsigned,
24 },
25};
26
27use crate::algebra::{Polynomial, Vector};
28use crate::encode::{
29 BitPack, RangeEncodedPolynomialSize, RangeEncodedVectorSize, RangeEncodingSize,
30};
31use crate::util::{B32, B64};
32
33pub(crate) type SpecQ = Sum<Diff<Shleft<U1, U23>, Shleft<U1, U13>>, U1>;
35pub(crate) type SpecD = U13;
36pub(crate) type QMinus1 = Diff<SpecQ, U1>;
37pub(crate) type BitlenQMinusD = Diff<Length<SpecQ>, SpecD>;
38pub(crate) type Pow2DMinus1 = Shleft<U1, Diff<SpecD, U1>>;
39pub(crate) type Pow2DMinus1Minus1 = Diff<Pow2DMinus1, U1>;
40
41#[expect(unreachable_pub)]
43pub trait SamplingSize: ArraySize + Len {
44 const ETA: Eta;
45}
46
47#[derive(Copy, Clone)]
48pub(crate) enum Eta {
49 Two,
50 Four,
51}
52
53impl SamplingSize for U2 {
54 const ETA: Eta = Eta::Two;
55}
56
57impl SamplingSize for U4 {
58 const ETA: Eta = Eta::Four;
59}
60
61#[expect(unreachable_pub)]
63pub trait MaskSamplingSize: Unsigned {
64 type SampleSize: ArraySize;
65
66 fn unpack(v: &Array<u8, Self::SampleSize>) -> Polynomial;
67}
68
69impl<G> MaskSamplingSize for G
70where
71 G: Unsigned + Sub<U1>,
72 (Diff<G, U1>, G): RangeEncodingSize,
73{
74 type SampleSize = RangeEncodedPolynomialSize<Diff<G, U1>, G>;
75
76 fn unpack(v: &Array<u8, Self::SampleSize>) -> Polynomial {
77 BitPack::<Diff<G, U1>, G>::unpack(v)
78 }
79}
80
81pub trait ParameterSet {
84 type K: ArraySize;
86
87 type L: ArraySize;
89
90 type Eta: SamplingSize;
92
93 type Gamma1: MaskSamplingSize;
95
96 type Gamma2: Unsigned;
98
99 type TwoGamma2: Unsigned;
101
102 type W1Bits: EncodingSize;
104
105 type Lambda: ArraySize;
107
108 type Omega: ArraySize;
110
111 const TAU: usize;
113
114 #[allow(clippy::as_conversions)]
116 #[allow(clippy::cast_possible_truncation)]
117 const BETA: u32 = (Self::TAU as u32) * Self::Eta::U32;
118}
119
120pub trait SigningKeyParams: ParameterSet {
121 type S1Size: ArraySize;
122 type S2Size: ArraySize;
123 type T0Size: ArraySize;
124 type SigningKeySize: ArraySize;
125
126 fn encode_s1(s1: &Vector<Self::L>) -> EncodedS1<Self>;
127 fn decode_s1(enc: &EncodedS1<Self>) -> Vector<Self::L>;
128
129 fn encode_s2(s2: &Vector<Self::K>) -> EncodedS2<Self>;
130 fn decode_s2(enc: &EncodedS2<Self>) -> Vector<Self::K>;
131
132 fn encode_t0(t0: &Vector<Self::K>) -> EncodedT0<Self>;
133 fn decode_t0(enc: &EncodedT0<Self>) -> Vector<Self::K>;
134
135 fn concat_sk(
136 rho: B32,
137 K: B32,
138 tr: B64,
139 s1: EncodedS1<Self>,
140 s2: EncodedS2<Self>,
141 t0: EncodedT0<Self>,
142 ) -> EncodedSigningKey<Self>;
143 fn split_sk(
144 enc: &EncodedSigningKey<Self>,
145 ) -> (
146 &B32,
147 &B32,
148 &B64,
149 &EncodedS1<Self>,
150 &EncodedS2<Self>,
151 &EncodedT0<Self>,
152 );
153}
154
155pub(crate) type EncodedS1<P> = Array<u8, <P as SigningKeyParams>::S1Size>;
156pub(crate) type EncodedS2<P> = Array<u8, <P as SigningKeyParams>::S2Size>;
157pub(crate) type EncodedT0<P> = Array<u8, <P as SigningKeyParams>::T0Size>;
158
159pub(crate) type SigningKeySize<P> = <P as SigningKeyParams>::SigningKeySize;
160
161pub type EncodedSigningKey<P> = Array<u8, SigningKeySize<P>>;
163
164impl<P> SigningKeyParams for P
165where
166 P: ParameterSet,
167 P::Eta: Add<P::Eta>,
169 Sum<P::Eta, P::Eta>: Len,
170 Length<Sum<P::Eta, P::Eta>>: EncodingSize,
171 EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>: Mul<P::L>,
173 Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>: ArraySize
174 + Div<P::L, Output = EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>>
175 + Rem<P::L, Output = U0>,
176 EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>: Mul<P::K>,
178 Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>: ArraySize
179 + Div<P::K, Output = EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>>
180 + Rem<P::K, Output = U0>,
181 U416: Mul<P::K>,
183 Prod<U416, P::K>: ArraySize + Div<P::K, Output = U416> + Rem<P::K, Output = U0>,
184 U128: Add<Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
186 Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>: ArraySize
187 + Add<Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>>
188 + Sub<U128, Output = Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
189 Sum<
190 Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
191 Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
192 >: ArraySize
193 + Add<Prod<U416, P::K>>
194 + Sub<
195 Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
196 Output = Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
197 >,
198 Sum<
199 Sum<
200 Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
201 Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
202 >,
203 Prod<U416, P::K>,
204 >: ArraySize
205 + Sub<
206 Sum<
207 Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
208 Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
209 >,
210 Output = Prod<U416, P::K>,
211 >,
212{
213 type S1Size = RangeEncodedVectorSize<P::Eta, P::Eta, P::L>;
214 type S2Size = RangeEncodedVectorSize<P::Eta, P::Eta, P::K>;
215 type T0Size = RangeEncodedVectorSize<Pow2DMinus1Minus1, Pow2DMinus1, P::K>;
216 type SigningKeySize = Sum<
217 Sum<
218 Sum<U128, RangeEncodedVectorSize<P::Eta, P::Eta, P::L>>,
219 RangeEncodedVectorSize<P::Eta, P::Eta, P::K>,
220 >,
221 RangeEncodedVectorSize<Pow2DMinus1Minus1, Pow2DMinus1, P::K>,
222 >;
223
224 fn encode_s1(s1: &Vector<Self::L>) -> EncodedS1<Self> {
225 BitPack::<P::Eta, P::Eta>::pack(s1)
226 }
227
228 fn decode_s1(enc: &EncodedS1<Self>) -> Vector<Self::L> {
229 BitPack::<P::Eta, P::Eta>::unpack(enc)
230 }
231
232 fn encode_s2(s2: &Vector<Self::K>) -> EncodedS2<Self> {
233 BitPack::<P::Eta, P::Eta>::pack(s2)
234 }
235
236 fn decode_s2(enc: &EncodedS2<Self>) -> Vector<Self::K> {
237 BitPack::<P::Eta, P::Eta>::unpack(enc)
238 }
239
240 fn encode_t0(t0: &Vector<Self::K>) -> EncodedT0<Self> {
241 BitPack::<Pow2DMinus1Minus1, Pow2DMinus1>::pack(t0)
242 }
243
244 fn decode_t0(enc: &EncodedT0<Self>) -> Vector<Self::K> {
245 BitPack::<Pow2DMinus1Minus1, Pow2DMinus1>::unpack(enc)
246 }
247
248 fn concat_sk(
249 rho: B32,
250 K: B32,
251 tr: B64,
252 s1: EncodedS1<Self>,
253 s2: EncodedS2<Self>,
254 t0: EncodedT0<Self>,
255 ) -> EncodedSigningKey<Self> {
256 rho.concat(K).concat(tr).concat(s1).concat(s2).concat(t0)
257 }
258
259 fn split_sk(
260 enc: &EncodedSigningKey<Self>,
261 ) -> (
262 &B32,
263 &B32,
264 &B64,
265 &EncodedS1<Self>,
266 &EncodedS2<Self>,
267 &EncodedT0<Self>,
268 ) {
269 let (enc, t0) = enc.split_ref();
270 let (enc, s2) = enc.split_ref();
271 let (enc, s1) = enc.split_ref();
272 let (enc, tr) = enc.split_ref::<U64>();
273 let (rho, K) = enc.split_ref();
274 (rho, K, tr, s1, s2, t0)
275 }
276}
277
278pub trait VerifyingKeyParams: ParameterSet {
279 type T1Size: ArraySize;
280 type VerifyingKeySize: ArraySize;
281
282 fn encode_t1(t1: &Vector<Self::K>) -> EncodedT1<Self>;
283 fn decode_t1(enc: &EncodedT1<Self>) -> Vector<Self::K>;
284
285 fn concat_vk(rho: B32, t1: EncodedT1<Self>) -> EncodedVerifyingKey<Self>;
286 fn split_vk(enc: &EncodedVerifyingKey<Self>) -> (&B32, &EncodedT1<Self>);
287}
288
289pub(crate) type VerifyingKeySize<P> = <P as VerifyingKeyParams>::VerifyingKeySize;
290
291pub(crate) type EncodedT1<P> = Array<u8, <P as VerifyingKeyParams>::T1Size>;
292
293pub type EncodedVerifyingKey<P> = Array<u8, VerifyingKeySize<P>>;
295
296impl<P> VerifyingKeyParams for P
297where
298 P: ParameterSet,
299 U320: Mul<P::K>,
301 Prod<U320, P::K>: ArraySize + Div<P::K, Output = U320> + Rem<P::K, Output = U0>,
302 U32: Add<Prod<U320, P::K>>,
304 Sum<U32, U32>: ArraySize,
305 Sum<U32, Prod<U320, P::K>>: ArraySize + Sub<U32, Output = Prod<U320, P::K>>,
306{
307 type T1Size = EncodedVectorSize<BitlenQMinusD, P::K>;
308 type VerifyingKeySize = Sum<U32, Self::T1Size>;
309
310 fn encode_t1(t1: &Vector<P::K>) -> EncodedT1<Self> {
311 Encode::<BitlenQMinusD>::encode(t1)
312 }
313
314 fn decode_t1(enc: &EncodedT1<Self>) -> Vector<Self::K> {
315 Encode::<BitlenQMinusD>::decode(enc)
316 }
317
318 fn concat_vk(rho: B32, t1: EncodedT1<Self>) -> EncodedVerifyingKey<Self> {
319 rho.concat(t1)
320 }
321
322 fn split_vk(enc: &EncodedVerifyingKey<Self>) -> (&B32, &EncodedT1<Self>) {
323 enc.split_ref()
324 }
325}
326
327pub trait SignatureParams: ParameterSet {
328 type W1Size: ArraySize;
329 type ZSize: ArraySize;
330 type HintSize: ArraySize;
331 type SignatureSize: ArraySize;
332
333 const GAMMA1_MINUS_BETA: u32;
334 const GAMMA2_MINUS_BETA: u32;
335
336 fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>);
337
338 fn encode_w1(t1: &Vector<Self::K>) -> EncodedW1<Self>;
339 fn decode_w1(enc: &EncodedW1<Self>) -> Vector<Self::K>;
340
341 fn encode_z(z: &Vector<Self::L>) -> EncodedZ<Self>;
342 fn decode_z(enc: &EncodedZ<Self>) -> Vector<Self::L>;
343
344 fn concat_sig(
345 c_tilde: EncodedCTilde<Self>,
346 z: EncodedZ<Self>,
347 h: EncodedHint<Self>,
348 ) -> EncodedSignature<Self>;
349 fn split_sig(
350 enc: &EncodedSignature<Self>,
351 ) -> (&EncodedCTilde<Self>, &EncodedZ<Self>, &EncodedHint<Self>);
352}
353
354pub(crate) type SignatureSize<P> = <P as SignatureParams>::SignatureSize;
355
356pub(crate) type EncodedCTilde<P> = Array<u8, <P as ParameterSet>::Lambda>;
357pub(crate) type EncodedW1<P> = Array<u8, <P as SignatureParams>::W1Size>;
358pub(crate) type EncodedZ<P> = Array<u8, <P as SignatureParams>::ZSize>;
359pub(crate) type EncodedHintIndices<P> = Array<u8, <P as ParameterSet>::Omega>;
360pub(crate) type EncodedHintCuts<P> = Array<u8, <P as ParameterSet>::K>;
361pub(crate) type EncodedHint<P> = Array<u8, <P as SignatureParams>::HintSize>;
362
363pub type EncodedSignature<P> = Array<u8, SignatureSize<P>>;
365
366impl<P> SignatureParams for P
367where
368 P: ParameterSet,
369 U32: Mul<P::W1Bits>,
371 EncodedPolynomialSize<P::W1Bits>: Mul<P::K>,
372 Prod<EncodedPolynomialSize<P::W1Bits>, P::K>:
373 ArraySize + Div<P::K, Output = EncodedPolynomialSize<P::W1Bits>> + Rem<P::K, Output = U0>,
374 P::Gamma1: Sub<U1>,
376 (Diff<P::Gamma1, U1>, P::Gamma1): RangeEncodingSize,
377 RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>: Mul<P::L>,
378 Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>: ArraySize
379 + Div<P::L, Output = RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>>
380 + Rem<P::L, Output = U0>,
381 P::Omega: Add<P::K>,
383 Sum<P::Omega, P::K>: ArraySize + Sub<P::Omega, Output = P::K>,
384 P::Lambda: Add<Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
386 Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>:
387 ArraySize
388 + Add<Sum<P::Omega, P::K>>
389 + Sub<
390 P::Lambda,
391 Output = Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>,
392 >,
393 Sum<
394 Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
395 Sum<P::Omega, P::K>,
396 >: ArraySize
397 + Sub<
398 Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
399 Output = Sum<P::Omega, P::K>,
400 >,
401{
402 type W1Size = EncodedVectorSize<Self::W1Bits, P::K>;
403 type ZSize = RangeEncodedVectorSize<Diff<P::Gamma1, U1>, P::Gamma1, P::L>;
404 type HintSize = Sum<P::Omega, P::K>;
405 type SignatureSize = Sum<Sum<P::Lambda, Self::ZSize>, Self::HintSize>;
406
407 const GAMMA1_MINUS_BETA: u32 = P::Gamma1::U32 - P::BETA;
408 const GAMMA2_MINUS_BETA: u32 = P::Gamma2::U32 - P::BETA;
409
410 fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>) {
411 y.split_ref()
412 }
413
414 fn encode_w1(w1: &Vector<Self::K>) -> EncodedW1<Self> {
415 Encode::<Self::W1Bits>::encode(w1)
416 }
417
418 fn decode_w1(enc: &EncodedW1<Self>) -> Vector<Self::K> {
419 Encode::<Self::W1Bits>::decode(enc)
420 }
421
422 fn encode_z(z: &Vector<Self::L>) -> EncodedZ<Self> {
423 BitPack::<Diff<P::Gamma1, U1>, P::Gamma1>::pack(z)
424 }
425
426 fn decode_z(enc: &EncodedZ<Self>) -> Vector<Self::L> {
427 BitPack::<Diff<P::Gamma1, U1>, P::Gamma1>::unpack(enc)
428 }
429
430 fn concat_sig(
431 c_tilde: EncodedCTilde<P>,
432 z: EncodedZ<P>,
433 h: EncodedHint<P>,
434 ) -> EncodedSignature<P> {
435 c_tilde.concat(z).concat(h)
436 }
437
438 fn split_sig(enc: &EncodedSignature<P>) -> (&EncodedCTilde<P>, &EncodedZ<P>, &EncodedHint<P>) {
439 let (enc, h) = enc.split_ref();
440 let (c_tilde, z) = enc.split_ref();
441 (c_tilde, z, h)
442 }
443}
444
445pub trait MlDsaParams:
449 SigningKeyParams + VerifyingKeyParams + SignatureParams + Debug + Default + PartialEq + Clone
450{
451}
452
453impl<T> MlDsaParams for T where
454 T: SigningKeyParams
455 + VerifyingKeyParams
456 + SignatureParams
457 + Debug
458 + Default
459 + PartialEq
460 + Clone
461{
462}