feanor_math/algorithms/convolution/
rns.rs

1use std::alloc::{Allocator, Global};
2use std::cmp::{min, max};
3use std::marker::PhantomData;
4
5use crate::algorithms::miller_rabin::is_prime;
6use crate::homomorphism::*;
7use crate::integer::*;
8use crate::lazy::LazyVec;
9use crate::primitive_int::StaticRing;
10use crate::ring::*;
11use crate::rings::zn::zn_64::{Zn, ZnBase, ZnFastmul, ZnFastmulBase};
12use crate::rings::zn::*;
13use crate::divisibility::*;
14use crate::seq::*;
15
16use super::ntt::NTTConvolution;
17use super::ConvolutionAlgorithm;
18
19///
20/// A [`ConvolutionAlgorithm`] that computes convolutions by computing them modulo a
21/// suitable composite modulus `q`, whose factors are of a certain shape, usually such
22/// as to allow for NTT-based convolutions.
23/// 
24/// Due to overlapping blanket impls, this type can only be used to compute convolutions
25/// over [`IntegerRing`]s. For computing convolutions over [`ZnRing`]s, wrap it in a
26/// [`RNSConvolutionZn`].
27/// 
28#[stability::unstable(feature = "enable")]
29pub struct RNSConvolution<I = BigIntRing, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>, A = Global, CreateC = CreateNTTConvolution>
30    where I: RingStore + Clone,
31        I::Type: IntegerRing,
32        C: ConvolutionAlgorithm<ZnBase>,
33        A: Allocator + Clone,
34        CreateC: Fn(Zn) -> C
35{
36    integer_ring: I,
37    rns_rings: LazyVec<zn_rns::Zn<Zn, I, A>>,
38    convolutions: LazyVec<C>,
39    create_convolution: CreateC,
40    required_root_of_unity_log2: usize,
41    allocator: A
42}
43
44///
45/// Same as [`RNSConvolution`], but computes convolutions over [`ZnRing`]s.
46/// 
47#[stability::unstable(feature = "enable")]
48#[repr(transparent)]
49pub struct RNSConvolutionZn<I = BigIntRing, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>, A = Global, CreateC = CreateNTTConvolution>
50    where I: RingStore + Clone,
51        I::Type: IntegerRing,
52        C: ConvolutionAlgorithm<ZnBase>,
53        A: Allocator + Clone,
54        CreateC: Fn(Zn) -> C
55{
56    base: RNSConvolution<I, C, A, CreateC>
57}
58
59///
60/// A prepared convolution operand for a [`RNSConvolution`].
61/// 
62#[stability::unstable(feature = "enable")]
63pub struct PreparedConvolutionOperand<R, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>>
64    where R: ?Sized + RingBase,
65        C: ConvolutionAlgorithm<ZnBase>
66{
67    prepared: LazyVec<C::PreparedConvolutionOperand>,
68    log2_data_size: usize,
69    ring: PhantomData<R>,
70    len_hint: Option<usize>
71}
72
73///
74/// Function that creates an [`NTTConvolution`] when given a suitable modulus.
75/// 
76#[stability::unstable(feature = "enable")]
77pub struct CreateNTTConvolution<A = Global>
78    where A: Allocator + Clone
79{
80    allocator: A
81}
82
83impl<I, C, A, CreateC> From<RNSConvolutionZn<I, C, A, CreateC>> for RNSConvolution<I, C, A, CreateC>
84    where I: RingStore + Clone,
85        I::Type: IntegerRing,
86        C: ConvolutionAlgorithm<ZnBase>,
87        A: Allocator + Clone,
88        CreateC: Fn(Zn) -> C
89{
90    fn from(value: RNSConvolutionZn<I, C, A, CreateC>) -> Self {
91        value.base
92    }
93}
94
95impl<'a, I, C, A, CreateC> From<&'a RNSConvolutionZn<I, C, A, CreateC>> for &'a RNSConvolution<I, C, A, CreateC>
96    where I: RingStore + Clone,
97        I::Type: IntegerRing,
98        C: ConvolutionAlgorithm<ZnBase>,
99        A: Allocator + Clone,
100        CreateC: Fn(Zn) -> C
101{
102    fn from(value: &'a RNSConvolutionZn<I, C, A, CreateC>) -> Self {
103        &value.base
104    }
105}
106
107impl<I, C, A, CreateC> From<RNSConvolution<I, C, A, CreateC>> for RNSConvolutionZn<I, C, A, CreateC>
108    where I: RingStore + Clone,
109        I::Type: IntegerRing,
110        C: ConvolutionAlgorithm<ZnBase>,
111        A: Allocator + Clone,
112        CreateC: Fn(Zn) -> C
113{
114    fn from(value: RNSConvolution<I, C, A, CreateC>) -> Self {
115        RNSConvolutionZn { base: value }
116    }
117}
118
119impl<'a, I, C, A, CreateC> From<&'a RNSConvolution<I, C, A, CreateC>> for &'a RNSConvolutionZn<I, C, A, CreateC>
120    where I: RingStore + Clone,
121        I::Type: IntegerRing,
122        C: ConvolutionAlgorithm<ZnBase>,
123        A: Allocator + Clone,
124        CreateC: Fn(Zn) -> C
125{
126    fn from(value: &'a RNSConvolution<I, C, A, CreateC>) -> Self {
127        unsafe { std::mem::transmute(value) }
128    }
129}
130
131impl CreateNTTConvolution<Global> {
132
133    #[stability::unstable(feature = "enable")]
134    pub fn new() -> Self {
135        Self { allocator: Global }
136    }
137}
138
139impl<A> FnOnce<(Zn,)> for CreateNTTConvolution<A>
140    where A: Allocator + Clone
141{
142    type Output = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>, A>;
143
144    extern "rust-call" fn call_once(self, args: (Zn,)) -> Self::Output {
145        self.call(args)
146    }
147}
148
149impl<A> FnMut<(Zn,)> for CreateNTTConvolution<A>
150    where A: Allocator + Clone
151{
152    extern "rust-call" fn call_mut(&mut self, args: (Zn,)) -> Self::Output {
153        self.call(args)
154    }
155}
156
157impl<A> Fn<(Zn,)> for CreateNTTConvolution<A>
158    where A: Allocator + Clone
159{
160    extern "rust-call" fn call(&self, args: (Zn,)) -> Self::Output {
161        let ring = args.0;
162        let ring_fastmul = ZnFastmul::new(ring).unwrap();
163        let hom = ring.into_can_hom(ring_fastmul).ok().unwrap();
164        NTTConvolution::new_with(hom, self.allocator.clone())
165    }
166}
167
168impl RNSConvolution {
169
170    ///
171    /// Creates a new [`RNSConvolution`] that can compute convolutions of sequences with output
172    /// length `<= 2^max_log2_n`. As base convolution, the [`NTTConvolution`] is used.
173    /// 
174    #[stability::unstable(feature = "enable")]
175    pub fn new(max_log2_n: usize) -> Self {
176        Self::new_with(max_log2_n, usize::MAX, BigIntRing::RING, Global, CreateNTTConvolution { allocator: Global })
177    }
178}
179
180impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
181    where I: RingStore + Clone,
182        I::Type: IntegerRing,
183        C: ConvolutionAlgorithm<ZnBase>,
184        A: Allocator + Clone,
185        CreateC: Fn(Zn) -> C
186{
187    ///
188    /// Creates a new [`RNSConvolution`] with all the given configuration parameters.
189    /// 
190    /// In particular
191    ///  - `required_root_of_unity_log2` and `max_prime_size_log2` control which prime factors
192    ///    are used for the underlying composite modulus; Only primes `<= 2^max_prime_size_log2` and
193    ///    `= 1` mod `required_root_of_unity_log2` are sampled
194    ///  - `integer_ring` is the ring to store intermediate lifts in; this probably has to be [`BigIntRing`],
195    ///    unless inputs are pretty small
196    ///  - `allocator` is used to allocate elements modulo the internal modulus, as elements of [`zn_rns::Zn`]
197    ///  - `create_convolution` is called whenever a new convolution algorithm for a new prime has to be
198    ///    created; the modulus of the given [`Zn`] always satisfies the constraints defined by `max_prime_size_log2`
199    ///    and `required_root_of_unity_log2`
200    /// 
201    #[stability::unstable(feature = "enable")]
202    pub fn new_with(required_root_of_unity_log2: usize, mut max_prime_size_log2: usize, integer_ring: I, allocator: A, create_convolution: CreateC) -> Self {
203        max_prime_size_log2 = min(max_prime_size_log2, 57);
204        let result = Self {
205            integer_ring: integer_ring,
206            create_convolution: create_convolution,
207            convolutions: LazyVec::new(),
208            rns_rings: LazyVec::new(),
209            required_root_of_unity_log2: required_root_of_unity_log2,
210            allocator: allocator
211        };
212        let initial_ring = zn_rns::Zn::new_with(vec![Zn::new(Self::sample_next_prime(required_root_of_unity_log2, (1 << max_prime_size_log2) + 1).unwrap() as u64)], result.integer_ring.clone(), result.allocator.clone());
213        _ = result.rns_rings.get_or_init(0, || initial_ring);
214        return result;
215    }
216
217    fn sample_next_prime(required_root_of_unity_log2: usize, current: i64) -> Option<i64> {
218        let mut k = StaticRing::<i64>::RING.checked_div(&(current - 1), &(1 << required_root_of_unity_log2)).unwrap();
219        while k > 0 {
220            k -= 1;
221            let candidate = (k << required_root_of_unity_log2) + 1;
222            if is_prime(StaticRing::<i64>::RING, &candidate, 10) {
223                return Some(candidate);
224            }
225        }
226        return None;
227    }
228
229    fn get_rns_ring(&self, moduli_count: usize) -> &zn_rns::Zn<Zn, I, A> {
230        self.rns_rings.get_or_init_incremental(moduli_count - 1, |_, prev| zn_rns::Zn::new_with(
231            prev.as_iter().cloned().chain([Zn::new(Self::sample_next_prime(self.required_root_of_unity_log2, *prev.at(prev.len() - 1).modulus()).unwrap() as u64)]).collect(),
232            self.integer_ring.clone(),
233            self.allocator.clone()
234        ))
235    }
236
237    fn get_rns_factor(&self, i: usize) -> &Zn {
238        let rns_ring = self.get_rns_ring(i + 1);
239        return rns_ring.at(rns_ring.len() - 1);
240    }
241
242    fn get_convolution(&self, i: usize) -> &C {
243        self.convolutions.get_or_init(i, || (self.create_convolution)(*self.get_rns_factor(i)))
244    }
245
246    fn compute_required_width(&self, input_size_log2: usize, lhs_len: usize, rhs_len: usize, inner_prod_len: usize) -> usize {
247        let log2_output_size = input_size_log2 * 2 + 
248            StaticRing::<i64>::RING.abs_log2_ceil(&min(lhs_len, rhs_len).try_into().unwrap()).unwrap_or(0) +
249            StaticRing::<i64>::RING.abs_log2_ceil(&inner_prod_len.try_into().unwrap()).unwrap_or(0) +
250            1;
251        let mut width = log2_output_size.div_ceil(57);
252        while log2_output_size > self.integer_ring.abs_log2_floor(self.get_rns_ring(width).modulus()).unwrap() {
253            width += 1;
254        }
255        return width;
256    }
257
258    fn get_log2_input_size<R, V1, V2, ToInt>(
259        &self,
260        lhs: V1,
261        lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
262        rhs: V2,
263        rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
264        _ring: &R,
265        mut to_int: ToInt,
266        ring_log2_el_size: Option<usize>
267    ) -> usize
268        where R: ?Sized + RingBase,
269            V1: VectorView<R::Element>,
270            V2: VectorView<R::Element>,
271            ToInt: FnMut(&R::Element) -> El<I>,
272    {
273        if let Some(log2_data_size) = ring_log2_el_size {
274            assert!(lhs_prep.is_none() || lhs_prep.unwrap().log2_data_size == log2_data_size);
275            assert!(rhs_prep.is_none() || rhs_prep.unwrap().log2_data_size == log2_data_size);
276            log2_data_size
277        } else { 
278            max(
279                if let Some(lhs_prep) = lhs_prep {
280                    lhs_prep.log2_data_size
281                } else {
282                    lhs.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
283                },
284                if let Some(rhs_prep) = rhs_prep {
285                    rhs_prep.log2_data_size
286                } else {
287                    rhs.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
288                }
289            )
290        }
291    }
292    
293    fn get_prepared_operand<'a, R, V>(
294        &self,
295        data: V,
296        data_prep: &'a PreparedConvolutionOperand<R, C>,
297        rns_index: usize,
298        _ring: &R
299    ) -> &'a C::PreparedConvolutionOperand
300        where R: ?Sized + RingBase,
301            V: VectorView<El<Zn>> + Copy
302    {
303        data_prep.prepared.get_or_init(rns_index, || self.get_convolution(rns_index).prepare_convolution_operand(data, data_prep.len_hint, self.get_rns_factor(rns_index)))
304    }
305
306    fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
307        &self,
308        lhs: V1,
309        lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
310        rhs: V2,
311        rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
312        dst: &mut [R::Element],
313        ring: &R,
314        mut to_int: ToInt,
315        mut from_int: FromInt,
316        ring_log2_el_size: Option<usize>
317    )
318        where R: ?Sized + RingBase,
319            V1: VectorView<R::Element>,
320            V2: VectorView<R::Element>,
321            ToInt: FnMut(&R::Element) -> El<I>,
322            FromInt: FnMut(El<I>) -> R::Element
323    {
324        if lhs.len() == 0 || rhs.len() == 0 {
325            return;
326        }
327
328        let input_size_log2 = self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size);
329        let width = self.compute_required_width(input_size_log2, lhs.len(), rhs.len(), 1);
330        let len = lhs.len() + rhs.len() - 1;
331
332        let mut res_data = Vec::with_capacity_in(len * width, self.allocator.clone());
333        for i in 0..width {
334            res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
335        }
336        let mut lhs_tmp = Vec::with_capacity_in(lhs.len(), self.allocator.clone());
337        let mut rhs_tmp = Vec::with_capacity_in(rhs.len(), self.allocator.clone());
338        for i in 0..width {
339            let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
340            lhs_tmp.clear();
341            lhs_tmp.extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
342            rhs_tmp.clear();
343            rhs_tmp.extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
344            self.get_convolution(i).compute_convolution_prepared(
345                &lhs_tmp, 
346                lhs_prep.map(|lhs_prep| self.get_prepared_operand(&lhs_tmp, lhs_prep, i, ring)),
347                &rhs_tmp, 
348                rhs_prep.map(|rhs_prep| self.get_prepared_operand(&rhs_tmp, rhs_prep, i, ring)),
349                &mut res_data[(i * len)..((i + 1) * len)], 
350                self.get_rns_factor(i)
351            );
352        }
353        for j in 0..len {
354            let add = self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * len + j])));
355            ring.add_assign(&mut dst[j], from_int(add));
356        }
357    }
358
359    fn compute_convolution_sum_impl<'a, R, J, V1, V2, ToInt, FromInt>(
360        &self, 
361        values: J, 
362        dst: &mut [R::Element], 
363        ring: &R,
364        mut to_int: ToInt,
365        mut from_int: FromInt,
366        ring_log2_el_size: Option<usize>
367    ) 
368        where R: ?Sized + RingBase, 
369            J: ExactSizeIterator<Item = (V1, Option<&'a PreparedConvolutionOperand<R, C>>, V2, Option<&'a PreparedConvolutionOperand<R, C>>)>,
370            V1: VectorView<R::Element>,
371            V2: VectorView<R::Element>,
372            ToInt: FnMut(&R::Element) -> El<I>,
373            FromInt: FnMut(El<I>) -> R::Element,
374            Self: 'a,
375            R: 'a
376    {
377        let out_len = dst.len();
378        let inner_product_length = dst.len();
379
380        let mut current_width = 0;
381        let mut current_input_size_log2 = 0;
382        let mut lhs_max_len = 0;
383        let mut rhs_max_len = 0;
384        let mut res_data = Vec::new_in(self.allocator.clone());
385        let mut lhs_tmp = Vec::new_in(self.allocator.clone());
386        let mut rhs_tmp = Vec::new_in(self.allocator.clone());
387        
388        let mut merge_current = |current_width: usize, lhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>, rhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>| {
389            if current_width == 0 {
390                lhs_tmp.clear();
391                rhs_tmp.clear();
392                return;
393            }
394            res_data.clear();
395            for i in 0..current_width {
396                res_data.extend((0..out_len).map(|_| self.get_rns_factor(i).zero()));
397                self.get_convolution(i).compute_convolution_sum(
398                    lhs_tmp.iter().zip(rhs_tmp.iter()).map(|((lhs, lhs_prep), (rhs, rhs_prep))| {
399                        let lhs_data = &lhs[(i * lhs.len() / current_width)..((i + 1) * lhs.len() / current_width)];
400                        let rhs_data = &rhs[(i * rhs.len() / current_width)..((i + 1) * rhs.len() / current_width)];
401                        (
402                            lhs_data,
403                            lhs_prep.map(|lhs_prep| self.get_prepared_operand(lhs_data, lhs_prep, i, ring)),
404                            rhs_data,
405                            rhs_prep.map(|rhs_prep| self.get_prepared_operand(rhs_data, rhs_prep, i, ring)),
406                    )
407                    }),
408                    &mut res_data[(i * out_len)..((i + 1) * out_len)],
409                    self.get_rns_factor(i)
410                );
411            }
412            lhs_tmp.clear();
413            rhs_tmp.clear();
414            for j in 0..out_len {
415                let add = self.get_rns_ring(current_width).smallest_lift(self.get_rns_ring(current_width).from_congruence((0..current_width).map(|i| res_data[i * out_len + j])));
416                ring.add_assign(&mut dst[j], from_int(add));
417            }
418        };
419
420        for (lhs, lhs_prep, rhs, rhs_prep) in values {
421            if lhs.len() == 0 || rhs.len() == 0 {
422                continue;
423            }
424            assert!(out_len >= lhs.len() + rhs.len() - 1);
425            current_input_size_log2 = max(
426                current_input_size_log2,
427                self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size)
428            );
429            lhs_max_len = max(lhs_max_len, lhs.len());
430            rhs_max_len = max(rhs_max_len, rhs.len());
431            let required_width = self.compute_required_width(current_input_size_log2, lhs_max_len, rhs_max_len, inner_product_length);
432
433            if required_width > current_width {
434                merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
435                current_width = required_width;
436            }
437
438            lhs_tmp.push((Vec::with_capacity_in(lhs.len() * current_width, self.allocator.clone()), lhs_prep));
439            rhs_tmp.push((Vec::with_capacity_in(rhs.len() * current_width, self.allocator.clone()), rhs_prep));
440            for i in 0..current_width {
441                let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
442                lhs_tmp.last_mut().unwrap().0.extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
443                rhs_tmp.last_mut().unwrap().0.extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
444            }
445        }
446        merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
447    }
448    
449    fn prepare_convolution_impl<R, V, ToInt>(
450        &self,
451        data: V,
452        _ring: &R,
453        length_hint: Option<usize>,
454        mut to_int: ToInt,
455        ring_log2_el_size: Option<usize>
456    ) -> PreparedConvolutionOperand<R, C>
457        where R: ?Sized + RingBase,
458            V: VectorView<R::Element>,
459            ToInt: FnMut(&R::Element) -> El<I>
460    {
461        let input_size_log2 = if let Some(log2_data_size) = ring_log2_el_size {
462            log2_data_size
463        } else { 
464            data.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
465        };
466        return PreparedConvolutionOperand {
467            ring: PhantomData,
468            len_hint: length_hint,
469            prepared: LazyVec::new(),
470            log2_data_size: input_size_log2
471        };
472    }
473}
474
475impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolution<I, C, A, CreateC>
476    where I: RingStore + Clone,
477        I::Type: IntegerRing,
478        C: ConvolutionAlgorithm<ZnBase>,
479        A: Allocator + Clone,
480        CreateC: Fn(Zn) -> C,
481        R: ?Sized + IntegerRing
482{
483    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
484
485    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
486        self.compute_convolution_impl(
487            lhs,
488            None,
489            rhs,
490            None,
491            dst,
492            ring.get_ring(),
493            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
494            |x| int_cast(x, ring, &self.integer_ring),
495            None
496        )
497    }
498
499    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
500        true
501    }
502
503    fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
504        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
505    {
506        self.prepare_convolution_impl(
507            val,
508            ring.get_ring(),
509            len_hint,
510            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
511            None
512        )
513    }
514    
515    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
516        where S: RingStore<Type = R> + Copy,
517            V1: VectorView<El<S>>,
518            V2: VectorView<El<S>>
519    {
520        self.compute_convolution_impl(
521            lhs,
522            lhs_prep,
523            rhs,
524            rhs_prep,
525            dst,
526            ring.get_ring(),
527            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
528            |x| int_cast(x, ring, &self.integer_ring),
529            None
530        )
531    }
532
533    fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S) 
534        where S: RingStore<Type = R> + Copy, 
535            J: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
536            V1: VectorView<R::Element>,
537            V2: VectorView<R::Element>,
538            Self: 'a,
539            R: 'a
540    {
541        self.compute_convolution_sum_impl(
542            values,
543            dst,
544            ring.get_ring(),
545            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
546            |x| int_cast(x, ring, &self.integer_ring),
547            None
548        )
549    }
550}
551
552impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolutionZn<I, C, A, CreateC>
553    where I: RingStore + Clone,
554        I::Type: IntegerRing,
555        C: ConvolutionAlgorithm<ZnBase>,
556        A: Allocator + Clone,
557        CreateC: Fn(Zn) -> C,
558        R: ?Sized + ZnRing + CanHomFrom<I::Type>
559{
560    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
561
562    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
563        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
564        self.base.compute_convolution_impl(
565            lhs,
566            None,
567            rhs,
568            None,
569            dst,
570            ring.get_ring(),
571            |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
572            |x| hom.map(x),
573            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
574        )
575    }
576
577    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
578        true
579    }
580
581    fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
582        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
583    {
584        self.base.prepare_convolution_impl(
585            val,
586            ring.get_ring(),
587            len_hint,
588            |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
589            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
590        )
591    }
592
593    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
594        where S: RingStore<Type = R> + Copy,
595            V1: VectorView<El<S>>,
596            V2: VectorView<El<S>>
597    {
598        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
599        self.base.compute_convolution_impl(
600            lhs,
601            lhs_prep,
602            rhs,
603            rhs_prep,
604            dst,
605            ring.get_ring(),
606            |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
607            |x| hom.map(x),
608            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
609        )
610    }
611
612    fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S) 
613        where S: RingStore<Type = R> + Copy, 
614            J: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
615            V1: VectorView<R::Element>,
616            V2: VectorView<R::Element>,
617            Self: 'a,
618            R: 'a
619    {
620        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
621        self.base.compute_convolution_sum_impl(
622            values,
623            dst,
624            ring.get_ring(),
625            |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
626            |x| hom.map(x),
627            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
628        )
629    }
630}
631
632#[cfg(test)]
633use super::STANDARD_CONVOLUTION;
634
635#[test]
636fn test_convolution_integer() {
637    let ring = StaticRing::<i128>::RING;
638    let convolution = RNSConvolution::new_with(7, usize::MAX, BigIntRing::RING, Global, |Fp| NTTConvolution::new_with(Fp.into_identity(), Global));
639
640    super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
641}
642
643#[test]
644fn test_convolution_zn() {
645    let ring = Zn::new((1 << 57) + 1);
646    let convolution = RNSConvolutionZn::from(RNSConvolution::new_with(7, usize::MAX, BigIntRing::RING, Global, |Fp| NTTConvolution::new_with(Fp.into_identity(), Global)));
647
648    super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
649}
650
651#[test]
652fn test_convolution_sum() {
653    let ring = StaticRing::<i128>::RING;
654    let convolution = RNSConvolution::new_with(7, 20, BigIntRing::RING, Global, |Fp| NTTConvolution::new_with(Fp.into_identity(), Global));
655    
656    let data = (0..40usize).map(|i| (
657        (0..(5 + i % 5)).map(|x| (1 << i) * (x as i128 - 2)).collect::<Vec<_>>(),
658        (0..(13 - i % 7)).map(|x| (1 << i) * (x as i128 + 1)).collect::<Vec<_>>(),
659    ));
660    let mut expected = (0..22).map(|_| 0).collect::<Vec<_>>();
661    STANDARD_CONVOLUTION.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut expected, ring);
662
663    let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
664    convolution.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut actual, ring);
665    assert_eq!(&expected[..21], actual);
666    
667    let data_prep = data.clone().map(|(l, r)| {
668        let l_prep = convolution.prepare_convolution_operand(&l, Some(21), ring);
669        let r_prep = convolution.prepare_convolution_operand(&r, Some(21), ring);
670        (l, l_prep, r, r_prep)
671    }).collect::<Vec<_>>();
672    let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
673    convolution.compute_convolution_sum(data_prep.iter().map(|(l, l_prep, r, r_prep)| (l, Some(l_prep), r, Some(r_prep))), &mut actual, ring);
674    assert_eq!(&expected[..21], actual);
675    
676    let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
677    convolution.compute_convolution_sum(data_prep.iter().enumerate().map(|(i, (l, l_prep, r, r_prep))| match i % 4 {
678        0 => (l, Some(l_prep), r, Some(r_prep)),
679        1 => (l, None, r, Some(r_prep)),
680        2 => (l, Some(l_prep), r, None),
681        3 => (l, None, r, None),
682        _ => unreachable!()
683    }), &mut actual, ring);
684    assert_eq!(&expected[..21], actual);
685}