he_ring/ntt/
dyn_convolution.rs

1
2use std::marker::PhantomData;
3use std::ops::Deref;
4
5use feanor_math::algorithms::convolution::ConvolutionAlgorithm;
6use feanor_math::ring::*;
7use feanor_math::seq::VectorView;
8
9///
10/// Trait for algorithms that compute convolutions. This mirrors
11/// [`ConvolutionAlgorithm`], but is dyn-compatible. This is useful 
12/// if you want to create a ring but only know the type of the 
13/// convolution algorithm at runtime.
14/// 
15/// Wrap a `dyn DynConvolutionAlgorithm<R>` in [`DynConvolutionAlgorithmConvolution`]
16/// to use it as a [`ConvolutionAlgorithm`].
17/// 
18pub trait DynConvolutionAlgorithm<R>
19    where R: ?Sized + RingBase
20{
21    ///
22    /// Computes `dst[i] += sum_j lhs[j] * rhs[i - j]`, where the sum runs over
23    /// these indices that do not cause an out-of-bounds. 
24    /// 
25    /// For implementation purposes, we requrie `dst.len() >= lhs.len() + rhs.len()` 
26    /// (not only `dst.len() >= lhs.len() + rhs.len() - 1`, which would be enough to
27    /// include `lhs[lhs.len() - 1] * rhs[rhs.len() - 1]`).
28    /// 
29    fn compute_convolution_dyn(&self, lhs: &[R::Element], rhs: &[R::Element], dst: &mut [R::Element], ring: &R);
30    fn supports_ring_dyn(&self, ring: &R) -> bool;
31}
32
33impl<R, C> DynConvolutionAlgorithm<R> for C
34    where R: ?Sized + RingBase,
35        C: ConvolutionAlgorithm<R>
36{
37    fn compute_convolution_dyn(&self, lhs: &[R::Element], rhs: &[R::Element], dst: &mut [R::Element], ring: &R) {
38        self.compute_convolution(lhs, rhs, dst, RingRef::new(ring));
39    }
40
41    fn supports_ring_dyn(&self, ring: &R) -> bool {
42        self.supports_ring(RingRef::new(ring))
43    }
44}
45
46///
47/// Wraps a [`DynConvolutionAlgorithm`] trait object to use it as a 
48/// [`ConvolutionAlgorithm`].
49/// 
50pub struct DynConvolutionAlgorithmConvolution<R, C = Box<dyn DynConvolutionAlgorithm<R>>>
51    where C: Deref,
52        C::Target: DynConvolutionAlgorithm<R>,
53        R: ?Sized + RingBase
54{
55    ring: PhantomData<R>,
56    conv: C
57}
58
59impl<C, R> Clone for DynConvolutionAlgorithmConvolution<R, C>
60    where C: Deref + Clone,
61        C::Target: DynConvolutionAlgorithm<R>,
62        R: ?Sized + RingBase
63{
64    fn clone(&self) -> Self {
65        Self {
66            ring: self.ring,
67            conv: self.conv.clone()
68        }
69    }
70}
71
72impl<C, R> DynConvolutionAlgorithmConvolution<R, C>
73    where C: Deref,
74        C::Target: DynConvolutionAlgorithm<R>,
75        R: ?Sized + RingBase
76{
77    pub fn new(conv: C) -> Self {
78        Self {
79            ring: PhantomData,
80            conv: conv
81        }
82    }
83}
84
85impl<C, R> ConvolutionAlgorithm<R> for DynConvolutionAlgorithmConvolution<R, C>
86    where C: Deref,
87        C::Target: DynConvolutionAlgorithm<R>,
88        R: ?Sized + RingBase
89{
90    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<El<S>>, V2: VectorView<El<S>>>(&self, lhs: V1, rhs: V2, dst: &mut [El<S>], ring: S) {
91        let copy_lhs = lhs.as_iter().map(|x| ring.clone_el(x)).collect::<Vec<_>>();
92        let copy_rhs = rhs.as_iter().map(|x| ring.clone_el(x)).collect::<Vec<_>>();
93        self.conv.compute_convolution_dyn(&copy_lhs, &copy_rhs, dst, ring.get_ring());
94    }
95
96    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool {
97        self.conv.supports_ring_dyn(ring.get_ring())
98    }
99}
100
101#[cfg(test)]
102use feanor_math::primitive_int::StaticRing;
103#[cfg(test)]
104use feanor_math::rings::zn::zn_64::{Zn, ZnBase};
105#[cfg(test)]
106use std::alloc::Global;
107#[cfg(test)]
108use feanor_math::algorithms::convolution::STANDARD_CONVOLUTION;
109#[cfg(test)]
110use feanor_math::rings::extension::extension_impl::FreeAlgebraImpl;
111#[cfg(test)]
112use feanor_math::rings::extension::FreeAlgebraStore;
113#[cfg(test)]
114use feanor_math::assert_el_eq;
115
116#[test]
117fn test_dyn_convolution_is_dyn_compatible() {
118    #[allow(unused)]
119    fn test(_: &dyn DynConvolutionAlgorithm<StaticRing<i64>>) {}
120}
121
122#[test]
123fn test_dyn_convolution_convolution_use_build_ring() {
124    fn do_test(conv: Box<dyn DynConvolutionAlgorithm<ZnBase>>) {
125        let base_ring = Zn::new(2);
126        let ring = FreeAlgebraImpl::new_with(base_ring, 3, [base_ring.one(), base_ring.one()], "a", Global, DynConvolutionAlgorithmConvolution::<ZnBase>::new(conv));
127        assert_el_eq!(&ring, ring.one(), ring.pow(ring.canonical_gen(), 7));
128    }
129    do_test(Box::new(STANDARD_CONVOLUTION));
130}