1use core::fmt::Debug;
14use core::ops::{Add, Div, Mul, Rem, Sub};
15
16use hybrid_array::{
17 typenum::{
18 operator_aliases::{Gcf, Prod, Quot, Sum},
19 type_operators::Gcd,
20 Const, ToUInt, U0, U12, U16, U2, U3, U32, U384, U4, U6, U64, U8,
21 },
22 Array,
23};
24
25use crate::algebra::{FieldElement, NttVector};
26use crate::encode::Encode;
27use crate::util::{Flatten, Unflatten, B32};
28
29pub trait ArraySize: hybrid_array::ArraySize + PartialEq + Debug {}
31
32impl<T> ArraySize for T where T: hybrid_array::ArraySize + PartialEq + Debug {}
33
34pub trait EncodingSize: ArraySize {
36 type EncodedPolynomialSize: ArraySize;
37 type ValueStep: ArraySize;
38 type ByteStep: ArraySize;
39}
40
41type EncodingUnit<D> = Quot<Prod<D, U8>, Gcf<D, U8>>;
42
43pub type EncodedPolynomialSize<D> = <D as EncodingSize>::EncodedPolynomialSize;
44pub type EncodedPolynomial<D> = Array<u8, EncodedPolynomialSize<D>>;
45
46impl<D> EncodingSize for D
47where
48 D: ArraySize + Mul<U8> + Gcd<U8> + Mul<U32>,
49 Prod<D, U32>: ArraySize,
50 Prod<D, U8>: Div<Gcf<D, U8>>,
51 EncodingUnit<D>: Div<D> + Div<U8>,
52 Quot<EncodingUnit<D>, D>: ArraySize,
53 Quot<EncodingUnit<D>, U8>: ArraySize,
54{
55 type EncodedPolynomialSize = Prod<D, U32>;
56 type ValueStep = Quot<EncodingUnit<D>, D>;
57 type ByteStep = Quot<EncodingUnit<D>, U8>;
58}
59
60pub trait VectorEncodingSize<K>: EncodingSize
62where
63 K: ArraySize,
64{
65 type EncodedPolynomialVectorSize: ArraySize;
66
67 fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedPolynomialVector<Self, K>;
68 fn unflatten(vec: &EncodedPolynomialVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K>;
69}
70
71pub type EncodedPolynomialVectorSize<D, K> =
72 <D as VectorEncodingSize<K>>::EncodedPolynomialVectorSize;
73pub type EncodedPolynomialVector<D, K> = Array<u8, EncodedPolynomialVectorSize<D, K>>;
74
75impl<D, K> VectorEncodingSize<K> for D
76where
77 D: EncodingSize,
78 K: ArraySize,
79 D::EncodedPolynomialSize: Mul<K>,
80 Prod<D::EncodedPolynomialSize, K>:
81 ArraySize + Div<K, Output = D::EncodedPolynomialSize> + Rem<K, Output = U0>,
82{
83 type EncodedPolynomialVectorSize = Prod<D::EncodedPolynomialSize, K>;
84
85 fn flatten(polys: Array<EncodedPolynomial<Self>, K>) -> EncodedPolynomialVector<Self, K> {
86 polys.flatten()
87 }
88
89 fn unflatten(vec: &EncodedPolynomialVector<Self, K>) -> Array<&EncodedPolynomial<Self>, K> {
90 vec.unflatten()
91 }
92}
93
94pub trait CbdSamplingSize: ArraySize {
96 type SampleSize: EncodingSize;
97 type OnesSize: ArraySize;
98 const ONES: Array<FieldElement, Self::OnesSize>;
99}
100
101#[allow(clippy::cast_possible_truncation)]
110const fn ones_array<const B: usize, const N: usize, U>() -> Array<FieldElement, U>
111where
112 U: ArraySize<ArrayType<FieldElement> = [FieldElement; N]>,
113 Const<N>: ToUInt<Output = U>,
114{
115 let max = 1 << B;
116 let mut out = [FieldElement(0); N];
117 let mut x = 0usize;
118 while x < max {
119 let mut y = 0usize;
120 #[allow(clippy::integer_division_remainder_used)]
121 while y < max {
122 let x_ones = x.count_ones() as u16;
123 let y_ones = y.count_ones() as u16;
124 let i = x + (y << B);
125 out[i] = FieldElement((x_ones + FieldElement::Q - y_ones) % FieldElement::Q);
126
127 y += 1;
128 }
129 x += 1;
130 }
131 Array(out)
132}
133
134impl CbdSamplingSize for U2 {
135 type SampleSize = U4;
136 type OnesSize = U16;
137 const ONES: Array<FieldElement, U16> = ones_array::<2, 16, U16>();
138}
139
140impl CbdSamplingSize for U3 {
141 type SampleSize = U6;
142 type OnesSize = U64;
143 const ONES: Array<FieldElement, U64> = ones_array::<3, 64, U64>();
144}
145
146pub trait ParameterSet: Default + Clone + Debug + PartialEq {
149 type K: ArraySize;
151
152 type Eta1: CbdSamplingSize;
155
156 type Eta2: CbdSamplingSize;
159
160 type Du: VectorEncodingSize<Self::K>;
162
163 type Dv: EncodingSize;
165}
166
167type EncodedUSize<P> = EncodedPolynomialVectorSize<<P as ParameterSet>::Du, <P as ParameterSet>::K>;
168type EncodedVSize<P> = EncodedPolynomialSize<<P as ParameterSet>::Dv>;
169
170type EncodedU<P> = Array<u8, EncodedUSize<P>>;
171type EncodedV<P> = Array<u8, EncodedVSize<P>>;
172
173pub trait PkeParams: ParameterSet {
175 type NttVectorSize: ArraySize;
176 type EncryptionKeySize: ArraySize;
177 type CiphertextSize: ArraySize;
178
179 fn encode_u12(p: &NttVector<Self::K>) -> EncodedNttVector<Self>;
180 fn decode_u12(v: &EncodedNttVector<Self>) -> NttVector<Self::K>;
181
182 fn concat_ct(u: EncodedU<Self>, v: EncodedV<Self>) -> EncodedCiphertext<Self>;
183 fn split_ct(ct: &EncodedCiphertext<Self>) -> (&EncodedU<Self>, &EncodedV<Self>);
184
185 fn concat_ek(t_hat: EncodedNttVector<Self>, rho: B32) -> EncodedEncryptionKey<Self>;
186 fn split_ek(ek: &EncodedEncryptionKey<Self>) -> (&EncodedNttVector<Self>, &B32);
187}
188
189pub type EncodedNttVector<P> = Array<u8, <P as PkeParams>::NttVectorSize>;
190pub type EncodedDecryptionKey<P> = Array<u8, <P as PkeParams>::NttVectorSize>;
191pub type EncodedEncryptionKey<P> = Array<u8, <P as PkeParams>::EncryptionKeySize>;
192pub type EncodedCiphertext<P> = Array<u8, <P as PkeParams>::CiphertextSize>;
193
194impl<P> PkeParams for P
195where
196 P: ParameterSet,
197 U384: Mul<P::K>,
198 Prod<U384, P::K>: ArraySize + Add<U32> + Div<P::K, Output = U384> + Rem<P::K, Output = U0>,
199 EncodedUSize<P>: Add<EncodedVSize<P>>,
200 Sum<EncodedUSize<P>, EncodedVSize<P>>:
201 ArraySize + Sub<EncodedUSize<P>, Output = EncodedVSize<P>>,
202 EncodedPolynomialVectorSize<U12, P::K>: Add<U32>,
203 Sum<EncodedPolynomialVectorSize<U12, P::K>, U32>:
204 ArraySize + Sub<EncodedPolynomialVectorSize<U12, P::K>, Output = U32>,
205{
206 type NttVectorSize = EncodedPolynomialVectorSize<U12, P::K>;
207 type EncryptionKeySize = Sum<Self::NttVectorSize, U32>;
208 type CiphertextSize = Sum<EncodedUSize<P>, EncodedVSize<P>>;
209
210 fn encode_u12(p: &NttVector<Self::K>) -> EncodedNttVector<Self> {
211 Encode::<U12>::encode(p)
212 }
213
214 fn decode_u12(v: &EncodedNttVector<Self>) -> NttVector<Self::K> {
215 Encode::<U12>::decode(v)
216 }
217
218 fn concat_ct(u: EncodedU<Self>, v: EncodedV<Self>) -> EncodedCiphertext<Self> {
219 u.concat(v)
220 }
221
222 fn split_ct(ct: &EncodedCiphertext<Self>) -> (&EncodedU<Self>, &EncodedV<Self>) {
223 ct.split_ref()
224 }
225
226 fn concat_ek(t_hat: EncodedNttVector<Self>, rho: B32) -> EncodedEncryptionKey<Self> {
227 t_hat.concat(rho)
228 }
229
230 fn split_ek(ek: &EncodedEncryptionKey<Self>) -> (&EncodedNttVector<Self>, &B32) {
231 ek.split_ref()
232 }
233}
234
235pub trait KemParams: PkeParams {
237 type DecapsulationKeySize: ArraySize;
238
239 fn concat_dk(
240 dk: EncodedDecryptionKey<Self>,
241 ek: EncodedEncryptionKey<Self>,
242 h: B32,
243 z: B32,
244 ) -> EncodedDecapsulationKey<Self>;
245
246 fn split_dk(
247 enc: &EncodedDecapsulationKey<Self>,
248 ) -> (
249 &EncodedDecryptionKey<Self>,
250 &EncodedEncryptionKey<Self>,
251 &B32,
252 &B32,
253 );
254}
255
256pub type DecapsulationKeySize<P> = <P as KemParams>::DecapsulationKeySize;
257pub type EncapsulationKeySize<P> = <P as PkeParams>::EncryptionKeySize;
258
259pub type EncodedDecapsulationKey<P> = Array<u8, <P as KemParams>::DecapsulationKeySize>;
260
261impl<P> KemParams for P
262where
263 P: PkeParams,
264 P::NttVectorSize: Add<P::EncryptionKeySize>,
265 Sum<P::NttVectorSize, P::EncryptionKeySize>:
266 ArraySize + Add<U32> + Sub<P::NttVectorSize, Output = P::EncryptionKeySize>,
267 Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>:
268 ArraySize + Add<U32> + Sub<Sum<P::NttVectorSize, P::EncryptionKeySize>, Output = U32>,
269 Sum<Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>, U32>:
270 ArraySize + Sub<Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>, Output = U32>,
271{
272 type DecapsulationKeySize = Sum<Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>, U32>;
273
274 fn concat_dk(
275 dk: EncodedDecryptionKey<Self>,
276 ek: EncodedEncryptionKey<Self>,
277 h: B32,
278 z: B32,
279 ) -> EncodedDecapsulationKey<Self> {
280 dk.concat(ek).concat(h).concat(z)
281 }
282
283 #[allow(clippy::similar_names)] fn split_dk(
285 enc: &EncodedDecapsulationKey<Self>,
286 ) -> (
287 &EncodedDecryptionKey<Self>,
288 &EncodedEncryptionKey<Self>,
289 &B32,
290 &B32,
291 ) {
292 let (enc, z) = enc.split_ref();
294 let (enc, h) = enc.split_ref();
295 let (dk_pke, ek_pke) = enc.split_ref();
296 (dk_pke, ek_pke, h, z)
297 }
298}