1use itertools::{izip, Itertools};
2use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub};
3use std::fmt::{Debug, Display};
4
5use crate::{
6 backend::ArithmeticOps,
7 parameters::{
8 DecompositionCount, DecompostionLogBase, DoubleDecomposerParams, SingleDecomposerParams,
9 },
10 utils::log2,
11};
12
13fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
14 assert!(logq >= (logb * d));
15 let ignored_bits = logq - (logb * d);
16
17 (0..d)
18 .into_iter()
19 .map(|i| T::one() << (logb * i + ignored_bits))
20 .collect_vec()
21}
22
23pub trait RlweDecomposer {
24 type Element;
25 type D: Decomposer<Element = Self::Element>;
26
27 fn a(&self) -> &Self::D;
29 fn b(&self) -> &Self::D;
31}
32
33impl<D> RlweDecomposer for (D, D)
34where
35 D: Decomposer,
36{
37 type D = D;
38 type Element = D::Element;
39 fn a(&self) -> &Self::D {
40 &self.0
41 }
42 fn b(&self) -> &Self::D {
43 &self.1
44 }
45}
46
47impl<D> DoubleDecomposerParams for D
48where
49 D: RlweDecomposer,
50{
51 type Base = DecompostionLogBase;
52 type Count = DecompositionCount;
53
54 fn decomposition_base(&self) -> Self::Base {
55 assert!(
56 Decomposer::decomposition_base(self.a()) == Decomposer::decomposition_base(self.b())
57 );
58 Decomposer::decomposition_base(self.a())
59 }
60 fn decomposition_count_a(&self) -> Self::Count {
61 Decomposer::decomposition_count(self.a())
62 }
63 fn decomposition_count_b(&self) -> Self::Count {
64 Decomposer::decomposition_count(self.b())
65 }
66}
67
68impl<D> SingleDecomposerParams for D
69where
70 D: Decomposer,
71{
72 type Base = DecompostionLogBase;
73 type Count = DecompositionCount;
74 fn decomposition_base(&self) -> Self::Base {
75 Decomposer::decomposition_base(self)
76 }
77 fn decomposition_count(&self) -> Self::Count {
78 Decomposer::decomposition_count(self)
79 }
80}
81
82pub trait Decomposer {
83 type Element;
84 type Iter: Iterator<Item = Self::Element>;
85 fn new(q: Self::Element, logb: usize, d: usize) -> Self;
86
87 fn decompose_to_vec(&self, v: &Self::Element) -> Vec<Self::Element>;
88 fn decompose_iter(&self, v: &Self::Element) -> Self::Iter;
89 fn decomposition_count(&self) -> DecompositionCount;
90 fn decomposition_base(&self) -> DecompostionLogBase;
91 fn gadget_vector(&self) -> Vec<Self::Element>;
92}
93
94pub struct DefaultDecomposer<T> {
95 q: T,
97 logq: usize,
99 logb: usize,
101 b: T,
103 b_mask: T,
106 bby2: T,
108 d: usize,
110 ignore_bits: usize,
112}
113
114pub trait NumInfo {
115 const BITS: u32;
116}
117
118impl NumInfo for u64 {
119 const BITS: u32 = u64::BITS;
120}
121impl NumInfo for u32 {
122 const BITS: u32 = u32::BITS;
123}
124impl NumInfo for u128 {
125 const BITS: u32 = u128::BITS;
126}
127
128impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> {
129 fn recompose<Op>(&self, limbs: &[T], modq_op: &Op) -> T
130 where
131 Op: ArithmeticOps<Element = T>,
132 {
133 let mut value = T::zero();
134 let gadget_vector = gadget_vector(self.logq, self.logb, self.d);
135 assert!(limbs.len() == gadget_vector.len());
136 izip!(limbs.iter(), gadget_vector.iter())
137 .for_each(|(d_el, beta)| value = modq_op.add(&value, &modq_op.mul(d_el, beta)));
138
139 value
140 }
141}
142
143impl<
144 T: PrimInt
145 + ToPrimitive
146 + FromPrimitive
147 + WrappingSub
148 + WrappingAdd
149 + NumInfo
150 + From<bool>
151 + Display
152 + Debug,
153 > Decomposer for DefaultDecomposer<T>
154{
155 type Element = T;
156 type Iter = DecomposerIter<T>;
157
158 fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
159 let logq = log2(&q);
161 assert!(
162 logq >= (logb * d),
163 "Decomposer wants logq >= logb*d but got logq={logq}, logb={logb}, d={d}"
164 );
165
166 let ignore_bits = logq - (logb * d);
167
168 DefaultDecomposer {
169 q,
170 logq,
171 logb,
172 b: T::one() << logb,
173 b_mask: (T::one() << logb) - T::one(),
174 bby2: T::one() << (logb - 1),
175 d,
176 ignore_bits,
177 }
178 }
179
180 fn decompose_to_vec(&self, value: &T) -> Vec<T> {
181 let q = self.q;
182 let logb = self.logb;
183 let b = T::one() << logb;
184 let full_mask = b - T::one();
185 let bby2 = b >> 1;
186
187 let mut value = *value;
188 if value >= (q >> 1) {
189 value = !(q - value) + T::one()
190 }
191 value = round_value(value, self.ignore_bits);
192 let mut out = Vec::with_capacity(self.d);
193 for _ in 0..(self.d) {
194 let k_i = value & full_mask;
195
196 value = (value - k_i) >> logb;
197
198 if k_i > bby2 || (k_i == bby2 && ((value & T::one()) == T::one())) {
199 out.push(q - (b - k_i));
200 value = value + T::one();
201 } else {
202 out.push(k_i);
203 }
204 }
205
206 return out;
207 }
208
209 fn decomposition_count(&self) -> DecompositionCount {
210 DecompositionCount(self.d)
211 }
212
213 fn decomposition_base(&self) -> DecompostionLogBase {
214 DecompostionLogBase(self.logb)
215 }
216
217 fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {
218 let mut value = *value;
219 if value >= (self.q >> 1) {
220 value = !(self.q - value) + T::one()
221 }
222 value = round_value(value, self.ignore_bits);
223
224 DecomposerIter {
225 value,
226 q: self.q,
227 logq: self.logq,
228 logb: self.logb,
229 b: self.b,
230 bby2: self.bby2,
231 b_mask: self.b_mask,
232 steps_left: self.d,
233 }
234 }
235
236 fn gadget_vector(&self) -> Vec<T> {
237 return gadget_vector(self.logq, self.logb, self.d);
238 }
239}
240
241impl<T: PrimInt> DefaultDecomposer<T> {}
242
243pub struct DecomposerIter<T> {
244 value: T,
246 steps_left: usize,
247 b_mask: T,
250 logb: usize,
251 bby2: T,
253 q: T,
255 logq: usize,
257 b: T,
259}
260
261impl<T: PrimInt + From<bool> + WrappingSub + Display> Iterator for DecomposerIter<T> {
262 type Item = T;
263
264 fn next(&mut self) -> Option<Self::Item> {
265 if self.steps_left != 0 {
266 self.steps_left -= 1;
267 let k_i = self.value & self.b_mask;
268
269 self.value = (self.value - k_i) >> self.logb;
270
271 let carry_bool =
285 k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one()));
286 let carry = <T as From<bool>>::from(carry_bool);
287 let neg_carry = (T::zero().wrapping_sub(&carry));
288 self.value = self.value + carry;
289 Some((neg_carry & self.q) + k_i - (carry << self.logb))
290
291 } else {
297 None
298 }
299 }
300}
301
302fn round_value<T: PrimInt + WrappingAdd>(value: T, ignore_bits: usize) -> T {
303 if ignore_bits == 0 {
304 return value;
305 }
306
307 let ignored_msb = (value & ((T::one() << ignore_bits) - T::one())) >> (ignore_bits - 1);
308 (value >> ignore_bits).wrapping_add(&ignored_msb)
309}
310
311#[cfg(test)]
312mod tests {
313
314 use itertools::Itertools;
315 use rand::{thread_rng, Rng};
316
317 use crate::{
318 backend::{ModInit, ModularOpsU64},
319 decomposer::round_value,
320 utils::generate_prime,
321 };
322
323 use super::{Decomposer, DefaultDecomposer};
324
325 #[test]
326 fn decomposition_works() {
327 let ring_size = 1 << 11;
328
329 let mut rng = thread_rng();
330
331 for logq in [37, 55] {
332 let logb = 11;
333 let d = 3;
334 for i in [true, false] {
337 let q = if i {
338 generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap()
339 } else {
340 1u64 << logq
341 };
342 let decomposer = DefaultDecomposer::new(q, logb, d);
343 let modq_op = ModularOpsU64::new(q);
344 for _ in 0..1000000 {
345 let value = rng.gen_range(0..q);
346 let limbs = decomposer.decompose_to_vec(&value);
347 let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
348 assert_eq!(limbs, limbs_from_iter);
349 let value_back = round_value(
350 decomposer.recompose(&limbs, &modq_op),
351 decomposer.ignore_bits,
352 );
353 let rounded_value = round_value(value, decomposer.ignore_bits);
354 assert!((rounded_value as i64 - value_back as i64).abs() <= 1,);
355
356 }
360 }
361
362 }
370 }
371}