1use std::alloc::{Allocator, Global};
2use std::cmp::{max, min};
3use std::marker::PhantomData;
4
5use super::ConvolutionAlgorithm;
6use super::ntt::NTTConvolution;
7use crate::algorithms::miller_rabin::is_prime;
8use crate::divisibility::*;
9use crate::homomorphism::*;
10use crate::integer::*;
11use crate::lazy::LazyVec;
12use crate::primitive_int::StaticRing;
13use crate::ring::*;
14use crate::rings::zn::zn_64::{Zn, ZnBase, ZnFastmul, ZnFastmulBase};
15use crate::rings::zn::*;
16use crate::seq::*;
17
18#[stability::unstable(feature = "enable")]
26pub struct RNSConvolution<
27 I = BigIntRing,
28 C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>,
29 A = Global,
30 CreateC = CreateNTTConvolution,
31> where
32 I: RingStore + Clone,
33 I::Type: IntegerRing,
34 C: ConvolutionAlgorithm<ZnBase>,
35 A: Allocator + Clone,
36 CreateC: Fn(Zn) -> C,
37{
38 integer_ring: I,
39 rns_rings: LazyVec<zn_rns::Zn<Zn, I, A>>,
40 convolutions: LazyVec<C>,
41 create_convolution: CreateC,
42 required_root_of_unity_log2: usize,
43 allocator: A,
44}
45
46#[stability::unstable(feature = "enable")]
48#[repr(transparent)]
49pub struct RNSConvolutionZn<
50 I = BigIntRing,
51 C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>,
52 A = Global,
53 CreateC = CreateNTTConvolution,
54> where
55 I: RingStore + Clone,
56 I::Type: IntegerRing,
57 C: ConvolutionAlgorithm<ZnBase>,
58 A: Allocator + Clone,
59 CreateC: Fn(Zn) -> C,
60{
61 base: RNSConvolution<I, C, A, CreateC>,
62}
63
64#[stability::unstable(feature = "enable")]
66pub struct PreparedConvolutionOperand<R, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>>
67where
68 R: ?Sized + RingBase,
69 C: ConvolutionAlgorithm<ZnBase>,
70{
71 prepared: LazyVec<C::PreparedConvolutionOperand>,
72 log2_data_size: usize,
73 ring: PhantomData<R>,
74 len_hint: Option<usize>,
75}
76
77#[stability::unstable(feature = "enable")]
79pub struct CreateNTTConvolution<A = Global>
80where
81 A: Allocator + Clone,
82{
83 allocator: A,
84}
85
86impl<I, C, A, CreateC> From<RNSConvolutionZn<I, C, A, CreateC>> for RNSConvolution<I, C, A, CreateC>
87where
88 I: RingStore + Clone,
89 I::Type: IntegerRing,
90 C: ConvolutionAlgorithm<ZnBase>,
91 A: Allocator + Clone,
92 CreateC: Fn(Zn) -> C,
93{
94 fn from(value: RNSConvolutionZn<I, C, A, CreateC>) -> Self { value.base }
95}
96
97impl<'a, I, C, A, CreateC> From<&'a RNSConvolutionZn<I, C, A, CreateC>> for &'a RNSConvolution<I, C, A, CreateC>
98where
99 I: RingStore + Clone,
100 I::Type: IntegerRing,
101 C: ConvolutionAlgorithm<ZnBase>,
102 A: Allocator + Clone,
103 CreateC: Fn(Zn) -> C,
104{
105 fn from(value: &'a RNSConvolutionZn<I, C, A, CreateC>) -> Self { &value.base }
106}
107
108impl<I, C, A, CreateC> From<RNSConvolution<I, C, A, CreateC>> for RNSConvolutionZn<I, C, A, CreateC>
109where
110 I: RingStore + Clone,
111 I::Type: IntegerRing,
112 C: ConvolutionAlgorithm<ZnBase>,
113 A: Allocator + Clone,
114 CreateC: Fn(Zn) -> C,
115{
116 fn from(value: RNSConvolution<I, C, A, CreateC>) -> Self { RNSConvolutionZn { base: value } }
117}
118
119impl<'a, I, C, A, CreateC> From<&'a RNSConvolution<I, C, A, CreateC>> for &'a RNSConvolutionZn<I, C, A, CreateC>
120where
121 I: RingStore + Clone,
122 I::Type: IntegerRing,
123 C: ConvolutionAlgorithm<ZnBase>,
124 A: Allocator + Clone,
125 CreateC: Fn(Zn) -> C,
126{
127 fn from(value: &'a RNSConvolution<I, C, A, CreateC>) -> Self { unsafe { std::mem::transmute(value) } }
128}
129
130impl CreateNTTConvolution<Global> {
131 #[stability::unstable(feature = "enable")]
133 pub const fn new() -> Self { Self { allocator: Global } }
134}
135
136impl<A> FnOnce<(Zn,)> for CreateNTTConvolution<A>
137where
138 A: Allocator + Clone,
139{
140 type Output = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>, A>;
141
142 extern "rust-call" fn call_once(self, args: (Zn,)) -> Self::Output { self.call(args) }
143}
144
145impl<A> FnMut<(Zn,)> for CreateNTTConvolution<A>
146where
147 A: Allocator + Clone,
148{
149 extern "rust-call" fn call_mut(&mut self, args: (Zn,)) -> Self::Output { self.call(args) }
150}
151
152impl<A> Fn<(Zn,)> for CreateNTTConvolution<A>
153where
154 A: Allocator + Clone,
155{
156 extern "rust-call" fn call(&self, args: (Zn,)) -> Self::Output {
157 let ring = args.0;
158 let ring_fastmul = ZnFastmul::new(ring).unwrap();
159 let hom = ring.into_can_hom(ring_fastmul).ok().unwrap();
160 NTTConvolution::new_with_hom(hom, self.allocator.clone())
161 }
162}
163
164impl RNSConvolution {
165 #[stability::unstable(feature = "enable")]
168 pub fn new(max_log2_n: usize) -> Self {
169 Self::new_with_convolution(
170 max_log2_n,
171 usize::MAX,
172 BigIntRing::RING,
173 Global,
174 CreateNTTConvolution { allocator: Global },
175 )
176 }
177}
178
179impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
180where
181 I: RingStore + Clone,
182 I::Type: IntegerRing,
183 C: ConvolutionAlgorithm<ZnBase>,
184 A: Allocator + Clone,
185 CreateC: Fn(Zn) -> C,
186{
187 #[stability::unstable(feature = "enable")]
201 pub fn new_with_convolution(
202 required_root_of_unity_log2: usize,
203 mut max_prime_size_log2: usize,
204 integer_ring: I,
205 allocator: A,
206 create_convolution: CreateC,
207 ) -> Self {
208 max_prime_size_log2 = min(max_prime_size_log2, 57);
209 let result = Self {
210 integer_ring,
211 create_convolution,
212 convolutions: LazyVec::new(),
213 rns_rings: LazyVec::new(),
214 required_root_of_unity_log2,
215 allocator,
216 };
217 let initial_ring = zn_rns::Zn::new_with_alloc(
218 vec![Zn::new(
219 Self::sample_next_prime(required_root_of_unity_log2, (1 << max_prime_size_log2) + 1).unwrap() as u64,
220 )],
221 result.integer_ring.clone(),
222 result.allocator.clone(),
223 );
224 _ = result.rns_rings.get_or_init(0, || initial_ring);
225 return result;
226 }
227
228 fn sample_next_prime(required_root_of_unity_log2: usize, current: i64) -> Option<i64> {
229 let mut k = StaticRing::<i64>::RING
230 .checked_div(&(current - 1), &(1 << required_root_of_unity_log2))
231 .unwrap();
232 while k > 0 {
233 k -= 1;
234 let candidate = (k << required_root_of_unity_log2) + 1;
235 if is_prime(StaticRing::<i64>::RING, &candidate, 10) {
236 return Some(candidate);
237 }
238 }
239 return None;
240 }
241
242 fn get_rns_ring(&self, moduli_count: usize) -> &zn_rns::Zn<Zn, I, A> {
243 self.rns_rings.get_or_init_incremental(moduli_count - 1, |_, prev| {
244 zn_rns::Zn::new_with_alloc(
245 prev.as_iter()
246 .cloned()
247 .chain([Zn::new(
248 Self::sample_next_prime(self.required_root_of_unity_log2, *prev.at(prev.len() - 1).modulus())
249 .unwrap() as u64,
250 )])
251 .collect(),
252 self.integer_ring.clone(),
253 self.allocator.clone(),
254 )
255 })
256 }
257
258 fn get_rns_factor(&self, i: usize) -> &Zn {
259 let rns_ring = self.get_rns_ring(i + 1);
260 return rns_ring.at(rns_ring.len() - 1);
261 }
262
263 fn get_convolution(&self, i: usize) -> &C {
264 self.convolutions
265 .get_or_init(i, || (self.create_convolution)(*self.get_rns_factor(i)))
266 }
267
268 fn compute_required_width(
270 &self,
271 input_size_log2: usize,
272 lhs_len: usize,
273 rhs_len: usize,
274 inner_prod_len: usize,
275 ) -> usize {
276 let log2_output_size = input_size_log2 * 2
277 + StaticRing::<i64>::RING
278 .abs_log2_ceil(&min(lhs_len, rhs_len).try_into().unwrap())
279 .unwrap_or(0)
280 + StaticRing::<i64>::RING
281 .abs_log2_ceil(&inner_prod_len.try_into().unwrap())
282 .unwrap_or(0)
283 + 1;
284 let mut width = log2_output_size.div_ceil(57);
285 while log2_output_size
286 > self
287 .integer_ring
288 .abs_log2_floor(self.get_rns_ring(width).modulus())
289 .unwrap()
290 {
291 width += 1;
292 }
293 return width;
294 }
295
296 fn get_log2_input_size<R, V1, V2, ToInt>(
297 &self,
298 lhs: V1,
299 lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
300 rhs: V2,
301 rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
302 _ring: &R,
303 mut to_int: ToInt,
304 ring_log2_el_size: Option<usize>,
305 ) -> usize
306 where
307 R: ?Sized + RingBase,
308 V1: VectorView<R::Element>,
309 V2: VectorView<R::Element>,
310 ToInt: FnMut(&R::Element) -> El<I>,
311 {
312 if let Some(log2_data_size) = ring_log2_el_size {
313 assert!(lhs_prep.is_none() || lhs_prep.unwrap().log2_data_size == log2_data_size);
314 assert!(rhs_prep.is_none() || rhs_prep.unwrap().log2_data_size == log2_data_size);
315 log2_data_size
316 } else {
317 max(
318 if let Some(lhs_prep) = lhs_prep {
319 lhs_prep.log2_data_size
320 } else {
321 lhs.as_iter()
322 .map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0))
323 .max()
324 .unwrap()
325 },
326 if let Some(rhs_prep) = rhs_prep {
327 rhs_prep.log2_data_size
328 } else {
329 rhs.as_iter()
330 .map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0))
331 .max()
332 .unwrap()
333 },
334 )
335 }
336 }
337
338 fn get_prepared_operand<'a, R, V>(
339 &self,
340 data: V,
341 data_prep: &'a PreparedConvolutionOperand<R, C>,
342 rns_index: usize,
343 _ring: &R,
344 ) -> &'a C::PreparedConvolutionOperand
345 where
346 R: ?Sized + RingBase,
347 V: VectorView<El<Zn>> + Copy,
348 {
349 data_prep.prepared.get_or_init(rns_index, || {
350 self.get_convolution(rns_index).prepare_convolution_operand(
351 data,
352 data_prep.len_hint,
353 self.get_rns_factor(rns_index),
354 )
355 })
356 }
357
358 fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
359 &self,
360 lhs: V1,
361 lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
362 rhs: V2,
363 rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
364 dst: &mut [R::Element],
365 ring: &R,
366 mut to_int: ToInt,
367 mut from_int: FromInt,
368 ring_log2_el_size: Option<usize>,
369 ) where
370 R: ?Sized + RingBase,
371 V1: VectorView<R::Element>,
372 V2: VectorView<R::Element>,
373 ToInt: FnMut(&R::Element) -> El<I>,
374 FromInt: FnMut(El<I>) -> R::Element,
375 {
376 if lhs.len() == 0 || rhs.len() == 0 {
377 return;
378 }
379
380 let input_size_log2 =
381 self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size);
382 let width = self.compute_required_width(input_size_log2, lhs.len(), rhs.len(), 1);
383 let len = lhs.len() + rhs.len() - 1;
384
385 let mut res_data = Vec::with_capacity_in(len * width, self.allocator.clone());
386 for i in 0..width {
387 res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
388 }
389 let mut lhs_tmp = Vec::with_capacity_in(lhs.len(), self.allocator.clone());
390 let mut rhs_tmp = Vec::with_capacity_in(rhs.len(), self.allocator.clone());
391 for i in 0..width {
392 let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
393 lhs_tmp.clear();
394 lhs_tmp.extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
395 rhs_tmp.clear();
396 rhs_tmp.extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
397 self.get_convolution(i).compute_convolution_prepared(
398 &lhs_tmp,
399 lhs_prep.map(|lhs_prep| self.get_prepared_operand(&lhs_tmp, lhs_prep, i, ring)),
400 &rhs_tmp,
401 rhs_prep.map(|rhs_prep| self.get_prepared_operand(&rhs_tmp, rhs_prep, i, ring)),
402 &mut res_data[(i * len)..((i + 1) * len)],
403 self.get_rns_factor(i),
404 );
405 }
406 for j in 0..len {
407 let add = self.get_rns_ring(width).smallest_lift(
408 self.get_rns_ring(width)
409 .from_congruence((0..width).map(|i| res_data[i * len + j])),
410 );
411 ring.add_assign(&mut dst[j], from_int(add));
412 }
413 }
414
415 fn compute_convolution_sum_impl<'a, R, J, V1, V2, ToInt, FromInt>(
416 &self,
417 values: J,
418 dst: &mut [R::Element],
419 ring: &R,
420 mut to_int: ToInt,
421 mut from_int: FromInt,
422 ring_log2_el_size: Option<usize>,
423 ) where
424 R: ?Sized + RingBase,
425 J: ExactSizeIterator<
426 Item = (
427 V1,
428 Option<&'a PreparedConvolutionOperand<R, C>>,
429 V2,
430 Option<&'a PreparedConvolutionOperand<R, C>>,
431 ),
432 >,
433 V1: VectorView<R::Element>,
434 V2: VectorView<R::Element>,
435 ToInt: FnMut(&R::Element) -> El<I>,
436 FromInt: FnMut(El<I>) -> R::Element,
437 Self: 'a,
438 R: 'a,
439 {
440 let out_len = dst.len();
441 let inner_product_length = dst.len();
442
443 let mut current_width = 0;
444 let mut current_input_size_log2 = 0;
445 let mut lhs_max_len = 0;
446 let mut rhs_max_len = 0;
447 let mut res_data = Vec::new_in(self.allocator.clone());
448 let mut lhs_tmp = Vec::new_in(self.allocator.clone());
449 let mut rhs_tmp = Vec::new_in(self.allocator.clone());
450
451 let mut merge_current =
458 |current_width: usize,
459 lhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>,
460 rhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>| {
461 if current_width == 0 {
462 lhs_tmp.clear();
463 rhs_tmp.clear();
464 return;
465 }
466 res_data.clear();
467 for i in 0..current_width {
468 res_data.extend((0..out_len).map(|_| self.get_rns_factor(i).zero()));
469 self.get_convolution(i).compute_convolution_sum(
470 lhs_tmp
471 .iter()
472 .zip(rhs_tmp.iter())
473 .map(|((lhs, lhs_prep), (rhs, rhs_prep))| {
474 let lhs_data =
475 &lhs[(i * lhs.len() / current_width)..((i + 1) * lhs.len() / current_width)];
476 let rhs_data =
477 &rhs[(i * rhs.len() / current_width)..((i + 1) * rhs.len() / current_width)];
478 (
479 lhs_data,
480 lhs_prep.map(|lhs_prep| self.get_prepared_operand(lhs_data, lhs_prep, i, ring)),
481 rhs_data,
482 rhs_prep.map(|rhs_prep| self.get_prepared_operand(rhs_data, rhs_prep, i, ring)),
483 )
484 }),
485 &mut res_data[(i * out_len)..((i + 1) * out_len)],
486 self.get_rns_factor(i),
487 );
488 }
489 lhs_tmp.clear();
490 rhs_tmp.clear();
491 for j in 0..out_len {
492 let add = self.get_rns_ring(current_width).smallest_lift(
493 self.get_rns_ring(current_width)
494 .from_congruence((0..current_width).map(|i| res_data[i * out_len + j])),
495 );
496 ring.add_assign(&mut dst[j], from_int(add));
497 }
498 };
499
500 for (lhs, lhs_prep, rhs, rhs_prep) in values {
501 if lhs.len() == 0 || rhs.len() == 0 {
502 continue;
503 }
504 assert!(out_len >= lhs.len() + rhs.len() - 1);
505 current_input_size_log2 = max(
506 current_input_size_log2,
507 self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size),
508 );
509 lhs_max_len = max(lhs_max_len, lhs.len());
510 rhs_max_len = max(rhs_max_len, rhs.len());
511 let required_width =
512 self.compute_required_width(current_input_size_log2, lhs_max_len, rhs_max_len, inner_product_length);
513
514 if required_width > current_width {
515 merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
516 current_width = required_width;
517 }
518
519 lhs_tmp.push((
520 Vec::with_capacity_in(lhs.len() * current_width, self.allocator.clone()),
521 lhs_prep,
522 ));
523 rhs_tmp.push((
524 Vec::with_capacity_in(rhs.len() * current_width, self.allocator.clone()),
525 rhs_prep,
526 ));
527 for i in 0..current_width {
528 let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
529 lhs_tmp
530 .last_mut()
531 .unwrap()
532 .0
533 .extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
534 rhs_tmp
535 .last_mut()
536 .unwrap()
537 .0
538 .extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
539 }
540 }
541 merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
542 }
543
544 fn prepare_convolution_impl<R, V, ToInt>(
545 &self,
546 data: V,
547 _ring: &R,
548 length_hint: Option<usize>,
549 mut to_int: ToInt,
550 ring_log2_el_size: Option<usize>,
551 ) -> PreparedConvolutionOperand<R, C>
552 where
553 R: ?Sized + RingBase,
554 V: VectorView<R::Element>,
555 ToInt: FnMut(&R::Element) -> El<I>,
556 {
557 let input_size_log2 = if let Some(log2_data_size) = ring_log2_el_size {
558 log2_data_size
559 } else {
560 data.as_iter()
561 .map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0))
562 .max()
563 .unwrap()
564 };
565 return PreparedConvolutionOperand {
566 ring: PhantomData,
567 len_hint: length_hint,
568 prepared: LazyVec::new(),
569 log2_data_size: input_size_log2,
570 };
571 }
572}
573
574impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolution<I, C, A, CreateC>
575where
576 I: RingStore + Clone,
577 I::Type: IntegerRing,
578 C: ConvolutionAlgorithm<ZnBase>,
579 A: Allocator + Clone,
580 CreateC: Fn(Zn) -> C,
581 R: ?Sized + IntegerRing,
582{
583 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
584
585 fn compute_convolution<
586 S: RingStore<Type = R> + Copy,
587 V1: VectorView<<R as RingBase>::Element>,
588 V2: VectorView<<R as RingBase>::Element>,
589 >(
590 &self,
591 lhs: V1,
592 rhs: V2,
593 dst: &mut [<R as RingBase>::Element],
594 ring: S,
595 ) {
596 self.compute_convolution_impl(
597 lhs,
598 None,
599 rhs,
600 None,
601 dst,
602 ring.get_ring(),
603 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
604 |x| int_cast(x, ring, &self.integer_ring),
605 None,
606 )
607 }
608
609 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool { true }
610
611 fn prepare_convolution_operand<S, V>(
612 &self,
613 val: V,
614 len_hint: Option<usize>,
615 ring: S,
616 ) -> Self::PreparedConvolutionOperand
617 where
618 S: RingStore<Type = R> + Copy,
619 V: VectorView<R::Element>,
620 {
621 self.prepare_convolution_impl(
622 val,
623 ring.get_ring(),
624 len_hint,
625 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
626 None,
627 )
628 }
629
630 fn compute_convolution_prepared<S, V1, V2>(
631 &self,
632 lhs: V1,
633 lhs_prep: Option<&Self::PreparedConvolutionOperand>,
634 rhs: V2,
635 rhs_prep: Option<&Self::PreparedConvolutionOperand>,
636 dst: &mut [R::Element],
637 ring: S,
638 ) where
639 S: RingStore<Type = R> + Copy,
640 V1: VectorView<El<S>>,
641 V2: VectorView<El<S>>,
642 {
643 self.compute_convolution_impl(
644 lhs,
645 lhs_prep,
646 rhs,
647 rhs_prep,
648 dst,
649 ring.get_ring(),
650 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
651 |x| int_cast(x, ring, &self.integer_ring),
652 None,
653 )
654 }
655
656 fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S)
657 where
658 S: RingStore<Type = R> + Copy,
659 J: ExactSizeIterator<
660 Item = (
661 V1,
662 Option<&'a Self::PreparedConvolutionOperand>,
663 V2,
664 Option<&'a Self::PreparedConvolutionOperand>,
665 ),
666 >,
667 V1: VectorView<R::Element>,
668 V2: VectorView<R::Element>,
669 Self: 'a,
670 R: 'a,
671 {
672 self.compute_convolution_sum_impl(
673 values,
674 dst,
675 ring.get_ring(),
676 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
677 |x| int_cast(x, ring, &self.integer_ring),
678 None,
679 )
680 }
681}
682
683impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolutionZn<I, C, A, CreateC>
684where
685 I: RingStore + Clone,
686 I::Type: IntegerRing,
687 C: ConvolutionAlgorithm<ZnBase>,
688 A: Allocator + Clone,
689 CreateC: Fn(Zn) -> C,
690 R: ?Sized + ZnRing + CanHomFrom<I::Type>,
691{
692 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
693
694 fn compute_convolution<
695 S: RingStore<Type = R> + Copy,
696 V1: VectorView<<R as RingBase>::Element>,
697 V2: VectorView<<R as RingBase>::Element>,
698 >(
699 &self,
700 lhs: V1,
701 rhs: V2,
702 dst: &mut [<R as RingBase>::Element],
703 ring: S,
704 ) {
705 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
706 self.base.compute_convolution_impl(
707 lhs,
708 None,
709 rhs,
710 None,
711 dst,
712 ring.get_ring(),
713 |x| {
714 int_cast(
715 ring.smallest_lift(ring.clone_el(x)),
716 &self.base.integer_ring,
717 ring.integer_ring(),
718 )
719 },
720 |x| hom.map(x),
721 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
722 )
723 }
724
725 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool { true }
726
727 fn prepare_convolution_operand<S, V>(
728 &self,
729 val: V,
730 len_hint: Option<usize>,
731 ring: S,
732 ) -> Self::PreparedConvolutionOperand
733 where
734 S: RingStore<Type = R> + Copy,
735 V: VectorView<R::Element>,
736 {
737 self.base.prepare_convolution_impl(
738 val,
739 ring.get_ring(),
740 len_hint,
741 |x| {
742 int_cast(
743 ring.smallest_lift(ring.clone_el(x)),
744 &self.base.integer_ring,
745 ring.integer_ring(),
746 )
747 },
748 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
749 )
750 }
751
752 fn compute_convolution_prepared<S, V1, V2>(
753 &self,
754 lhs: V1,
755 lhs_prep: Option<&Self::PreparedConvolutionOperand>,
756 rhs: V2,
757 rhs_prep: Option<&Self::PreparedConvolutionOperand>,
758 dst: &mut [R::Element],
759 ring: S,
760 ) where
761 S: RingStore<Type = R> + Copy,
762 V1: VectorView<El<S>>,
763 V2: VectorView<El<S>>,
764 {
765 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
766 self.base.compute_convolution_impl(
767 lhs,
768 lhs_prep,
769 rhs,
770 rhs_prep,
771 dst,
772 ring.get_ring(),
773 |x| {
774 int_cast(
775 ring.smallest_lift(ring.clone_el(x)),
776 &self.base.integer_ring,
777 ring.integer_ring(),
778 )
779 },
780 |x| hom.map(x),
781 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
782 )
783 }
784
785 fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S)
786 where
787 S: RingStore<Type = R> + Copy,
788 J: ExactSizeIterator<
789 Item = (
790 V1,
791 Option<&'a Self::PreparedConvolutionOperand>,
792 V2,
793 Option<&'a Self::PreparedConvolutionOperand>,
794 ),
795 >,
796 V1: VectorView<R::Element>,
797 V2: VectorView<R::Element>,
798 Self: 'a,
799 R: 'a,
800 {
801 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
802 self.base.compute_convolution_sum_impl(
803 values,
804 dst,
805 ring.get_ring(),
806 |x| {
807 int_cast(
808 ring.smallest_lift(ring.clone_el(x)),
809 &self.base.integer_ring,
810 ring.integer_ring(),
811 )
812 },
813 |x| hom.map(x),
814 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
815 )
816 }
817}
818
819#[cfg(test)]
820use super::STANDARD_CONVOLUTION;
821
822#[test]
823fn test_convolution_integer() {
824 let ring = StaticRing::<i128>::RING;
825 let convolution =
826 RNSConvolution::new_with_convolution(7, usize::MAX, BigIntRing::RING, Global, NTTConvolution::new);
827
828 super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
829}
830
831#[test]
832fn test_convolution_zn() {
833 let ring = Zn::new((1 << 57) + 1);
834 let convolution = RNSConvolutionZn::from(RNSConvolution::new_with_convolution(
835 7,
836 usize::MAX,
837 BigIntRing::RING,
838 Global,
839 NTTConvolution::new,
840 ));
841
842 super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
843}
844
845#[test]
846fn test_convolution_sum() {
847 let ring = StaticRing::<i128>::RING;
848 let convolution = RNSConvolution::new_with_convolution(7, 20, BigIntRing::RING, Global, NTTConvolution::new);
849
850 let data = (0..40usize).map(|i| {
851 (
852 (0..(5 + i % 5)).map(|x| (1 << i) * (x as i128 - 2)).collect::<Vec<_>>(),
853 (0..(13 - i % 7))
854 .map(|x| (1 << i) * (x as i128 + 1))
855 .collect::<Vec<_>>(),
856 )
857 });
858 let mut expected = (0..22).map(|_| 0).collect::<Vec<_>>();
859 STANDARD_CONVOLUTION.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut expected, ring);
860
861 let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
862 convolution.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut actual, ring);
863 assert_eq!(&expected[..21], actual);
864
865 let data_prep = data
866 .clone()
867 .map(|(l, r)| {
868 let l_prep = convolution.prepare_convolution_operand(&l, Some(21), ring);
869 let r_prep = convolution.prepare_convolution_operand(&r, Some(21), ring);
870 (l, l_prep, r, r_prep)
871 })
872 .collect::<Vec<_>>();
873 let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
874 convolution.compute_convolution_sum(
875 data_prep
876 .iter()
877 .map(|(l, l_prep, r, r_prep)| (l, Some(l_prep), r, Some(r_prep))),
878 &mut actual,
879 ring,
880 );
881 assert_eq!(&expected[..21], actual);
882
883 let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
884 convolution.compute_convolution_sum(
885 data_prep
886 .iter()
887 .enumerate()
888 .map(|(i, (l, l_prep, r, r_prep))| match i % 4 {
889 0 => (l, Some(l_prep), r, Some(r_prep)),
890 1 => (l, None, r, Some(r_prep)),
891 2 => (l, Some(l_prep), r, None),
892 3 => (l, None, r, None),
893 _ => unreachable!(),
894 }),
895 &mut actual,
896 ring,
897 );
898 assert_eq!(&expected[..21], actual);
899}