he_ring/ntt/
dyn_convolution.rs1
2use std::marker::PhantomData;
3use std::ops::Deref;
4
5use feanor_math::algorithms::convolution::ConvolutionAlgorithm;
6use feanor_math::ring::*;
7use feanor_math::seq::VectorView;
8
9pub trait DynConvolutionAlgorithm<R>
19 where R: ?Sized + RingBase
20{
21 fn compute_convolution_dyn(&self, lhs: &[R::Element], rhs: &[R::Element], dst: &mut [R::Element], ring: &R);
30 fn supports_ring_dyn(&self, ring: &R) -> bool;
31}
32
33impl<R, C> DynConvolutionAlgorithm<R> for C
34 where R: ?Sized + RingBase,
35 C: ConvolutionAlgorithm<R>
36{
37 fn compute_convolution_dyn(&self, lhs: &[R::Element], rhs: &[R::Element], dst: &mut [R::Element], ring: &R) {
38 self.compute_convolution(lhs, rhs, dst, RingRef::new(ring));
39 }
40
41 fn supports_ring_dyn(&self, ring: &R) -> bool {
42 self.supports_ring(RingRef::new(ring))
43 }
44}
45
46pub struct DynConvolutionAlgorithmConvolution<R, C = Box<dyn DynConvolutionAlgorithm<R>>>
51 where C: Deref,
52 C::Target: DynConvolutionAlgorithm<R>,
53 R: ?Sized + RingBase
54{
55 ring: PhantomData<R>,
56 conv: C
57}
58
59impl<C, R> Clone for DynConvolutionAlgorithmConvolution<R, C>
60 where C: Deref + Clone,
61 C::Target: DynConvolutionAlgorithm<R>,
62 R: ?Sized + RingBase
63{
64 fn clone(&self) -> Self {
65 Self {
66 ring: self.ring,
67 conv: self.conv.clone()
68 }
69 }
70}
71
72impl<C, R> DynConvolutionAlgorithmConvolution<R, C>
73 where C: Deref,
74 C::Target: DynConvolutionAlgorithm<R>,
75 R: ?Sized + RingBase
76{
77 pub fn new(conv: C) -> Self {
78 Self {
79 ring: PhantomData,
80 conv: conv
81 }
82 }
83}
84
85impl<C, R> ConvolutionAlgorithm<R> for DynConvolutionAlgorithmConvolution<R, C>
86 where C: Deref,
87 C::Target: DynConvolutionAlgorithm<R>,
88 R: ?Sized + RingBase
89{
90 fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<El<S>>, V2: VectorView<El<S>>>(&self, lhs: V1, rhs: V2, dst: &mut [El<S>], ring: S) {
91 let copy_lhs = lhs.as_iter().map(|x| ring.clone_el(x)).collect::<Vec<_>>();
92 let copy_rhs = rhs.as_iter().map(|x| ring.clone_el(x)).collect::<Vec<_>>();
93 self.conv.compute_convolution_dyn(©_lhs, ©_rhs, dst, ring.get_ring());
94 }
95
96 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool {
97 self.conv.supports_ring_dyn(ring.get_ring())
98 }
99}
100
101#[cfg(test)]
102use feanor_math::primitive_int::StaticRing;
103#[cfg(test)]
104use feanor_math::rings::zn::zn_64::{Zn, ZnBase};
105#[cfg(test)]
106use std::alloc::Global;
107#[cfg(test)]
108use feanor_math::algorithms::convolution::STANDARD_CONVOLUTION;
109#[cfg(test)]
110use feanor_math::rings::extension::extension_impl::FreeAlgebraImpl;
111#[cfg(test)]
112use feanor_math::rings::extension::FreeAlgebraStore;
113#[cfg(test)]
114use feanor_math::assert_el_eq;
115
116#[test]
117fn test_dyn_convolution_is_dyn_compatible() {
118 #[allow(unused)]
119 fn test(_: &dyn DynConvolutionAlgorithm<StaticRing<i64>>) {}
120}
121
122#[test]
123fn test_dyn_convolution_convolution_use_build_ring() {
124 fn do_test(conv: Box<dyn DynConvolutionAlgorithm<ZnBase>>) {
125 let base_ring = Zn::new(2);
126 let ring = FreeAlgebraImpl::new_with(base_ring, 3, [base_ring.one(), base_ring.one()], "a", Global, DynConvolutionAlgorithmConvolution::<ZnBase>::new(conv));
127 assert_el_eq!(&ring, ring.one(), ring.pow(ring.canonical_gen(), 7));
128 }
129 do_test(Box::new(STANDARD_CONVOLUTION));
130}