1use core::fmt::Debug;
2use core::ops::{Div, Mul, Rem};
3use hybrid_array::{
4 Array,
5 typenum::{Gcd, Gcf, Prod, Quot, U0, U8, U32, U256, Unsigned},
6};
7use num_traits::One;
8
9use super::algebra::{Elem, Field, NttPolynomial, NttVector, Polynomial, Vector};
10use super::util::{Flatten, Truncate, Unflatten};
11
12pub trait ArraySize: hybrid_array::ArraySize + PartialEq + Debug {}
14
15impl<T> ArraySize for T where T: hybrid_array::ArraySize + PartialEq + Debug {}
16
17pub trait EncodingSize: ArraySize {
19 type EncodedPolynomialSize: ArraySize;
20 type ValueStep: ArraySize;
21 type ByteStep: ArraySize;
22}
23
24type EncodingUnit<D> = Quot<Prod<D, U8>, Gcf<D, U8>>;
25
26pub type EncodedPolynomialSize<D> = <D as EncodingSize>::EncodedPolynomialSize;
27pub type EncodedPolynomial<D> = Array<u8, EncodedPolynomialSize<D>>;
28
29impl<D> EncodingSize for D
30where
31 D: ArraySize + Mul<U8> + Gcd<U8> + Mul<U32>,
32 Prod<D, U32>: ArraySize,
33 Prod<D, U8>: Div<Gcf<D, U8>>,
34 EncodingUnit<D>: Div<D> + Div<U8>,
35 Quot<EncodingUnit<D>, D>: ArraySize,
36 Quot<EncodingUnit<D>, U8>: ArraySize,
37{
38 type EncodedPolynomialSize = Prod<D, U32>;
39 type ValueStep = Quot<EncodingUnit<D>, D>;
40 type ByteStep = Quot<EncodingUnit<D>, U8>;
41}
42
43type DecodedValue<F> = Array<Elem<F>, U256>;
44
45pub trait VectorEncodingSize<K>: EncodingSize
47where
48 K: ArraySize,
49{
50 type EncodedVectorSize: ArraySize;
51
52 fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedVector<Self, K>;
53 fn unflatten(vec: &EncodedVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K>;
54}
55
56pub type EncodedVectorSize<D, K> = <D as VectorEncodingSize<K>>::EncodedVectorSize;
57pub type EncodedVector<D, K> = Array<u8, EncodedVectorSize<D, K>>;
58
59impl<D, K> VectorEncodingSize<K> for D
60where
61 D: EncodingSize,
62 K: ArraySize,
63 D::EncodedPolynomialSize: Mul<K>,
64 Prod<D::EncodedPolynomialSize, K>:
65 ArraySize + Div<K, Output = D::EncodedPolynomialSize> + Rem<K, Output = U0>,
66{
67 type EncodedVectorSize = Prod<D::EncodedPolynomialSize, K>;
68
69 fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedVector<Self, K> {
70 polys.flatten()
71 }
72
73 fn unflatten(vec: &EncodedVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K> {
74 vec.unflatten()
75 }
76}
77
78fn byte_encode<F: Field, D: EncodingSize>(vals: &DecodedValue<F>) -> EncodedPolynomial<D> {
81 let val_step = D::ValueStep::USIZE;
82 let byte_step = D::ByteStep::USIZE;
83
84 let mut bytes = EncodedPolynomial::<D>::default();
85
86 let vc = vals.chunks(val_step);
87 let bc = bytes.chunks_mut(byte_step);
88 for (v, b) in vc.zip(bc) {
89 let mut x = 0u128;
90 for (j, vj) in v.iter().enumerate() {
91 let vj: u128 = vj.0.into();
92 x |= vj << (D::USIZE * j);
93 }
94
95 let xb = x.to_le_bytes();
96 b.copy_from_slice(&xb[..byte_step]);
97 }
98
99 bytes
100}
101
102fn byte_decode<F: Field, D: EncodingSize>(bytes: &EncodedPolynomial<D>) -> DecodedValue<F> {
105 let val_step = D::ValueStep::USIZE;
106 let byte_step = D::ByteStep::USIZE;
107 let mask = (F::Int::one() << D::USIZE) - F::Int::one();
108
109 let mut vals = DecodedValue::default();
110
111 let vc = vals.chunks_mut(val_step);
112 let bc = bytes.chunks(byte_step);
113 for (v, b) in vc.zip(bc) {
114 let mut xb = [0u8; 16];
115 xb[..byte_step].copy_from_slice(b);
116
117 let x = u128::from_le_bytes(xb);
118 for (j, vj) in v.iter_mut().enumerate() {
119 let val = F::Int::truncate(x >> (D::USIZE * j));
120 vj.0 = val & mask;
121
122 if D::USIZE == 12 {
124 vj.0 = vj.0 % F::Q;
125 }
126 }
127 }
128
129 vals
130}
131
132pub trait Encode<D: EncodingSize> {
133 type EncodedSize: ArraySize;
134 fn encode(&self) -> Array<u8, Self::EncodedSize>;
135 fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self;
136}
137
138impl<F: Field, D: EncodingSize> Encode<D> for Polynomial<F> {
139 type EncodedSize = D::EncodedPolynomialSize;
140
141 fn encode(&self) -> Array<u8, Self::EncodedSize> {
142 byte_encode::<F, D>(&self.0)
143 }
144
145 fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
146 Self(byte_decode::<F, D>(enc))
147 }
148}
149
150impl<F, D, K> Encode<D> for Vector<F, K>
151where
152 F: Field,
153 K: ArraySize,
154 D: VectorEncodingSize<K>,
155{
156 type EncodedSize = D::EncodedVectorSize;
157
158 fn encode(&self) -> Array<u8, Self::EncodedSize> {
159 let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
160 <D as VectorEncodingSize<K>>::flatten(polys)
161 }
162
163 fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
164 let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
165 Self(
166 unfold
167 .iter()
168 .map(|&x| <Polynomial<F> as Encode<D>>::decode(x))
169 .collect(),
170 )
171 }
172}
173
174impl<F: Field, D: EncodingSize> Encode<D> for NttPolynomial<F> {
175 type EncodedSize = D::EncodedPolynomialSize;
176
177 fn encode(&self) -> Array<u8, Self::EncodedSize> {
178 byte_encode::<F, D>(&self.0)
179 }
180
181 fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
182 Self(byte_decode::<F, D>(enc))
183 }
184}
185
186impl<F, D, K> Encode<D> for NttVector<F, K>
187where
188 F: Field,
189 D: VectorEncodingSize<K>,
190 K: ArraySize,
191{
192 type EncodedSize = D::EncodedVectorSize;
193
194 fn encode(&self) -> Array<u8, Self::EncodedSize> {
195 let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
196 <D as VectorEncodingSize<K>>::flatten(polys)
197 }
198
199 fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
200 let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
201 Self(
202 unfold
203 .iter()
204 .map(|&x| <NttPolynomial<F> as Encode<D>>::decode(x))
205 .collect(),
206 )
207 }
208}