Skip to main content

feanor_math/algorithms/convolution/
rns.rs

1use std::alloc::{Allocator, Global};
2use std::cmp::{max, min};
3use std::marker::PhantomData;
4
5use super::ConvolutionAlgorithm;
6use super::ntt::NTTConvolution;
7use crate::algorithms::miller_rabin::is_prime;
8use crate::divisibility::*;
9use crate::homomorphism::*;
10use crate::integer::*;
11use crate::lazy::LazyVec;
12use crate::primitive_int::StaticRing;
13use crate::ring::*;
14use crate::rings::zn::zn_64::{Zn, ZnBase, ZnFastmul, ZnFastmulBase};
15use crate::rings::zn::*;
16use crate::seq::*;
17
18/// A [`ConvolutionAlgorithm`] that computes convolutions by computing them modulo a
19/// suitable composite modulus `q`, whose factors are of a certain shape, usually such
20/// as to allow for NTT-based convolutions.
21///
22/// Due to overlapping blanket impls, this type can only be used to compute convolutions
23/// over [`IntegerRing`]s. For computing convolutions over [`ZnRing`]s, wrap it in a
24/// [`RNSConvolutionZn`].
25#[stability::unstable(feature = "enable")]
26pub struct RNSConvolution<
27    I = BigIntRing,
28    C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>,
29    A = Global,
30    CreateC = CreateNTTConvolution,
31> where
32    I: RingStore + Clone,
33    I::Type: IntegerRing,
34    C: ConvolutionAlgorithm<ZnBase>,
35    A: Allocator + Clone,
36    CreateC: Fn(Zn) -> C,
37{
38    integer_ring: I,
39    rns_rings: LazyVec<zn_rns::Zn<Zn, I, A>>,
40    convolutions: LazyVec<C>,
41    create_convolution: CreateC,
42    required_root_of_unity_log2: usize,
43    allocator: A,
44}
45
46/// Same as [`RNSConvolution`], but computes convolutions over [`ZnRing`]s.
47#[stability::unstable(feature = "enable")]
48#[repr(transparent)]
49pub struct RNSConvolutionZn<
50    I = BigIntRing,
51    C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>,
52    A = Global,
53    CreateC = CreateNTTConvolution,
54> where
55    I: RingStore + Clone,
56    I::Type: IntegerRing,
57    C: ConvolutionAlgorithm<ZnBase>,
58    A: Allocator + Clone,
59    CreateC: Fn(Zn) -> C,
60{
61    base: RNSConvolution<I, C, A, CreateC>,
62}
63
64/// A prepared convolution operand for a [`RNSConvolution`].
65#[stability::unstable(feature = "enable")]
66pub struct PreparedConvolutionOperand<R, C = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>>>
67where
68    R: ?Sized + RingBase,
69    C: ConvolutionAlgorithm<ZnBase>,
70{
71    prepared: LazyVec<C::PreparedConvolutionOperand>,
72    log2_data_size: usize,
73    ring: PhantomData<R>,
74    len_hint: Option<usize>,
75}
76
77/// Function that creates an [`NTTConvolution`] when given a suitable modulus.
78#[stability::unstable(feature = "enable")]
79pub struct CreateNTTConvolution<A = Global>
80where
81    A: Allocator + Clone,
82{
83    allocator: A,
84}
85
86impl<I, C, A, CreateC> From<RNSConvolutionZn<I, C, A, CreateC>> for RNSConvolution<I, C, A, CreateC>
87where
88    I: RingStore + Clone,
89    I::Type: IntegerRing,
90    C: ConvolutionAlgorithm<ZnBase>,
91    A: Allocator + Clone,
92    CreateC: Fn(Zn) -> C,
93{
94    fn from(value: RNSConvolutionZn<I, C, A, CreateC>) -> Self { value.base }
95}
96
97impl<'a, I, C, A, CreateC> From<&'a RNSConvolutionZn<I, C, A, CreateC>> for &'a RNSConvolution<I, C, A, CreateC>
98where
99    I: RingStore + Clone,
100    I::Type: IntegerRing,
101    C: ConvolutionAlgorithm<ZnBase>,
102    A: Allocator + Clone,
103    CreateC: Fn(Zn) -> C,
104{
105    fn from(value: &'a RNSConvolutionZn<I, C, A, CreateC>) -> Self { &value.base }
106}
107
108impl<I, C, A, CreateC> From<RNSConvolution<I, C, A, CreateC>> for RNSConvolutionZn<I, C, A, CreateC>
109where
110    I: RingStore + Clone,
111    I::Type: IntegerRing,
112    C: ConvolutionAlgorithm<ZnBase>,
113    A: Allocator + Clone,
114    CreateC: Fn(Zn) -> C,
115{
116    fn from(value: RNSConvolution<I, C, A, CreateC>) -> Self { RNSConvolutionZn { base: value } }
117}
118
119impl<'a, I, C, A, CreateC> From<&'a RNSConvolution<I, C, A, CreateC>> for &'a RNSConvolutionZn<I, C, A, CreateC>
120where
121    I: RingStore + Clone,
122    I::Type: IntegerRing,
123    C: ConvolutionAlgorithm<ZnBase>,
124    A: Allocator + Clone,
125    CreateC: Fn(Zn) -> C,
126{
127    fn from(value: &'a RNSConvolution<I, C, A, CreateC>) -> Self { unsafe { std::mem::transmute(value) } }
128}
129
130impl CreateNTTConvolution<Global> {
131    /// Creates a new [`CreateNTTConvolution`].
132    #[stability::unstable(feature = "enable")]
133    pub const fn new() -> Self { Self { allocator: Global } }
134}
135
136impl<A> FnOnce<(Zn,)> for CreateNTTConvolution<A>
137where
138    A: Allocator + Clone,
139{
140    type Output = NTTConvolution<ZnBase, ZnFastmulBase, CanHom<ZnFastmul, Zn>, A>;
141
142    extern "rust-call" fn call_once(self, args: (Zn,)) -> Self::Output { self.call(args) }
143}
144
145impl<A> FnMut<(Zn,)> for CreateNTTConvolution<A>
146where
147    A: Allocator + Clone,
148{
149    extern "rust-call" fn call_mut(&mut self, args: (Zn,)) -> Self::Output { self.call(args) }
150}
151
152impl<A> Fn<(Zn,)> for CreateNTTConvolution<A>
153where
154    A: Allocator + Clone,
155{
156    extern "rust-call" fn call(&self, args: (Zn,)) -> Self::Output {
157        let ring = args.0;
158        let ring_fastmul = ZnFastmul::new(ring).unwrap();
159        let hom = ring.into_can_hom(ring_fastmul).ok().unwrap();
160        NTTConvolution::new_with_hom(hom, self.allocator.clone())
161    }
162}
163
164impl RNSConvolution {
165    /// Creates a new [`RNSConvolution`] that can compute convolutions of sequences with output
166    /// length `<= 2^max_log2_n`. As base convolution, the [`NTTConvolution`] is used.
167    #[stability::unstable(feature = "enable")]
168    pub fn new(max_log2_n: usize) -> Self {
169        Self::new_with_convolution(
170            max_log2_n,
171            usize::MAX,
172            BigIntRing::RING,
173            Global,
174            CreateNTTConvolution { allocator: Global },
175        )
176    }
177}
178
179impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
180where
181    I: RingStore + Clone,
182    I::Type: IntegerRing,
183    C: ConvolutionAlgorithm<ZnBase>,
184    A: Allocator + Clone,
185    CreateC: Fn(Zn) -> C,
186{
187    /// Creates a new [`RNSConvolution`] with all the given configuration parameters.
188    ///
189    /// In particular
190    ///  - `required_root_of_unity_log2` and `max_prime_size_log2` control which prime factors are
191    ///    used for the underlying composite modulus; Only primes `<= 2^max_prime_size_log2` and `=
192    ///    1` mod `required_root_of_unity_log2` are sampled
193    ///  - `integer_ring` is the ring to store intermediate lifts in; this probably has to be
194    ///    [`BigIntRing`], unless inputs are pretty small
195    ///  - `allocator` is used to allocate elements modulo the internal modulus, as elements of
196    ///    [`zn_rns::Zn`]
197    ///  - `create_convolution` is called whenever a new convolution algorithm for a new prime has
198    ///    to be created; the modulus of the given [`Zn`] always satisfies the constraints defined
199    ///    by `max_prime_size_log2` and `required_root_of_unity_log2`
200    #[stability::unstable(feature = "enable")]
201    pub fn new_with_convolution(
202        required_root_of_unity_log2: usize,
203        mut max_prime_size_log2: usize,
204        integer_ring: I,
205        allocator: A,
206        create_convolution: CreateC,
207    ) -> Self {
208        max_prime_size_log2 = min(max_prime_size_log2, 57);
209        let result = Self {
210            integer_ring,
211            create_convolution,
212            convolutions: LazyVec::new(),
213            rns_rings: LazyVec::new(),
214            required_root_of_unity_log2,
215            allocator,
216        };
217        let initial_ring = zn_rns::Zn::new_with_alloc(
218            vec![Zn::new(
219                Self::sample_next_prime(required_root_of_unity_log2, (1 << max_prime_size_log2) + 1).unwrap() as u64,
220            )],
221            result.integer_ring.clone(),
222            result.allocator.clone(),
223        );
224        _ = result.rns_rings.get_or_init(0, || initial_ring);
225        return result;
226    }
227
228    fn sample_next_prime(required_root_of_unity_log2: usize, current: i64) -> Option<i64> {
229        let mut k = StaticRing::<i64>::RING
230            .checked_div(&(current - 1), &(1 << required_root_of_unity_log2))
231            .unwrap();
232        while k > 0 {
233            k -= 1;
234            let candidate = (k << required_root_of_unity_log2) + 1;
235            if is_prime(StaticRing::<i64>::RING, &candidate, 10) {
236                return Some(candidate);
237            }
238        }
239        return None;
240    }
241
242    fn get_rns_ring(&self, moduli_count: usize) -> &zn_rns::Zn<Zn, I, A> {
243        self.rns_rings.get_or_init_incremental(moduli_count - 1, |_, prev| {
244            zn_rns::Zn::new_with_alloc(
245                prev.as_iter()
246                    .cloned()
247                    .chain([Zn::new(
248                        Self::sample_next_prime(self.required_root_of_unity_log2, *prev.at(prev.len() - 1).modulus())
249                            .unwrap() as u64,
250                    )])
251                    .collect(),
252                self.integer_ring.clone(),
253                self.allocator.clone(),
254            )
255        })
256    }
257
258    fn get_rns_factor(&self, i: usize) -> &Zn {
259        let rns_ring = self.get_rns_ring(i + 1);
260        return rns_ring.at(rns_ring.len() - 1);
261    }
262
263    fn get_convolution(&self, i: usize) -> &C {
264        self.convolutions
265            .get_or_init(i, || (self.create_convolution)(*self.get_rns_factor(i)))
266    }
267
268    /// "width" refers to the number of RNS factors we need
269    fn compute_required_width(
270        &self,
271        input_size_log2: usize,
272        lhs_len: usize,
273        rhs_len: usize,
274        inner_prod_len: usize,
275    ) -> usize {
276        let log2_output_size = input_size_log2 * 2
277            + StaticRing::<i64>::RING
278                .abs_log2_ceil(&min(lhs_len, rhs_len).try_into().unwrap())
279                .unwrap_or(0)
280            + StaticRing::<i64>::RING
281                .abs_log2_ceil(&inner_prod_len.try_into().unwrap())
282                .unwrap_or(0)
283            + 1;
284        let mut width = log2_output_size.div_ceil(57);
285        while log2_output_size
286            > self
287                .integer_ring
288                .abs_log2_floor(self.get_rns_ring(width).modulus())
289                .unwrap()
290        {
291            width += 1;
292        }
293        return width;
294    }
295
296    fn get_log2_input_size<R, V1, V2, ToInt>(
297        &self,
298        lhs: V1,
299        lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
300        rhs: V2,
301        rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
302        _ring: &R,
303        mut to_int: ToInt,
304        ring_log2_el_size: Option<usize>,
305    ) -> usize
306    where
307        R: ?Sized + RingBase,
308        V1: VectorView<R::Element>,
309        V2: VectorView<R::Element>,
310        ToInt: FnMut(&R::Element) -> El<I>,
311    {
312        if let Some(log2_data_size) = ring_log2_el_size {
313            assert!(lhs_prep.is_none() || lhs_prep.unwrap().log2_data_size == log2_data_size);
314            assert!(rhs_prep.is_none() || rhs_prep.unwrap().log2_data_size == log2_data_size);
315            log2_data_size
316        } else {
317            max(
318                if let Some(lhs_prep) = lhs_prep {
319                    lhs_prep.log2_data_size
320                } else {
321                    lhs.as_iter()
322                        .map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0))
323                        .max()
324                        .unwrap()
325                },
326                if let Some(rhs_prep) = rhs_prep {
327                    rhs_prep.log2_data_size
328                } else {
329                    rhs.as_iter()
330                        .map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0))
331                        .max()
332                        .unwrap()
333                },
334            )
335        }
336    }
337
338    fn get_prepared_operand<'a, R, V>(
339        &self,
340        data: V,
341        data_prep: &'a PreparedConvolutionOperand<R, C>,
342        rns_index: usize,
343        _ring: &R,
344    ) -> &'a C::PreparedConvolutionOperand
345    where
346        R: ?Sized + RingBase,
347        V: VectorView<El<Zn>> + Copy,
348    {
349        data_prep.prepared.get_or_init(rns_index, || {
350            self.get_convolution(rns_index).prepare_convolution_operand(
351                data,
352                data_prep.len_hint,
353                self.get_rns_factor(rns_index),
354            )
355        })
356    }
357
358    fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
359        &self,
360        lhs: V1,
361        lhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
362        rhs: V2,
363        rhs_prep: Option<&PreparedConvolutionOperand<R, C>>,
364        dst: &mut [R::Element],
365        ring: &R,
366        mut to_int: ToInt,
367        mut from_int: FromInt,
368        ring_log2_el_size: Option<usize>,
369    ) where
370        R: ?Sized + RingBase,
371        V1: VectorView<R::Element>,
372        V2: VectorView<R::Element>,
373        ToInt: FnMut(&R::Element) -> El<I>,
374        FromInt: FnMut(El<I>) -> R::Element,
375    {
376        if lhs.len() == 0 || rhs.len() == 0 {
377            return;
378        }
379
380        let input_size_log2 =
381            self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size);
382        let width = self.compute_required_width(input_size_log2, lhs.len(), rhs.len(), 1);
383        let len = lhs.len() + rhs.len() - 1;
384
385        let mut res_data = Vec::with_capacity_in(len * width, self.allocator.clone());
386        for i in 0..width {
387            res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
388        }
389        let mut lhs_tmp = Vec::with_capacity_in(lhs.len(), self.allocator.clone());
390        let mut rhs_tmp = Vec::with_capacity_in(rhs.len(), self.allocator.clone());
391        for i in 0..width {
392            let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
393            lhs_tmp.clear();
394            lhs_tmp.extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
395            rhs_tmp.clear();
396            rhs_tmp.extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
397            self.get_convolution(i).compute_convolution_prepared(
398                &lhs_tmp,
399                lhs_prep.map(|lhs_prep| self.get_prepared_operand(&lhs_tmp, lhs_prep, i, ring)),
400                &rhs_tmp,
401                rhs_prep.map(|rhs_prep| self.get_prepared_operand(&rhs_tmp, rhs_prep, i, ring)),
402                &mut res_data[(i * len)..((i + 1) * len)],
403                self.get_rns_factor(i),
404            );
405        }
406        for j in 0..len {
407            let add = self.get_rns_ring(width).smallest_lift(
408                self.get_rns_ring(width)
409                    .from_congruence((0..width).map(|i| res_data[i * len + j])),
410            );
411            ring.add_assign(&mut dst[j], from_int(add));
412        }
413    }
414
415    fn compute_convolution_sum_impl<'a, R, J, V1, V2, ToInt, FromInt>(
416        &self,
417        values: J,
418        dst: &mut [R::Element],
419        ring: &R,
420        mut to_int: ToInt,
421        mut from_int: FromInt,
422        ring_log2_el_size: Option<usize>,
423    ) where
424        R: ?Sized + RingBase,
425        J: ExactSizeIterator<
426            Item = (
427                V1,
428                Option<&'a PreparedConvolutionOperand<R, C>>,
429                V2,
430                Option<&'a PreparedConvolutionOperand<R, C>>,
431            ),
432        >,
433        V1: VectorView<R::Element>,
434        V2: VectorView<R::Element>,
435        ToInt: FnMut(&R::Element) -> El<I>,
436        FromInt: FnMut(El<I>) -> R::Element,
437        Self: 'a,
438        R: 'a,
439    {
440        let out_len = dst.len();
441        let inner_product_length = dst.len();
442
443        let mut current_width = 0;
444        let mut current_input_size_log2 = 0;
445        let mut lhs_max_len = 0;
446        let mut rhs_max_len = 0;
447        let mut res_data = Vec::new_in(self.allocator.clone());
448        let mut lhs_tmp = Vec::new_in(self.allocator.clone());
449        let mut rhs_tmp = Vec::new_in(self.allocator.clone());
450
451        // the algorithm is as follows:
452        //  - we keep track of the current "width" (i.e. number of RNS factors) to represent the result
453        //  - we collect iterator elements, until the current width is not sufficient anymore
454        //  - then we do `merge_current()`: forward all collected samples to the child convolutions, add the
455        //    result to `dst`, and clear the buffers; continue with updated width
456
457        let mut merge_current =
458            |current_width: usize,
459             lhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>,
460             rhs_tmp: &mut Vec<(Vec<El<Zn>, _>, Option<&'a PreparedConvolutionOperand<R, C>>), _>| {
461                if current_width == 0 {
462                    lhs_tmp.clear();
463                    rhs_tmp.clear();
464                    return;
465                }
466                res_data.clear();
467                for i in 0..current_width {
468                    res_data.extend((0..out_len).map(|_| self.get_rns_factor(i).zero()));
469                    self.get_convolution(i).compute_convolution_sum(
470                        lhs_tmp
471                            .iter()
472                            .zip(rhs_tmp.iter())
473                            .map(|((lhs, lhs_prep), (rhs, rhs_prep))| {
474                                let lhs_data =
475                                    &lhs[(i * lhs.len() / current_width)..((i + 1) * lhs.len() / current_width)];
476                                let rhs_data =
477                                    &rhs[(i * rhs.len() / current_width)..((i + 1) * rhs.len() / current_width)];
478                                (
479                                    lhs_data,
480                                    lhs_prep.map(|lhs_prep| self.get_prepared_operand(lhs_data, lhs_prep, i, ring)),
481                                    rhs_data,
482                                    rhs_prep.map(|rhs_prep| self.get_prepared_operand(rhs_data, rhs_prep, i, ring)),
483                                )
484                            }),
485                        &mut res_data[(i * out_len)..((i + 1) * out_len)],
486                        self.get_rns_factor(i),
487                    );
488                }
489                lhs_tmp.clear();
490                rhs_tmp.clear();
491                for j in 0..out_len {
492                    let add = self.get_rns_ring(current_width).smallest_lift(
493                        self.get_rns_ring(current_width)
494                            .from_congruence((0..current_width).map(|i| res_data[i * out_len + j])),
495                    );
496                    ring.add_assign(&mut dst[j], from_int(add));
497                }
498            };
499
500        for (lhs, lhs_prep, rhs, rhs_prep) in values {
501            if lhs.len() == 0 || rhs.len() == 0 {
502                continue;
503            }
504            assert!(out_len >= lhs.len() + rhs.len() - 1);
505            current_input_size_log2 = max(
506                current_input_size_log2,
507                self.get_log2_input_size(&lhs, lhs_prep, &rhs, rhs_prep, ring, &mut to_int, ring_log2_el_size),
508            );
509            lhs_max_len = max(lhs_max_len, lhs.len());
510            rhs_max_len = max(rhs_max_len, rhs.len());
511            let required_width =
512                self.compute_required_width(current_input_size_log2, lhs_max_len, rhs_max_len, inner_product_length);
513
514            if required_width > current_width {
515                merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
516                current_width = required_width;
517            }
518
519            lhs_tmp.push((
520                Vec::with_capacity_in(lhs.len() * current_width, self.allocator.clone()),
521                lhs_prep,
522            ));
523            rhs_tmp.push((
524                Vec::with_capacity_in(rhs.len() * current_width, self.allocator.clone()),
525                rhs_prep,
526            ));
527            for i in 0..current_width {
528                let hom = self.get_rns_factor(i).into_can_hom(&self.integer_ring).ok().unwrap();
529                lhs_tmp
530                    .last_mut()
531                    .unwrap()
532                    .0
533                    .extend(lhs.as_iter().map(|x| hom.map(to_int(x))));
534                rhs_tmp
535                    .last_mut()
536                    .unwrap()
537                    .0
538                    .extend(rhs.as_iter().map(|x| hom.map(to_int(x))));
539            }
540        }
541        merge_current(current_width, &mut lhs_tmp, &mut rhs_tmp);
542    }
543
544    fn prepare_convolution_impl<R, V, ToInt>(
545        &self,
546        data: V,
547        _ring: &R,
548        length_hint: Option<usize>,
549        mut to_int: ToInt,
550        ring_log2_el_size: Option<usize>,
551    ) -> PreparedConvolutionOperand<R, C>
552    where
553        R: ?Sized + RingBase,
554        V: VectorView<R::Element>,
555        ToInt: FnMut(&R::Element) -> El<I>,
556    {
557        let input_size_log2 = if let Some(log2_data_size) = ring_log2_el_size {
558            log2_data_size
559        } else {
560            data.as_iter()
561                .map(|x| self.integer_ring.abs_log2_ceil(&to_int(x)).unwrap_or(0))
562                .max()
563                .unwrap()
564        };
565        return PreparedConvolutionOperand {
566            ring: PhantomData,
567            len_hint: length_hint,
568            prepared: LazyVec::new(),
569            log2_data_size: input_size_log2,
570        };
571    }
572}
573
574impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolution<I, C, A, CreateC>
575where
576    I: RingStore + Clone,
577    I::Type: IntegerRing,
578    C: ConvolutionAlgorithm<ZnBase>,
579    A: Allocator + Clone,
580    CreateC: Fn(Zn) -> C,
581    R: ?Sized + IntegerRing,
582{
583    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
584
585    fn compute_convolution<
586        S: RingStore<Type = R> + Copy,
587        V1: VectorView<<R as RingBase>::Element>,
588        V2: VectorView<<R as RingBase>::Element>,
589    >(
590        &self,
591        lhs: V1,
592        rhs: V2,
593        dst: &mut [<R as RingBase>::Element],
594        ring: S,
595    ) {
596        self.compute_convolution_impl(
597            lhs,
598            None,
599            rhs,
600            None,
601            dst,
602            ring.get_ring(),
603            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
604            |x| int_cast(x, ring, &self.integer_ring),
605            None,
606        )
607    }
608
609    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool { true }
610
611    fn prepare_convolution_operand<S, V>(
612        &self,
613        val: V,
614        len_hint: Option<usize>,
615        ring: S,
616    ) -> Self::PreparedConvolutionOperand
617    where
618        S: RingStore<Type = R> + Copy,
619        V: VectorView<R::Element>,
620    {
621        self.prepare_convolution_impl(
622            val,
623            ring.get_ring(),
624            len_hint,
625            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
626            None,
627        )
628    }
629
630    fn compute_convolution_prepared<S, V1, V2>(
631        &self,
632        lhs: V1,
633        lhs_prep: Option<&Self::PreparedConvolutionOperand>,
634        rhs: V2,
635        rhs_prep: Option<&Self::PreparedConvolutionOperand>,
636        dst: &mut [R::Element],
637        ring: S,
638    ) where
639        S: RingStore<Type = R> + Copy,
640        V1: VectorView<El<S>>,
641        V2: VectorView<El<S>>,
642    {
643        self.compute_convolution_impl(
644            lhs,
645            lhs_prep,
646            rhs,
647            rhs_prep,
648            dst,
649            ring.get_ring(),
650            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
651            |x| int_cast(x, ring, &self.integer_ring),
652            None,
653        )
654    }
655
656    fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S)
657    where
658        S: RingStore<Type = R> + Copy,
659        J: ExactSizeIterator<
660            Item = (
661                V1,
662                Option<&'a Self::PreparedConvolutionOperand>,
663                V2,
664                Option<&'a Self::PreparedConvolutionOperand>,
665            ),
666        >,
667        V1: VectorView<R::Element>,
668        V2: VectorView<R::Element>,
669        Self: 'a,
670        R: 'a,
671    {
672        self.compute_convolution_sum_impl(
673            values,
674            dst,
675            ring.get_ring(),
676            |x| int_cast(ring.clone_el(x), &self.integer_ring, ring),
677            |x| int_cast(x, ring, &self.integer_ring),
678            None,
679        )
680    }
681}
682
683impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolutionZn<I, C, A, CreateC>
684where
685    I: RingStore + Clone,
686    I::Type: IntegerRing,
687    C: ConvolutionAlgorithm<ZnBase>,
688    A: Allocator + Clone,
689    CreateC: Fn(Zn) -> C,
690    R: ?Sized + ZnRing + CanHomFrom<I::Type>,
691{
692    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
693
694    fn compute_convolution<
695        S: RingStore<Type = R> + Copy,
696        V1: VectorView<<R as RingBase>::Element>,
697        V2: VectorView<<R as RingBase>::Element>,
698    >(
699        &self,
700        lhs: V1,
701        rhs: V2,
702        dst: &mut [<R as RingBase>::Element],
703        ring: S,
704    ) {
705        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
706        self.base.compute_convolution_impl(
707            lhs,
708            None,
709            rhs,
710            None,
711            dst,
712            ring.get_ring(),
713            |x| {
714                int_cast(
715                    ring.smallest_lift(ring.clone_el(x)),
716                    &self.base.integer_ring,
717                    ring.integer_ring(),
718                )
719            },
720            |x| hom.map(x),
721            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
722        )
723    }
724
725    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool { true }
726
727    fn prepare_convolution_operand<S, V>(
728        &self,
729        val: V,
730        len_hint: Option<usize>,
731        ring: S,
732    ) -> Self::PreparedConvolutionOperand
733    where
734        S: RingStore<Type = R> + Copy,
735        V: VectorView<R::Element>,
736    {
737        self.base.prepare_convolution_impl(
738            val,
739            ring.get_ring(),
740            len_hint,
741            |x| {
742                int_cast(
743                    ring.smallest_lift(ring.clone_el(x)),
744                    &self.base.integer_ring,
745                    ring.integer_ring(),
746                )
747            },
748            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
749        )
750    }
751
752    fn compute_convolution_prepared<S, V1, V2>(
753        &self,
754        lhs: V1,
755        lhs_prep: Option<&Self::PreparedConvolutionOperand>,
756        rhs: V2,
757        rhs_prep: Option<&Self::PreparedConvolutionOperand>,
758        dst: &mut [R::Element],
759        ring: S,
760    ) where
761        S: RingStore<Type = R> + Copy,
762        V1: VectorView<El<S>>,
763        V2: VectorView<El<S>>,
764    {
765        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
766        self.base.compute_convolution_impl(
767            lhs,
768            lhs_prep,
769            rhs,
770            rhs_prep,
771            dst,
772            ring.get_ring(),
773            |x| {
774                int_cast(
775                    ring.smallest_lift(ring.clone_el(x)),
776                    &self.base.integer_ring,
777                    ring.integer_ring(),
778                )
779            },
780            |x| hom.map(x),
781            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
782        )
783    }
784
785    fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [R::Element], ring: S)
786    where
787        S: RingStore<Type = R> + Copy,
788        J: ExactSizeIterator<
789            Item = (
790                V1,
791                Option<&'a Self::PreparedConvolutionOperand>,
792                V2,
793                Option<&'a Self::PreparedConvolutionOperand>,
794            ),
795        >,
796        V1: VectorView<R::Element>,
797        V2: VectorView<R::Element>,
798        Self: 'a,
799        R: 'a,
800    {
801        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
802        self.base.compute_convolution_sum_impl(
803            values,
804            dst,
805            ring.get_ring(),
806            |x| {
807                int_cast(
808                    ring.smallest_lift(ring.clone_el(x)),
809                    &self.base.integer_ring,
810                    ring.integer_ring(),
811                )
812            },
813            |x| hom.map(x),
814            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
815        )
816    }
817}
818
819#[cfg(test)]
820use super::STANDARD_CONVOLUTION;
821
822#[test]
823fn test_convolution_integer() {
824    let ring = StaticRing::<i128>::RING;
825    let convolution =
826        RNSConvolution::new_with_convolution(7, usize::MAX, BigIntRing::RING, Global, NTTConvolution::new);
827
828    super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
829}
830
831#[test]
832fn test_convolution_zn() {
833    let ring = Zn::new((1 << 57) + 1);
834    let convolution = RNSConvolutionZn::from(RNSConvolution::new_with_convolution(
835        7,
836        usize::MAX,
837        BigIntRing::RING,
838        Global,
839        NTTConvolution::new,
840    ));
841
842    super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
843}
844
845#[test]
846fn test_convolution_sum() {
847    let ring = StaticRing::<i128>::RING;
848    let convolution = RNSConvolution::new_with_convolution(7, 20, BigIntRing::RING, Global, NTTConvolution::new);
849
850    let data = (0..40usize).map(|i| {
851        (
852            (0..(5 + i % 5)).map(|x| (1 << i) * (x as i128 - 2)).collect::<Vec<_>>(),
853            (0..(13 - i % 7))
854                .map(|x| (1 << i) * (x as i128 + 1))
855                .collect::<Vec<_>>(),
856        )
857    });
858    let mut expected = (0..22).map(|_| 0).collect::<Vec<_>>();
859    STANDARD_CONVOLUTION.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut expected, ring);
860
861    let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
862    convolution.compute_convolution_sum(data.clone().map(|(l, r)| (l, None, r, None)), &mut actual, ring);
863    assert_eq!(&expected[..21], actual);
864
865    let data_prep = data
866        .clone()
867        .map(|(l, r)| {
868            let l_prep = convolution.prepare_convolution_operand(&l, Some(21), ring);
869            let r_prep = convolution.prepare_convolution_operand(&r, Some(21), ring);
870            (l, l_prep, r, r_prep)
871        })
872        .collect::<Vec<_>>();
873    let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
874    convolution.compute_convolution_sum(
875        data_prep
876            .iter()
877            .map(|(l, l_prep, r, r_prep)| (l, Some(l_prep), r, Some(r_prep))),
878        &mut actual,
879        ring,
880    );
881    assert_eq!(&expected[..21], actual);
882
883    let mut actual = (0..21).map(|_| 0).collect::<Vec<_>>();
884    convolution.compute_convolution_sum(
885        data_prep
886            .iter()
887            .enumerate()
888            .map(|(i, (l, l_prep, r, r_prep))| match i % 4 {
889                0 => (l, Some(l_prep), r, Some(r_prep)),
890                1 => (l, None, r, Some(r_prep)),
891                2 => (l, Some(l_prep), r, None),
892                3 => (l, None, r, None),
893                _ => unreachable!(),
894            }),
895        &mut actual,
896        ring,
897    );
898    assert_eq!(&expected[..21], actual);
899}