ml_dsa/module_lattice/
encode.rs

1use core::fmt::Debug;
2use core::ops::{Div, Mul, Rem};
3use hybrid_array::{
4    Array,
5    typenum::{Gcd, Gcf, Prod, Quot, U0, U8, U32, U256, Unsigned},
6};
7use num_traits::One;
8
9use super::algebra::{Elem, Field, NttPolynomial, NttVector, Polynomial, Vector};
10use super::util::{Flatten, Truncate, Unflatten};
11
12/// An array length with other useful properties
13pub trait ArraySize: hybrid_array::ArraySize + PartialEq + Debug {}
14
15impl<T> ArraySize for T where T: hybrid_array::ArraySize + PartialEq + Debug {}
16
17/// An integer that can describe encoded polynomials.
18pub trait EncodingSize: ArraySize {
19    type EncodedPolynomialSize: ArraySize;
20    type ValueStep: ArraySize;
21    type ByteStep: ArraySize;
22}
23
24type EncodingUnit<D> = Quot<Prod<D, U8>, Gcf<D, U8>>;
25
26pub type EncodedPolynomialSize<D> = <D as EncodingSize>::EncodedPolynomialSize;
27pub type EncodedPolynomial<D> = Array<u8, EncodedPolynomialSize<D>>;
28
29impl<D> EncodingSize for D
30where
31    D: ArraySize + Mul<U8> + Gcd<U8> + Mul<U32>,
32    Prod<D, U32>: ArraySize,
33    Prod<D, U8>: Div<Gcf<D, U8>>,
34    EncodingUnit<D>: Div<D> + Div<U8>,
35    Quot<EncodingUnit<D>, D>: ArraySize,
36    Quot<EncodingUnit<D>, U8>: ArraySize,
37{
38    type EncodedPolynomialSize = Prod<D, U32>;
39    type ValueStep = Quot<EncodingUnit<D>, D>;
40    type ByteStep = Quot<EncodingUnit<D>, U8>;
41}
42
43type DecodedValue<F> = Array<Elem<F>, U256>;
44
45/// An integer that can describe encoded vectors.
46pub trait VectorEncodingSize<K>: EncodingSize
47where
48    K: ArraySize,
49{
50    type EncodedVectorSize: ArraySize;
51
52    fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedVector<Self, K>;
53    fn unflatten(vec: &EncodedVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K>;
54}
55
56pub type EncodedVectorSize<D, K> = <D as VectorEncodingSize<K>>::EncodedVectorSize;
57pub type EncodedVector<D, K> = Array<u8, EncodedVectorSize<D, K>>;
58
59impl<D, K> VectorEncodingSize<K> for D
60where
61    D: EncodingSize,
62    K: ArraySize,
63    D::EncodedPolynomialSize: Mul<K>,
64    Prod<D::EncodedPolynomialSize, K>:
65        ArraySize + Div<K, Output = D::EncodedPolynomialSize> + Rem<K, Output = U0>,
66{
67    type EncodedVectorSize = Prod<D::EncodedPolynomialSize, K>;
68
69    fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedVector<Self, K> {
70        polys.flatten()
71    }
72
73    fn unflatten(vec: &EncodedVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K> {
74        vec.unflatten()
75    }
76}
77
78// FIPS 203: Algorithm 4 ByteEncode_d
79// FIPS 204: Algorithm 16 SimpleBitPack
80fn byte_encode<F: Field, D: EncodingSize>(vals: &DecodedValue<F>) -> EncodedPolynomial<D> {
81    let val_step = D::ValueStep::USIZE;
82    let byte_step = D::ByteStep::USIZE;
83
84    let mut bytes = EncodedPolynomial::<D>::default();
85
86    let vc = vals.chunks(val_step);
87    let bc = bytes.chunks_mut(byte_step);
88    for (v, b) in vc.zip(bc) {
89        let mut x = 0u128;
90        for (j, vj) in v.iter().enumerate() {
91            let vj: u128 = vj.0.into();
92            x |= vj << (D::USIZE * j);
93        }
94
95        let xb = x.to_le_bytes();
96        b.copy_from_slice(&xb[..byte_step]);
97    }
98
99    bytes
100}
101
102// FIPS 203: Algorithm 5 ByteDecode_d(F)
103// FIPS 204: Algorithm 18 SimpleBitUnpack
104fn byte_decode<F: Field, D: EncodingSize>(bytes: &EncodedPolynomial<D>) -> DecodedValue<F> {
105    let val_step = D::ValueStep::USIZE;
106    let byte_step = D::ByteStep::USIZE;
107    let mask = (F::Int::one() << D::USIZE) - F::Int::one();
108
109    let mut vals = DecodedValue::default();
110
111    let vc = vals.chunks_mut(val_step);
112    let bc = bytes.chunks(byte_step);
113    for (v, b) in vc.zip(bc) {
114        let mut xb = [0u8; 16];
115        xb[..byte_step].copy_from_slice(b);
116
117        let x = u128::from_le_bytes(xb);
118        for (j, vj) in v.iter_mut().enumerate() {
119            let val = F::Int::truncate(x >> (D::USIZE * j));
120            vj.0 = val & mask;
121
122            // Special case for FIPS 203
123            if D::USIZE == 12 {
124                vj.0 = vj.0 % F::Q;
125            }
126        }
127    }
128
129    vals
130}
131
132pub trait Encode<D: EncodingSize> {
133    type EncodedSize: ArraySize;
134    fn encode(&self) -> Array<u8, Self::EncodedSize>;
135    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self;
136}
137
138impl<F: Field, D: EncodingSize> Encode<D> for Polynomial<F> {
139    type EncodedSize = D::EncodedPolynomialSize;
140
141    fn encode(&self) -> Array<u8, Self::EncodedSize> {
142        byte_encode::<F, D>(&self.0)
143    }
144
145    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
146        Self(byte_decode::<F, D>(enc))
147    }
148}
149
150impl<F, D, K> Encode<D> for Vector<F, K>
151where
152    F: Field,
153    K: ArraySize,
154    D: VectorEncodingSize<K>,
155{
156    type EncodedSize = D::EncodedVectorSize;
157
158    fn encode(&self) -> Array<u8, Self::EncodedSize> {
159        let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
160        <D as VectorEncodingSize<K>>::flatten(polys)
161    }
162
163    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
164        let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
165        Self(
166            unfold
167                .iter()
168                .map(|&x| <Polynomial<F> as Encode<D>>::decode(x))
169                .collect(),
170        )
171    }
172}
173
174impl<F: Field, D: EncodingSize> Encode<D> for NttPolynomial<F> {
175    type EncodedSize = D::EncodedPolynomialSize;
176
177    fn encode(&self) -> Array<u8, Self::EncodedSize> {
178        byte_encode::<F, D>(&self.0)
179    }
180
181    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
182        Self(byte_decode::<F, D>(enc))
183    }
184}
185
186impl<F, D, K> Encode<D> for NttVector<F, K>
187where
188    F: Field,
189    D: VectorEncodingSize<K>,
190    K: ArraySize,
191{
192    type EncodedSize = D::EncodedVectorSize;
193
194    fn encode(&self) -> Array<u8, Self::EncodedSize> {
195        let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
196        <D as VectorEncodingSize<K>>::flatten(polys)
197    }
198
199    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
200        let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
201        Self(
202            unfold
203                .iter()
204                .map(|&x| <NttPolynomial<F> as Encode<D>>::decode(x))
205                .collect(),
206        )
207    }
208}