feanor_math/algorithms/convolution/
rns.rs

1use std::alloc::{Allocator, Global};
2use std::cmp::{min, max};
3use std::marker::PhantomData;
4
5use crate::algorithms::miller_rabin::is_prime;
6use crate::homomorphism::*;
7use crate::integer::*;
8use crate::lazy::LazyVec;
9use crate::primitive_int::StaticRing;
10use crate::ring::*;
11use crate::rings::zn::zn_64::{Zn, ZnBase, ZnFastmul, ZnFastmulBase};
12use crate::rings::zn::*;
13use crate::divisibility::*;
14use crate::seq::*;
15
16use super::ntt::NTTConvolution;
17use super::ConvolutionAlgorithm;
18
19///
20/// A [`ConvolutionAlgorithm`] that computes convolutions by computing them modulo a
21/// suitable composite modulus `q`, whose factors are of a certain shape, usually such
22/// as to allow for NTT-based convolutions.
23/// 
24/// Due to overlapping blanket impls, this type can only be used to compute convolutions
25/// over [`IntegerRing`]s. For computing convolutions over [`ZnRing`]s, wrap it in a
26/// [`RNSConvolutionZn`].
27/// 
28#[stability::unstable(feature = "enable")]
29pub struct RNSConvolution<I = BigIntRing, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>, A = Global, CreateC = CreateNTTConvolution>
30    where I: RingStore + Clone,
31        I::Type: IntegerRing,
32        C: ConvolutionAlgorithm<ZnBase>,
33        A: Allocator + Clone,
34        CreateC: Fn(Zn) -> C
35{
36    integer_ring: I,
37    rns_rings: LazyVec<zn_rns::Zn<Zn, I, A>>,
38    convolutions: LazyVec<C>,
39    create_convolution: CreateC,
40    required_root_of_unity_log2: usize,
41    allocator: A
42}
43
44///
45/// Same as [`RNSConvolution`], but computes convolutions over [`ZnRing`]s.
46/// 
47#[stability::unstable(feature = "enable")]
48#[repr(transparent)]
49pub struct RNSConvolutionZn<I = BigIntRing, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>, A = Global, CreateC = CreateNTTConvolution>
50    where I: RingStore + Clone,
51        I::Type: IntegerRing,
52        C: ConvolutionAlgorithm<ZnBase>,
53        A: Allocator + Clone,
54        CreateC: Fn(Zn) -> C
55{
56    base: RNSConvolution<I, C, A, CreateC>
57}
58
59///
60/// A prepared convolution operand for a [`RNSConvolution`].
61/// 
62#[stability::unstable(feature = "enable")]
63pub struct PreparedConvolutionOperand<R, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>>
64    where R: ?Sized + RingBase,
65        C: ConvolutionAlgorithm<ZnBase>
66{
67    prepared: LazyVec<C::PreparedConvolutionOperand>,
68    log2_data_size: usize,
69    ring: PhantomData<R>,
70    len_hint: Option<usize>
71}
72
73///
74/// Function that creates an [`NTTConvolution`] when given a suitable modulus.
75/// 
76#[stability::unstable(feature = "enable")]
77pub struct CreateNTTConvolution<A = Global>
78    where A: Allocator + Clone
79{
80    allocator: A
81}
82
83impl<I, C, A, CreateC> From<RNSConvolutionZn<I, C, A, CreateC>> for RNSConvolution<I, C, A, CreateC>
84    where I: RingStore + Clone,
85        I::Type: IntegerRing,
86        C: ConvolutionAlgorithm<ZnBase>,
87        A: Allocator + Clone,
88        CreateC: Fn(Zn) -> C
89{
90    fn from(value: RNSConvolutionZn<I, C, A, CreateC>) -> Self {
91        value.base
92    }
93}
94
95impl<'a, I, C, A, CreateC> From<&'a RNSConvolutionZn<I, C, A, CreateC>> for &'a RNSConvolution<I, C, A, CreateC>
96    where I: RingStore + Clone,
97        I::Type: IntegerRing,
98        C: ConvolutionAlgorithm<ZnBase>,
99        A: Allocator + Clone,
100        CreateC: Fn(Zn) -> C
101{
102    fn from(value: &'a RNSConvolutionZn<I, C, A, CreateC>) -> Self {
103        &value.base
104    }
105}
106
107impl<I, C, A, CreateC> From<RNSConvolution<I, C, A, CreateC>> for RNSConvolutionZn<I, C, A, CreateC>
108    where I: RingStore + Clone,
109        I::Type: IntegerRing,
110        C: ConvolutionAlgorithm<ZnBase>,
111        A: Allocator + Clone,
112        CreateC: Fn(Zn) -> C
113{
114    fn from(value: RNSConvolution<I, C, A, CreateC>) -> Self {
115        RNSConvolutionZn { base: value }
116    }
117}
118
119impl<'a, I, C, A, CreateC> From<&'a RNSConvolution<I, C, A, CreateC>> for &'a RNSConvolutionZn<I, C, A, CreateC>
120    where I: RingStore + Clone,
121        I::Type: IntegerRing,
122        C: ConvolutionAlgorithm<ZnBase>,
123        A: Allocator + Clone,
124        CreateC: Fn(Zn) -> C
125{
126    fn from(value: &'a RNSConvolution<I, C, A, CreateC>) -> Self {
127        unsafe { std::mem::transmute(value) }
128    }
129}
130
131impl CreateNTTConvolution<Global> {
132
133    ///
134    /// Creates a new [`CreateNTTConvolution`].
135    /// 
136    #[stability::unstable(feature = "enable")]
137    pub const fn new() -> Self {
138        Self { allocator: Global }
139    }
140}
141
142impl<A> FnOnce<(Zn,)> for CreateNTTConvolution<A>
143    where A: Allocator + Clone
144{
145    type Output = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>, A>;
146
147    extern "rust-call" fn call_once(self, args: (Zn,)) -> Self::Output {
148        self.call(args)
149    }
150}
151
152impl<A> FnMut<(Zn,)> for CreateNTTConvolution<A>
153    where A: Allocator + Clone
154{
155    extern "rust-call" fn call_mut(&mut self, args: (Zn,)) -> Self::Output {
156        self.call(args)
157    }
158}
159
160impl<A> Fn<(Zn,)> for CreateNTTConvolution<A>
161    where A: Allocator + Clone
162{
163    extern "rust-call" fn call(&self, args: (Zn,)) -> Self::Output {
164        let ring = args.0;
165        let ring_fastmul = ZnFastmul::new(ring).unwrap();
166        let hom = ring.into_can_hom(ring_fastmul).ok().unwrap();
167        NTTConvolution::new_with_hom(hom, self.allocator.clone())
168    }
169}
170
171impl RNSConvolution {
172
173    ///
174    /// Creates a new [`RNSConvolution`] that can compute convolutions of sequences with output
175    /// length `<= 2^max_log2_n`. As base convolution, the [`NTTConvolution`] is used.
176    /// 
177    #[stability::unstable(feature = "enable")]
178    pub fn new(max_log2_n: usize) -> Self {
179        Self::new_with_convolution(max_log2_n, usize::MAX, BigIntRing::RING, Global, CreateNTTConvolution { allocator: Global })
180    }
181}
182
183impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
184    where I: RingStore + Clone,
185        I::Type: IntegerRing,
186        C: ConvolutionAlgorithm<ZnBase>,
187        A: Allocator + Clone,
188        CreateC: Fn(Zn) -> C
189{
190    ///
191    /// Creates a new [`RNSConvolution`] with all the given configuration parameters.
192    /// 
193    /// In particular
194    ///  - `required_root_of_unity_log2` and `max_prime_size_log2` control which prime factors
195    ///    are used for the underlying composite modulus; Only primes `<= 2^max_prime_size_log2` and
196    ///    `= 1` mod `required_root_of_unity_log2` are sampled
197    ///  - `integer_ring` is the ring to store intermediate lifts in; this probably has to be [`BigIntRing`],
198    ///    unless inputs are pretty small
199    ///  - `allocator` is used to allocate elements modulo the internal modulus, as elements of [`zn_rns::Zn`]
200    ///  - `create_convolution` is called whenever a new convolution algorithm for a new prime has to be
201    ///    created; the modulus of the given [`Zn`] always satisfies the constraints defined by `max_prime_size_log2`
202    ///    and `required_root_of_unity_log2`
203    /// 
204    #[stability::unstable(feature = "enable")]
205    pub fn new_with_convolution(required_root_of_unity_log2: usize, mut max_prime_size_log2: usize, integer_ring: I, allocator: A, create_convolution: CreateC) -> Self {
206        max_prime_size_log2 = min(max_prime_size_log2, 57);
207        let result = Self {
208            integer_ring: integer_ring,
209            create_convolution: create_convolution,
210            convolutions: LazyVec::new(),
211            rns_rings: LazyVec::new(),
212            required_root_of_unity_log2: required_root_of_unity_log2,
213            allocator: allocator
214        };
215        let initial_ring = zn_rns::Zn::new_with_alloc(
216            vec![Zn::new(Self::sample_next_prime(required_root_of_unity_log2, (1 << max_prime_size_log2) + 1).unwrap() as u64)],
217            result.integer_ring.clone(), 
218            result.allocator.clone()
219        );
220        _ = result.rns_rings.get_or_init(0, || initial_ring);
221        return result;
222    }
223
224    fn sample_next_prime(required_root_of_unity_log2: usize, current: i64) -> Option<i64> {
225        let mut k = StaticRing::<i64>::RING.checked_div(&(current - 1), &(1 << required_root_of_unity_log2)).unwrap();
226        while k > 0 {
227            k -= 1;
228            let candidate = (k << required_root_of_unity_log2) + 1;
229            if is_prime(StaticRing::<i64>::RING, &candidate, 10) {
230                return Some(candidate);
231            }
232        }
233        return None;
234    }
235
236    fn get_rns_ring(&self, moduli_count: usize) -> &zn_rns::Zn<Zn, I, A> {
237        self.rns_rings.get_or_init_incremental(moduli_count - 1, |_, prev| zn_rns::Zn::new_with_alloc(
238            prev.as_iter().cloned().chain([Zn::new(Self::sample_next_prime(self.required_root_of_unity_log2, *prev.at(prev.len() - 1).modulus()).unwrap() as u64)]).collect(),
239            self.integer_ring.clone(),
240            self.allocator.clone()
241        ))
242    }
243
244    fn get_rns_factor(&self, i: usize) -> &Zn {
245        let rns_ring = self.get_rns_ring(i + 1);
246        return rns_ring.at(rns_ring.len() - 1);
247    }
248
249    fn get_convolution(&self, i: usize) -> &C {
250        self.convolutions.get_or_init(i, || (self.create_convolution)(*self.get_rns_factor(i)))
251    }
252
253    ///
254    /// "width" refers to the number of RNS factors we need
255    /// 
256    fn compute_required_width(&self, input_size_log2: usize, lhs_len: usize, rhs_len: usize, inner_prod_len: usize) -> usize {
257        let log2_output_size = input_size_log2 * 2 + 
258            StaticRing::<i64>::RING.abs_log2_ceil(&min(lhs_len, rhs_len).try_into().unwrap()).unwrap_or(0) +
259            StaticRing::<i64>::RING.abs_log2_ceil(&inner_prod_len.try_into().unwrap()).unwrap_or(0) +
260            1;
261        let mut width = log2_output_size.div_ceil(57);
262        while log2_output_size > self.integer_ring.abs_log2_floor(self.get_rns_ring(width).modulus()).unwrap() {
263            width += 1;
264        }
265        return width;
266    }
267
268    fn get_log2_input_size<R, V1, V2, ToInt>(
269        &self,
270        lhs: V1,
271        lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
272        rhs: V2,
273        rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
274        _ring: &R,
275        mut to_int: ToInt,
276        ring_log2_el_size: Option<usize>
277    ) -> usize
278        where R: ?Sized + RingBase,
279            V1: VectorView<R::Element>,
280            V2: VectorView<R::Element>,
281            ToInt: FnMut(&R::Element) -> El<I>,
282    {
283        if let Some(log2_data_size) = ring_log2_el_size {
284            assert!(lhs_prep.is_none() || lhs_prep.unwrap().log2_data_size == log2_data_size);
285            assert!(rhs_prep.is_none() || rhs_prep.unwrap().log2_data_size == log2_data_size);
286            log2_data_size
287        } else { 
288            max(
289                if let Some(lhs_prep) = lhs_prep {
290                    lhs_prep.log2_data_size
291                } else {
292                    lhs.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
293                },
294                if let Some(rhs_prep) = rhs_prep {
295                    rhs_prep.log2_data_size
296                } else {
297                    rhs.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
298                }
299            )
300        }
301    }
302    
303    fn get_prepared_operand<'a, R, V>(
304        &self,
305        data: V,
306        data_prep: &'a PreparedConvolutionOperand<R, C>,
307        rns_index: usize,
308        _ring: &R
309    ) -> &'a C::PreparedConvolutionOperand
310        where R: ?Sized + RingBase,
311            V: VectorView<El<Zn>> + Copy
312    {
313        data_prep.prepared.get_or_init(rns_index, || self.get_convolution(rns_index).prepare_convolution_operand(data, data_prep.len_hint, self.get_rns_factor(rns_index)))
314    }
315
316    fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
317        &self,
318        lhs: V1,
319        lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
320        rhs: V2,
321        rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
322        dst: &mut [R::Element],
323        ring: &R,
324        mut to_int: ToInt,
325        mut from_int: FromInt,
326        ring_log2_el_size: Option<usize>
327    )
328        where R: ?Sized + RingBase,
329            V1: VectorView<R::Element>,
330            V2: VectorView<R::Element>,
331            ToInt: FnMut(&R::Element) -> El<I>,
332            FromInt: FnMut(El<I>) -> R::Element
333    {
334        if lhs.len() == 0 || rhs.len() == 0 {
335            return;
336        }
337
338        let input_size_log2 = self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size);
339        let width = self.compute_required_width(input_size_log2, lhs.len(), rhs.len(), 1);
340        let len = lhs.len() + rhs.len() - 1;
341
342        let mut res_data = Vec::with_capacity_in(len * width, self.allocator.clone());
343        for i in 0..width {
344            res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
345        }
346        let mut lhs_tmp = Vec::with_capacity_in(lhs.len(), self.allocator.clone());
347        let mut rhs_tmp = Vec::with_capacity_in(rhs.len(), self.allocator.clone());
348        for i in 0..width {
349            let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
350            lhs_tmp.clear();
351            lhs_tmp.extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
352            rhs_tmp.clear();
353            rhs_tmp.extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
354            self.get_convolution(i).compute_convolution_prepared(
355                &lhs_tmp, 
356                lhs_prep.map(|lhs_prep| self.get_prepared_operand(&lhs_tmp, lhs_prep, i, ring)),
357                &rhs_tmp, 
358                rhs_prep.map(|rhs_prep| self.get_prepared_operand(&rhs_tmp, rhs_prep, i, ring)),
359                &mut res_data[(i * len)..((i + 1) * len)], 
360                self.get_rns_factor(i)
361            );
362        }
363        for j in 0..len {
364            let add = self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * len + j])));
365            ring.add_assign(&mut dst[j], from_int(add));
366        }
367    }
368
369    fn compute_convolution_sum_impl<'a, R, J, V1, V2, ToInt, FromInt>(
370        &self, 
371        values: J, 
372        dst: &mut [R::Element], 
373        ring: &R,
374        mut to_int: ToInt,
375        mut from_int: FromInt,
376        ring_log2_el_size: Option<usize>
377    ) 
378        where R: ?Sized + RingBase, 
379            J: ExactSizeIterator<Item = (V1, Option<&'a PreparedConvolutionOperand<R, C>>, V2, Option<&'a PreparedConvolutionOperand<R, C>>)>,
380            V1: VectorView<R::Element>,
381            V2: VectorView<R::Element>,
382            ToInt: FnMut(&R::Element) -> El<I>,
383            FromInt: FnMut(El<I>) -> R::Element,
384            Self: 'a,
385            R: 'a
386    {
387        let out_len = dst.len();
388        let inner_product_length = dst.len();
389
390        let mut current_width = 0;
391        let mut current_input_size_log2 = 0;
392        let mut lhs_max_len = 0;
393        let mut rhs_max_len = 0;
394        let mut res_data = Vec::new_in(self.allocator.clone());
395        let mut lhs_tmp = Vec::new_in(self.allocator.clone());
396        let mut rhs_tmp = Vec::new_in(self.allocator.clone());
397        
398        // the algorithm is as follows:
399        //  - we keep track of the current "width" (i.e. number of RNS factors) to represent the result
400        //  - we collect iterator elements, until the current width is not sufficient anymore
401        //  - then we do `merge_current()`: forward all collected samples to the child convolutions,
402        //    add the result to `dst`, and clear the buffers; continue with updated width
403
404        let mut merge_current = |
405            current_width: usize, 
406            lhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>, 
407            rhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>
408        | {
409            if current_width == 0 {
410                lhs_tmp.clear();
411                rhs_tmp.clear();
412                return;
413            }
414            res_data.clear();
415            for i in 0..current_width {
416                res_data.extend((0..out_len).map(|_| self.get_rns_factor(i).zero()));
417                self.get_convolution(i).compute_convolution_sum(
418                    lhs_tmp.iter().zip(rhs_tmp.iter()).map(|((lhs, lhs_prep), (rhs, rhs_prep))| {
419                        let lhs_data = &lhs[(i * lhs.len() / current_width)..((i + 1) * lhs.len() / current_width)];
420                        let rhs_data = &rhs[(i * rhs.len() / current_width)..((i + 1) * rhs.len() / current_width)];
421                        (
422                            lhs_data,
423                            lhs_prep.map(|lhs_prep| self.get_prepared_operand(lhs_data, lhs_prep, i, ring)),
424                            rhs_data,
425                            rhs_prep.map(|rhs_prep| self.get_prepared_operand(rhs_data, rhs_prep, i, ring)),
426                    )
427                    }),
428                    &mut res_data[(i * out_len)..((i + 1) * out_len)],
429                    self.get_rns_factor(i)
430                );
431            }
432            lhs_tmp.clear();
433            rhs_tmp.clear();
434            for j in 0..out_len {
435                let add = self.get_rns_ring(current_width).smallest_lift(self.get_rns_ring(current_width).from_congruence((0..current_width).map(|i| res_data[i * out_len + j])));
436                ring.add_assign(&mut dst[j], from_int(add));
437            }
438        };
439
440        for (lhs, lhs_prep, rhs, rhs_prep) in values {
441            if lhs.len() == 0 || rhs.len() == 0 {
442                continue;
443            }
444            assert!(out_len >= lhs.len() + rhs.len() - 1);
445            current_input_size_log2 = max(
446                current_input_size_log2,
447                self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size)
448            );
449            lhs_max_len = max(lhs_max_len, lhs.len());
450            rhs_max_len = max(rhs_max_len, rhs.len());
451            let required_width = self.compute_required_width(current_input_size_log2, lhs_max_len, rhs_max_len, inner_product_length);
452
453            if required_width > current_width {
454                merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
455                current_width = required_width;
456            }
457
458            lhs_tmp.push((Vec::with_capacity_in(lhs.len() * current_width, self.allocator.clone()), lhs_prep));
459            rhs_tmp.push((Vec::with_capacity_in(rhs.len() * current_width, self.allocator.clone()), rhs_prep));
460            for i in 0..current_width {
461                let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
462                lhs_tmp.last_mut().unwrap().0.extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
463                rhs_tmp.last_mut().unwrap().0.extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
464            }
465        }
466        merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
467    }
468    
469    fn prepare_convolution_impl<R, V, ToInt>(
470        &self,
471        data: V,
472        _ring: &R,
473        length_hint: Option<usize>,
474        mut to_int: ToInt,
475        ring_log2_el_size: Option<usize>
476    ) -> PreparedConvolutionOperand<R, C>
477        where R: ?Sized + RingBase,
478            V: VectorView<R::Element>,
479            ToInt: FnMut(&R::Element) -> El<I>
480    {
481        let input_size_log2 = if let Some(log2_data_size) = ring_log2_el_size {
482            log2_data_size
483        } else { 
484            data.as_iter().map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
485        };
486        return PreparedConvolutionOperand {
487            ring: PhantomData,
488            len_hint: length_hint,
489            prepared: LazyVec::new(),
490            log2_data_size: input_size_log2
491        };
492    }
493}
494
495impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolution<I, C, A, CreateC>
496    where I: RingStore + Clone,
497        I::Type: IntegerRing,
498        C: ConvolutionAlgorithm<ZnBase>,
499        A: Allocator + Clone,
500        CreateC: Fn(Zn) -> C,
501        R: ?Sized + IntegerRing
502{
503    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
504
505    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
506        self.compute_convolution_impl(
507            lhs,
508            None,
509            rhs,
510            None,
511            dst,
512            ring.get_ring(),
513            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
514            |x| int_cast(x, ring, &self.integer_ring),
515            None
516        )
517    }
518
519    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
520        true
521    }
522
523    fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
524        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
525    {
526        self.prepare_convolution_impl(
527            val,
528            ring.get_ring(),
529            len_hint,
530            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
531            None
532        )
533    }
534    
535    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
536        where S: RingStore<Type = R> + Copy,
537            V1: VectorView<El<S>>,
538            V2: VectorView<El<S>>
539    {
540        self.compute_convolution_impl(
541            lhs,
542            lhs_prep,
543            rhs,
544            rhs_prep,
545            dst,
546            ring.get_ring(),
547            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
548            |x| int_cast(x, ring, &self.integer_ring),
549            None
550        )
551    }
552
553    fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S) 
554        where S: RingStore<Type = R> + Copy, 
555            J: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
556            V1: VectorView<R::Element>,
557            V2: VectorView<R::Element>,
558            Self: 'a,
559            R: 'a
560    {
561        self.compute_convolution_sum_impl(
562            values,
563            dst,
564            ring.get_ring(),
565            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
566            |x| int_cast(x, ring, &self.integer_ring),
567            None
568        )
569    }
570}
571
572impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolutionZn<I, C, A, CreateC>
573    where I: RingStore + Clone,
574        I::Type: IntegerRing,
575        C: ConvolutionAlgorithm<ZnBase>,
576        A: Allocator + Clone,
577        CreateC: Fn(Zn) -> C,
578        R: ?Sized + ZnRing + CanHomFrom<I::Type>
579{
580    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
581
582    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
583        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
584        self.base.compute_convolution_impl(
585            lhs,
586            None,
587            rhs,
588            None,
589            dst,
590            ring.get_ring(),
591            |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
592            |x| hom.map(x),
593            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
594        )
595    }
596
597    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
598        true
599    }
600
601    fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
602        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
603    {
604        self.base.prepare_convolution_impl(
605            val,
606            ring.get_ring(),
607            len_hint,
608            |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
609            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
610        )
611    }
612
613    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
614        where S: RingStore<Type = R> + Copy,
615            V1: VectorView<El<S>>,
616            V2: VectorView<El<S>>
617    {
618        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
619        self.base.compute_convolution_impl(
620            lhs,
621            lhs_prep,
622            rhs,
623            rhs_prep,
624            dst,
625            ring.get_ring(),
626            |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
627            |x| hom.map(x),
628            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
629        )
630    }
631
632    fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S) 
633        where S: RingStore<Type = R> + Copy, 
634            J: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
635            V1: VectorView<R::Element>,
636            V2: VectorView<R::Element>,
637            Self: 'a,
638            R: 'a
639    {
640        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
641        self.base.compute_convolution_sum_impl(
642            values,
643            dst,
644            ring.get_ring(),
645            |x| int_cast(ring.smallest_lift(ring.clone_el(x)), &self.base.integer_ring, ring.integer_ring()),
646            |x| hom.map(x),
647            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
648        )
649    }
650}
651
652#[cfg(test)]
653use super::STANDARD_CONVOLUTION;
654
655#[test]
656fn test_convolution_integer() {
657    let ring = StaticRing::<i128>::RING;
658    let convolution = RNSConvolution::new_with_convolution(7, usize::MAX, BigIntRing::RING, Global, NTTConvolution::new);
659
660    super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
661}
662
663#[test]
664fn test_convolution_zn() {
665    let ring = Zn::new((1 << 57) + 1);
666    let convolution = RNSConvolutionZn::from(RNSConvolution::new_with_convolution(7, usize::MAX, BigIntRing::RING, Global, NTTConvolution::new));
667
668    super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
669}
670
671#[test]
672fn test_convolution_sum() {
673    let ring = StaticRing::<i128>::RING;
674    let convolution = RNSConvolution::new_with_convolution(7, 20, BigIntRing::RING, Global, NTTConvolution::new);
675    
676    let data = (0..40usize).map(|i| (
677        (0..(5 + i % 5)).map(|x| (1 << i) * (x as i128 - 2)).collect::<Vec<_>>(),
678        (0..(13 - i % 7)).map(|x| (1 << i) * (x as i128 + 1)).collect::<Vec<_>>(),
679    ));
680    let mut expected = (0..22).map(|_| 0).collect::<Vec<_>>();
681    STANDARD_CONVOLUTION.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut expected, ring);
682
683    let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
684    convolution.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut actual, ring);
685    assert_eq!(&expected[..21], actual);
686    
687    let data_prep = data.clone().map(|(l, r)| {
688        let l_prep = convolution.prepare_convolution_operand(&l, Some(21), ring);
689        let r_prep = convolution.prepare_convolution_operand(&r, Some(21), ring);
690        (l, l_prep, r, r_prep)
691    }).collect::<Vec<_>>();
692    let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
693    convolution.compute_convolution_sum(data_prep.iter().map(|(l, l_prep, r, r_prep)| (l, Some(l_prep), r, Some(r_prep))), &mut actual, ring);
694    assert_eq!(&expected[..21], actual);
695    
696    let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
697    convolution.compute_convolution_sum(data_prep.iter().enumerate().map(|(i, (l, l_prep, r, r_prep))| match i % 4 {
698        0 => (l, Some(l_prep), r, Some(r_prep)),
699        1 => (l, None, r, Some(r_prep)),
700        2 => (l, Some(l_prep), r, None),
701        3 => (l, None, r, None),
702        _ => unreachable!()
703    }), &mut actual, ring);
704    assert_eq!(&expected[..21], actual);
705}