feanor_math/
wrapper.rs

1use std::{ops::*, hash::Hash, fmt::{Display, Debug}};
2
3use crate::homomorphism::*;
4use crate::ring::*;
5use crate::field::*;
6
7///
8/// Stores a ring element together with its ring, so that ring operations do
9/// not require explicit mention of the ring object. This can be used both for
10/// convenience of notation (i.e. use `a + b` instead of `ring.add(a, b)`) and
11/// might also be necessary when e.g. storing elements in a set.
12/// 
13/// # Examples
14/// ```rust
15/// # use feanor_math::ring::*;
16/// # use feanor_math::rings::poly::*;
17/// # use feanor_math::rings::poly::dense_poly::*;
18/// # use feanor_math::wrapper::*;
19/// # use feanor_math::primitive_int::*;
20/// let ring = DensePolyRing::new(StaticRing::<i64>::RING, "X");
21/// let x = RingElementWrapper::new(&ring, ring.indeterminate());
22/// println!("The result is: {}", x.clone() + x.clone() * x);
23/// // instead of
24/// let x = ring.indeterminate();
25/// println!("The result is: {}", ring.format(&ring.add(ring.mul(ring.clone_el(&x), ring.clone_el(&x)), ring.clone_el(&x))));
26/// ```
27/// You can also retrieve the wrapped element
28/// ```rust
29/// # use feanor_math::assert_el_eq;
30/// # use feanor_math::ring::*;
31/// # use feanor_math::rings::poly::*;
32/// # use feanor_math::rings::poly::dense_poly::*;
33/// # use feanor_math::wrapper::*;
34/// # use feanor_math::primitive_int::*;
35/// let ring = DensePolyRing::new(StaticRing::<i64>::RING, "X");
36/// let x = RingElementWrapper::new(&ring, ring.indeterminate());
37/// assert_el_eq!(&ring, ring.add(ring.mul(ring.clone_el(&x), ring.clone_el(&x)), ring.clone_el(&x)), (x.clone() + x.clone() * x).unwrap());
38/// ```
39/// 
40pub struct RingElementWrapper<R>
41    where R: RingStore
42{
43    ring: R,
44    element: El<R>
45}
46
47impl<R: RingStore> RingElementWrapper<R> {
48
49    ///
50    /// Creates a new [`RingElementWrapper`] wrapping the given element of the given ring.
51    /// 
52    pub const fn new(ring: R, element: El<R>) -> Self {
53        Self { ring, element }
54    }
55
56    ///
57    /// Raises the stored element to the given power.
58    /// 
59    /// Consider using [`RingElementWrapper::pow_ref()`] if you don't want to 
60    /// move the element.
61    /// 
62    pub fn pow(self, power: usize) -> Self {
63        Self {
64            element: self.ring.pow(self.element, power),
65            ring: self.ring
66        }
67    }
68
69    ///
70    /// Raises the stored element to the given power.
71    /// 
72    pub fn pow_ref(&self, power: usize) -> Self
73        where R: Clone
74    {
75        Self {
76            element: self.ring.pow(self.ring.clone_el(&self.element), power),
77            ring: self.ring.clone()
78        }
79    }
80
81    ///
82    /// Returns the stored element.
83    /// 
84    pub fn unwrap(self) -> El<R> {
85        self.element
86    }
87
88    ///
89    /// Returns the stored element as reference.
90    /// 
91    pub fn unwrap_ref(&self) -> &El<R> {
92        &self.element
93    }
94
95    ///
96    /// Returns a reference to the ring that this element belongs to.
97    /// 
98    pub fn parent(&self) -> &R {
99        &self.ring
100    }
101
102    ///
103    /// Returns `true` if this element is zero.
104    /// 
105    /// Equivalent to `self.parent().is_zero(self.unwrap_ref())`.
106    /// 
107    pub fn is_zero(&self) -> bool {
108        self.parent().is_zero(self.unwrap_ref())
109    }
110
111    ///
112    /// Returns `true` if this element is one.
113    /// 
114    /// Equivalent to `self.parent().is_one(self.unwrap_ref())`.
115    /// 
116    pub fn is_one(&self) -> bool {
117        self.parent().is_one(self.unwrap_ref())
118    }
119    
120    ///
121    /// Returns `true` if this element is negative one.
122    /// 
123    /// Equivalent to `self.parent().is_neg_one(self.unwrap_ref())`.
124    /// 
125    pub fn is_neg_one(&self) -> bool {
126        self.parent().is_neg_one(self.unwrap_ref())
127    }
128}
129
130macro_rules! impl_xassign_trait {
131    ($trait_name:ident, $fn_name:ident, $fn_ref_name:ident) => {
132        
133        impl<R: RingStore> $trait_name for RingElementWrapper<R> {
134
135            fn $fn_name(&mut self, rhs: Self) {
136                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
137                self.ring.$fn_name(&mut self.element, rhs.element);
138            }
139        }
140
141        impl<'a, R: RingStore> $trait_name<&'a Self> for RingElementWrapper<R> {
142
143            fn $fn_name(&mut self, rhs: &'a Self) {
144                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
145                self.ring.$fn_ref_name(&mut self.element, &rhs.element);
146            }
147        }
148    };
149}
150
151macro_rules! impl_trait {
152    ($trait_name:ident, $fn_name:ident, $fn_name_ref_fst:ident, $fn_name_ref_snd:ident, $fn_name_ref:ident) => {
153        
154        impl<R: RingStore> $trait_name for RingElementWrapper<R> {
155            type Output = Self;
156
157            fn $fn_name(self, rhs: Self) -> Self::Output {
158                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
159                Self { ring: self.ring, element: rhs.ring.$fn_name(self.element, rhs.element) }
160            }
161        }
162
163        impl<'a, R: RingStore> $trait_name<RingElementWrapper<R>> for &'a RingElementWrapper<R> {
164            type Output = RingElementWrapper<R>;
165
166            fn $fn_name(self, rhs: RingElementWrapper<R>) -> Self::Output {
167                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
168                RingElementWrapper { ring: rhs.ring, element: self.ring.$fn_name_ref_fst(&self.element, rhs.element) }
169            }
170        }
171
172        impl<'a, R: RingStore> $trait_name<&'a RingElementWrapper<R>> for RingElementWrapper<R> {
173            type Output = RingElementWrapper<R>;
174
175            fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
176                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
177                RingElementWrapper { ring: self.ring, element: rhs.ring.$fn_name_ref_snd(self.element, &rhs.element) }
178            }
179        }
180
181        impl<'a, 'b, R: RingStore + Clone> $trait_name<&'a RingElementWrapper<R>> for &'b RingElementWrapper<R> {
182            type Output = RingElementWrapper<R>;
183
184            fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
185                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
186                RingElementWrapper { ring: self.ring.clone(), element: self.ring.$fn_name_ref(&self.element, &rhs.element) }
187            }
188        }
189    };
190}
191
192impl_xassign_trait!{ AddAssign, add_assign, add_assign_ref }
193impl_xassign_trait!{ MulAssign, mul_assign, mul_assign_ref }
194impl_xassign_trait!{ SubAssign, sub_assign, sub_assign_ref }
195impl_trait!{ Add, add, add_ref_fst, add_ref_snd, add_ref }
196impl_trait!{ Mul, mul, mul_ref_fst, mul_ref_snd, mul_ref }
197impl_trait!{ Sub, sub, sub_ref_fst, sub_ref_snd, sub_ref }
198
199impl<R: RingStore> Div<RingElementWrapper<R>> for RingElementWrapper<R>
200    where R::Type: Field
201{
202    type Output = Self;
203
204    fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
205        RingElementWrapper { element: self.ring.div(&self.element, &rhs.element), ring: self.ring }
206    }
207}
208
209impl<'a, 'b, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for &'b RingElementWrapper<R>
210    where R::Type: Field
211{
212    type Output = RingElementWrapper<R>;
213
214    fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
215        RingElementWrapper { element: self.ring.div(&self.element, &rhs.element), ring: self.ring.clone() }
216    }
217}
218
219impl<'a, R: RingStore + Clone> Div<RingElementWrapper<R>> for &'a RingElementWrapper<R>
220    where R::Type: Field
221{
222    type Output = RingElementWrapper<R>;
223
224    fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
225        RingElementWrapper { element: self.ring.div(&self.element, &rhs.element), ring: rhs.ring }
226    }
227}
228
229impl<'a, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for RingElementWrapper<R>
230    where R::Type: Field
231{
232    type Output = RingElementWrapper<R>;
233
234    fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
235        RingElementWrapper { element: rhs.ring.div(&self.element, &rhs.element), ring: self.ring }
236    }
237}
238
239macro_rules! impl_xassign_trait_int {
240    ($trait_name:ident, $fn_name:ident) => {
241        
242        impl<R: RingStore> $trait_name<i32> for RingElementWrapper<R> {
243
244            fn $fn_name(&mut self, rhs: i32) {
245                self.ring.$fn_name(&mut self.element, self.ring.int_hom().map(rhs));
246            }
247        }
248    };
249}
250
251macro_rules! impl_trait_int {
252    ($trait_name:ident, $fn_name:ident) => {
253        
254        impl<R: RingStore> $trait_name<i32> for RingElementWrapper<R> {
255            type Output = Self;
256
257            fn $fn_name(self, rhs: i32) -> Self::Output {
258                RingElementWrapper { element: self.ring.$fn_name(self.element, self.ring.int_hom().map(rhs)), ring: self.ring }
259            }
260        }
261
262        impl<R: RingStore> $trait_name<RingElementWrapper<R>> for i32 {
263            type Output = RingElementWrapper<R>;
264
265            fn $fn_name(self, rhs: RingElementWrapper<R>) -> Self::Output {
266                RingElementWrapper { element: rhs.ring.$fn_name(rhs.ring.int_hom().map(self), rhs.element), ring: rhs.ring }
267            }
268        }
269
270        impl<'a, R: RingStore + Clone> $trait_name<i32> for &'a RingElementWrapper<R> {
271            type Output = RingElementWrapper<R>;
272
273            fn $fn_name(self, rhs: i32) -> Self::Output {
274                RingElementWrapper { element: self.ring.$fn_name(self.ring.clone_el(&self.element), self.ring.int_hom().map(rhs)), ring: self.ring.clone() }
275            }
276        }
277
278        impl<'a, R: RingStore + Clone> $trait_name<&'a RingElementWrapper<R>> for i32 {
279            type Output = RingElementWrapper<R>;
280
281            fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
282                RingElementWrapper { element: rhs.ring.$fn_name(rhs.ring.int_hom().map(self), rhs.ring.clone_el(&rhs.element)), ring: rhs.ring.clone() }
283            }
284        }
285    };
286}
287
288impl_xassign_trait_int!{ AddAssign, add_assign }
289impl_xassign_trait_int!{ MulAssign, mul_assign }
290impl_xassign_trait_int!{ SubAssign, sub_assign }
291impl_trait_int!{ Add, add }
292impl_trait_int!{ Mul, mul }
293impl_trait_int!{ Sub, sub }
294
295impl<R: RingStore> Div<i32> for RingElementWrapper<R>
296    where R::Type: Field
297{
298    type Output = Self;
299
300    fn div(self, rhs: i32) -> Self::Output {
301        RingElementWrapper { element: self.ring.div(&self.element, &self.ring.int_hom().map(rhs)), ring: self.ring }
302    }
303}
304
305impl<R: RingStore> Div<RingElementWrapper<R>> for i32
306    where R::Type: Field
307{
308    type Output = RingElementWrapper<R>;
309
310    fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
311        RingElementWrapper { element: rhs.ring.div(&rhs.ring.int_hom().map(self), &rhs.element), ring: rhs.ring }
312    }
313}
314
315impl<'a, R: RingStore + Clone> Div<i32> for &'a RingElementWrapper<R>
316    where R::Type: Field
317{
318    type Output = RingElementWrapper<R>;
319
320    fn div(self, rhs: i32) -> Self::Output {
321        RingElementWrapper { element: self.ring.div(&self.element, &self.ring.int_hom().map(rhs)), ring: self.ring.clone() }
322    }
323}
324
325impl<'a, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for i32
326    where R::Type: Field
327{
328    type Output = RingElementWrapper<R>;
329
330    fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
331        RingElementWrapper { element: rhs.ring.div(&rhs.ring.int_hom().map(self), &rhs.element), ring: rhs.ring.clone() }
332    }
333}
334
335impl<R: RingStore + Copy> Copy for RingElementWrapper<R> 
336    where El<R>: Copy
337{}
338
339impl<R: RingStore + Clone> Clone for RingElementWrapper<R> {
340
341    fn clone(&self) -> Self {
342        Self { ring: self.ring.clone(), element: self.ring.clone_el(&self.element) }
343    }
344}
345
346impl<R: RingStore> PartialEq for RingElementWrapper<R> {
347
348    fn eq(&self, other: &Self) -> bool {
349        debug_assert!(self.ring.get_ring() == other.ring.get_ring());
350        self.ring.eq_el(&self.element, &other.element)
351    }
352}
353
354impl<R: RingStore> Eq for RingElementWrapper<R> {}
355
356impl<R: RingStore> PartialEq<i32> for RingElementWrapper<R> {
357
358    fn eq(&self, other: &i32) -> bool {
359        match *other {
360            0 => self.is_zero(),
361            1 => self.is_one(),
362            -1 => self.is_neg_one(),
363            x => self.parent().eq_el(self.unwrap_ref(), &self.parent().int_hom().map(x))
364        }
365    }
366}
367
368impl<R: RingStore> PartialEq<RingElementWrapper<R>> for i32 {
369
370    fn eq(&self, other: &RingElementWrapper<R>) -> bool {
371        other == self
372    }
373}
374
375impl<R: RingStore> Hash for RingElementWrapper<R> 
376    where R::Type: HashableElRing
377{
378    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
379        self.ring.hash(&self.element, state)
380    }
381}
382
383impl<R: RingStore> Display for RingElementWrapper<R> {
384
385    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        self.ring.get_ring().dbg(&self.element, f)
387    }
388}
389
390impl<R: RingStore> Debug for RingElementWrapper<R> {
391
392    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393        self.ring.get_ring().dbg(&self.element, f)
394    }
395}
396
397impl<R: RingStore> Deref for RingElementWrapper<R> {
398    type Target = El<R>;
399
400    fn deref(&self) -> &Self::Target {
401        &self.element
402    }
403}
404
405#[cfg(test)]
406use crate::rings::finite::FiniteRingStore;
407#[cfg(test)]
408use crate::rings::zn::zn_64;
409
410#[test]
411fn test_arithmetic_expression() {
412    let ring = zn_64::Zn::new(17);
413
414    for x in ring.elements() {
415        for y in ring.elements() {
416            for z in ring.elements() {
417                let expected = ring.add(ring.mul(x, y), ring.mul(ring.add(x, z), ring.sub(y, z)));
418                let x = RingElementWrapper::new(&ring, x);
419                let y = RingElementWrapper::new(&ring, y);
420                let z = RingElementWrapper::new(&ring, z);
421                assert_el_eq!(ring, expected, (x * y + (x + z) * (y - z)).unwrap());
422            }
423        }
424    }
425}
426
427#[test]
428fn test_arithmetic_expression_int() {
429    let ring = zn_64::Zn::new(17);
430
431    for x in ring.elements() {
432        for y in ring.elements() {
433            for z in ring.elements() {
434                let expected = ring.add(ring.add(ring.int_hom().mul_map(ring.mul(x, y), 8), ring.mul(ring.add(ring.add(ring.one(), x), ring.int_hom().mul_map(z, 2)), ring.sub(y, ring.int_hom().mul_map(z, 2)))), ring.int_hom().map(5));
435                let x = RingElementWrapper::new(&ring, x);
436                let y = RingElementWrapper::new(&ring, y);
437                let z = RingElementWrapper::new(&ring, z);
438                assert_el_eq!(ring, expected, (x * 8 * y + (1 + x + 2 * z) * (y - z * 2) + 5).unwrap());
439            }
440        }
441    }
442}
443
444#[test]
445fn test_arithmetic_expression_ref() {
446    let ring = zn_64::Zn::new(17);
447
448    for x in ring.elements() {
449        for y in ring.elements() {
450            for z in ring.elements() {
451                let expected = ring.add(ring.mul(x, y), ring.mul(ring.add(x, z), ring.sub(y, z)));
452                let x = RingElementWrapper::new(&ring, x);
453                let y = RingElementWrapper::new(&ring, y);
454                let z = RingElementWrapper::new(&ring, z);
455                assert_el_eq!(ring, expected, (x * &y + (&x + &z) * (&y - z)).unwrap());
456            }
457        }
458    }
459}
460
461#[test]
462fn test_arithmetic_expression_int_ref() {
463    let ring = zn_64::Zn::new(17);
464
465    for x in ring.elements() {
466        for y in ring.elements() {
467            for z in ring.elements() {
468                let expected = ring.add(ring.add(ring.int_hom().mul_map(ring.mul(x, y), 8), ring.mul(ring.add(ring.add(ring.one(), x), ring.int_hom().mul_map(z, 2)), ring.sub(y, ring.int_hom().mul_map(z, 2)))), ring.int_hom().map(5));
469                let x = RingElementWrapper::new(&ring, x);
470                let y = RingElementWrapper::new(&ring, y);
471                let z = RingElementWrapper::new(&ring, z);
472                assert_el_eq!(ring, expected, (x * 8 * &y + (1 + &x + 2 * &z) * (&y - z * 2) + 5).unwrap());
473            }
474        }
475    }
476}