ml_dsa/
encode.rs

1use crate::module_lattice::encode::{ArraySize, Encode, EncodingSize, VectorEncodingSize};
2use core::ops::Add;
3use hybrid_array::{
4    Array,
5    typenum::{Len, Length, Sum, Unsigned},
6};
7
8use crate::algebra::{Elem, Polynomial, Vector};
9
10/// A pair of integers that describes a range
11pub trait RangeEncodingSize {
12    type Min: Unsigned;
13    type Max: Unsigned;
14    type EncodingSize: EncodingSize;
15}
16
17impl<A, B> RangeEncodingSize for (A, B)
18where
19    A: Unsigned + Add<B>,
20    B: Unsigned,
21    Sum<A, B>: Len,
22    Length<Sum<A, B>>: EncodingSize,
23{
24    type Min = A;
25    type Max = B;
26    type EncodingSize = Length<Sum<A, B>>;
27}
28
29pub type RangeMin<A, B> = <(A, B) as RangeEncodingSize>::Min;
30pub type RangeMax<A, B> = <(A, B) as RangeEncodingSize>::Max;
31pub type RangeEncodingBits<A, B> = <(A, B) as RangeEncodingSize>::EncodingSize;
32pub type RangeEncodedPolynomialSize<A, B> =
33    <RangeEncodingBits<A, B> as EncodingSize>::EncodedPolynomialSize;
34pub type RangeEncodedPolynomial<A, B> = Array<u8, RangeEncodedPolynomialSize<A, B>>;
35pub type RangeEncodedVectorSize<A, B, K> =
36    <RangeEncodingBits<A, B> as VectorEncodingSize<K>>::EncodedVectorSize;
37pub type RangeEncodedVector<A, B, K> = Array<u8, RangeEncodedVectorSize<A, B, K>>;
38
39/// `BitPack` represents range-encoding logic
40pub trait BitPack<A, B> {
41    type PackedSize: ArraySize;
42    fn pack(&self) -> Array<u8, Self::PackedSize>;
43    fn unpack(enc: &Array<u8, Self::PackedSize>) -> Self;
44}
45
46impl<A, B> BitPack<A, B> for Polynomial
47where
48    (A, B): RangeEncodingSize,
49{
50    type PackedSize = RangeEncodedPolynomialSize<A, B>;
51
52    // Algorithm 17 BitPack
53    fn pack(&self) -> RangeEncodedPolynomial<A, B> {
54        let a = Elem::new(RangeMin::<A, B>::U32);
55        let b = Elem::new(RangeMax::<A, B>::U32);
56
57        let to_encode = Self::new(
58            self.0
59                .iter()
60                .map(|w| {
61                    assert!(w.0 <= b.0 || w.0 >= (-a).0);
62                    b - *w
63                })
64                .collect(),
65        );
66        Encode::<RangeEncodingBits<A, B>>::encode(&to_encode)
67    }
68
69    // Algorithm 17 BitUnPack
70    fn unpack(enc: &RangeEncodedPolynomial<A, B>) -> Self {
71        let a = Elem::new(RangeMin::<A, B>::U32);
72        let b = Elem::new(RangeMax::<A, B>::U32);
73        let mut decoded: Self = Encode::<RangeEncodingBits<A, B>>::decode(enc);
74
75        for z in &mut decoded.0 {
76            assert!(z.0 <= (a + b).0);
77            *z = b - *z;
78        }
79
80        decoded
81    }
82}
83
84impl<K, A, B> BitPack<A, B> for Vector<K>
85where
86    K: ArraySize,
87    (A, B): RangeEncodingSize,
88    RangeEncodingBits<A, B>: VectorEncodingSize<K>,
89{
90    type PackedSize = RangeEncodedVectorSize<A, B, K>;
91
92    fn pack(&self) -> RangeEncodedVector<A, B, K> {
93        let polys = self.0.iter().map(|x| BitPack::<A, B>::pack(x)).collect();
94        RangeEncodingBits::<A, B>::flatten(polys)
95    }
96
97    fn unpack(enc: &RangeEncodedVector<A, B, K>) -> Self {
98        let unfold = RangeEncodingBits::<A, B>::unflatten(enc);
99        Self(
100            unfold
101                .into_iter()
102                .map(|x| <Polynomial as BitPack<A, B>>::unpack(x))
103                .collect(),
104        )
105    }
106}
107
108#[cfg(test)]
109pub(crate) mod test {
110    use super::*;
111    use crate::module_lattice::encode::*;
112    use core::ops::Rem;
113    use hybrid_array::typenum::{
114        U1, U2, U3, U4, U6, U7, U8, U9, U10, U13, U17, U19,
115        marker_traits::Zero,
116        operator_aliases::{Diff, Mod, Shleft},
117    };
118    use rand::Rng;
119
120    use crate::algebra::*;
121
122    // A helper trait to construct larger arrays by repeating smaller ones
123    trait Repeat<T: Clone, D: ArraySize> {
124        fn repeat(&self) -> Array<T, D>;
125    }
126
127    impl<T, N, D> Repeat<T, D> for Array<T, N>
128    where
129        N: ArraySize,
130        T: Clone,
131        D: ArraySize + Rem<N>,
132        Mod<D, N>: Zero,
133    {
134        #[allow(clippy::integer_division_remainder_used)]
135        fn repeat(&self) -> Array<T, D> {
136            Array::from_fn(|i| self[i % N::USIZE].clone())
137        }
138    }
139
140    #[allow(clippy::integer_division_remainder_used)]
141    fn simple_bit_pack_test<D>(b: u32, decoded: &Polynomial, encoded: &EncodedPolynomial<D>)
142    where
143        D: EncodingSize,
144    {
145        // Test known answer
146        let actual_encoded = Encode::<D>::encode(decoded);
147        assert_eq!(actual_encoded, *encoded);
148
149        let actual_decoded: Polynomial = Encode::<D>::decode(encoded);
150        assert_eq!(actual_decoded, *decoded);
151
152        // Test random decode/encode and encode/decode round trips
153        let mut rng = rand::rngs::OsRng;
154        let decoded = Polynomial::new(Array::from_fn(|_| {
155            let x: u32 = rng.r#gen();
156            Elem::new(x % (b + 1))
157        }));
158
159        let actual_encoded = Encode::<D>::encode(&decoded);
160        let actual_decoded: Polynomial = Encode::<D>::decode(&actual_encoded);
161        assert_eq!(actual_decoded, decoded);
162
163        let actual_reencoded = Encode::<D>::encode(&decoded);
164        assert_eq!(actual_reencoded, actual_encoded);
165    }
166
167    #[test]
168    fn simple_bit_pack() {
169        // Use a standard test pattern across all the cases
170        let decoded = Polynomial::new(
171            Array::<_, U8>([
172                Elem::new(0),
173                Elem::new(1),
174                Elem::new(2),
175                Elem::new(3),
176                Elem::new(4),
177                Elem::new(5),
178                Elem::new(6),
179                Elem::new(7),
180            ])
181            .repeat(),
182        );
183
184        // 10 bits
185        // <-> b = 2^{bitlen(q-1) - d} - 1 = 2^10 - 1
186        let b = (1 << 10) - 1;
187        let encoded: EncodedPolynomial<U10> =
188            Array::<_, U10>([0x00, 0x04, 0x20, 0xc0, 0x00, 0x04, 0x14, 0x60, 0xc0, 0x01]).repeat();
189        simple_bit_pack_test::<U10>(b, &decoded, &encoded);
190
191        // 8 bits
192        // gamma2 = (q - 1) / 88
193        // b = (q - 1) / (2 gamma2) - 1 = 175 = 2^8 - 81
194        let b = (1 << 8) - 81;
195        let encoded: EncodedPolynomial<U8> =
196            Array::<_, U8>([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]).repeat();
197        simple_bit_pack_test::<U8>(b, &decoded, &encoded);
198
199        // 6 bits
200        // gamma2 = (q - 1) / 32
201        // b = (q - 1) / (2 gamma2) - 1 = 63 = 2^6 - 1
202        let b = (1 << 6) - 1;
203        let encoded: EncodedPolynomial<U6> =
204            Array::<_, U6>([0x40, 0x20, 0x0c, 0x44, 0x61, 0x1c]).repeat();
205        simple_bit_pack_test::<U6>(b, &decoded, &encoded);
206    }
207
208    #[allow(clippy::integer_division_remainder_used)]
209    fn bit_pack_test<A, B>(decoded: &Polynomial, encoded: &RangeEncodedPolynomial<A, B>)
210    where
211        A: Unsigned,
212        B: Unsigned,
213        (A, B): RangeEncodingSize,
214    {
215        let a = Elem::new(A::U32);
216        let b = Elem::new(B::U32);
217
218        // Test known answer
219        let actual_encoded = BitPack::<A, B>::pack(decoded);
220        assert_eq!(actual_encoded, *encoded);
221
222        let actual_decoded: Polynomial = BitPack::<A, B>::unpack(encoded);
223        assert_eq!(actual_decoded, *decoded);
224
225        // Test random decode/encode and encode/decode round trips
226        let mut rng = rand::rngs::OsRng;
227        let decoded = Polynomial::new(Array::from_fn(|_| {
228            let mut x: u32 = rng.r#gen();
229            x %= a.0 + b.0;
230            b - Elem::new(x)
231        }));
232
233        let actual_encoded = BitPack::<A, B>::pack(&decoded);
234        let actual_decoded: Polynomial = BitPack::<A, B>::unpack(&actual_encoded);
235        assert_eq!(actual_decoded, decoded);
236
237        let actual_reencoded = BitPack::<A, B>::pack(&decoded);
238        assert_eq!(actual_reencoded, actual_encoded);
239    }
240
241    #[test]
242    fn bit_pack() {
243        type D = U13;
244        type Pow2D = Shleft<U1, D>;
245        type Pow2DMin = Diff<Pow2D, U1>;
246
247        type Gamma1Lo = Shleft<U1, U17>;
248        type Gamma1LoMin = Diff<Gamma1Lo, U1>;
249
250        type Gamma1Hi = Shleft<U1, U19>;
251        type Gamma1HiMin = Diff<Gamma1Hi, U1>;
252
253        // Use a standard test pattern across all the cases
254        // (We can't use -2 because the eta=2 case doesn't actually cover -2)
255        let decoded = Polynomial::new(
256            Array::<_, U4>([
257                Elem::new(BaseField::Q - 1),
258                Elem::new(0),
259                Elem::new(1),
260                Elem::new(2),
261            ])
262            .repeat(),
263        );
264
265        // BitPack(_, eta, eta), eta = 2, 4
266        let encoded: RangeEncodedPolynomial<U2, U2> = Array::<_, U3>([0x53, 0x30, 0x05]).repeat();
267        bit_pack_test::<U2, U2>(&decoded, &encoded);
268
269        let encoded: RangeEncodedPolynomial<U4, U4> = Array::<_, U2>([0x45, 0x23]).repeat();
270        bit_pack_test::<U4, U4>(&decoded, &encoded);
271
272        // BitPack(_, 2^d - 1, 2^d), d = 13
273        let encoded: RangeEncodedPolynomial<Pow2DMin, Pow2D> =
274            Array::<_, U7>([0x01, 0x20, 0x00, 0xf8, 0xff, 0xf9, 0x7f]).repeat();
275        bit_pack_test::<Pow2DMin, Pow2D>(&decoded, &encoded);
276
277        // BitPack(_, gamma1 - 1, gamma1), gamma1 = 2^17, 2^19
278        let encoded: RangeEncodedPolynomial<Gamma1LoMin, Gamma1Lo> =
279            Array::<_, U9>([0x01, 0x00, 0x02, 0x00, 0xf8, 0xff, 0x9f, 0xff, 0x7f]).repeat();
280        bit_pack_test::<Gamma1LoMin, Gamma1Lo>(&decoded, &encoded);
281
282        let encoded: RangeEncodedPolynomial<Gamma1HiMin, Gamma1Hi> =
283            Array::<_, U10>([0x00, 0x00, 0xf8, 0xff, 0x7f, 0xfe, 0xff, 0xd7, 0xff, 0x7f]).repeat();
284        bit_pack_test::<Gamma1Hi, Gamma1HiMin>(&decoded, &encoded);
285    }
286}