Skip to main content

module_lattice/
encoding.rs

1use array::{
2    Array, Flatten, Unflatten,
3    typenum::{Gcd, Gcf, Prod, Quot, U0, U8, U32, U256, Unsigned},
4};
5use core::fmt::Debug;
6use core::ops::{Div, Mul, Rem};
7use num_traits::One;
8
9use super::algebra::{Elem, Field, NttPolynomial, NttVector, Polynomial, Vector};
10use super::truncate::Truncate;
11
12/// An array length with other useful properties
13pub trait ArraySize: array::ArraySize + PartialEq + Debug {}
14
15impl<T> ArraySize for T where T: array::ArraySize + PartialEq + Debug {}
16
17/// An integer that can describe encoded polynomials.
18pub trait EncodingSize: ArraySize {
19    /// Size of an encoded polynomial.
20    type EncodedPolynomialSize: ArraySize;
21    /// Value step.
22    type ValueStep: ArraySize;
23    /// Byte step.
24    type ByteStep: ArraySize;
25}
26
27type EncodingUnit<D> = Quot<Prod<D, U8>, Gcf<D, U8>>;
28
29/// Size of an encoded polynomial.
30pub type EncodedPolynomialSize<D> = <D as EncodingSize>::EncodedPolynomialSize;
31/// Encoded polynomial.
32pub type EncodedPolynomial<D> = Array<u8, EncodedPolynomialSize<D>>;
33
34impl<D> EncodingSize for D
35where
36    D: ArraySize + Mul<U8> + Gcd<U8> + Mul<U32>,
37    Prod<D, U32>: ArraySize,
38    Prod<D, U8>: Div<Gcf<D, U8>>,
39    EncodingUnit<D>: Div<D> + Div<U8>,
40    Quot<EncodingUnit<D>, D>: ArraySize,
41    Quot<EncodingUnit<D>, U8>: ArraySize,
42{
43    type EncodedPolynomialSize = Prod<D, U32>;
44    type ValueStep = Quot<EncodingUnit<D>, D>;
45    type ByteStep = Quot<EncodingUnit<D>, U8>;
46}
47
48/// Decoded value.
49pub type DecodedValue<F> = Array<Elem<F>, U256>;
50
51/// An integer that can describe encoded vectors.
52pub trait VectorEncodingSize<K>: EncodingSize
53where
54    K: ArraySize,
55{
56    /// Size of an encoded vector.
57    type EncodedVectorSize: ArraySize;
58
59    /// Flatten encoded polynomial array into encoded vector.
60    fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedVector<Self, K>;
61    /// Unflatten encoded vector into encoded polynomial array.
62    fn unflatten(vec: &EncodedVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K>;
63}
64
65/// Size of an encoded vector.
66pub type EncodedVectorSize<D, K> = <D as VectorEncodingSize<K>>::EncodedVectorSize;
67/// Encoded vector.
68pub type EncodedVector<D, K> = Array<u8, EncodedVectorSize<D, K>>;
69
70impl<D, K> VectorEncodingSize<K> for D
71where
72    D: EncodingSize,
73    K: ArraySize,
74    D::EncodedPolynomialSize: Mul<K>,
75    Prod<D::EncodedPolynomialSize, K>:
76        ArraySize + Div<K, Output = D::EncodedPolynomialSize> + Rem<K, Output = U0>,
77{
78    type EncodedVectorSize = Prod<D::EncodedPolynomialSize, K>;
79
80    fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedVector<Self, K> {
81        polys.flatten()
82    }
83
84    fn unflatten(vec: &EncodedVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K> {
85        vec.unflatten()
86    }
87}
88
89/// FIPS 203: Algorithm 4 `ByteEncode_d`.
90/// FIPS 204: Algorithm 16 `SimpleBitPack`.
91pub fn byte_encode<F: Field, D: EncodingSize>(vals: &DecodedValue<F>) -> EncodedPolynomial<D> {
92    let val_step = D::ValueStep::USIZE;
93    let byte_step = D::ByteStep::USIZE;
94
95    let mut bytes = EncodedPolynomial::<D>::default();
96
97    let vc = vals.chunks(val_step);
98    let bc = bytes.chunks_mut(byte_step);
99    for (v, b) in vc.zip(bc) {
100        let mut x = 0u128;
101        for (j, vj) in v.iter().enumerate() {
102            let vj: u128 = vj.0.into();
103            x |= vj << (D::USIZE * j);
104        }
105
106        let xb = x.to_le_bytes();
107        b.copy_from_slice(&xb[..byte_step]);
108    }
109
110    bytes
111}
112
113/// FIPS 203: Algorithm 5 `ByteDecode_d(F)`
114/// FIPS 204: Algorithm 18 `SimpleBitUnpack`
115pub fn byte_decode<F: Field, D: EncodingSize>(bytes: &EncodedPolynomial<D>) -> DecodedValue<F> {
116    let val_step = D::ValueStep::USIZE;
117    let byte_step = D::ByteStep::USIZE;
118    let mask = (F::Int::one() << D::USIZE) - F::Int::one();
119
120    let mut vals = DecodedValue::default();
121
122    let vc = vals.chunks_mut(val_step);
123    let bc = bytes.chunks(byte_step);
124    for (v, b) in vc.zip(bc) {
125        let mut xb = [0u8; 16];
126        xb[..byte_step].copy_from_slice(b);
127
128        let x = u128::from_le_bytes(xb);
129        for (j, vj) in v.iter_mut().enumerate() {
130            let val = F::Int::truncate(x >> (D::USIZE * j));
131            vj.0 = val & mask;
132
133            // Special case for FIPS 203
134            if D::USIZE == 12 {
135                vj.0 = vj.0 % F::Q;
136            }
137        }
138    }
139
140    vals
141}
142
143/// Encoding trait.
144pub trait Encode<D: EncodingSize> {
145    /// Size of the encoded object.
146    type EncodedSize: ArraySize;
147    /// Encode object.
148    fn encode(&self) -> Array<u8, Self::EncodedSize>;
149    /// Decode object.
150    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self;
151}
152
153impl<F: Field, D: EncodingSize> Encode<D> for Polynomial<F> {
154    type EncodedSize = D::EncodedPolynomialSize;
155
156    fn encode(&self) -> Array<u8, Self::EncodedSize> {
157        byte_encode::<F, D>(&self.0)
158    }
159
160    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
161        Self(byte_decode::<F, D>(enc))
162    }
163}
164
165impl<F, D, K> Encode<D> for Vector<F, K>
166where
167    F: Field,
168    K: ArraySize,
169    D: VectorEncodingSize<K>,
170{
171    type EncodedSize = D::EncodedVectorSize;
172
173    fn encode(&self) -> Array<u8, Self::EncodedSize> {
174        let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
175        <D as VectorEncodingSize<K>>::flatten(polys)
176    }
177
178    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
179        let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
180        Self(
181            unfold
182                .iter()
183                .map(|&x| <Polynomial<F> as Encode<D>>::decode(x))
184                .collect(),
185        )
186    }
187}
188
189impl<F: Field, D: EncodingSize> Encode<D> for NttPolynomial<F> {
190    type EncodedSize = D::EncodedPolynomialSize;
191
192    fn encode(&self) -> Array<u8, Self::EncodedSize> {
193        byte_encode::<F, D>(&self.0)
194    }
195
196    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
197        Self(byte_decode::<F, D>(enc))
198    }
199}
200
201impl<F, D, K> Encode<D> for NttVector<F, K>
202where
203    F: Field,
204    D: VectorEncodingSize<K>,
205    K: ArraySize,
206{
207    type EncodedSize = D::EncodedVectorSize;
208
209    fn encode(&self) -> Array<u8, Self::EncodedSize> {
210        let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
211        <D as VectorEncodingSize<K>>::flatten(polys)
212    }
213
214    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
215        let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
216        Self(
217            unfold
218                .iter()
219                .map(|&x| <NttPolynomial<F> as Encode<D>>::decode(x))
220                .collect(),
221        )
222    }
223}