1use std::alloc::{Allocator, Global};
2use std::cmp::{min, max};
3use std::marker::PhantomData;
4
5use crate::algorithms::miller_rabin::is_prime;
6use crate::homomorphism::*;
7use crate::integer::*;
8use crate::lazy::LazyVec;
9use crate::primitive_int::StaticRing;
10use crate::ring::*;
11use crate::rings::zn::zn_64::{Zn, ZnBase};
12use crate::rings::zn::*;
13use crate::divisibility::*;
14use crate::seq::*;
15
16use super::ntt::NTTConvolution;
17use super::{ConvolutionAlgorithm, PreparedConvolutionAlgorithm, PreparedConvolutionOperation};
18
19#[stability::unstable(feature = "enable")]
29pub struct RNSConvolution<I = BigIntRing, C = NTTConvolution<Zn>, A = Global, CreateC = CreateNTTConvolution>
30 where I: RingStore + Clone,
31 I::Type: IntegerRing,
32 C: ConvolutionAlgorithm<ZnBase>,
33 A: Allocator + Clone,
34 CreateC: Fn(Zn) -> C
35{
36 integer_ring: I,
37 rns_rings: LazyVec<zn_rns::Zn<Zn, I, A>>,
38 convolutions: LazyVec<C>,
39 create_convolution: CreateC,
40 required_root_of_unity_log2: usize,
41 allocator: A
42}
43
44#[stability::unstable(feature = "enable")]
48#[repr(transparent)]
49pub struct RNSConvolutionZn<I = BigIntRing, C = NTTConvolution<Zn>, A = Global, CreateC = CreateNTTConvolution>
50 where I: RingStore + Clone,
51 I::Type: IntegerRing,
52 C: ConvolutionAlgorithm<ZnBase>,
53 A: Allocator + Clone,
54 CreateC: Fn(Zn) -> C
55{
56 base: RNSConvolution<I, C, A, CreateC>
57}
58
59#[stability::unstable(feature = "enable")]
63pub struct PreparedConvolutionOperand<R, C = NTTConvolution<Zn>>
64 where R: ?Sized + RingBase,
65 C: PreparedConvolutionAlgorithm<ZnBase>
66{
67 data: Vec<R::Element>,
68 prepared: LazyVec<C::PreparedConvolutionOperand>,
69 log2_data_size: usize
70}
71
72#[stability::unstable(feature = "enable")]
76pub struct PreparedConvolutionOperandZn<R, C = NTTConvolution<Zn>>(PreparedConvolutionOperand<R::IntegerRingBase, C>)
77 where R: ?Sized + ZnRing,
78 C: PreparedConvolutionAlgorithm<ZnBase>;
79
80#[stability::unstable(feature = "enable")]
84pub struct CreateNTTConvolution<A = Global>
85 where A: Allocator + Clone
86{
87 allocator: A
88}
89
90impl<I, C, A, CreateC> From<RNSConvolutionZn<I, C, A, CreateC>> for RNSConvolution<I, C, A, CreateC>
91 where I: RingStore + Clone,
92 I::Type: IntegerRing,
93 C: ConvolutionAlgorithm<ZnBase>,
94 A: Allocator + Clone,
95 CreateC: Fn(Zn) -> C
96{
97 fn from(value: RNSConvolutionZn<I, C, A, CreateC>) -> Self {
98 value.base
99 }
100}
101
102impl<'a, I, C, A, CreateC> From<&'a RNSConvolutionZn<I, C, A, CreateC>> for &'a RNSConvolution<I, C, A, CreateC>
103 where I: RingStore + Clone,
104 I::Type: IntegerRing,
105 C: ConvolutionAlgorithm<ZnBase>,
106 A: Allocator + Clone,
107 CreateC: Fn(Zn) -> C
108{
109 fn from(value: &'a RNSConvolutionZn<I, C, A, CreateC>) -> Self {
110 &value.base
111 }
112}
113
114impl<I, C, A, CreateC> From<RNSConvolution<I, C, A, CreateC>> for RNSConvolutionZn<I, C, A, CreateC>
115 where I: RingStore + Clone,
116 I::Type: IntegerRing,
117 C: ConvolutionAlgorithm<ZnBase>,
118 A: Allocator + Clone,
119 CreateC: Fn(Zn) -> C
120{
121 fn from(value: RNSConvolution<I, C, A, CreateC>) -> Self {
122 RNSConvolutionZn { base: value }
123 }
124}
125
126impl<'a, I, C, A, CreateC> From<&'a RNSConvolution<I, C, A, CreateC>> for &'a RNSConvolutionZn<I, C, A, CreateC>
127 where I: RingStore + Clone,
128 I::Type: IntegerRing,
129 C: ConvolutionAlgorithm<ZnBase>,
130 A: Allocator + Clone,
131 CreateC: Fn(Zn) -> C
132{
133 fn from(value: &'a RNSConvolution<I, C, A, CreateC>) -> Self {
134 unsafe { std::mem::transmute(value) }
135 }
136}
137
138impl CreateNTTConvolution<Global> {
139 #[stability::unstable(feature = "enable")]
140 pub fn new() -> Self {
141 Self { allocator: Global }
142 }
143}
144
145impl<A> FnOnce<(Zn,)> for CreateNTTConvolution<A>
146 where A: Allocator + Clone
147{
148 type Output = NTTConvolution<Zn, A>;
149
150 extern "rust-call" fn call_once(self, args: (Zn,)) -> Self::Output {
151 self.call(args)
152 }
153}
154
155impl<A> FnMut<(Zn,)> for CreateNTTConvolution<A>
156 where A: Allocator + Clone
157{
158 extern "rust-call" fn call_mut(&mut self, args: (Zn,)) -> Self::Output {
159 self.call(args)
160 }
161}
162
163impl<A> Fn<(Zn,)> for CreateNTTConvolution<A>
164 where A: Allocator + Clone
165{
166 extern "rust-call" fn call(&self, args: (Zn,)) -> Self::Output {
167 NTTConvolution::new_with(args.0, self.allocator.clone())
168 }
169}
170
171impl RNSConvolution {
172
173 #[stability::unstable(feature = "enable")]
178 pub fn new(max_log2_n: usize) -> Self {
179 Self::new_with(max_log2_n, usize::MAX, BigIntRing::RING, Global, CreateNTTConvolution { allocator: Global })
180 }
181}
182
183impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
184 where I: RingStore + Clone,
185 I::Type: IntegerRing,
186 C: ConvolutionAlgorithm<ZnBase>,
187 A: Allocator + Clone,
188 CreateC: Fn(Zn) -> C
189{
190 #[stability::unstable(feature = "enable")]
205 pub fn new_with(required_root_of_unity_log2: usize, mut max_prime_size_log2: usize, integer_ring: I, allocator: A, create_convolution: CreateC) -> Self {
206 max_prime_size_log2 = min(max_prime_size_log2, 57);
207 let result = Self {
208 integer_ring: integer_ring,
209 create_convolution: create_convolution,
210 convolutions: LazyVec::new(),
211 rns_rings: LazyVec::new(),
212 required_root_of_unity_log2: required_root_of_unity_log2,
213 allocator: allocator
214 };
215 let initial_ring = zn_rns::Zn::new_with(vec![Zn::new(Self::sample_next_prime(required_root_of_unity_log2, (1 << max_prime_size_log2) + 1).unwrap() as u64)], result.integer_ring.clone(), result.allocator.clone());
216 _ = result.rns_rings.get_or_init(0, || initial_ring);
217 return result;
218 }
219
220 fn sample_next_prime(required_root_of_unity_log2: usize, current: i64) -> Option<i64> {
221 let mut k = StaticRing::<i64>::RING.checked_div(&(current - 1), &(1 << required_root_of_unity_log2)).unwrap();
222 while k > 0 {
223 k -= 1;
224 let candidate = (k << required_root_of_unity_log2) + 1;
225 if is_prime(StaticRing::<i64>::RING, &candidate, 10) {
226 return Some(candidate);
227 }
228 }
229 return None;
230 }
231
232 fn get_rns_ring(&self, moduli_count: usize) -> &zn_rns::Zn<Zn, I, A> {
233 self.rns_rings.get_or_init_incremental(moduli_count, |_, prev| zn_rns::Zn::new_with(
234 prev.as_iter().cloned().chain([Zn::new(Self::sample_next_prime(self.required_root_of_unity_log2, *prev.at(prev.len() - 1).modulus()).unwrap() as u64)]).collect(),
235 self.integer_ring.clone(),
236 self.allocator.clone()
237 ));
238 return self.rns_rings.get(moduli_count - 1).unwrap();
239 }
240
241 fn get_rns_factor(&self, i: usize) -> &Zn {
242 let rns_ring = self.get_rns_ring(i + 1);
243 return rns_ring.at(rns_ring.len() - 1);
244 }
245
246 fn get_convolution(&self, i: usize) -> &C {
247 self.convolutions.get_or_init(i, || (self.create_convolution)(*self.get_rns_factor(i)))
248 }
249
250 fn extend_operand<R, F>(&self, operand: &PreparedConvolutionOperand<R, C>, target_width: usize, mut mod_part: F)
251 where R: ?Sized + RingBase,
252 C: PreparedConvolutionAlgorithm<ZnBase>,
253 F: FnMut(&R::Element, usize) -> El<Zn>
254 {
255 let mut tmp = Vec::new();
256 tmp.resize_with(operand.data.len(), || self.get_rns_factor(0).zero());
257 for i in 0..target_width {
258 _ = operand.prepared.get_or_init(i, || {
259 for j in 0..operand.data.len() {
260 tmp[j] = mod_part(&operand.data[j], i);
261 }
262 self.get_convolution(i).prepare_convolution_operand(&tmp, self.get_rns_factor(i))
263 });
264 }
265 }
266
267 fn compute_required_width(&self, input_size_log2: usize, lhs_len: usize, rhs_len: usize, inner_prod_len: usize) -> usize {
268 let log2_output_size = input_size_log2 * 2 +
269 StaticRing::<i64>::RING.abs_log2_ceil(&(min(lhs_len, rhs_len) as i64)).unwrap_or(0) +
270 StaticRing::<i64>::RING.abs_log2_ceil(&(inner_prod_len as i64)).unwrap_or(0) +
271 1;
272 let mut width = (log2_output_size - 1) / 57 + 1;
273 while log2_output_size > self.integer_ring.abs_log2_floor(self.get_rns_ring(width).modulus()).unwrap() {
274 width += 1;
275 }
276 return width;
277 }
278}
279
280impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
281 where I: RingStore + Clone,
282 I::Type: IntegerRing,
283 C: ConvolutionAlgorithm<ZnBase>,
284 A: Allocator + Clone,
285 CreateC: Fn(Zn) -> C
286{
287 fn compute_convolution_impl<S, V1, V2, D>(&self, input_size_log2: usize, lhs: V1, rhs: V2, mut dst: D, ring: S)
288 where S: RingStore,
289 S::Type: RingBase + IntegerRing,
290 D: FnMut(usize, El<I>),
291 V1: VectorFn<El<S>>,
292 V2: VectorFn<El<S>>
293 {
294 let width = self.compute_required_width(input_size_log2, lhs.len(), rhs.len(), 1);
295 let len = lhs.len() + rhs.len();
296 let mut res_data = Vec::with_capacity(len * width);
297 for i in 0..width {
298 res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
299 }
300
301 let mut lhs_tmp = Vec::with_capacity(lhs.len());
302 lhs_tmp.resize_with(lhs.len(), || self.get_rns_factor(0).zero());
303 let mut rhs_tmp = Vec::with_capacity(rhs.len());
304 rhs_tmp.resize_with(rhs.len(), || self.get_rns_factor(0).zero());
305 for i in 0..width {
306 let hom = self.get_rns_factor(i).can_hom(&ring).unwrap();
307 for j in 0..lhs.len() {
308 lhs_tmp[j] = hom.map(lhs.at(j));
309 }
310 for j in 0..rhs.len() {
311 rhs_tmp[j] = hom.map(rhs.at(j));
312 }
313 self.get_convolution(i).compute_convolution(&lhs_tmp, &rhs_tmp, &mut res_data[(i * len)..((i + 1) * len)], self.get_rns_factor(i));
314 }
315
316 for j in 0..(len - 1) {
317 dst(j, self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * len + j]))));
318 }
319 }
320}
321
322impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
323 where I: RingStore + Clone,
324 I::Type: IntegerRing,
325 C: PreparedConvolutionAlgorithm<ZnBase>,
326 A: Allocator + Clone,
327 CreateC: Fn(Zn) -> C
328{
329 fn prepare_convolution_operand_impl<S, V>(&self, input_size_log2: usize, val: V, _ring: S) -> PreparedConvolutionOperand<S::Type, C>
330 where S: RingStore + Copy,
331 S::Type: IntegerRing,
332 V: VectorFn<El<S>>
333 {
334 let mut data = Vec::with_capacity(val.len());
335 data.extend(val.iter());
336 return PreparedConvolutionOperand {
337 data: data,
338 prepared: LazyVec::new(),
339 log2_data_size: input_size_log2
340 };
341 }
342
343 fn compute_convolution_inner_product_lhs_prepared_impl<'a, S, V, D>(&self, rhs_input_size_log2: usize, values: &[(&'a PreparedConvolutionOperand<S::Type, C>, V)], mut dst: D, ring: S)
344 where S: RingStore + Copy,
345 S::Type: IntegerRing,
346 D: FnMut(usize, El<I>),
347 V: VectorFn<El<S>>,
348 S: 'a,
349 Self: 'a,
350 PreparedConvolutionOperand<S::Type, C>: 'a
351 {
352 let max_len = values.iter().map(|(lhs, rhs)| lhs.data.len() + rhs.len()).max().unwrap_or(0);
353 let input_size_log2 = max(rhs_input_size_log2, values.iter().map(|(lhs, _)| lhs.log2_data_size).max().unwrap_or(0));
354 let width = self.compute_required_width(input_size_log2, (max_len - 1) / 2 + 1, (max_len - 1) / 2 + 1, values.len());
355 let mut res_data = Vec::with_capacity(max_len * width);
356 for i in 0..width {
357 res_data.extend((0..max_len).map(|_| self.get_rns_factor(i).zero()));
358 }
359
360 let mut rhs_tmp = Vec::with_capacity(max_len * values.len());
361 rhs_tmp.resize_with(max_len * values.len(), || self.get_rns_factor(0).zero());
362
363 let homs = (0..width).map(|i| self.get_rns_factor(i).can_hom(&ring).unwrap()).collect::<Vec<_>>();
364 for j in 0..values.len() {
365 self.extend_operand(values[j].0, width, |x, i| homs[i].map_ref(x));
366 }
367
368 for i in 0..width {
369 for j in 0..values.len() {
370 for k in 0..values[j].1.len() {
371 rhs_tmp[j * max_len + k] = homs[i].map(values[j].1.at(k));
372 }
373 }
374 self.get_convolution(i).compute_convolution_inner_product_lhs_prepared(
375 values.iter().enumerate().map(|(j, (lhs, _))| (lhs.prepared.get(i).unwrap(), &rhs_tmp[(j * max_len)..(j * max_len + values[j].1.len())])),
376 &mut res_data[(i * max_len)..((i + 1) * max_len)],
377 self.get_rns_factor(i)
378 );
379 }
380
381 for j in 0..(max_len - 1) {
382 dst(j, self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * max_len + j]))));
383 }
384 }
385
386 fn compute_convolution_inner_product_prepared_impl<'a, S, D>(&self, values: &[(&'a PreparedConvolutionOperand<S::Type, C>, &'a PreparedConvolutionOperand<S::Type, C>)], mut dst: D, ring: S)
387 where S: RingStore + Copy,
388 S::Type: IntegerRing,
389 D: FnMut(usize, El<I>),
390 Self: 'a,
391 S: 'a,
392 PreparedConvolutionOperand<S::Type, C>: 'a
393 {
394 let max_len = values.iter().map(|(lhs, rhs)| lhs.data.len() + rhs.data.len()).max().unwrap_or(0);
395 let input_size_log2 = values.iter().map(|(lhs, rhs)| max(lhs.log2_data_size, rhs.log2_data_size)).max().unwrap_or(0);
396 let width = self.compute_required_width(input_size_log2, (max_len - 1) / 2 + 1, (max_len - 1) / 2 + 1, values.len());
397 let mut res_data = Vec::with_capacity(max_len * width);
398 for i in 0..width {
399 res_data.extend((0..max_len).map(|_| self.get_rns_factor(i).zero()));
400 }
401
402 let homs = (0..width).map(|i| self.get_rns_factor(i).can_hom(&ring).unwrap()).collect::<Vec<_>>();
403 for j in 0..values.len() {
404 self.extend_operand(values[j].0, width, |x, i| homs[i].map_ref(x));
405 self.extend_operand(values[j].1, width, |x, i| homs[i].map_ref(x));
406 }
407
408 for i in 0..width {
409 self.get_convolution(i).compute_convolution_inner_product_prepared(
410 values.iter().map(|(lhs, rhs)| (lhs.prepared.get(i).unwrap(), rhs.prepared.get(i).unwrap())),
411 &mut res_data[(i * max_len)..((i + 1) * max_len)],
412 self.get_rns_factor(i)
413 );
414 }
415
416 for j in 0..(max_len - 1) {
417 dst(j, self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * max_len + j]))));
418 }
419 }
420
421 fn compute_convolution_lhs_prepared_impl<S, V, D>(&self, rhs_input_size_log2: usize, lhs: &PreparedConvolutionOperand<S::Type, C>, rhs: V, mut dst: D, ring: S)
422 where S: RingStore + Copy,
423 S::Type: IntegerRing,
424 D: FnMut(usize, El<I>),
425 V: VectorFn<El<S>>
426 {
427 let width = self.compute_required_width(max(rhs_input_size_log2, lhs.log2_data_size), lhs.data.len(), rhs.len(), 1);
428 let len = lhs.data.len() + rhs.len();
429 let mut res_data = Vec::with_capacity(len * width);
430 for i in 0..width {
431 res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
432 }
433
434 let mut rhs_tmp = Vec::with_capacity(rhs.len());
435 rhs_tmp.resize_with(rhs.len(), || self.get_rns_factor(0).zero());
436
437 let homs = (0..width).map(|i| self.get_rns_factor(i).can_hom(&ring).unwrap()).collect::<Vec<_>>();
438 self.extend_operand(lhs, width, |x, i| homs[i].map_ref(x));
439
440 for i in 0..width {
441 for j in 0..rhs.len() {
442 rhs_tmp[j] = homs[i].map(rhs.at(j));
443 }
444 self.get_convolution(i).compute_convolution_lhs_prepared(lhs.prepared.get(i).unwrap(), &rhs_tmp, &mut res_data[(i * len)..((i + 1) * len)], self.get_rns_factor(i));
445 }
446
447 for j in 0..(len - 1) {
448 dst(j, self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * len + j]))));
449 }
450 }
451
452 fn compute_convolution_prepared_impl<S, D>(&self, lhs: &PreparedConvolutionOperand<S::Type, C>, rhs: &PreparedConvolutionOperand<S::Type, C>, mut dst: D, ring: S)
453 where S: RingStore + Copy,
454 S::Type: IntegerRing,
455 D: FnMut(usize, El<I>),
456 {
457 let width = self.compute_required_width(max(lhs.log2_data_size, rhs.log2_data_size), lhs.data.len(), rhs.data.len(), 1);
458 let len = lhs.data.len() + rhs.data.len();
459 let mut res_data = Vec::with_capacity(len * width);
460 for i in 0..width {
461 res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
462 }
463
464 let homs = (0..width).map(|i| self.get_rns_factor(i).can_hom(&ring).unwrap()).collect::<Vec<_>>();
465 self.extend_operand(lhs, width, |x, i| homs[i].map_ref(x));
466 self.extend_operand(rhs, width, |x, i| homs[i].map_ref(x));
467
468 for i in 0..width {
469 self.get_convolution(i).compute_convolution_prepared(lhs.prepared.get(i).unwrap(), rhs.prepared.get(i).unwrap(), &mut res_data[(i * len)..((i + 1) * len)], self.get_rns_factor(i));
470 }
471
472 for j in 0..(len - 1) {
473 dst(j, self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * len + j]))));
474 }
475 }
476}
477
478impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolution<I, C, A, CreateC>
479 where I: RingStore + Clone,
480 I::Type: IntegerRing,
481 C: ConvolutionAlgorithm<ZnBase>,
482 A: Allocator + Clone,
483 CreateC: Fn(Zn) -> C,
484 R: ?Sized + IntegerRing
485{
486 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
487 assert!(dst.len() >= lhs.len() + rhs.len() - 1);
488 let log2_input_size = lhs.as_iter().chain(rhs.as_iter()).map(|x| ring.abs_log2_ceil(x).unwrap_or(0)).max().unwrap_or(0);
489 println!("{}", log2_input_size);
490 let hom = ring.can_hom(&self.integer_ring).unwrap();
491 return self.compute_convolution_impl(
492 log2_input_size,
493 lhs.clone_ring_els(ring),
494 rhs.clone_ring_els(ring),
495 |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
496 ring
497 );
498 }
499
500 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
501 true
502 }
503
504 fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
505 where F: PreparedConvolutionOperation<Self, R>
506 {
507 struct CallFunction<F, R, I, C, A, CreateC>
508 where I: RingStore + Clone,
509 I::Type: IntegerRing,
510 C: ConvolutionAlgorithm<ZnBase>,
511 A: Allocator + Clone,
512 CreateC: Fn(Zn) -> C,
513 R: ?Sized + IntegerRing,
514 F: PreparedConvolutionOperation<RNSConvolution<I, C, A, CreateC>, R>
515 {
516 ring: PhantomData<Box<R>>,
517 convolution: PhantomData<RNSConvolution<I, C, A, CreateC>>,
518 function: F
519 }
520 impl<F, R, I, C, A, CreateC> PreparedConvolutionOperation<C, ZnBase> for CallFunction<F, R, I, C, A, CreateC>
521 where I: RingStore + Clone,
522 I::Type: IntegerRing,
523 C: ConvolutionAlgorithm<ZnBase>,
524 A: Allocator + Clone,
525 CreateC: Fn(Zn) -> C,
526 R: ?Sized + IntegerRing,
527 F: PreparedConvolutionOperation<RNSConvolution<I, C, A, CreateC>, R>
528 {
529 type Output = F::Output;
530
531 fn execute(self) -> Self::Output
532 where C: PreparedConvolutionAlgorithm<ZnBase>
533 {
534 self.function.execute()
535 }
536 }
537 return <C as ConvolutionAlgorithm<ZnBase>>::specialize_prepared_convolution::<CallFunction<F, R, I, C, A, CreateC>>(CallFunction {
538 function: function,
539 ring: PhantomData,
540 convolution: PhantomData
541 }).map_err(|f| f.function);
542 }
543}
544
545impl<R, I, C, A, CreateC> PreparedConvolutionAlgorithm<R> for RNSConvolution<I, C, A, CreateC>
546 where I: RingStore + Clone,
547 I::Type: IntegerRing,
548 C: PreparedConvolutionAlgorithm<ZnBase>,
549 A: Allocator + Clone,
550 CreateC: Fn(Zn) -> C,
551 R: ?Sized + IntegerRing
552{
553 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
554
555 fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
556 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
557 {
558 let log2_input_size = val.as_iter().map(|x| ring.abs_log2_ceil(x).unwrap_or(0)).max().unwrap_or(0);
559 return self.prepare_convolution_operand_impl(log2_input_size, val.clone_ring_els(ring), ring);
560 }
561
562 fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
563 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
564 {
565 assert!(dst.len() >= lhs.data.len() + rhs.len() - 1);
566 let rhs_log2_input_size = rhs.as_iter().map(|x| ring.abs_log2_ceil(x).unwrap_or(0)).max().unwrap_or(0);
567 let hom = ring.can_hom(&self.integer_ring).unwrap();
568 return self.compute_convolution_lhs_prepared_impl(
569 rhs_log2_input_size,
570 lhs,
571 rhs.clone_ring_els(ring),
572 |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
573 ring
574 );
575 }
576
577 fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
578 where S: RingStore<Type = R> + Copy
579 {
580 assert!(dst.len() >= lhs.data.len() + rhs.data.len() - 1);
581 let hom = ring.can_hom(&self.integer_ring).unwrap();
582 return self.compute_convolution_prepared_impl(
583 lhs,
584 rhs,
585 |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
586 ring
587 );
588 }
589
590 fn compute_convolution_inner_product_lhs_prepared<'a, S, J, V>(&self, values: J, dst: &mut [R::Element], ring: S)
591 where S: RingStore<Type = R> + Copy,
592 J: Iterator<Item = (&'a Self::PreparedConvolutionOperand, V)>,
593 V: VectorView<R::Element>,
594 Self: 'a,
595 R: 'a,
596 Self::PreparedConvolutionOperand: 'a
597 {
598 let values = values.map(|(lhs, rhs)| (lhs, rhs.into_clone_ring_els(ring))).collect::<Vec<_>>();
599 let rhs_log2_input_size = values.iter().flat_map(|(_, rhs)| rhs.iter()).map(|x| ring.abs_log2_ceil(&x).unwrap_or(0)).max().unwrap_or(0);
600 let hom = ring.can_hom(&self.integer_ring).unwrap();
601 return self.compute_convolution_inner_product_lhs_prepared_impl(
602 rhs_log2_input_size,
603 &values,
604 |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
605 ring
606 );
607 }
608
609 fn compute_convolution_inner_product_prepared<'a, S, J>(&self, values: J, dst: &mut [R::Element], ring: S)
610 where S: RingStore<Type = R> + Copy,
611 J: Iterator<Item = (&'a Self::PreparedConvolutionOperand, &'a Self::PreparedConvolutionOperand)>,
612 Self::PreparedConvolutionOperand: 'a,
613 Self: 'a,
614 R: 'a,
615 {
616 let values = values.collect::<Vec<_>>();
617 let hom = ring.can_hom(&self.integer_ring).unwrap();
618 return self.compute_convolution_inner_product_prepared_impl(
619 &values,
620 |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
621 ring
622 );
623 }
624}
625
626impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolutionZn<I, C, A, CreateC>
627 where I: RingStore + Clone,
628 I::Type: IntegerRing,
629 C: ConvolutionAlgorithm<ZnBase>,
630 A: Allocator + Clone,
631 CreateC: Fn(Zn) -> C,
632 R: ?Sized + ZnRing + CanHomFrom<I::Type>
633{
634 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
635 assert!(dst.len() >= lhs.len() + rhs.len() - 1);
636 let log2_input_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap() - 1;
637 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
638 return self.base.compute_convolution_impl(
639 log2_input_size,
640 lhs.clone_ring_els(ring).map_fn(|x| ring.smallest_lift(x)),
641 rhs.clone_ring_els(ring).map_fn(|x| ring.smallest_lift(x)),
642 |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
643 ring.integer_ring()
644 );
645 }
646
647 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
648 true
649 }
650
651 fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
652 where F: PreparedConvolutionOperation<Self, R>
653 {
654 struct CallFunction<F, R, I, C, A, CreateC>
655 where I: RingStore + Clone,
656 I::Type: IntegerRing,
657 C: ConvolutionAlgorithm<ZnBase>,
658 A: Allocator + Clone,
659 CreateC: Fn(Zn) -> C,
660 R: ?Sized + ZnRing + CanHomFrom<I::Type>,
661 F: PreparedConvolutionOperation<RNSConvolutionZn<I, C, A, CreateC>, R>
662 {
663 ring: PhantomData<Box<R>>,
664 convolution: PhantomData<RNSConvolution<I, C, A, CreateC>>,
665 function: F
666 }
667 impl<F, R, I, C, A, CreateC> PreparedConvolutionOperation<C, ZnBase> for CallFunction<F, R, I, C, A, CreateC>
668 where I: RingStore + Clone,
669 I::Type: IntegerRing,
670 C: ConvolutionAlgorithm<ZnBase>,
671 A: Allocator + Clone,
672 CreateC: Fn(Zn) -> C,
673 R: ?Sized + ZnRing + CanHomFrom<I::Type>,
674 F: PreparedConvolutionOperation<RNSConvolutionZn<I, C, A, CreateC>, R>
675 {
676 type Output = F::Output;
677
678 fn execute(self) -> Self::Output
679 where C: PreparedConvolutionAlgorithm<ZnBase>
680 {
681 self.function.execute()
682 }
683 }
684 return <C as ConvolutionAlgorithm<ZnBase>>::specialize_prepared_convolution::<CallFunction<F, R, I, C, A, CreateC>>(CallFunction {
685 function: function,
686 ring: PhantomData,
687 convolution: PhantomData
688 }).map_err(|f| f.function);
689 }
690}
691
692impl<R, I, C, A, CreateC> PreparedConvolutionAlgorithm<R> for RNSConvolutionZn<I, C, A, CreateC>
693 where I: RingStore + Clone,
694 I::Type: IntegerRing,
695 C: PreparedConvolutionAlgorithm<ZnBase>,
696 A: Allocator + Clone,
697 CreateC: Fn(Zn) -> C,
698 R: ?Sized + ZnRing + CanHomFrom<I::Type>
699{
700 type PreparedConvolutionOperand = PreparedConvolutionOperandZn<R, C>;
701
702 fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
703 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
704 {
705 let log2_input_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap() - 1;
706 return PreparedConvolutionOperandZn(self.base.prepare_convolution_operand_impl(log2_input_size, val.clone_ring_els(ring).map_fn(|x| ring.smallest_lift(x)), ring.integer_ring()));
707 }
708
709 fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
710 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
711 {
712 assert!(dst.len() >= lhs.0.data.len() + rhs.len() - 1);
713 let rhs_log2_input_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap() - 1;
714 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
715 return self.base.compute_convolution_lhs_prepared_impl(
716 rhs_log2_input_size,
717 &lhs.0,
718 rhs.clone_ring_els(ring).map_fn(|x| ring.smallest_lift(x)),
719 |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
720 ring.integer_ring()
721 );
722 }
723
724 fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
725 where S: RingStore<Type = R> + Copy
726 {
727 assert!(dst.len() >= lhs.0.data.len() + rhs.0.data.len() - 1);
728 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
729 return self.base.compute_convolution_prepared_impl(
730 &lhs.0,
731 &rhs.0,
732 |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
733 ring.integer_ring()
734 );
735 }
736
737 fn compute_convolution_inner_product_lhs_prepared<'a, S, J, V>(&self, values: J, dst: &mut [R::Element], ring: S)
738 where S: RingStore<Type = R> + Copy,
739 J: Iterator<Item = (&'a Self::PreparedConvolutionOperand, V)>,
740 V: VectorView<R::Element>,
741 Self: 'a,
742 R: 'a,
743 Self::PreparedConvolutionOperand: 'a
744 {
745 let values = values.map(|(lhs, rhs)| (&lhs.0, rhs.into_clone_ring_els(ring).map_fn(|x| ring.smallest_lift(x)))).collect::<Vec<_>>();
746 let rhs_log2_input_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap() - 1;
747 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
748 return self.base.compute_convolution_inner_product_lhs_prepared_impl(
749 rhs_log2_input_size,
750 &values,
751 |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
752 ring.integer_ring()
753 );
754 }
755
756 fn compute_convolution_inner_product_prepared<'a, S, J>(&self, values: J, dst: &mut [R::Element], ring: S)
757 where S: RingStore<Type = R> + Copy,
758 J: Iterator<Item = (&'a Self::PreparedConvolutionOperand, &'a Self::PreparedConvolutionOperand)>,
759 Self::PreparedConvolutionOperand: 'a,
760 Self: 'a,
761 R: 'a,
762 {
763 let values = values.map(|(lhs, rhs)| (&lhs.0, &rhs.0)).collect::<Vec<_>>();
764 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
765 return self.base.compute_convolution_inner_product_prepared_impl(
766 &values,
767 |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
768 ring.integer_ring()
769 );
770 }
771}
772
773#[test]
774fn test_convolution_integer() {
775 let ring = StaticRing::<i128>::RING;
776 let convolution = RNSConvolution::new_with(7, usize::MAX, BigIntRing::RING, Global, |Fp| NTTConvolution::new_with(Fp, Global));
777
778 super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
779 super::generic_tests::test_prepared_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
780}
781
782#[test]
783fn test_convolution_zn() {
784 let ring = Zn::new((1 << 57) + 1);
785 let convolution = RNSConvolutionZn::from(RNSConvolution::new_with(7, usize::MAX, BigIntRing::RING, Global, |Fp| NTTConvolution::new_with(Fp, Global)));
786
787 super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
788 super::generic_tests::test_prepared_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
789}
790
791#[test]
792fn test_specialize_prepared() {
793 let ring = Zn::new((1 << 57) + 1);
794 let convolution = RNSConvolutionZn::from(RNSConvolution::new(7));
795
796 struct CheckIsPrepared(RNSConvolutionZn, Zn);
797 impl PreparedConvolutionOperation<RNSConvolutionZn, ZnBase> for CheckIsPrepared {
798 type Output = ();
799 fn execute(self) -> Self::Output {
800 super::generic_tests::test_prepared_convolution(&self.0, &self.1, self.1.int_hom().map(1 << 30));
801 }
802 }
803 assert!(RNSConvolutionZn::specialize_prepared_convolution(CheckIsPrepared(convolution, ring)).is_ok());
804}