1use std::alloc::{Allocator, Global};
2use std::cmp::{min, max};
3use std::marker::PhantomData;
4
5use crate::algorithms::miller_rabin::is_prime;
6use crate::homomorphism::*;
7use crate::integer::*;
8use crate::lazy::LazyVec;
9use crate::primitive_int::StaticRing;
10use crate::ring::*;
11use crate::rings::zn::zn_64::{Zn, ZnBase, ZnFastmul, ZnFastmulBase};
12use crate::rings::zn::*;
13use crate::divisibility::*;
14use crate::seq::*;
15
16use super::ntt::NTTConvolution;
17use super::ConvolutionAlgorithm;
18
19#[stability::unstable(feature = "enable")]
29pub struct RNSConvolution<I = BigIntRing, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>, A = Global, CreateC = CreateNTTConvolution>
30 where I: RingStore + Clone,
31 I::Type: IntegerRing,
32 C: ConvolutionAlgorithm<ZnBase>,
33 A: Allocator + Clone,
34 CreateC: Fn(Zn) -> C
35{
36 integer_ring: I,
37 rns_rings: LazyVec<zn_rns::Zn<Zn, I, A>>,
38 convolutions: LazyVec<C>,
39 create_convolution: CreateC,
40 required_root_of_unity_log2: usize,
41 allocator: A
42}
43
44#[stability::unstable(feature = "enable")]
48#[repr(transparent)]
49pub struct RNSConvolutionZn<I = BigIntRing, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>, A = Global, CreateC = CreateNTTConvolution>
50 where I: RingStore + Clone,
51 I::Type: IntegerRing,
52 C: ConvolutionAlgorithm<ZnBase>,
53 A: Allocator + Clone,
54 CreateC: Fn(Zn) -> C
55{
56 base: RNSConvolution<I, C, A, CreateC>
57}
58
59#[stability::unstable(feature = "enable")]
63pub struct PreparedConvolutionOperand<R, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>>
64 where R: ?Sized + RingBase,
65 C: ConvolutionAlgorithm<ZnBase>
66{
67 prepared: LazyVec<C::PreparedConvolutionOperand>,
68 log2_data_size: usize,
69 ring: PhantomData<R>,
70 len_hint: Option<usize>
71}
72
73#[stability::unstable(feature = "enable")]
77pub struct CreateNTTConvolution<A = Global>
78 where A: Allocator + Clone
79{
80 allocator: A
81}
82
83impl<I, C, A, CreateC> From<RNSConvolutionZn<I, C, A, CreateC>> for RNSConvolution<I, C, A, CreateC>
84 where I: RingStore + Clone,
85 I::Type: IntegerRing,
86 C: ConvolutionAlgorithm<ZnBase>,
87 A: Allocator + Clone,
88 CreateC: Fn(Zn) -> C
89{
90 fn from(value: RNSConvolutionZn<I, C, A, CreateC>) -> Self {
91 value.base
92 }
93}
94
95impl<'a, I, C, A, CreateC> From<&'a RNSConvolutionZn<I, C, A, CreateC>> for &'a RNSConvolution<I, C, A, CreateC>
96 where I: RingStore + Clone,
97 I::Type: IntegerRing,
98 C: ConvolutionAlgorithm<ZnBase>,
99 A: Allocator + Clone,
100 CreateC: Fn(Zn) -> C
101{
102 fn from(value: &'a RNSConvolutionZn<I, C, A, CreateC>) -> Self {
103 &value.base
104 }
105}
106
107impl<I, C, A, CreateC> From<RNSConvolution<I, C, A, CreateC>> for RNSConvolutionZn<I, C, A, CreateC>
108 where I: RingStore + Clone,
109 I::Type: IntegerRing,
110 C: ConvolutionAlgorithm<ZnBase>,
111 A: Allocator + Clone,
112 CreateC: Fn(Zn) -> C
113{
114 fn from(value: RNSConvolution<I, C, A, CreateC>) -> Self {
115 RNSConvolutionZn { base: value }
116 }
117}
118
119impl<'a, I, C, A, CreateC> From<&'a RNSConvolution<I, C, A, CreateC>> for &'a RNSConvolutionZn<I, C, A, CreateC>
120 where I: RingStore + Clone,
121 I::Type: IntegerRing,
122 C: ConvolutionAlgorithm<ZnBase>,
123 A: Allocator + Clone,
124 CreateC: Fn(Zn) -> C
125{
126 fn from(value: &'a RNSConvolution<I, C, A, CreateC>) -> Self {
127 unsafe { std::mem::transmute(value) }
128 }
129}
130
131impl CreateNTTConvolution<Global> {
132
133 #[stability::unstable(feature = "enable")]
137 pub const fn new() -> Self {
138 Self { allocator: Global }
139 }
140}
141
142impl<A> FnOnce<(Zn,)> for CreateNTTConvolution<A>
143 where A: Allocator + Clone
144{
145 type Output = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>, A>;
146
147 extern "rust-call" fn call_once(self, args: (Zn,)) -> Self::Output {
148 self.call(args)
149 }
150}
151
152impl<A> FnMut<(Zn,)> for CreateNTTConvolution<A>
153 where A: Allocator + Clone
154{
155 extern "rust-call" fn call_mut(&mut self, args: (Zn,)) -> Self::Output {
156 self.call(args)
157 }
158}
159
160impl<A> Fn<(Zn,)> for CreateNTTConvolution<A>
161 where A: Allocator + Clone
162{
163 extern "rust-call" fn call(&self, args: (Zn,)) -> Self::Output {
164 let ring = args.0;
165 let ring_fastmul = ZnFastmul::new(ring).unwrap();
166 let hom = ring.into_can_hom(ring_fastmul).ok().unwrap();
167 NTTConvolution::new_with_hom(hom, self.allocator.clone())
168 }
169}
170
171impl RNSConvolution {
172
173 #[stability::unstable(feature = "enable")]
178 pub fn new(max_log2_n: usize) -> Self {
179 Self::new_with_convolution(max_log2_n, usize::MAX, BigIntRing::RING, Global, CreateNTTConvolution { allocator: Global })
180 }
181}
182
183impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
184 where I: RingStore + Clone,
185 I::Type: IntegerRing,
186 C: ConvolutionAlgorithm<ZnBase>,
187 A: Allocator + Clone,
188 CreateC: Fn(Zn) -> C
189{
190 #[stability::unstable(feature = "enable")]
205 pub fn new_with_convolution(required_root_of_unity_log2: usize, mut max_prime_size_log2: usize, integer_ring: I, allocator: A, create_convolution: CreateC) -> Self {
206 max_prime_size_log2 = min(max_prime_size_log2, 57);
207 let result = Self {
208 integer_ring: integer_ring,
209 create_convolution: create_convolution,
210 convolutions: LazyVec::new(),
211 rns_rings: LazyVec::new(),
212 required_root_of_unity_log2: required_root_of_unity_log2,
213 allocator: allocator
214 };
215 let initial_ring = zn_rns::Zn::new_with_alloc(
216 vec![Zn::new(Self::sample_next_prime(required_root_of_unity_log2, (1 << max_prime_size_log2) + 1).unwrap() as u64)],
217 result.integer_ring.clone(),
218 result.allocator.clone()
219 );
220 _ = result.rns_rings.get_or_init(0, || initial_ring);
221 return result;
222 }
223
224 fn sample_next_prime(required_root_of_unity_log2: usize, current: i64) -> Option<i64> {
225 let mut k = StaticRing::<i64>::RING.checked_div(&(current - 1), &(1 << required_root_of_unity_log2)).unwrap();
226 while k > 0 {
227 k -= 1;
228 let candidate = (k << required_root_of_unity_log2) + 1;
229 if is_prime(StaticRing::<i64>::RING, &candidate, 10) {
230 return Some(candidate);
231 }
232 }
233 return None;
234 }
235
236 fn get_rns_ring(&self, moduli_count: usize) -> &zn_rns::Zn<Zn, I, A> {
237 self.rns_rings.get_or_init_incremental(moduli_count - 1, |_, prev| zn_rns::Zn::new_with_alloc(
238 prev.as_iter().cloned().chain([Zn::new(Self::sample_next_prime(self.required_root_of_unity_log2, *prev.at(prev.len() - 1).modulus()).unwrap() as u64)]).collect(),
239 self.integer_ring.clone(),
240 self.allocator.clone()
241 ))
242 }
243
244 fn get_rns_factor(&self, i: usize) -> &Zn {
245 let rns_ring = self.get_rns_ring(i + 1);
246 return rns_ring.at(rns_ring.len() - 1);
247 }
248
249 fn get_convolution(&self, i: usize) -> &C {
250 self.convolutions.get_or_init(i, || (self.create_convolution)(*self.get_rns_factor(i)))
251 }
252
253 fn compute_required_width(&self, input_size_log2: usize, lhs_len: usize, rhs_len: usize, inner_prod_len: usize) -> usize {
257 let log2_output_size = input_size_log2 * 2 +
258 StaticRing::<i64>::RING.abs_log2_ceil(&min(lhs_len, rhs_len).try_into().unwrap()).unwrap_or(0) +
259 StaticRing::<i64>::RING.abs_log2_ceil(&inner_prod_len.try_into().unwrap()).unwrap_or(0) +
260 1;
261 let mut width = log2_output_size.div_ceil(57);
262 while log2_output_size > self.integer_ring.abs_log2_floor(self.get_rns_ring(width).modulus()).unwrap() {
263 width += 1;
264 }
265 return width;
266 }
267
268 fn get_log2_input_size<R, V1, V2, ToInt>(
269 &self,
270 lhs: V1,
271 lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
272 rhs: V2,
273 rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
274 _ring: &R,
275 mut to_int: ToInt,
276 ring_log2_el_size: Option<usize>
277 ) -> usize
278 where R: ?Sized + RingBase,
279 V1: VectorView<R::Element>,
280 V2: VectorView<R::Element>,
281 ToInt: FnMut(&R::Element) -> El<I>,
282 {
283 if let Some(log2_data_size) = ring_log2_el_size {
284 assert!(lhs_prep.is_none() || lhs_prep.unwrap().log2_data_size == log2_data_size);
285 assert!(rhs_prep.is_none() || rhs_prep.unwrap().log2_data_size == log2_data_size);
286 log2_data_size
287 } else {
288 max(
289 if let Some(lhs_prep) = lhs_prep {
290 lhs_prep.log2_data_size
291 } else {
292 lhs.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
293 },
294 if let Some(rhs_prep) = rhs_prep {
295 rhs_prep.log2_data_size
296 } else {
297 rhs.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
298 }
299 )
300 }
301 }
302
303 fn get_prepared_operand<'a, R, V>(
304 &self,
305 data: V,
306 data_prep: &'a PreparedConvolutionOperand<R, C>,
307 rns_index: usize,
308 _ring: &R
309 ) -> &'a C::PreparedConvolutionOperand
310 where R: ?Sized + RingBase,
311 V: VectorView<El<Zn>> + Copy
312 {
313 data_prep.prepared.get_or_init(rns_index, || self.get_convolution(rns_index).prepare_convolution_operand(data, data_prep.len_hint, self.get_rns_factor(rns_index)))
314 }
315
316 fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
317 &self,
318 lhs: V1,
319 lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
320 rhs: V2,
321 rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
322 dst: &mut [R::Element],
323 ring: &R,
324 mut to_int: ToInt,
325 mut from_int: FromInt,
326 ring_log2_el_size: Option<usize>
327 )
328 where R: ?Sized + RingBase,
329 V1: VectorView<R::Element>,
330 V2: VectorView<R::Element>,
331 ToInt: FnMut(&R::Element) -> El<I>,
332 FromInt: FnMut(El<I>) -> R::Element
333 {
334 if lhs.len() == 0 || rhs.len() == 0 {
335 return;
336 }
337
338 let input_size_log2 = self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size);
339 let width = self.compute_required_width(input_size_log2, lhs.len(), rhs.len(), 1);
340 let len = lhs.len() + rhs.len() - 1;
341
342 let mut res_data = Vec::with_capacity_in(len * width, self.allocator.clone());
343 for i in 0..width {
344 res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
345 }
346 let mut lhs_tmp = Vec::with_capacity_in(lhs.len(), self.allocator.clone());
347 let mut rhs_tmp = Vec::with_capacity_in(rhs.len(), self.allocator.clone());
348 for i in 0..width {
349 let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
350 lhs_tmp.clear();
351 lhs_tmp.extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
352 rhs_tmp.clear();
353 rhs_tmp.extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
354 self.get_convolution(i).compute_convolution_prepared(
355 &lhs_tmp,
356 lhs_prep.map(|lhs_prep| self.get_prepared_operand(&lhs_tmp, lhs_prep, i, ring)),
357 &rhs_tmp,
358 rhs_prep.map(|rhs_prep| self.get_prepared_operand(&rhs_tmp, rhs_prep, i, ring)),
359 &mut res_data[(i * len)..((i + 1) * len)],
360 self.get_rns_factor(i)
361 );
362 }
363 for j in 0..len {
364 let add = self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * len + j])));
365 ring.add_assign(&mut dst[j], from_int(add));
366 }
367 }
368
369 fn compute_convolution_sum_impl<'a, R, J, V1, V2, ToInt, FromInt>(
370 &self,
371 values: J,
372 dst: &mut [R::Element],
373 ring: &R,
374 mut to_int: ToInt,
375 mut from_int: FromInt,
376 ring_log2_el_size: Option<usize>
377 )
378 where R: ?Sized + RingBase,
379 J: ExactSizeIterator<Item = (V1, Option<&'a PreparedConvolutionOperand<R, C>>, V2, Option<&'a PreparedConvolutionOperand<R, C>>)>,
380 V1: VectorView<R::Element>,
381 V2: VectorView<R::Element>,
382 ToInt: FnMut(&R::Element) -> El<I>,
383 FromInt: FnMut(El<I>) -> R::Element,
384 Self: 'a,
385 R: 'a
386 {
387 let out_len = dst.len();
388 let inner_product_length = dst.len();
389
390 let mut current_width = 0;
391 let mut current_input_size_log2 = 0;
392 let mut lhs_max_len = 0;
393 let mut rhs_max_len = 0;
394 let mut res_data = Vec::new_in(self.allocator.clone());
395 let mut lhs_tmp = Vec::new_in(self.allocator.clone());
396 let mut rhs_tmp = Vec::new_in(self.allocator.clone());
397
398 let mut merge_current = |
405 current_width: usize,
406 lhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>,
407 rhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>
408 | {
409 if current_width == 0 {
410 lhs_tmp.clear();
411 rhs_tmp.clear();
412 return;
413 }
414 res_data.clear();
415 for i in 0..current_width {
416 res_data.extend((0..out_len).map(|_| self.get_rns_factor(i).zero()));
417 self.get_convolution(i).compute_convolution_sum(
418 lhs_tmp.iter().zip(rhs_tmp.iter()).map(|((lhs, lhs_prep), (rhs, rhs_prep))| {
419 let lhs_data = &lhs[(i * lhs.len() / current_width)..((i + 1) * lhs.len() / current_width)];
420 let rhs_data = &rhs[(i * rhs.len() / current_width)..((i + 1) * rhs.len() / current_width)];
421 (
422 lhs_data,
423 lhs_prep.map(|lhs_prep| self.get_prepared_operand(lhs_data, lhs_prep, i, ring)),
424 rhs_data,
425 rhs_prep.map(|rhs_prep| self.get_prepared_operand(rhs_data, rhs_prep, i, ring)),
426 )
427 }),
428 &mut res_data[(i * out_len)..((i + 1) * out_len)],
429 self.get_rns_factor(i)
430 );
431 }
432 lhs_tmp.clear();
433 rhs_tmp.clear();
434 for j in 0..out_len {
435 let add = self.get_rns_ring(current_width).smallest_lift(self.get_rns_ring(current_width).from_congruence((0..current_width).map(|i| res_data[i * out_len + j])));
436 ring.add_assign(&mut dst[j], from_int(add));
437 }
438 };
439
440 for (lhs, lhs_prep, rhs, rhs_prep) in values {
441 if lhs.len() == 0 || rhs.len() == 0 {
442 continue;
443 }
444 assert!(out_len >= lhs.len() + rhs.len() - 1);
445 current_input_size_log2 = max(
446 current_input_size_log2,
447 self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size)
448 );
449 lhs_max_len = max(lhs_max_len, lhs.len());
450 rhs_max_len = max(rhs_max_len, rhs.len());
451 let required_width = self.compute_required_width(current_input_size_log2, lhs_max_len, rhs_max_len, inner_product_length);
452
453 if required_width > current_width {
454 merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
455 current_width = required_width;
456 }
457
458 lhs_tmp.push((Vec::with_capacity_in(lhs.len() * current_width, self.allocator.clone()), lhs_prep));
459 rhs_tmp.push((Vec::with_capacity_in(rhs.len() * current_width, self.allocator.clone()), rhs_prep));
460 for i in 0..current_width {
461 let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
462 lhs_tmp.last_mut().unwrap().0.extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
463 rhs_tmp.last_mut().unwrap().0.extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
464 }
465 }
466 merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
467 }
468
469 fn prepare_convolution_impl<R, V, ToInt>(
470 &self,
471 data: V,
472 _ring: &R,
473 length_hint: Option<usize>,
474 mut to_int: ToInt,
475 ring_log2_el_size: Option<usize>
476 ) -> PreparedConvolutionOperand<R, C>
477 where R: ?Sized + RingBase,
478 V: VectorView<R::Element>,
479 ToInt: FnMut(&R::Element) -> El<I>
480 {
481 let input_size_log2 = if let Some(log2_data_size) = ring_log2_el_size {
482 log2_data_size
483 } else {
484 data.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
485 };
486 return PreparedConvolutionOperand {
487 ring: PhantomData,
488 len_hint: length_hint,
489 prepared: LazyVec::new(),
490 log2_data_size: input_size_log2
491 };
492 }
493}
494
495impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolution<I, C, A, CreateC>
496 where I: RingStore + Clone,
497 I::Type: IntegerRing,
498 C: ConvolutionAlgorithm<ZnBase>,
499 A: Allocator + Clone,
500 CreateC: Fn(Zn) -> C,
501 R: ?Sized + IntegerRing
502{
503 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
504
505 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
506 self.compute_convolution_impl(
507 lhs,
508 None,
509 rhs,
510 None,
511 dst,
512 ring.get_ring(),
513 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
514 |x| int_cast(x, ring, &self.integer_ring),
515 None
516 )
517 }
518
519 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
520 true
521 }
522
523 fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
524 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
525 {
526 self.prepare_convolution_impl(
527 val,
528 ring.get_ring(),
529 len_hint,
530 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
531 None
532 )
533 }
534
535 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
536 where S: RingStore<Type = R> + Copy,
537 V1: VectorView<El<S>>,
538 V2: VectorView<El<S>>
539 {
540 self.compute_convolution_impl(
541 lhs,
542 lhs_prep,
543 rhs,
544 rhs_prep,
545 dst,
546 ring.get_ring(),
547 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
548 |x| int_cast(x, ring, &self.integer_ring),
549 None
550 )
551 }
552
553 fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S)
554 where S: RingStore<Type = R> + Copy,
555 J: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
556 V1: VectorView<R::Element>,
557 V2: VectorView<R::Element>,
558 Self: 'a,
559 R: 'a
560 {
561 self.compute_convolution_sum_impl(
562 values,
563 dst,
564 ring.get_ring(),
565 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
566 |x| int_cast(x, ring, &self.integer_ring),
567 None
568 )
569 }
570}
571
572impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolutionZn<I, C, A, CreateC>
573 where I: RingStore + Clone,
574 I::Type: IntegerRing,
575 C: ConvolutionAlgorithm<ZnBase>,
576 A: Allocator + Clone,
577 CreateC: Fn(Zn) -> C,
578 R: ?Sized + ZnRing + CanHomFrom<I::Type>
579{
580 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
581
582 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
583 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
584 self.base.compute_convolution_impl(
585 lhs,
586 None,
587 rhs,
588 None,
589 dst,
590 ring.get_ring(),
591 |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
592 |x| hom.map(x),
593 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
594 )
595 }
596
597 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
598 true
599 }
600
601 fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
602 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
603 {
604 self.base.prepare_convolution_impl(
605 val,
606 ring.get_ring(),
607 len_hint,
608 |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
609 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
610 )
611 }
612
613 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
614 where S: RingStore<Type = R> + Copy,
615 V1: VectorView<El<S>>,
616 V2: VectorView<El<S>>
617 {
618 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
619 self.base.compute_convolution_impl(
620 lhs,
621 lhs_prep,
622 rhs,
623 rhs_prep,
624 dst,
625 ring.get_ring(),
626 |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
627 |x| hom.map(x),
628 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
629 )
630 }
631
632 fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S)
633 where S: RingStore<Type = R> + Copy,
634 J: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
635 V1: VectorView<R::Element>,
636 V2: VectorView<R::Element>,
637 Self: 'a,
638 R: 'a
639 {
640 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
641 self.base.compute_convolution_sum_impl(
642 values,
643 dst,
644 ring.get_ring(),
645 |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
646 |x| hom.map(x),
647 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
648 )
649 }
650}
651
652#[cfg(test)]
653use super::STANDARD_CONVOLUTION;
654
655#[test]
656fn test_convolution_integer() {
657 let ring = StaticRing::<i128>::RING;
658 let convolution = RNSConvolution::new_with_convolution(7, usize::MAX, BigIntRing::RING, Global, NTTConvolution::new);
659
660 super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
661}
662
663#[test]
664fn test_convolution_zn() {
665 let ring = Zn::new((1 << 57) + 1);
666 let convolution = RNSConvolutionZn::from(RNSConvolution::new_with_convolution(7, usize::MAX, BigIntRing::RING, Global, NTTConvolution::new));
667
668 super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
669}
670
671#[test]
672fn test_convolution_sum() {
673 let ring = StaticRing::<i128>::RING;
674 let convolution = RNSConvolution::new_with_convolution(7, 20, BigIntRing::RING, Global, NTTConvolution::new);
675
676 let data = (0..40usize).map(|i| (
677 (0..(5 + i % 5)).map(|x| (1 << i) * (x as i128 - 2)).collect::<Vec<_>>(),
678 (0..(13 - i % 7)).map(|x| (1 << i) * (x as i128 + 1)).collect::<Vec<_>>(),
679 ));
680 let mut expected = (0..22).map(|_| 0).collect::<Vec<_>>();
681 STANDARD_CONVOLUTION.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut expected, ring);
682
683 let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
684 convolution.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut actual, ring);
685 assert_eq!(&expected[..21], actual);
686
687 let data_prep = data.clone().map(|(l, r)| {
688 let l_prep = convolution.prepare_convolution_operand(&l, Some(21), ring);
689 let r_prep = convolution.prepare_convolution_operand(&r, Some(21), ring);
690 (l, l_prep, r, r_prep)
691 }).collect::<Vec<_>>();
692 let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
693 convolution.compute_convolution_sum(data_prep.iter().map(|(l, l_prep, r, r_prep)| (l, Some(l_prep), r, Some(r_prep))), &mut actual, ring);
694 assert_eq!(&expected[..21], actual);
695
696 let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
697 convolution.compute_convolution_sum(data_prep.iter().enumerate().map(|(i, (l, l_prep, r, r_prep))| match i % 4 {
698 0 => (l, Some(l_prep), r, Some(r_prep)),
699 1 => (l, None, r, Some(r_prep)),
700 2 => (l, Some(l_prep), r, None),
701 3 => (l, None, r, None),
702 _ => unreachable!()
703 }), &mut actual, ring);
704 assert_eq!(&expected[..21], actual);
705}