Skip to main content

module_lattice/
encoding.rs

1use array::{
2    Array, Flatten, Unflatten,
3    typenum::{Gcd, Gcf, Prod, Quot, U0, U8, U32, U256, Unsigned},
4};
5use core::fmt::Debug;
6use core::ops::{Div, Mul, Rem};
7use num_traits::One;
8
9use super::algebra::{Elem, Field, NttPolynomial, NttVector, Polynomial, Vector};
10use super::truncate::Truncate;
11
12/// An array length with other useful properties
13pub trait ArraySize: array::ArraySize + PartialEq + Debug {}
14
15impl<T> ArraySize for T where T: array::ArraySize + PartialEq + Debug {}
16
17/// An integer that can describe encoded polynomials.
18pub trait EncodingSize: ArraySize {
19    /// Size of an encoded polynomial.
20    type EncodedPolynomialSize: ArraySize;
21    /// Value step.
22    type ValueStep: ArraySize;
23    /// Byte step.
24    type ByteStep: ArraySize;
25}
26
27type EncodingUnit<D> = Quot<Prod<D, U8>, Gcf<D, U8>>;
28
29/// Size of an encoded polynomial.
30pub type EncodedPolynomialSize<D> = <D as EncodingSize>::EncodedPolynomialSize;
31/// Encoded polynomial.
32pub type EncodedPolynomial<D> = Array<u8, EncodedPolynomialSize<D>>;
33
34impl<D> EncodingSize for D
35where
36    D: ArraySize + Mul<U8> + Gcd<U8> + Mul<U32>,
37    Prod<D, U32>: ArraySize,
38    Prod<D, U8>: Div<Gcf<D, U8>>,
39    EncodingUnit<D>: Div<D> + Div<U8>,
40    Quot<EncodingUnit<D>, D>: ArraySize,
41    Quot<EncodingUnit<D>, U8>: ArraySize,
42{
43    type EncodedPolynomialSize = Prod<D, U32>;
44    type ValueStep = Quot<EncodingUnit<D>, D>;
45    type ByteStep = Quot<EncodingUnit<D>, U8>;
46}
47
48/// Decoded value.
49pub type DecodedValue<F> = Array<Elem<F>, U256>;
50
51/// An integer that can describe encoded vectors.
52pub trait VectorEncodingSize<K>: EncodingSize
53where
54    K: ArraySize,
55{
56    /// Size of an encoded vector.
57    type EncodedVectorSize: ArraySize;
58
59    /// Flatten encoded polynomial array into encoded vector.
60    fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedVector<Self, K>;
61    /// Unflatten encoded vector into encoded polynomial array.
62    fn unflatten(vec: &EncodedVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K>;
63}
64
65/// Size of an encoded vector.
66pub type EncodedVectorSize<D, K> = <D as VectorEncodingSize<K>>::EncodedVectorSize;
67/// Encoded vector.
68pub type EncodedVector<D, K> = Array<u8, EncodedVectorSize<D, K>>;
69
70impl<D, K> VectorEncodingSize<K> for D
71where
72    D: EncodingSize,
73    K: ArraySize,
74    D::EncodedPolynomialSize: Mul<K>,
75    Prod<D::EncodedPolynomialSize, K>:
76        ArraySize + Div<K, Output = D::EncodedPolynomialSize> + Rem<K, Output = U0>,
77{
78    type EncodedVectorSize = Prod<D::EncodedPolynomialSize, K>;
79
80    fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedVector<Self, K> {
81        polys.flatten()
82    }
83
84    fn unflatten(vec: &EncodedVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K> {
85        vec.unflatten()
86    }
87}
88
89/// FIPS 203: Algorithm 4 `ByteEncode_d`.
90/// FIPS 204: Algorithm 16 `SimpleBitPack`.
91pub fn byte_encode<F: Field, D: EncodingSize>(vals: &DecodedValue<F>) -> EncodedPolynomial<D> {
92    let val_step = D::ValueStep::USIZE;
93    let byte_step = D::ByteStep::USIZE;
94
95    let mut bytes = EncodedPolynomial::<D>::default();
96
97    let vc = vals.chunks(val_step);
98    let bc = bytes.chunks_mut(byte_step);
99    for (v, b) in vc.zip(bc) {
100        let mut x = 0u128;
101        for (j, vj) in v.iter().enumerate() {
102            let vj: u128 = vj.0.into();
103            x |= vj << (D::USIZE * j);
104        }
105
106        let xb = x.to_le_bytes();
107        b.copy_from_slice(&xb[..byte_step]);
108    }
109
110    bytes
111}
112
113/// FIPS 203: Algorithm 5 `ByteDecode_d(F)`
114/// FIPS 204: Algorithm 18 `SimpleBitUnpack`
115pub fn byte_decode<F: Field, D: EncodingSize>(bytes: &EncodedPolynomial<D>) -> DecodedValue<F> {
116    let val_step = D::ValueStep::USIZE;
117    let byte_step = D::ByteStep::USIZE;
118    let mask = (F::Int::one() << D::USIZE) - F::Int::one();
119
120    let mut vals = DecodedValue::default();
121
122    let vc = vals.chunks_mut(val_step);
123    let bc = bytes.chunks(byte_step);
124    for (v, b) in vc.zip(bc) {
125        let mut xb = [0u8; 16];
126        xb[..byte_step].copy_from_slice(b);
127
128        let x = u128::from_le_bytes(xb);
129        for (j, vj) in v.iter_mut().enumerate() {
130            let val = F::Int::truncate(x >> (D::USIZE * j));
131            vj.0 = val & mask;
132
133            // Special case for FIPS 203. For 12-bit values (max 4095) with Q = 3329,
134            // the masked value is always in [0, 2Q), so `small_reduce` is exact and
135            // avoids the hardware UDIV that `% F::Q` would emit.
136            if D::USIZE == 12 {
137                vj.0 = F::small_reduce(vj.0);
138            }
139        }
140    }
141
142    vals
143}
144
145/// Encoding trait.
146pub trait Encode<D: EncodingSize> {
147    /// Size of the encoded object.
148    type EncodedSize: ArraySize;
149    /// Encode object.
150    fn encode(&self) -> Array<u8, Self::EncodedSize>;
151    /// Decode object.
152    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self;
153}
154
155impl<F: Field, D: EncodingSize> Encode<D> for Polynomial<F> {
156    type EncodedSize = D::EncodedPolynomialSize;
157
158    fn encode(&self) -> Array<u8, Self::EncodedSize> {
159        byte_encode::<F, D>(&self.0)
160    }
161
162    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
163        Self(byte_decode::<F, D>(enc))
164    }
165}
166
167impl<F, D, K> Encode<D> for Vector<F, K>
168where
169    F: Field,
170    K: ArraySize,
171    D: VectorEncodingSize<K>,
172{
173    type EncodedSize = D::EncodedVectorSize;
174
175    fn encode(&self) -> Array<u8, Self::EncodedSize> {
176        let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
177        <D as VectorEncodingSize<K>>::flatten(polys)
178    }
179
180    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
181        let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
182        Self(
183            unfold
184                .iter()
185                .map(|&x| <Polynomial<F> as Encode<D>>::decode(x))
186                .collect(),
187        )
188    }
189}
190
191impl<F: Field, D: EncodingSize> Encode<D> for NttPolynomial<F> {
192    type EncodedSize = D::EncodedPolynomialSize;
193
194    fn encode(&self) -> Array<u8, Self::EncodedSize> {
195        byte_encode::<F, D>(&self.0)
196    }
197
198    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
199        Self(byte_decode::<F, D>(enc))
200    }
201}
202
203impl<F, D, K> Encode<D> for NttVector<F, K>
204where
205    F: Field,
206    D: VectorEncodingSize<K>,
207    K: ArraySize,
208{
209    type EncodedSize = D::EncodedVectorSize;
210
211    fn encode(&self) -> Array<u8, Self::EncodedSize> {
212        let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
213        <D as VectorEncodingSize<K>>::flatten(polys)
214    }
215
216    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
217        let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
218        Self(
219            unfold
220                .iter()
221                .map(|&x| <NttPolynomial<F> as Encode<D>>::decode(x))
222                .collect(),
223        )
224    }
225}