1use core::fmt::Debug;
14use core::ops::{Add, Div, Mul, Rem, Sub};
15
16use crate::module_lattice::encode::{
17 ArraySize, Encode, EncodedPolynomialSize, EncodedVectorSize, EncodingSize,
18};
19use hybrid_array::{
20 Array,
21 typenum::{
22 Diff, Len, Length, Prod, Shleft, Sum, U0, U1, U2, U4, U13, U23, U32, U64, U128, U320, U416,
23 Unsigned,
24 },
25};
26
27use crate::algebra::{Polynomial, Vector};
28use crate::encode::{
29 BitPack, RangeEncodedPolynomialSize, RangeEncodedVectorSize, RangeEncodingSize,
30};
31use crate::util::{B32, B64};
32
33pub type SpecQ = Sum<Diff<Shleft<U1, U23>, Shleft<U1, U13>>, U1>;
35pub type SpecD = U13;
36pub type QMinus1 = Diff<SpecQ, U1>;
37pub type BitlenQMinusD = Diff<Length<SpecQ>, SpecD>;
38pub type Pow2DMinus1 = Shleft<U1, Diff<SpecD, U1>>;
39pub type Pow2DMinus1Minus1 = Diff<Pow2DMinus1, U1>;
40
41pub trait SamplingSize: ArraySize + Len {
43 const ETA: Eta;
44}
45
46#[derive(Copy, Clone)]
47pub enum Eta {
48 Two,
49 Four,
50}
51
52impl SamplingSize for U2 {
53 const ETA: Eta = Eta::Two;
54}
55
56impl SamplingSize for U4 {
57 const ETA: Eta = Eta::Four;
58}
59
60pub trait MaskSamplingSize: Unsigned {
62 type SampleSize: ArraySize;
63
64 fn unpack(v: &Array<u8, Self::SampleSize>) -> Polynomial;
65}
66
67impl<G> MaskSamplingSize for G
68where
69 G: Unsigned + Sub<U1>,
70 (Diff<G, U1>, G): RangeEncodingSize,
71{
72 type SampleSize = RangeEncodedPolynomialSize<Diff<G, U1>, G>;
73
74 fn unpack(v: &Array<u8, Self::SampleSize>) -> Polynomial {
75 BitPack::<Diff<G, U1>, G>::unpack(v)
76 }
77}
78
79pub trait ParameterSet {
82 type K: ArraySize;
84
85 type L: ArraySize;
87
88 type Eta: SamplingSize;
90
91 type Gamma1: MaskSamplingSize;
93
94 type Gamma2: Unsigned;
96
97 type TwoGamma2: Unsigned;
99
100 type W1Bits: EncodingSize;
102
103 type Lambda: ArraySize;
105
106 type Omega: ArraySize;
108
109 const TAU: usize;
111
112 #[allow(clippy::as_conversions)]
114 #[allow(clippy::cast_possible_truncation)]
115 const BETA: u32 = (Self::TAU as u32) * Self::Eta::U32;
116}
117
118pub trait SigningKeyParams: ParameterSet {
119 type S1Size: ArraySize;
120 type S2Size: ArraySize;
121 type T0Size: ArraySize;
122 type SigningKeySize: ArraySize;
123
124 fn encode_s1(s1: &Vector<Self::L>) -> EncodedS1<Self>;
125 fn decode_s1(enc: &EncodedS1<Self>) -> Vector<Self::L>;
126
127 fn encode_s2(s2: &Vector<Self::K>) -> EncodedS2<Self>;
128 fn decode_s2(enc: &EncodedS2<Self>) -> Vector<Self::K>;
129
130 fn encode_t0(t0: &Vector<Self::K>) -> EncodedT0<Self>;
131 fn decode_t0(enc: &EncodedT0<Self>) -> Vector<Self::K>;
132
133 fn concat_sk(
134 rho: B32,
135 K: B32,
136 tr: B64,
137 s1: EncodedS1<Self>,
138 s2: EncodedS2<Self>,
139 t0: EncodedT0<Self>,
140 ) -> EncodedSigningKey<Self>;
141 fn split_sk(
142 enc: &EncodedSigningKey<Self>,
143 ) -> (
144 &B32,
145 &B32,
146 &B64,
147 &EncodedS1<Self>,
148 &EncodedS2<Self>,
149 &EncodedT0<Self>,
150 );
151}
152
153pub type EncodedS1<P> = Array<u8, <P as SigningKeyParams>::S1Size>;
154pub type EncodedS2<P> = Array<u8, <P as SigningKeyParams>::S2Size>;
155pub type EncodedT0<P> = Array<u8, <P as SigningKeyParams>::T0Size>;
156
157pub type SigningKeySize<P> = <P as SigningKeyParams>::SigningKeySize;
158
159pub type EncodedSigningKey<P> = Array<u8, SigningKeySize<P>>;
161
162impl<P> SigningKeyParams for P
163where
164 P: ParameterSet,
165 P::Eta: Add<P::Eta>,
167 Sum<P::Eta, P::Eta>: Len,
168 Length<Sum<P::Eta, P::Eta>>: EncodingSize,
169 EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>: Mul<P::L>,
171 Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>: ArraySize
172 + Div<P::L, Output = EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>>
173 + Rem<P::L, Output = U0>,
174 EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>: Mul<P::K>,
176 Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>: ArraySize
177 + Div<P::K, Output = EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>>
178 + Rem<P::K, Output = U0>,
179 U416: Mul<P::K>,
181 Prod<U416, P::K>: ArraySize + Div<P::K, Output = U416> + Rem<P::K, Output = U0>,
182 U128: Add<Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
184 Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>: ArraySize
185 + Add<Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>>
186 + Sub<U128, Output = Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
187 Sum<
188 Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
189 Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
190 >: ArraySize
191 + Add<Prod<U416, P::K>>
192 + Sub<
193 Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
194 Output = Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
195 >,
196 Sum<
197 Sum<
198 Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
199 Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
200 >,
201 Prod<U416, P::K>,
202 >: ArraySize
203 + Sub<
204 Sum<
205 Sum<U128, Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::L>>,
206 Prod<EncodedPolynomialSize<Length<Sum<P::Eta, P::Eta>>>, P::K>,
207 >,
208 Output = Prod<U416, P::K>,
209 >,
210{
211 type S1Size = RangeEncodedVectorSize<P::Eta, P::Eta, P::L>;
212 type S2Size = RangeEncodedVectorSize<P::Eta, P::Eta, P::K>;
213 type T0Size = RangeEncodedVectorSize<Pow2DMinus1Minus1, Pow2DMinus1, P::K>;
214 type SigningKeySize = Sum<
215 Sum<
216 Sum<U128, RangeEncodedVectorSize<P::Eta, P::Eta, P::L>>,
217 RangeEncodedVectorSize<P::Eta, P::Eta, P::K>,
218 >,
219 RangeEncodedVectorSize<Pow2DMinus1Minus1, Pow2DMinus1, P::K>,
220 >;
221
222 fn encode_s1(s1: &Vector<Self::L>) -> EncodedS1<Self> {
223 BitPack::<P::Eta, P::Eta>::pack(s1)
224 }
225
226 fn decode_s1(enc: &EncodedS1<Self>) -> Vector<Self::L> {
227 BitPack::<P::Eta, P::Eta>::unpack(enc)
228 }
229
230 fn encode_s2(s2: &Vector<Self::K>) -> EncodedS2<Self> {
231 BitPack::<P::Eta, P::Eta>::pack(s2)
232 }
233
234 fn decode_s2(enc: &EncodedS2<Self>) -> Vector<Self::K> {
235 BitPack::<P::Eta, P::Eta>::unpack(enc)
236 }
237
238 fn encode_t0(t0: &Vector<Self::K>) -> EncodedT0<Self> {
239 BitPack::<Pow2DMinus1Minus1, Pow2DMinus1>::pack(t0)
240 }
241
242 fn decode_t0(enc: &EncodedT0<Self>) -> Vector<Self::K> {
243 BitPack::<Pow2DMinus1Minus1, Pow2DMinus1>::unpack(enc)
244 }
245
246 fn concat_sk(
247 rho: B32,
248 K: B32,
249 tr: B64,
250 s1: EncodedS1<Self>,
251 s2: EncodedS2<Self>,
252 t0: EncodedT0<Self>,
253 ) -> EncodedSigningKey<Self> {
254 rho.concat(K).concat(tr).concat(s1).concat(s2).concat(t0)
255 }
256
257 fn split_sk(
258 enc: &EncodedSigningKey<Self>,
259 ) -> (
260 &B32,
261 &B32,
262 &B64,
263 &EncodedS1<Self>,
264 &EncodedS2<Self>,
265 &EncodedT0<Self>,
266 ) {
267 let (enc, t0) = enc.split_ref();
268 let (enc, s2) = enc.split_ref();
269 let (enc, s1) = enc.split_ref();
270 let (enc, tr) = enc.split_ref::<U64>();
271 let (rho, K) = enc.split_ref();
272 (rho, K, tr, s1, s2, t0)
273 }
274}
275
276pub trait VerifyingKeyParams: ParameterSet {
277 type T1Size: ArraySize;
278 type VerifyingKeySize: ArraySize;
279
280 fn encode_t1(t1: &Vector<Self::K>) -> EncodedT1<Self>;
281 fn decode_t1(enc: &EncodedT1<Self>) -> Vector<Self::K>;
282
283 fn concat_vk(rho: B32, t1: EncodedT1<Self>) -> EncodedVerifyingKey<Self>;
284 fn split_vk(enc: &EncodedVerifyingKey<Self>) -> (&B32, &EncodedT1<Self>);
285}
286
287pub type VerifyingKeySize<P> = <P as VerifyingKeyParams>::VerifyingKeySize;
288
289pub type EncodedT1<P> = Array<u8, <P as VerifyingKeyParams>::T1Size>;
290
291pub type EncodedVerifyingKey<P> = Array<u8, VerifyingKeySize<P>>;
293
294impl<P> VerifyingKeyParams for P
295where
296 P: ParameterSet,
297 U320: Mul<P::K>,
299 Prod<U320, P::K>: ArraySize + Div<P::K, Output = U320> + Rem<P::K, Output = U0>,
300 U32: Add<Prod<U320, P::K>>,
302 Sum<U32, U32>: ArraySize,
303 Sum<U32, Prod<U320, P::K>>: ArraySize + Sub<U32, Output = Prod<U320, P::K>>,
304{
305 type T1Size = EncodedVectorSize<BitlenQMinusD, P::K>;
306 type VerifyingKeySize = Sum<U32, Self::T1Size>;
307
308 fn encode_t1(t1: &Vector<P::K>) -> EncodedT1<Self> {
309 Encode::<BitlenQMinusD>::encode(t1)
310 }
311
312 fn decode_t1(enc: &EncodedT1<Self>) -> Vector<Self::K> {
313 Encode::<BitlenQMinusD>::decode(enc)
314 }
315
316 fn concat_vk(rho: B32, t1: EncodedT1<Self>) -> EncodedVerifyingKey<Self> {
317 rho.concat(t1)
318 }
319
320 fn split_vk(enc: &EncodedVerifyingKey<Self>) -> (&B32, &EncodedT1<Self>) {
321 enc.split_ref()
322 }
323}
324
325pub trait SignatureParams: ParameterSet {
326 type W1Size: ArraySize;
327 type ZSize: ArraySize;
328 type HintSize: ArraySize;
329 type SignatureSize: ArraySize;
330
331 const GAMMA1_MINUS_BETA: u32;
332 const GAMMA2_MINUS_BETA: u32;
333
334 fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>);
335
336 fn encode_w1(t1: &Vector<Self::K>) -> EncodedW1<Self>;
337 fn decode_w1(enc: &EncodedW1<Self>) -> Vector<Self::K>;
338
339 fn encode_z(z: &Vector<Self::L>) -> EncodedZ<Self>;
340 fn decode_z(enc: &EncodedZ<Self>) -> Vector<Self::L>;
341
342 fn concat_sig(
343 c_tilde: EncodedCTilde<Self>,
344 z: EncodedZ<Self>,
345 h: EncodedHint<Self>,
346 ) -> EncodedSignature<Self>;
347 fn split_sig(
348 enc: &EncodedSignature<Self>,
349 ) -> (&EncodedCTilde<Self>, &EncodedZ<Self>, &EncodedHint<Self>);
350}
351
352pub type SignatureSize<P> = <P as SignatureParams>::SignatureSize;
353
354pub type EncodedCTilde<P> = Array<u8, <P as ParameterSet>::Lambda>;
355pub type EncodedW1<P> = Array<u8, <P as SignatureParams>::W1Size>;
356pub type EncodedZ<P> = Array<u8, <P as SignatureParams>::ZSize>;
357pub type EncodedHintIndices<P> = Array<u8, <P as ParameterSet>::Omega>;
358pub type EncodedHintCuts<P> = Array<u8, <P as ParameterSet>::K>;
359pub type EncodedHint<P> = Array<u8, <P as SignatureParams>::HintSize>;
360
361pub type EncodedSignature<P> = Array<u8, SignatureSize<P>>;
363
364impl<P> SignatureParams for P
365where
366 P: ParameterSet,
367 U32: Mul<P::W1Bits>,
369 EncodedPolynomialSize<P::W1Bits>: Mul<P::K>,
370 Prod<EncodedPolynomialSize<P::W1Bits>, P::K>:
371 ArraySize + Div<P::K, Output = EncodedPolynomialSize<P::W1Bits>> + Rem<P::K, Output = U0>,
372 P::Gamma1: Sub<U1>,
374 (Diff<P::Gamma1, U1>, P::Gamma1): RangeEncodingSize,
375 RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>: Mul<P::L>,
376 Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>: ArraySize
377 + Div<P::L, Output = RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>>
378 + Rem<P::L, Output = U0>,
379 P::Omega: Add<P::K>,
381 Sum<P::Omega, P::K>: ArraySize + Sub<P::Omega, Output = P::K>,
382 P::Lambda: Add<Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
384 Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>:
385 ArraySize
386 + Add<Sum<P::Omega, P::K>>
387 + Sub<
388 P::Lambda,
389 Output = Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>,
390 >,
391 Sum<
392 Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
393 Sum<P::Omega, P::K>,
394 >: ArraySize
395 + Sub<
396 Sum<P::Lambda, Prod<RangeEncodedPolynomialSize<Diff<P::Gamma1, U1>, P::Gamma1>, P::L>>,
397 Output = Sum<P::Omega, P::K>,
398 >,
399{
400 type W1Size = EncodedVectorSize<Self::W1Bits, P::K>;
401 type ZSize = RangeEncodedVectorSize<Diff<P::Gamma1, U1>, P::Gamma1, P::L>;
402 type HintSize = Sum<P::Omega, P::K>;
403 type SignatureSize = Sum<Sum<P::Lambda, Self::ZSize>, Self::HintSize>;
404
405 const GAMMA1_MINUS_BETA: u32 = P::Gamma1::U32 - P::BETA;
406 const GAMMA2_MINUS_BETA: u32 = P::Gamma2::U32 - P::BETA;
407
408 fn split_hint(y: &EncodedHint<Self>) -> (&EncodedHintIndices<Self>, &EncodedHintCuts<Self>) {
409 y.split_ref()
410 }
411
412 fn encode_w1(w1: &Vector<Self::K>) -> EncodedW1<Self> {
413 Encode::<Self::W1Bits>::encode(w1)
414 }
415
416 fn decode_w1(enc: &EncodedW1<Self>) -> Vector<Self::K> {
417 Encode::<Self::W1Bits>::decode(enc)
418 }
419
420 fn encode_z(z: &Vector<Self::L>) -> EncodedZ<Self> {
421 BitPack::<Diff<P::Gamma1, U1>, P::Gamma1>::pack(z)
422 }
423
424 fn decode_z(enc: &EncodedZ<Self>) -> Vector<Self::L> {
425 BitPack::<Diff<P::Gamma1, U1>, P::Gamma1>::unpack(enc)
426 }
427
428 fn concat_sig(
429 c_tilde: EncodedCTilde<P>,
430 z: EncodedZ<P>,
431 h: EncodedHint<P>,
432 ) -> EncodedSignature<P> {
433 c_tilde.concat(z).concat(h)
434 }
435
436 fn split_sig(enc: &EncodedSignature<P>) -> (&EncodedCTilde<P>, &EncodedZ<P>, &EncodedHint<P>) {
437 let (enc, h) = enc.split_ref();
438 let (c_tilde, z) = enc.split_ref();
439 (c_tilde, z, h)
440 }
441}
442
443pub trait MlDsaParams:
447 SigningKeyParams + VerifyingKeyParams + SignatureParams + Debug + Default + PartialEq + Clone
448{
449}
450
451impl<T> MlDsaParams for T where
452 T: SigningKeyParams
453 + VerifyingKeyParams
454 + SignatureParams
455 + Debug
456 + Default
457 + PartialEq
458 + Clone
459{
460}