feanor_math/algorithms/convolution/
rns.rs

1use std::alloc::{Allocator, Global};
2use std::cmp::{min, max};
3use std::marker::PhantomData;
4
5use crate::algorithms::miller_rabin::is_prime;
6use crate::homomorphism::*;
7use crate::integer::*;
8use crate::lazy::LazyVec;
9use crate::primitive_int::StaticRing;
10use crate::ring::*;
11use crate::rings::zn::zn_64::{Zn, ZnBase};
12use crate::rings::zn::*;
13use crate::divisibility::*;
14use crate::seq::*;
15
16use super::ntt::NTTConvolution;
17use super::{ConvolutionAlgorithm, PreparedConvolutionAlgorithm, PreparedConvolutionOperation};
18
19///
20/// A [`ConvolutionAlgorithm`] that computes convolutions by computing them modulo a
21/// suitable composite modulus `q`, whose factors are of a certain shape, usually such
22/// as to allow for NTT-based convolutions.
23/// 
24/// Due to overlapping blanket impls, this type can only be used to compute convolutions
25/// over [`IntegerRing`]s. For computing convolutions over [`ZnRing`]s, wrap it in a
26/// [`RNSConvolutionZn`].
27/// 
28#[stability::unstable(feature = "enable")]
29pub struct RNSConvolution<I = BigIntRing, C = NTTConvolution<Zn>, A = Global, CreateC = CreateNTTConvolution>
30    where I: RingStore + Clone,
31        I::Type: IntegerRing,
32        C: ConvolutionAlgorithm<ZnBase>,
33        A: Allocator + Clone,
34        CreateC: Fn(Zn) -> C
35{
36    integer_ring: I,
37    rns_rings: LazyVec<zn_rns::Zn<Zn, I, A>>,
38    convolutions: LazyVec<C>,
39    create_convolution: CreateC,
40    required_root_of_unity_log2: usize,
41    allocator: A
42}
43
44///
45/// Same as [`RNSConvolution`], but computes convolutions over [`ZnRing`]s.
46/// 
47#[stability::unstable(feature = "enable")]
48#[repr(transparent)]
49pub struct RNSConvolutionZn<I = BigIntRing, C = NTTConvolution<Zn>, A = Global, CreateC = CreateNTTConvolution>
50    where I: RingStore + Clone,
51        I::Type: IntegerRing,
52        C: ConvolutionAlgorithm<ZnBase>,
53        A: Allocator + Clone,
54        CreateC: Fn(Zn) -> C
55{
56    base: RNSConvolution<I, C, A, CreateC>
57}
58
59///
60/// A prepared convolution operand for a [`RNSConvolution`].
61/// 
62#[stability::unstable(feature = "enable")]
63pub struct PreparedConvolutionOperand<R, C = NTTConvolution<Zn>>
64    where R: ?Sized + RingBase,
65        C: PreparedConvolutionAlgorithm<ZnBase>
66{
67    data: Vec<R::Element>,
68    prepared: LazyVec<C::PreparedConvolutionOperand>,
69    log2_data_size: usize
70}
71
72///
73/// A prepared convolution operand for a [`RNSConvolutionZn`].
74/// 
75#[stability::unstable(feature = "enable")]
76pub struct PreparedConvolutionOperandZn<R, C = NTTConvolution<Zn>>(PreparedConvolutionOperand<R::IntegerRingBase, C>)
77    where R: ?Sized + ZnRing,
78        C: PreparedConvolutionAlgorithm<ZnBase>;
79
80///
81/// Function that creates an [`NTTConvolution`] when given a suitable modulus.
82/// 
83#[stability::unstable(feature = "enable")]
84pub struct CreateNTTConvolution<A = Global>
85    where A: Allocator + Clone
86{
87    allocator: A
88}
89
90impl<I, C, A, CreateC> From<RNSConvolutionZn<I, C, A, CreateC>> for RNSConvolution<I, C, A, CreateC>
91    where I: RingStore + Clone,
92        I::Type: IntegerRing,
93        C: ConvolutionAlgorithm<ZnBase>,
94        A: Allocator + Clone,
95        CreateC: Fn(Zn) -> C
96{
97    fn from(value: RNSConvolutionZn<I, C, A, CreateC>) -> Self {
98        value.base
99    }
100}
101
102impl<'a, I, C, A, CreateC> From<&'a RNSConvolutionZn<I, C, A, CreateC>> for &'a RNSConvolution<I, C, A, CreateC>
103    where I: RingStore + Clone,
104        I::Type: IntegerRing,
105        C: ConvolutionAlgorithm<ZnBase>,
106        A: Allocator + Clone,
107        CreateC: Fn(Zn) -> C
108{
109    fn from(value: &'a RNSConvolutionZn<I, C, A, CreateC>) -> Self {
110        &value.base
111    }
112}
113
114impl<I, C, A, CreateC> From<RNSConvolution<I, C, A, CreateC>> for RNSConvolutionZn<I, C, A, CreateC>
115    where I: RingStore + Clone,
116        I::Type: IntegerRing,
117        C: ConvolutionAlgorithm<ZnBase>,
118        A: Allocator + Clone,
119        CreateC: Fn(Zn) -> C
120{
121    fn from(value: RNSConvolution<I, C, A, CreateC>) -> Self {
122        RNSConvolutionZn { base: value }
123    }
124}
125
126impl<'a, I, C, A, CreateC> From<&'a RNSConvolution<I, C, A, CreateC>> for &'a RNSConvolutionZn<I, C, A, CreateC>
127    where I: RingStore + Clone,
128        I::Type: IntegerRing,
129        C: ConvolutionAlgorithm<ZnBase>,
130        A: Allocator + Clone,
131        CreateC: Fn(Zn) -> C
132{
133    fn from(value: &'a RNSConvolution<I, C, A, CreateC>) -> Self {
134        unsafe { std::mem::transmute(value) }
135    }
136}
137
138impl CreateNTTConvolution<Global> {
139    #[stability::unstable(feature = "enable")]
140    pub fn new() -> Self {
141        Self { allocator: Global }
142    }
143}
144
145impl<A> FnOnce<(Zn,)> for CreateNTTConvolution<A>
146    where A: Allocator + Clone
147{
148    type Output = NTTConvolution<Zn, A>;
149
150    extern "rust-call" fn call_once(self, args: (Zn,)) -> Self::Output {
151        self.call(args)
152    }
153}
154
155impl<A> FnMut<(Zn,)> for CreateNTTConvolution<A>
156    where A: Allocator + Clone
157{
158    extern "rust-call" fn call_mut(&mut self, args: (Zn,)) -> Self::Output {
159        self.call(args)
160    }
161}
162
163impl<A> Fn<(Zn,)> for CreateNTTConvolution<A>
164    where A: Allocator + Clone
165{
166    extern "rust-call" fn call(&self, args: (Zn,)) -> Self::Output {
167        NTTConvolution::new_with(args.0, self.allocator.clone())
168    }
169}
170
171impl RNSConvolution {
172
173    ///
174    /// Creates a new [`RNSConvolution`] that can compute convolutions of sequences with output
175    /// length `<= 2^max_log2_n`. As base convolution, the [`NTTConvolution`] is used.
176    /// 
177    #[stability::unstable(feature = "enable")]
178    pub fn new(max_log2_n: usize) -> Self {
179        Self::new_with(max_log2_n, usize::MAX, BigIntRing::RING, Global, CreateNTTConvolution { allocator: Global })
180    }
181}
182
183impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
184    where I: RingStore + Clone,
185        I::Type: IntegerRing,
186        C: ConvolutionAlgorithm<ZnBase>,
187        A: Allocator + Clone,
188        CreateC: Fn(Zn) -> C
189{
190    ///
191    /// Creates a new [`RNSConvolution`] with all the given configuration parameters.
192    /// 
193    /// In particular
194    ///  - `required_root_of_unity_log2` and `max_prime_size_log2` control which prime factors
195    ///    are used for the underlying composite modulus; Only primes `<= 2^max_prime_size_log2` and
196    ///    `= 1` mod `required_root_of_unity_log2` are sampled
197    ///  - `integer_ring` is the ring to store intermediate lifts in; this probably has to be [`BigIntRing`],
198    ///    unless inputs are pretty small
199    ///  - `allocator` is used to allocate elements modulo the internal modulus, as elements of [`zn_rns::Zn`]
200    ///  - `create_convolution` is called whenever a new convolution algorithm for a new prime has to be
201    ///    created; the modulus of the given [`Zn`] always satisfies the constraints defined by `max_prime_size_log2`
202    ///    and `required_root_of_unity_log2`
203    /// 
204    #[stability::unstable(feature = "enable")]
205    pub fn new_with(required_root_of_unity_log2: usize, mut max_prime_size_log2: usize, integer_ring: I, allocator: A, create_convolution: CreateC) -> Self {
206        max_prime_size_log2 = min(max_prime_size_log2, 57);
207        let result = Self {
208            integer_ring: integer_ring,
209            create_convolution: create_convolution,
210            convolutions: LazyVec::new(),
211            rns_rings: LazyVec::new(),
212            required_root_of_unity_log2: required_root_of_unity_log2,
213            allocator: allocator
214        };
215        let initial_ring = zn_rns::Zn::new_with(vec![Zn::new(Self::sample_next_prime(required_root_of_unity_log2, (1 << max_prime_size_log2) + 1).unwrap() as u64)], result.integer_ring.clone(), result.allocator.clone());
216        _ = result.rns_rings.get_or_init(0, || initial_ring);
217        return result;
218    }
219
220    fn sample_next_prime(required_root_of_unity_log2: usize, current: i64) -> Option<i64> {
221        let mut k = StaticRing::<i64>::RING.checked_div(&(current - 1), &(1 << required_root_of_unity_log2)).unwrap();
222        while k > 0 {
223            k -= 1;
224            let candidate = (k << required_root_of_unity_log2) + 1;
225            if is_prime(StaticRing::<i64>::RING, &candidate, 10) {
226                return Some(candidate);
227            }
228        }
229        return None;
230    }
231
232    fn get_rns_ring(&self, moduli_count: usize) -> &zn_rns::Zn<Zn, I, A> {
233        self.rns_rings.get_or_init_incremental(moduli_count, |_, prev| zn_rns::Zn::new_with(
234            prev.as_iter().cloned().chain([Zn::new(Self::sample_next_prime(self.required_root_of_unity_log2, *prev.at(prev.len() - 1).modulus()).unwrap() as u64)]).collect(),
235            self.integer_ring.clone(),
236            self.allocator.clone()
237        ));
238        return self.rns_rings.get(moduli_count - 1).unwrap();
239    }
240
241    fn get_rns_factor(&self, i: usize) -> &Zn {
242        let rns_ring = self.get_rns_ring(i + 1);
243        return rns_ring.at(rns_ring.len() - 1);
244    }
245
246    fn get_convolution(&self, i: usize) -> &C {
247        self.convolutions.get_or_init(i, || (self.create_convolution)(*self.get_rns_factor(i)))
248    }
249
250    fn extend_operand<R, F>(&self, operand: &PreparedConvolutionOperand<R, C>, target_width: usize, mut mod_part: F)
251        where R: ?Sized + RingBase,
252            C: PreparedConvolutionAlgorithm<ZnBase>,
253            F: FnMut(&R::Element, usize) -> El<Zn>
254    {
255        let mut tmp = Vec::new();
256        tmp.resize_with(operand.data.len(), || self.get_rns_factor(0).zero());
257        for i in 0..target_width {
258            _ = operand.prepared.get_or_init(i, || {
259                for j in 0..operand.data.len() {
260                    tmp[j] = mod_part(&operand.data[j], i);
261                }
262                self.get_convolution(i).prepare_convolution_operand(&tmp, self.get_rns_factor(i))
263            });
264        }
265    }
266
267    fn compute_required_width(&self, input_size_log2: usize, lhs_len: usize, rhs_len: usize, inner_prod_len: usize) -> usize {
268        let log2_output_size = input_size_log2 * 2 + 
269            StaticRing::<i64>::RING.abs_log2_ceil(&(min(lhs_len, rhs_len) as i64)).unwrap_or(0) +
270            StaticRing::<i64>::RING.abs_log2_ceil(&(inner_prod_len as i64)).unwrap_or(0) +
271            1;
272        let mut width = (log2_output_size - 1) / 57 + 1;
273        while log2_output_size > self.integer_ring.abs_log2_floor(self.get_rns_ring(width).modulus()).unwrap() {
274            width += 1;
275        }
276        return width;
277    }
278}
279
280impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
281    where I: RingStore + Clone,
282        I::Type: IntegerRing,
283        C: ConvolutionAlgorithm<ZnBase>,
284        A: Allocator + Clone,
285        CreateC: Fn(Zn) -> C
286{
287    fn compute_convolution_impl<S, V1, V2, D>(&self, input_size_log2: usize, lhs: V1, rhs: V2, mut dst: D, ring: S)
288        where S: RingStore,
289            S::Type: RingBase + IntegerRing,
290            D: FnMut(usize, El<I>),
291            V1: VectorFn<El<S>>,
292            V2: VectorFn<El<S>>
293    {
294        let width = self.compute_required_width(input_size_log2, lhs.len(), rhs.len(), 1);
295        let len = lhs.len() + rhs.len();
296        let mut res_data = Vec::with_capacity(len * width);
297        for i in 0..width {
298            res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
299        }
300
301        let mut lhs_tmp = Vec::with_capacity(lhs.len());
302        lhs_tmp.resize_with(lhs.len(), || self.get_rns_factor(0).zero());
303        let mut rhs_tmp = Vec::with_capacity(rhs.len());
304        rhs_tmp.resize_with(rhs.len(), || self.get_rns_factor(0).zero());
305        for i in 0..width {
306            let hom = self.get_rns_factor(i).can_hom(&ring).unwrap();
307            for j in 0..lhs.len() {
308                lhs_tmp[j] = hom.map(lhs.at(j));
309            }
310            for j in 0..rhs.len() {
311                rhs_tmp[j] = hom.map(rhs.at(j));
312            }
313            self.get_convolution(i).compute_convolution(&lhs_tmp, &rhs_tmp, &mut res_data[(i * len)..((i + 1) * len)], self.get_rns_factor(i));
314        }
315
316        for j in 0..(len - 1) {
317            dst(j, self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * len + j]))));
318        }
319    }
320}
321
322impl<I, C, A, CreateC> RNSConvolution<I, C, A, CreateC>
323    where I: RingStore + Clone,
324        I::Type: IntegerRing,
325        C: PreparedConvolutionAlgorithm<ZnBase>,
326        A: Allocator + Clone,
327        CreateC: Fn(Zn) -> C
328{
329    fn prepare_convolution_operand_impl<S, V>(&self, input_size_log2: usize, val: V, _ring: S) -> PreparedConvolutionOperand<S::Type, C>
330        where S: RingStore + Copy,
331            S::Type: IntegerRing, 
332            V: VectorFn<El<S>>
333    {
334        let mut data = Vec::with_capacity(val.len());
335        data.extend(val.iter());
336        return PreparedConvolutionOperand {
337            data: data,
338            prepared: LazyVec::new(),
339            log2_data_size: input_size_log2
340        };
341    }
342
343    fn compute_convolution_inner_product_lhs_prepared_impl<'a, S, V, D>(&self, rhs_input_size_log2: usize, values: &[(&'a PreparedConvolutionOperand<S::Type, C>, V)], mut dst: D, ring: S)
344        where S: RingStore + Copy,
345            S::Type: IntegerRing, 
346            D: FnMut(usize, El<I>),
347            V: VectorFn<El<S>>,
348            S: 'a,
349            Self: 'a,
350            PreparedConvolutionOperand<S::Type, C>: 'a
351    {
352        let max_len = values.iter().map(|(lhs, rhs)| lhs.data.len() + rhs.len()).max().unwrap_or(0);
353        let input_size_log2 = max(rhs_input_size_log2, values.iter().map(|(lhs, _)| lhs.log2_data_size).max().unwrap_or(0));
354        let width = self.compute_required_width(input_size_log2, (max_len - 1) / 2 + 1, (max_len - 1) / 2 + 1, values.len());
355        let mut res_data = Vec::with_capacity(max_len * width);
356        for i in 0..width {
357            res_data.extend((0..max_len).map(|_| self.get_rns_factor(i).zero()));
358        }
359
360        let mut rhs_tmp = Vec::with_capacity(max_len * values.len());
361        rhs_tmp.resize_with(max_len * values.len(), || self.get_rns_factor(0).zero());
362
363        let homs = (0..width).map(|i| self.get_rns_factor(i).can_hom(&ring).unwrap()).collect::<Vec<_>>();
364        for j in 0..values.len() {
365            self.extend_operand(values[j].0, width, |x, i| homs[i].map_ref(x));
366        }
367
368        for i in 0..width {
369            for j in 0..values.len() {
370                for k in 0..values[j].1.len() {
371                    rhs_tmp[j * max_len + k] = homs[i].map(values[j].1.at(k));
372                }
373            }
374            self.get_convolution(i).compute_convolution_inner_product_lhs_prepared(
375                values.iter().enumerate().map(|(j, (lhs, _))| (lhs.prepared.get(i).unwrap(), &rhs_tmp[(j * max_len)..(j * max_len + values[j].1.len())])), 
376                &mut res_data[(i * max_len)..((i + 1) * max_len)], 
377                self.get_rns_factor(i)
378            );
379        }
380        
381        for j in 0..(max_len - 1) {
382            dst(j, self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * max_len + j]))));
383        }
384    }
385
386    fn compute_convolution_inner_product_prepared_impl<'a, S, D>(&self, values: &[(&'a PreparedConvolutionOperand<S::Type, C>, &'a PreparedConvolutionOperand<S::Type, C>)], mut dst: D, ring: S)
387        where S: RingStore + Copy,
388            S::Type: IntegerRing, 
389            D: FnMut(usize, El<I>),
390            Self: 'a,
391            S: 'a,
392            PreparedConvolutionOperand<S::Type, C>: 'a
393    {
394        let max_len = values.iter().map(|(lhs, rhs)| lhs.data.len() + rhs.data.len()).max().unwrap_or(0);
395        let input_size_log2 = values.iter().map(|(lhs, rhs)| max(lhs.log2_data_size, rhs.log2_data_size)).max().unwrap_or(0);
396        let width = self.compute_required_width(input_size_log2, (max_len - 1) / 2 + 1, (max_len - 1) / 2 + 1, values.len());
397        let mut res_data = Vec::with_capacity(max_len * width);
398        for i in 0..width {
399            res_data.extend((0..max_len).map(|_| self.get_rns_factor(i).zero()));
400        }
401
402        let homs = (0..width).map(|i| self.get_rns_factor(i).can_hom(&ring).unwrap()).collect::<Vec<_>>();
403        for j in 0..values.len() {
404            self.extend_operand(values[j].0, width, |x, i| homs[i].map_ref(x));
405            self.extend_operand(values[j].1, width, |x, i| homs[i].map_ref(x));
406        }
407
408        for i in 0..width {
409            self.get_convolution(i).compute_convolution_inner_product_prepared(
410                values.iter().map(|(lhs, rhs)| (lhs.prepared.get(i).unwrap(), rhs.prepared.get(i).unwrap())), 
411                &mut res_data[(i * max_len)..((i + 1) * max_len)], 
412                self.get_rns_factor(i)
413            );
414        }
415        
416        for j in 0..(max_len - 1) {
417            dst(j, self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * max_len + j]))));
418        }
419    }
420
421    fn compute_convolution_lhs_prepared_impl<S, V, D>(&self, rhs_input_size_log2: usize, lhs: &PreparedConvolutionOperand<S::Type, C>, rhs: V, mut dst: D, ring: S)
422        where S: RingStore + Copy,
423            S::Type: IntegerRing, 
424            D: FnMut(usize, El<I>),
425            V: VectorFn<El<S>>
426    {
427        let width = self.compute_required_width(max(rhs_input_size_log2, lhs.log2_data_size), lhs.data.len(), rhs.len(), 1);
428        let len = lhs.data.len() + rhs.len();
429        let mut res_data = Vec::with_capacity(len * width);
430        for i in 0..width {
431            res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
432        }
433
434        let mut rhs_tmp = Vec::with_capacity(rhs.len());
435        rhs_tmp.resize_with(rhs.len(), || self.get_rns_factor(0).zero());
436
437        let homs = (0..width).map(|i| self.get_rns_factor(i).can_hom(&ring).unwrap()).collect::<Vec<_>>();
438        self.extend_operand(lhs, width, |x, i| homs[i].map_ref(x));
439
440        for i in 0..width {
441            for j in 0..rhs.len() {
442                rhs_tmp[j] = homs[i].map(rhs.at(j));
443            }
444            self.get_convolution(i).compute_convolution_lhs_prepared(lhs.prepared.get(i).unwrap(), &rhs_tmp, &mut res_data[(i * len)..((i + 1) * len)], self.get_rns_factor(i));
445        }
446        
447        for j in 0..(len - 1) {
448            dst(j, self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * len + j]))));
449        }
450    }
451
452    fn compute_convolution_prepared_impl<S, D>(&self, lhs: &PreparedConvolutionOperand<S::Type, C>, rhs: &PreparedConvolutionOperand<S::Type, C>, mut dst: D, ring: S)
453        where S: RingStore + Copy,
454            S::Type: IntegerRing, 
455            D: FnMut(usize, El<I>),
456    {
457        let width = self.compute_required_width(max(lhs.log2_data_size, rhs.log2_data_size), lhs.data.len(), rhs.data.len(), 1);
458        let len = lhs.data.len() + rhs.data.len();
459        let mut res_data = Vec::with_capacity(len * width);
460        for i in 0..width {
461            res_data.extend((0..len).map(|_| self.get_rns_factor(i).zero()));
462        }
463
464        let homs = (0..width).map(|i| self.get_rns_factor(i).can_hom(&ring).unwrap()).collect::<Vec<_>>();
465        self.extend_operand(lhs, width, |x, i| homs[i].map_ref(x));
466        self.extend_operand(rhs, width, |x, i| homs[i].map_ref(x));
467
468        for i in 0..width {
469            self.get_convolution(i).compute_convolution_prepared(lhs.prepared.get(i).unwrap(), rhs.prepared.get(i).unwrap(), &mut res_data[(i * len)..((i + 1) * len)], self.get_rns_factor(i));
470        }
471        
472        for j in 0..(len - 1) {
473            dst(j, self.get_rns_ring(width).smallest_lift(self.get_rns_ring(width).from_congruence((0..width).map(|i| res_data[i * len + j]))));
474        }
475    }
476}
477
478impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolution<I, C, A, CreateC>
479    where I: RingStore + Clone,
480        I::Type: IntegerRing,
481        C: ConvolutionAlgorithm<ZnBase>,
482        A: Allocator + Clone,
483        CreateC: Fn(Zn) -> C,
484        R: ?Sized + IntegerRing
485{
486    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
487        assert!(dst.len() >= lhs.len() + rhs.len() - 1);
488        let log2_input_size = lhs.as_iter().chain(rhs.as_iter()).map(|x| ring.abs_log2_ceil(x).unwrap_or(0)).max().unwrap_or(0);
489        println!("{}", log2_input_size);
490        let hom = ring.can_hom(&self.integer_ring).unwrap();
491        return self.compute_convolution_impl(
492            log2_input_size,
493            lhs.clone_ring_els(ring),
494            rhs.clone_ring_els(ring),
495            |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
496            ring
497        );
498    }
499
500    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
501        true
502    }
503
504    fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
505        where F: PreparedConvolutionOperation<Self, R>
506    {
507        struct CallFunction<F, R, I, C, A, CreateC>
508            where I: RingStore + Clone,
509                I::Type: IntegerRing,
510                C: ConvolutionAlgorithm<ZnBase>,
511                A: Allocator + Clone,
512                CreateC: Fn(Zn) -> C,
513                R: ?Sized + IntegerRing,
514                F: PreparedConvolutionOperation<RNSConvolution<I, C, A, CreateC>, R>
515        {
516            ring: PhantomData<Box<R>>,
517            convolution: PhantomData<RNSConvolution<I, C, A, CreateC>>,
518            function: F
519        }
520        impl<F, R, I, C, A, CreateC> PreparedConvolutionOperation<C, ZnBase> for CallFunction<F, R, I, C, A, CreateC>
521            where I: RingStore + Clone,
522                I::Type: IntegerRing,
523                C: ConvolutionAlgorithm<ZnBase>,
524                A: Allocator + Clone,
525                CreateC: Fn(Zn) -> C,
526                R: ?Sized + IntegerRing,
527                F: PreparedConvolutionOperation<RNSConvolution<I, C, A, CreateC>, R>
528        {
529            type Output = F::Output;
530
531            fn execute(self) -> Self::Output
532                where C: PreparedConvolutionAlgorithm<ZnBase>
533            {
534                self.function.execute()
535            }
536        }
537        return <C as ConvolutionAlgorithm<ZnBase>>::specialize_prepared_convolution::<CallFunction<F, R, I, C, A, CreateC>>(CallFunction {
538            function: function,
539            ring: PhantomData,
540            convolution: PhantomData
541        }).map_err(|f| f.function);
542    }
543}
544
545impl<R, I, C, A, CreateC> PreparedConvolutionAlgorithm<R> for RNSConvolution<I, C, A, CreateC>
546    where I: RingStore + Clone,
547        I::Type: IntegerRing,
548        C: PreparedConvolutionAlgorithm<ZnBase>,
549        A: Allocator + Clone,
550        CreateC: Fn(Zn) -> C,
551        R: ?Sized + IntegerRing
552{
553    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, C>;
554    
555    fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
556        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
557    {
558        let log2_input_size = val.as_iter().map(|x| ring.abs_log2_ceil(x).unwrap_or(0)).max().unwrap_or(0);
559        return self.prepare_convolution_operand_impl(log2_input_size, val.clone_ring_els(ring), ring);
560    }
561
562    fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
563        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
564    {
565        assert!(dst.len() >= lhs.data.len() + rhs.len() - 1);
566        let rhs_log2_input_size = rhs.as_iter().map(|x| ring.abs_log2_ceil(x).unwrap_or(0)).max().unwrap_or(0);
567        let hom = ring.can_hom(&self.integer_ring).unwrap();
568        return self.compute_convolution_lhs_prepared_impl(
569            rhs_log2_input_size, 
570            lhs, 
571            rhs.clone_ring_els(ring), 
572            |i, x| ring.add_assign(&mut dst[i], hom.map(x)), 
573            ring
574        );
575    }
576
577    fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
578        where S: RingStore<Type = R> + Copy
579    {
580        assert!(dst.len() >= lhs.data.len() + rhs.data.len() - 1);
581        let hom = ring.can_hom(&self.integer_ring).unwrap();
582        return self.compute_convolution_prepared_impl(
583            lhs, 
584            rhs, 
585            |i, x| ring.add_assign(&mut dst[i], hom.map(x)), 
586            ring
587        );
588    }
589
590    fn compute_convolution_inner_product_lhs_prepared<'a, S, J, V>(&self, values: J, dst: &mut [R::Element], ring: S) 
591        where S: RingStore<Type = R> + Copy, 
592            J: Iterator<Item = (&'a Self::PreparedConvolutionOperand, V)>,
593            V: VectorView<R::Element>,
594            Self: 'a,
595            R: 'a,
596            Self::PreparedConvolutionOperand: 'a
597    {
598        let values = values.map(|(lhs, rhs)| (lhs, rhs.into_clone_ring_els(ring))).collect::<Vec<_>>();
599        let rhs_log2_input_size = values.iter().flat_map(|(_, rhs)| rhs.iter()).map(|x| ring.abs_log2_ceil(&x).unwrap_or(0)).max().unwrap_or(0);
600        let hom = ring.can_hom(&self.integer_ring).unwrap();
601        return self.compute_convolution_inner_product_lhs_prepared_impl(
602            rhs_log2_input_size,
603            &values,
604            |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
605            ring
606        );
607    }
608
609    fn compute_convolution_inner_product_prepared<'a, S, J>(&self, values: J, dst: &mut [R::Element], ring: S) 
610        where S: RingStore<Type = R> + Copy, 
611            J: Iterator<Item = (&'a Self::PreparedConvolutionOperand, &'a Self::PreparedConvolutionOperand)>,
612            Self::PreparedConvolutionOperand: 'a,
613            Self: 'a,
614            R: 'a,
615    {
616        let values = values.collect::<Vec<_>>();
617        let hom = ring.can_hom(&self.integer_ring).unwrap();
618        return self.compute_convolution_inner_product_prepared_impl(
619            &values,
620            |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
621            ring
622        );
623    }
624}
625
626impl<R, I, C, A, CreateC> ConvolutionAlgorithm<R> for RNSConvolutionZn<I, C, A, CreateC>
627    where I: RingStore + Clone,
628        I::Type: IntegerRing,
629        C: ConvolutionAlgorithm<ZnBase>,
630        A: Allocator + Clone,
631        CreateC: Fn(Zn) -> C,
632        R: ?Sized + ZnRing + CanHomFrom<I::Type>
633{
634    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S) {
635        assert!(dst.len() >= lhs.len() + rhs.len() - 1);
636        let log2_input_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap() - 1;
637        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
638        return self.base.compute_convolution_impl(
639            log2_input_size,
640            lhs.clone_ring_els(ring).map_fn(|x| ring.smallest_lift(x)),
641            rhs.clone_ring_els(ring).map_fn(|x| ring.smallest_lift(x)),
642            |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
643            ring.integer_ring()
644        );
645    }
646
647    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
648        true
649    }
650
651    fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
652        where F: PreparedConvolutionOperation<Self, R>
653    {
654        struct CallFunction<F, R, I, C, A, CreateC>
655            where I: RingStore + Clone,
656                I::Type: IntegerRing,
657                C: ConvolutionAlgorithm<ZnBase>,
658                A: Allocator + Clone,
659                CreateC: Fn(Zn) -> C,
660                R: ?Sized + ZnRing + CanHomFrom<I::Type>,
661                F: PreparedConvolutionOperation<RNSConvolutionZn<I, C, A, CreateC>, R>
662        {
663            ring: PhantomData<Box<R>>,
664            convolution: PhantomData<RNSConvolution<I, C, A, CreateC>>,
665            function: F
666        }
667        impl<F, R, I, C, A, CreateC> PreparedConvolutionOperation<C, ZnBase> for CallFunction<F, R, I, C, A, CreateC>
668            where I: RingStore + Clone,
669                I::Type: IntegerRing,
670                C: ConvolutionAlgorithm<ZnBase>,
671                A: Allocator + Clone,
672                CreateC: Fn(Zn) -> C,
673                R: ?Sized + ZnRing + CanHomFrom<I::Type>,
674                F: PreparedConvolutionOperation<RNSConvolutionZn<I, C, A, CreateC>, R>
675        {
676            type Output = F::Output;
677
678            fn execute(self) -> Self::Output
679                where C: PreparedConvolutionAlgorithm<ZnBase>
680            {
681                self.function.execute()
682            }
683        }
684        return <C as ConvolutionAlgorithm<ZnBase>>::specialize_prepared_convolution::<CallFunction<F, R, I, C, A, CreateC>>(CallFunction {
685            function: function,
686            ring: PhantomData,
687            convolution: PhantomData
688        }).map_err(|f| f.function);
689    }
690}
691
692impl<R, I, C, A, CreateC> PreparedConvolutionAlgorithm<R> for RNSConvolutionZn<I, C, A, CreateC>
693    where I: RingStore + Clone,
694        I::Type: IntegerRing,
695        C: PreparedConvolutionAlgorithm<ZnBase>,
696        A: Allocator + Clone,
697        CreateC: Fn(Zn) -> C,
698        R: ?Sized + ZnRing + CanHomFrom<I::Type>
699{
700    type PreparedConvolutionOperand = PreparedConvolutionOperandZn<R, C>;
701    
702    fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
703        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
704    {
705        let log2_input_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap() - 1;
706        return PreparedConvolutionOperandZn(self.base.prepare_convolution_operand_impl(log2_input_size, val.clone_ring_els(ring).map_fn(|x| ring.smallest_lift(x)), ring.integer_ring()));
707    }
708
709    fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
710        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
711    {
712        assert!(dst.len() >= lhs.0.data.len() + rhs.len() - 1);
713        let rhs_log2_input_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap() - 1;
714        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
715        return self.base.compute_convolution_lhs_prepared_impl(
716            rhs_log2_input_size, 
717            &lhs.0, 
718            rhs.clone_ring_els(ring).map_fn(|x| ring.smallest_lift(x)), 
719            |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
720            ring.integer_ring()
721        );
722    }
723
724    fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
725        where S: RingStore<Type = R> + Copy
726    {
727        assert!(dst.len() >= lhs.0.data.len() + rhs.0.data.len() - 1);
728        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
729        return self.base.compute_convolution_prepared_impl(
730            &lhs.0, 
731            &rhs.0, 
732            |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
733            ring.integer_ring()
734        );
735    }
736
737    fn compute_convolution_inner_product_lhs_prepared<'a, S, J, V>(&self, values: J, dst: &mut [R::Element], ring: S) 
738        where S: RingStore<Type = R> + Copy, 
739            J: Iterator<Item = (&'a Self::PreparedConvolutionOperand, V)>,
740            V: VectorView<R::Element>,
741            Self: 'a,
742            R: 'a,
743            Self::PreparedConvolutionOperand: 'a
744    {
745        let values = values.map(|(lhs, rhs)| (&lhs.0, rhs.into_clone_ring_els(ring).map_fn(|x| ring.smallest_lift(x)))).collect::<Vec<_>>();
746        let rhs_log2_input_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap() - 1;
747        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
748        return self.base.compute_convolution_inner_product_lhs_prepared_impl(
749            rhs_log2_input_size,
750            &values,
751            |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
752            ring.integer_ring()
753        );
754    }
755
756    fn compute_convolution_inner_product_prepared<'a, S, J>(&self, values: J, dst: &mut [R::Element], ring: S) 
757        where S: RingStore<Type = R> + Copy, 
758            J: Iterator<Item = (&'a Self::PreparedConvolutionOperand, &'a Self::PreparedConvolutionOperand)>,
759            Self::PreparedConvolutionOperand: 'a,
760            Self: 'a,
761            R: 'a,
762    {
763        let values = values.map(|(lhs, rhs)| (&lhs.0, &rhs.0)).collect::<Vec<_>>();
764        let hom = ring.can_hom(&self.base.integer_ring).unwrap();
765        return self.base.compute_convolution_inner_product_prepared_impl(
766            &values,
767            |i, x| ring.add_assign(&mut dst[i], hom.map(x)),
768            ring.integer_ring()
769        );
770    }
771}
772
773#[test]
774fn test_convolution_integer() {
775    let ring = StaticRing::<i128>::RING;
776    let convolution = RNSConvolution::new_with(7, usize::MAX, BigIntRing::RING, Global, |Fp| NTTConvolution::new_with(Fp, Global));
777
778    super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
779    super::generic_tests::test_prepared_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
780}
781
782#[test]
783fn test_convolution_zn() {
784    let ring = Zn::new((1 << 57) + 1);
785    let convolution = RNSConvolutionZn::from(RNSConvolution::new_with(7, usize::MAX, BigIntRing::RING, Global, |Fp| NTTConvolution::new_with(Fp, Global)));
786
787    super::generic_tests::test_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
788    super::generic_tests::test_prepared_convolution(&convolution, &ring, ring.int_hom().map(1 << 30));
789}
790
791#[test]
792fn test_specialize_prepared() {
793    let ring = Zn::new((1 << 57) + 1);
794    let convolution = RNSConvolutionZn::from(RNSConvolution::new(7));
795
796    struct CheckIsPrepared(RNSConvolutionZn, Zn);
797    impl PreparedConvolutionOperation<RNSConvolutionZn, ZnBase> for CheckIsPrepared {
798        type Output = ();
799        fn execute(self) -> Self::Output {
800            super::generic_tests::test_prepared_convolution(&self.0, &self.1, self.1.int_hom().map(1 << 30));
801        }
802    }
803    assert!(RNSConvolutionZn::specialize_prepared_convolution(CheckIsPrepared(convolution, ring)).is_ok());
804}