1use crate::{
14 B32, Ciphertext, Kem,
15 algebra::{BaseField, Elem, NttVector},
16};
17use array::{
18 Array,
19 typenum::{
20 Const, ToUInt, U0, U2, U3, U4, U6, U12, U16, U32, U64, U384,
21 operator_aliases::{Prod, Sum},
22 },
23};
24use core::{
25 fmt::Debug,
26 ops::{Add, Div, Mul, Rem, Sub},
27};
28use module_lattice::{
29 ArraySize, Encode, EncodedPolynomialSize, EncodedVectorSize, EncodingSize, Field,
30 VectorEncodingSize,
31};
32
33#[cfg(doc)]
34use crate::Seed;
35
36#[allow(clippy::integer_division_remainder_used, reason = "constant")]
42const fn ones_array<const B: usize, const N: usize, U>() -> Array<Elem, U>
43where
44 U: ArraySize<ArrayType<Elem> = [Elem; N]>,
45 Const<N>: ToUInt<Output = U>,
46{
47 let max = 1 << B;
48 let mut out = [Elem::new(0); N];
49 let mut x = 0usize;
50 while x < max {
51 let mut y = 0usize;
52 while y < max {
53 let x_ones = (x.count_ones() & 0xFFFF) as u16;
54 let y_ones = (y.count_ones() & 0xFFFF) as u16;
55 let i = x + (y << B);
56 out[i] = Elem::new((x_ones + BaseField::Q - y_ones) % BaseField::Q);
57
58 y += 1;
59 }
60 x += 1;
61 }
62 Array(out)
63}
64
65#[allow(unreachable_pub)]
67pub trait CbdSamplingSize: ArraySize {
68 type SampleSize: EncodingSize;
69 type OnesSize: ArraySize;
70 const ONES: Array<Elem, Self::OnesSize>;
71}
72
73impl CbdSamplingSize for U2 {
74 type SampleSize = U4;
75 type OnesSize = U16;
76 const ONES: Array<Elem, U16> = ones_array::<2, 16, U16>();
77}
78
79impl CbdSamplingSize for U3 {
80 type SampleSize = U6;
81 type OnesSize = U64;
82 const ONES: Array<Elem, U64> = ones_array::<3, 64, U64>();
83}
84
85pub trait ParameterSet: Default + Clone + Debug + PartialEq {
89 type K: ArraySize;
91
92 type Eta1: CbdSamplingSize;
95
96 type Eta2: CbdSamplingSize;
99
100 type Du: VectorEncodingSize<Self::K>;
102
103 type Dv: EncodingSize;
105}
106
107pub(crate) type EncodedUSize<P> =
108 EncodedVectorSize<<P as ParameterSet>::Du, <P as ParameterSet>::K>;
109pub(crate) type EncodedVSize<P> = EncodedPolynomialSize<<P as ParameterSet>::Dv>;
110
111type EncodedU<P> = Array<u8, EncodedUSize<P>>;
112type EncodedV<P> = Array<u8, EncodedVSize<P>>;
113
114pub trait PkeParams: Kem<SharedKeySize = U32> + ParameterSet {
116 type NttVectorSize: ArraySize;
117 type EncryptionKeySize: ArraySize;
118
119 fn encode_u12(p: &NttVector<Self::K>) -> EncodedNttVector<Self>;
120 fn decode_u12(v: &EncodedNttVector<Self>) -> NttVector<Self::K>;
121
122 fn concat_ct(u: EncodedU<Self>, v: EncodedV<Self>) -> Ciphertext<Self>;
123 fn split_ct(ct: &Ciphertext<Self>) -> (&EncodedU<Self>, &EncodedV<Self>);
124
125 fn concat_ek(t_hat: EncodedNttVector<Self>, rho: B32) -> EncodedEncryptionKey<Self>;
126 fn split_ek(ek: &EncodedEncryptionKey<Self>) -> (&EncodedNttVector<Self>, &B32);
127}
128
129pub(crate) type EncodedNttVector<P> = Array<u8, <P as PkeParams>::NttVectorSize>;
130pub(crate) type EncodedDecryptionKey<P> = Array<u8, <P as PkeParams>::NttVectorSize>;
131pub(crate) type EncodedEncryptionKey<P> = Array<u8, <P as PkeParams>::EncryptionKeySize>;
132
133impl<P> PkeParams for P
134where
135 P: Kem<CiphertextSize = Sum<EncodedUSize<P>, EncodedVSize<P>>, SharedKeySize = U32>
136 + ParameterSet,
137 U384: Mul<P::K>,
138 Prod<U384, P::K>: ArraySize + Add<U32> + Div<P::K, Output = U384> + Rem<P::K, Output = U0>,
139 EncodedUSize<P>: Add<EncodedVSize<P>>,
140 Sum<EncodedUSize<P>, EncodedVSize<P>>:
141 ArraySize + Sub<EncodedUSize<P>, Output = EncodedVSize<P>>,
142 EncodedVectorSize<U12, P::K>: Add<U32>,
143 Sum<EncodedVectorSize<U12, P::K>, U32>:
144 ArraySize + Sub<EncodedVectorSize<U12, P::K>, Output = U32>,
145{
146 type NttVectorSize = EncodedVectorSize<U12, P::K>;
147 type EncryptionKeySize = Sum<Self::NttVectorSize, U32>;
148
149 fn encode_u12(p: &NttVector<Self::K>) -> EncodedNttVector<Self> {
150 Encode::<U12>::encode(p)
151 }
152
153 fn decode_u12(v: &EncodedNttVector<Self>) -> NttVector<Self::K> {
154 Encode::<U12>::decode(v)
155 }
156
157 fn concat_ct(u: EncodedU<Self>, v: EncodedV<Self>) -> Ciphertext<Self> {
158 u.concat(v)
159 }
160
161 fn split_ct(ct: &Ciphertext<Self>) -> (&EncodedU<Self>, &EncodedV<Self>) {
162 ct.split_ref()
163 }
164
165 fn concat_ek(t_hat: EncodedNttVector<Self>, rho: B32) -> EncodedEncryptionKey<Self> {
166 t_hat.concat(rho)
167 }
168
169 fn split_ek(ek: &EncodedEncryptionKey<Self>) -> (&EncodedNttVector<Self>, &B32) {
170 ek.split_ref()
171 }
172}
173
174pub trait KemParams: PkeParams {
176 type DecapsulationKeySize: ArraySize;
177
178 fn concat_dk(
179 dk: EncodedDecryptionKey<Self>,
180 ek: EncodedEncryptionKey<Self>,
181 h: B32,
182 z: B32,
183 ) -> ExpandedDecapsulationKey<Self>;
184
185 fn split_dk(
186 enc: &ExpandedDecapsulationKey<Self>,
187 ) -> (
188 &EncodedDecryptionKey<Self>,
189 &EncodedEncryptionKey<Self>,
190 &B32,
191 &B32,
192 );
193}
194
195pub(crate) type DecapsulationKeySize<P> = <P as KemParams>::DecapsulationKeySize;
196pub(crate) type EncapsulationKeySize<P> = <P as PkeParams>::EncryptionKeySize;
197
198pub type ExpandedDecapsulationKey<P> = Array<u8, <P as KemParams>::DecapsulationKeySize>;
200
201impl<P> KemParams for P
202where
203 P: PkeParams,
204 P::NttVectorSize: Add<P::EncryptionKeySize>,
205 Sum<P::NttVectorSize, P::EncryptionKeySize>:
206 ArraySize + Add<U32> + Sub<P::NttVectorSize, Output = P::EncryptionKeySize>,
207 Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>:
208 ArraySize + Add<U32> + Sub<Sum<P::NttVectorSize, P::EncryptionKeySize>, Output = U32>,
209 Sum<Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>, U32>:
210 ArraySize + Sub<Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>, Output = U32>,
211{
212 type DecapsulationKeySize = Sum<Sum<Sum<P::NttVectorSize, P::EncryptionKeySize>, U32>, U32>;
213
214 fn concat_dk(
215 dk: EncodedDecryptionKey<Self>,
216 ek: EncodedEncryptionKey<Self>,
217 h: B32,
218 z: B32,
219 ) -> ExpandedDecapsulationKey<Self> {
220 dk.concat(ek).concat(h).concat(z)
221 }
222
223 #[allow(clippy::similar_names)] fn split_dk(
225 enc: &ExpandedDecapsulationKey<Self>,
226 ) -> (
227 &EncodedDecryptionKey<Self>,
228 &EncodedEncryptionKey<Self>,
229 &B32,
230 &B32,
231 ) {
232 let (enc, z) = enc.split_ref();
234 let (enc, h) = enc.split_ref();
235 let (dk_pke, ek_pke) = enc.split_ref();
236 (dk_pke, ek_pke, h, z)
237 }
238}