1use std::alloc::{Allocator, Global};
2use std::cmp::{min, max};
3use std::marker::PhantomData;
4
5use crate::algorithms::miller_rabin::is_prime;
6use crate::homomorphism::*;
7use crate::integer::*;
8use crate::lazy::LazyVec;
9use crate::primitive_int::StaticRing;
10use crate::ring::*;
11use crate::rings::zn::zn_64::{Zn, ZnBase, ZnFastmul, ZnFastmulBase};
12use crate::rings::zn::*;
13use crate::divisibility::*;
14use crate::seq::*;
15
16use super::ntt::NTTConvolution;
17use super::ConvolutionAlgorithm;
18
19#[stability::unstable(feature = "enable")]
29pub struct RNSConvolution<I = BigIntRing, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>, A = Global, CreateC = CreateNTTConvolution>
30 where I: RingStore + Clone,
31 I::Type: IntegerRing,
32 C: ConvolutionAlgorithm<ZnBase>,
33 A: Allocator + Clone,
34 CreateC: Fn(Zn) -> C
35{
36 integer_ring: I,
37 rns_rings: LazyVec<zn_rns::Zn<Zn, I, A>>,
38 convolutions: LazyVec<C>,
39 create_convolution: CreateC,
40 required_root_of_unity_log2: usize,
41 allocator: A
42}
43
44#[stability::unstable(feature = "enable")]
48#[repr(transparent)]
49pub struct RNSConvolutionZn<I = BigIntRing, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>, A = Global, CreateC = CreateNTTConvolution>
50 where I: RingStore + Clone,
51 I::Type: IntegerRing,
52 C: ConvolutionAlgorithm<ZnBase>,
53 A: Allocator + Clone,
54 CreateC: Fn(Zn) -> C
55{
56 base: RNSConvolution<I, C, A, CreateC>
57}
58
59#[stability::unstable(feature = "enable")]
63pub struct PreparedConvolutionOperand<R, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>>
64 where R: ?Sized + RingBase,
65 C: ConvolutionAlgorithm<ZnBase>
66{
67 prepared: LazyVec<C::PreparedConvolutionOperand>,
68 log2_data_size: usize,
69 ring: PhantomData<R>,
70 len_hint: Option<usize>
71}
72
73#[stability::unstable(feature = "enable")]
77pub struct CreateNTTConvolution<A = Global>
78 where A: Allocator + Clone
79{
80 allocator: A
81}
82
83impl<I, C, A, CreateC> From<RNSConvolutionZn<I, C, A, CreateC>> for RNSConvolution<I, C, A, CreateC>
84 where I: RingStore + Clone,
85 I::Type: IntegerRing,
86 C: ConvolutionAlgorithm<ZnBase>,
87 A: Allocator + Clone,
88 CreateC: Fn(Zn) -> C
89{
90 fn from(value: RNSConvolutionZn<I, C, A, CreateC>) -> Self {
91 value.base
92 }
93}
94
95impl<'a, I, C, A, CreateC> From<&'a RNSConvolutionZn<I, C, A, CreateC>> for &'a RNSConvolution<I, C, A, CreateC>
96 where I: RingStore + Clone,
97 I::Type: IntegerRing,
98 C: ConvolutionAlgorithm<ZnBase>,
99 A: Allocator + Clone,
100 CreateC: Fn(Zn) -> C
101{
102 fn from(value: &'a RNSConvolutionZn<I, C, A, CreateC>) -> Self {
103 &value.base
104 }
105}
106
107impl<I, C, A, CreateC> From<RNSConvolution<I, C, A, CreateC>> for RNSConvolutionZn<I, C, A, CreateC>
108 where I: RingStore + Clone,
109 I::Type: IntegerRing,
110 C: ConvolutionAlgorithm<ZnBase>,
111 A: Allocator + Clone,
112 CreateC: Fn(Zn) -> C
113{
114 fn from(value: RNSConvolution<I, C, A, CreateC>) -> Self {
115 RNSConvolutionZn { base: value }
116 }
117}
118
119impl<'a, I, C, A, CreateC> From<&'a RNSConvolution<I, C, A, CreateC>> for &'a RNSConvolutionZn<I, C, A, CreateC>
120 where I: RingStore + Clone,
121 I::Type: IntegerRing,
122 C: ConvolutionAlgorithm<ZnBase>,
123 A: Allocator + Clone,
124 CreateC: Fn(Zn) -> C
125{
126 fn from(value: &'a RNSConvolution<I, C, A, CreateC>) -> Self {
127 unsafe { std::mem::transmute(value) }
128 }
129}
130
131impl CreateNTTConvolution<Global> {
132
133 #[stability::unstable(feature = "enable")]
134 pub fn new() -> Self {
135 Self { allocator: Global }
136 }
137}
138
139impl<A> FnOnce<(Zn,)> for CreateNTTConvolution<A>
140 where A: Allocator + Clone
141{
142 type Output = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>, A>;
143
144 extern "rust-call" fn call_once(self, args: (Zn,)) -> Self::Output {
145 self.call(args)
146 }
147}
148
149impl<A> FnMut<(Zn,)> for CreateNTTConvolution<A>
150 where A: Allocator + Clone
151{
152 extern "rust-call" fn call_mut(&mut self, args: (Zn,)) -> Self::Output {
153 self.call(args)
154 }
155}
156
157impl<A> Fn<(Zn,)> for CreateNTTConvolution<A>
158 where A: Allocator + Clone
159{
160 extern "rust-call" fn call(&self, args: (Zn,)) -> Self::Output {
161 let ring = args.0;
162 let ring_fastmul = ZnFastmul::new(ring).unwrap();
163 let hom = ring.into_can_hom(ring_fastmul).ok().unwrap();
164 NTTConvolution::new_with(hom, self.allocator.clone())
165 }
166}
167
168impl RNSConvolution {
169
170 #[stability::unstable(feature = "enable")]
175 pub fn new(max_log2_n: usize) -> Self {
176 Self::new_with(max_log2_n, usize::MAX, BigIntRing::RING, Global, CreateNTTConvolution { allocator: Global })
177 }
178}
179
180impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
181 where I: RingStore + Clone,
182 I::Type: IntegerRing,
183 C: ConvolutionAlgorithm<ZnBase>,
184 A: Allocator + Clone,
185 CreateC: Fn(Zn) -> C
186{
187 #[stability::unstable(feature = "enable")]
202 pub fn new_with(required_root_of_unity_log2: usize, mut max_prime_size_log2: usize, integer_ring: I, allocator: A, create_convolution: CreateC) -> Self {
203 max_prime_size_log2 = min(max_prime_size_log2, 57);
204 let result = Self {
205 integer_ring: integer_ring,
206 create_convolution: create_convolution,
207 convolutions: LazyVec::new(),
208 rns_rings: LazyVec::new(),
209 required_root_of_unity_log2: required_root_of_unity_log2,
210 allocator: allocator
211 };
212 let initial_ring = zn_rns::Zn::new_with(vec![Zn::new(Self::sample_next_prime(required_root_of_unity_log2, (1 << max_prime_size_log2) + 1).unwrap() as u64)], result.integer_ring.clone(), result.allocator.clone());
213 _ = result.rns_rings.get_or_init(0, || initial_ring);
214 return result;
215 }
216
217 fn sample_next_prime(required_root_of_unity_log2: usize, current: i64) -> Option<i64> {
218 let mut k = StaticRing::<i64>::RING.checked_div(&(current - 1), &(1 << required_root_of_unity_log2)).unwrap();
219 while k > 0 {
220 k -= 1;
221 let candidate = (k << required_root_of_unity_log2) + 1;
222 if is_prime(StaticRing::<i64>::RING, &candidate, 10) {
223 return Some(candidate);
224 }
225 }
226 return None;
227 }
228
229 fn get_rns_ring(&self, moduli_count: usize) -> &zn_rns::Zn<Zn, I, A> {
230 self.rns_rings.get_or_init_incremental(moduli_count - 1, |_, prev| zn_rns::Zn::new_with(
231 prev.as_iter().cloned().chain([Zn::new(Self::sample_next_prime(self.required_root_of_unity_log2, *prev.at(prev.len() - 1).modulus()).unwrap() as u64)]).collect(),
232 self.integer_ring.clone(),
233 self.allocator.clone()
234 ))
235 }
236
237 fn get_rns_factor(&self, i: usize) -> &Zn {
238 let rns_ring = self.get_rns_ring(i + 1);
239 return rns_ring.at(rns_ring.len() - 1);
240 }
241
242 fn get_convolution(&self, i: usize) -> &C {
243 self.convolutions.get_or_init(i, || (self.create_convolution)(*self.get_rns_factor(i)))
244 }
245
246 fn compute_required_width(&self, input_size_log2: usize, lhs_len: usize, rhs_len: usize, inner_prod_len: usize) -> usize {
247 let log2_output_size = input_size_log2 * 2 +
248 StaticRing::<i64>::RING.abs_log2_ceil(&min(lhs_len, rhs_len).try_into().unwrap()).unwrap_or(0) +
249 StaticRing::<i64>::RING.abs_log2_ceil(&inner_prod_len.try_into().unwrap()).unwrap_or(0) +
250 1;
251 let mut width = log2_output_size.div_ceil(57);
252 while log2_output_size > self.integer_ring.abs_log2_floor(self.get_rns_ring(width).modulus()).unwrap() {
253 width += 1;
254 }
255 return width;
256 }
257
258 fn get_log2_input_size<R, V1, V2, ToInt>(
259 &self,
260 lhs: V1,
261 lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
262 rhs: V2,
263 rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
264 _ring: &R,
265 mut to_int: ToInt,
266 ring_log2_el_size: Option<usize>
267 ) -> usize
268 where R: ?Sized + RingBase,
269 V1: VectorView<R::Element>,
270 V2: VectorView<R::Element>,
271 ToInt: FnMut(&R::Element) -> El<I>,
272 {
273 if let Some(log2_data_size) = ring_log2_el_size {
274 assert!(lhs_prep.is_none() || lhs_prep.unwrap().log2_data_size == log2_data_size);
275 assert!(rhs_prep.is_none() || rhs_prep.unwrap().log2_data_size == log2_data_size);
276 log2_data_size
277 } else {
278 max(
279 if let Some(lhs_prep) = lhs_prep {
280 lhs_prep.log2_data_size
281 } else {
282 lhs.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
283 },
284 if let Some(rhs_prep) = rhs_prep {
285 rhs_prep.log2_data_size
286 } else {
287 rhs.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
288 }
289 )
290 }
291 }
292
293 fn get_prepared_operand<'a, R, V>(
294 &self,
295 data: V,
296 data_prep: &'a PreparedConvolutionOperand<R, C>,
297 rns_index: usize,
298 _ring: &R
299 ) -> &'a C::PreparedConvolutionOperand
300 where R: ?Sized + RingBase,
301 V: VectorView<El<Zn>> + Copy
302 {
303 data_prep.prepared.get_or_init(rns_index, || self.get_convolution(rns_index).prepare_convolution_operand(data, data_prep.len_hint, self.get_rns_factor(rns_index)))
304 }
305
306 fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
307 &self,
308 lhs: V1,
309 lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
310 rhs: V2,
311 rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
312 dst: &mut [R::Element],
313 ring: &R,
314 mut to_int: ToInt,
315 mut from_int: FromInt,
316 ring_log2_el_size: Option<usize>
317 )
318 where R: ?Sized + RingBase,
319 V1: VectorView<R::Element>,
320 V2: VectorView<R::Element>,
321 ToInt: FnMut(&R::Element) -> El<I>,
322 FromInt: FnMut(El<I>) -> R::Element
323 {
324 if lhs.len() == 0 || rhs.len() == 0 {
325 return;
326 }
327
328 let input_size_log2 = self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size);
329 let width = self.compute_required_width(input_size_log2, lhs.len(), rhs.len(), 1);
330 let len = lhs.len() + rhs.len() - 1;
331
332 let mut res_data = Vec::with_capacity_in(len * width, self.allocator.clone());
333 for i in 0..width {
334 res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
335 }
336 let mut lhs_tmp = Vec::with_capacity_in(lhs.len(), self.allocator.clone());
337 let mut rhs_tmp = Vec::with_capacity_in(rhs.len(), self.allocator.clone());
338 for i in 0..width {
339 let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
340 lhs_tmp.clear();
341 lhs_tmp.extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
342 rhs_tmp.clear();
343 rhs_tmp.extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
344 self.get_convolution(i).compute_convolution_prepared(
345 &lhs_tmp,
346 lhs_prep.map(|lhs_prep| self.get_prepared_operand(&lhs_tmp, lhs_prep, i, ring)),
347 &rhs_tmp,
348 rhs_prep.map(|rhs_prep| self.get_prepared_operand(&rhs_tmp, rhs_prep, i, ring)),
349 &mut res_data[(i * len)..((i + 1) * len)],
350 self.get_rns_factor(i)
351 );
352 }
353 for j in 0..len {
354 let add = self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * len + j])));
355 ring.add_assign(&mut dst[j], from_int(add));
356 }
357 }
358
359 fn compute_convolution_sum_impl<'a, R, J, V1, V2, ToInt, FromInt>(
360 &self,
361 values: J,
362 dst: &mut [R::Element],
363 ring: &R,
364 mut to_int: ToInt,
365 mut from_int: FromInt,
366 ring_log2_el_size: Option<usize>
367 )
368 where R: ?Sized + RingBase,
369 J: ExactSizeIterator<Item = (V1, Option<&'a PreparedConvolutionOperand<R, C>>, V2, Option<&'a PreparedConvolutionOperand<R, C>>)>,
370 V1: VectorView<R::Element>,
371 V2: VectorView<R::Element>,
372 ToInt: FnMut(&R::Element) -> El<I>,
373 FromInt: FnMut(El<I>) -> R::Element,
374 Self: 'a,
375 R: 'a
376 {
377 let out_len = dst.len();
378 let inner_product_length = dst.len();
379
380 let mut current_width = 0;
381 let mut current_input_size_log2 = 0;
382 let mut lhs_max_len = 0;
383 let mut rhs_max_len = 0;
384 let mut res_data = Vec::new_in(self.allocator.clone());
385 let mut lhs_tmp = Vec::new_in(self.allocator.clone());
386 let mut rhs_tmp = Vec::new_in(self.allocator.clone());
387
388 let mut merge_current = |current_width: usize, lhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>, rhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>| {
389 if current_width == 0 {
390 lhs_tmp.clear();
391 rhs_tmp.clear();
392 return;
393 }
394 res_data.clear();
395 for i in 0..current_width {
396 res_data.extend((0..out_len).map(|_| self.get_rns_factor(i).zero()));
397 self.get_convolution(i).compute_convolution_sum(
398 lhs_tmp.iter().zip(rhs_tmp.iter()).map(|((lhs, lhs_prep), (rhs, rhs_prep))| {
399 let lhs_data = &lhs[(i * lhs.len() / current_width)..((i + 1) * lhs.len() / current_width)];
400 let rhs_data = &rhs[(i * rhs.len() / current_width)..((i + 1) * rhs.len() / current_width)];
401 (
402 lhs_data,
403 lhs_prep.map(|lhs_prep| self.get_prepared_operand(lhs_data, lhs_prep, i, ring)),
404 rhs_data,
405 rhs_prep.map(|rhs_prep| self.get_prepared_operand(rhs_data, rhs_prep, i, ring)),
406 )
407 }),
408 &mut res_data[(i * out_len)..((i + 1) * out_len)],
409 self.get_rns_factor(i)
410 );
411 }
412 lhs_tmp.clear();
413 rhs_tmp.clear();
414 for j in 0..out_len {
415 let add = self.get_rns_ring(current_width).smallest_lift(self.get_rns_ring(current_width).from_congruence((0..current_width).map(|i| res_data[i * out_len + j])));
416 ring.add_assign(&mut dst[j], from_int(add));
417 }
418 };
419
420 for (lhs, lhs_prep, rhs, rhs_prep) in values {
421 if lhs.len() == 0 || rhs.len() == 0 {
422 continue;
423 }
424 assert!(out_len >= lhs.len() + rhs.len() - 1);
425 current_input_size_log2 = max(
426 current_input_size_log2,
427 self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size)
428 );
429 lhs_max_len = max(lhs_max_len, lhs.len());
430 rhs_max_len = max(rhs_max_len, rhs.len());
431 let required_width = self.compute_required_width(current_input_size_log2, lhs_max_len, rhs_max_len, inner_product_length);
432
433 if required_width > current_width {
434 merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
435 current_width = required_width;
436 }
437
438 lhs_tmp.push((Vec::with_capacity_in(lhs.len() * current_width, self.allocator.clone()), lhs_prep));
439 rhs_tmp.push((Vec::with_capacity_in(rhs.len() * current_width, self.allocator.clone()), rhs_prep));
440 for i in 0..current_width {
441 let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
442 lhs_tmp.last_mut().unwrap().0.extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
443 rhs_tmp.last_mut().unwrap().0.extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
444 }
445 }
446 merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
447 }
448
449 fn prepare_convolution_impl<R, V, ToInt>(
450 &self,
451 data: V,
452 _ring: &R,
453 length_hint: Option<usize>,
454 mut to_int: ToInt,
455 ring_log2_el_size: Option<usize>
456 ) -> PreparedConvolutionOperand<R, C>
457 where R: ?Sized + RingBase,
458 V: VectorView<R::Element>,
459 ToInt: FnMut(&R::Element) -> El<I>
460 {
461 let input_size_log2 = if let Some(log2_data_size) = ring_log2_el_size {
462 log2_data_size
463 } else {
464 data.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
465 };
466 return PreparedConvolutionOperand {
467 ring: PhantomData,
468 len_hint: length_hint,
469 prepared: LazyVec::new(),
470 log2_data_size: input_size_log2
471 };
472 }
473}
474
475impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolution<I, C, A, CreateC>
476 where I: RingStore + Clone,
477 I::Type: IntegerRing,
478 C: ConvolutionAlgorithm<ZnBase>,
479 A: Allocator + Clone,
480 CreateC: Fn(Zn) -> C,
481 R: ?Sized + IntegerRing
482{
483 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
484
485 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
486 self.compute_convolution_impl(
487 lhs,
488 None,
489 rhs,
490 None,
491 dst,
492 ring.get_ring(),
493 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
494 |x| int_cast(x, ring, &self.integer_ring),
495 None
496 )
497 }
498
499 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
500 true
501 }
502
503 fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
504 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
505 {
506 self.prepare_convolution_impl(
507 val,
508 ring.get_ring(),
509 len_hint,
510 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
511 None
512 )
513 }
514
515 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
516 where S: RingStore<Type = R> + Copy,
517 V1: VectorView<El<S>>,
518 V2: VectorView<El<S>>
519 {
520 self.compute_convolution_impl(
521 lhs,
522 lhs_prep,
523 rhs,
524 rhs_prep,
525 dst,
526 ring.get_ring(),
527 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
528 |x| int_cast(x, ring, &self.integer_ring),
529 None
530 )
531 }
532
533 fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S)
534 where S: RingStore<Type = R> + Copy,
535 J: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
536 V1: VectorView<R::Element>,
537 V2: VectorView<R::Element>,
538 Self: 'a,
539 R: 'a
540 {
541 self.compute_convolution_sum_impl(
542 values,
543 dst,
544 ring.get_ring(),
545 |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
546 |x| int_cast(x, ring, &self.integer_ring),
547 None
548 )
549 }
550}
551
552impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolutionZn<I, C, A, CreateC>
553 where I: RingStore + Clone,
554 I::Type: IntegerRing,
555 C: ConvolutionAlgorithm<ZnBase>,
556 A: Allocator + Clone,
557 CreateC: Fn(Zn) -> C,
558 R: ?Sized + ZnRing + CanHomFrom<I::Type>
559{
560 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
561
562 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
563 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
564 self.base.compute_convolution_impl(
565 lhs,
566 None,
567 rhs,
568 None,
569 dst,
570 ring.get_ring(),
571 |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
572 |x| hom.map(x),
573 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
574 )
575 }
576
577 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
578 true
579 }
580
581 fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
582 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
583 {
584 self.base.prepare_convolution_impl(
585 val,
586 ring.get_ring(),
587 len_hint,
588 |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
589 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
590 )
591 }
592
593 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
594 where S: RingStore<Type = R> + Copy,
595 V1: VectorView<El<S>>,
596 V2: VectorView<El<S>>
597 {
598 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
599 self.base.compute_convolution_impl(
600 lhs,
601 lhs_prep,
602 rhs,
603 rhs_prep,
604 dst,
605 ring.get_ring(),
606 |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
607 |x| hom.map(x),
608 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
609 )
610 }
611
612 fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S)
613 where S: RingStore<Type = R> + Copy,
614 J: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
615 V1: VectorView<R::Element>,
616 V2: VectorView<R::Element>,
617 Self: 'a,
618 R: 'a
619 {
620 let hom = ring.can_hom(&self.base.integer_ring).unwrap();
621 self.base.compute_convolution_sum_impl(
622 values,
623 dst,
624 ring.get_ring(),
625 |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
626 |x| hom.map(x),
627 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
628 )
629 }
630}
631
632#[cfg(test)]
633use super::STANDARD_CONVOLUTION;
634
635#[test]
636fn test_convolution_integer() {
637 let ring = StaticRing::<i128>::RING;
638 let convolution = RNSConvolution::new_with(7, usize::MAX, BigIntRing::RING, Global, |Fp| NTTConvolution::new_with(Fp.into_identity(), Global));
639
640 super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
641}
642
643#[test]
644fn test_convolution_zn() {
645 let ring = Zn::new((1 << 57) + 1);
646 let convolution = RNSConvolutionZn::from(RNSConvolution::new_with(7, usize::MAX, BigIntRing::RING, Global, |Fp| NTTConvolution::new_with(Fp.into_identity(), Global)));
647
648 super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
649}
650
651#[test]
652fn test_convolution_sum() {
653 let ring = StaticRing::<i128>::RING;
654 let convolution = RNSConvolution::new_with(7, 20, BigIntRing::RING, Global, |Fp| NTTConvolution::new_with(Fp.into_identity(), Global));
655
656 let data = (0..40usize).map(|i| (
657 (0..(5 + i % 5)).map(|x| (1 << i) * (x as i128 - 2)).collect::<Vec<_>>(),
658 (0..(13 - i % 7)).map(|x| (1 << i) * (x as i128 + 1)).collect::<Vec<_>>(),
659 ));
660 let mut expected = (0..22).map(|_| 0).collect::<Vec<_>>();
661 STANDARD_CONVOLUTION.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut expected, ring);
662
663 let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
664 convolution.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut actual, ring);
665 assert_eq!(&expected[..21], actual);
666
667 let data_prep = data.clone().map(|(l, r)| {
668 let l_prep = convolution.prepare_convolution_operand(&l, Some(21), ring);
669 let r_prep = convolution.prepare_convolution_operand(&r, Some(21), ring);
670 (l, l_prep, r, r_prep)
671 }).collect::<Vec<_>>();
672 let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
673 convolution.compute_convolution_sum(data_prep.iter().map(|(l, l_prep, r, r_prep)| (l, Some(l_prep), r, Some(r_prep))), &mut actual, ring);
674 assert_eq!(&expected[..21], actual);
675
676 let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
677 convolution.compute_convolution_sum(data_prep.iter().enumerate().map(|(i, (l, l_prep, r, r_prep))| match i % 4 {
678 0 => (l, Some(l_prep), r, Some(r_prep)),
679 1 => (l, None, r, Some(r_prep)),
680 2 => (l, Some(l_prep), r, None),
681 3 => (l, None, r, None),
682 _ => unreachable!()
683 }), &mut actual, ring);
684 assert_eq!(&expected[..21], actual);
685}