Skip to main content

feanor_math/
wrapper.rs

1use std::fmt::{Debug, Display};
2use std::hash::Hash;
3use std::ops::*;
4
5use crate::field::*;
6use crate::homomorphism::*;
7use crate::ring::*;
8
9/// Stores a ring element together with its ring, so that ring operations do
10/// not require explicit mention of the ring object. This can be used both for
11/// convenience of notation (i.e. use `a + b` instead of `ring.add(a, b)`) and
12/// might also be necessary when e.g. storing elements in a set.
13///
14/// # Examples
15/// ```rust
16/// # use feanor_math::ring::*;
17/// # use feanor_math::rings::poly::*;
18/// # use feanor_math::rings::poly::dense_poly::*;
19/// # use feanor_math::wrapper::*;
20/// # use feanor_math::primitive_int::*;
21/// let ring = DensePolyRing::new(StaticRing::<i64>::RING, "X");
22/// let x = RingElementWrapper::new(&ring, ring.indeterminate());
23/// println!("The result is: {}", x.clone() + x.clone() * x);
24/// // instead of
25/// let x = ring.indeterminate();
26/// println!(
27///     "The result is: {}",
28///     ring.format(&ring.add(
29///         ring.mul(ring.clone_el(&x), ring.clone_el(&x)),
30///         ring.clone_el(&x)
31///     ))
32/// );
33/// ```
34/// You can also retrieve the wrapped element
35/// ```rust
36/// # use feanor_math::assert_el_eq;
37/// # use feanor_math::ring::*;
38/// # use feanor_math::rings::poly::*;
39/// # use feanor_math::rings::poly::dense_poly::*;
40/// # use feanor_math::wrapper::*;
41/// # use feanor_math::primitive_int::*;
42/// let ring = DensePolyRing::new(StaticRing::<i64>::RING, "X");
43/// let x = RingElementWrapper::new(&ring, ring.indeterminate());
44/// assert_el_eq!(
45///     &ring,
46///     ring.add(
47///         ring.mul(ring.clone_el(&x), ring.clone_el(&x)),
48///         ring.clone_el(&x)
49///     ),
50///     (x.clone() + x.clone() * x).unwrap()
51/// );
52/// ```
53pub struct RingElementWrapper<R>
54where
55    R: RingStore,
56{
57    ring: R,
58    element: El<R>,
59}
60
61impl<R: RingStore> RingElementWrapper<R> {
62    /// Creates a new [`RingElementWrapper`] wrapping the given element of the given ring.
63    pub const fn new(ring: R, element: El<R>) -> Self { Self { ring, element } }
64
65    /// Raises the stored element to the given power.
66    ///
67    /// Consider using [`RingElementWrapper::pow_ref()`] if you don't want to
68    /// move the element.
69    pub fn pow(self, power: usize) -> Self {
70        Self {
71            element: self.ring.pow(self.element, power),
72            ring: self.ring,
73        }
74    }
75
76    /// Raises the stored element to the given power.
77    pub fn pow_ref(&self, power: usize) -> Self
78    where
79        R: Clone,
80    {
81        Self {
82            element: self.ring.pow(self.ring.clone_el(&self.element), power),
83            ring: self.ring.clone(),
84        }
85    }
86
87    /// Returns the stored element.
88    pub fn unwrap(self) -> El<R> { self.element }
89
90    /// Returns the stored element as reference.
91    pub fn unwrap_ref(&self) -> &El<R> { &self.element }
92
93    /// Returns a reference to the ring that this element belongs to.
94    pub fn parent(&self) -> &R { &self.ring }
95
96    /// Returns `true` if this element is zero.
97    ///
98    /// Equivalent to `self.parent().is_zero(self.unwrap_ref())`.
99    pub fn is_zero(&self) -> bool { self.parent().is_zero(self.unwrap_ref()) }
100
101    /// Returns `true` if this element is one.
102    ///
103    /// Equivalent to `self.parent().is_one(self.unwrap_ref())`.
104    pub fn is_one(&self) -> bool { self.parent().is_one(self.unwrap_ref()) }
105
106    /// Returns `true` if this element is negative one.
107    ///
108    /// Equivalent to `self.parent().is_neg_one(self.unwrap_ref())`.
109    pub fn is_neg_one(&self) -> bool { self.parent().is_neg_one(self.unwrap_ref()) }
110}
111
112macro_rules! impl_xassign_trait {
113    ($trait_name:ident, $fn_name:ident, $fn_ref_name:ident) => {
114        impl<R: RingStore> $trait_name for RingElementWrapper<R> {
115            fn $fn_name(&mut self, rhs: Self) {
116                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
117                self.ring.$fn_name(&mut self.element, rhs.element);
118            }
119        }
120
121        impl<'a, R: RingStore> $trait_name<&'a Self> for RingElementWrapper<R> {
122            fn $fn_name(&mut self, rhs: &'a Self) {
123                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
124                self.ring.$fn_ref_name(&mut self.element, &rhs.element);
125            }
126        }
127    };
128}
129
130macro_rules! impl_trait {
131    ($trait_name:ident, $fn_name:ident, $fn_name_ref_fst:ident, $fn_name_ref_snd:ident, $fn_name_ref:ident) => {
132        impl<R: RingStore> $trait_name for RingElementWrapper<R> {
133            type Output = Self;
134
135            fn $fn_name(self, rhs: Self) -> Self::Output {
136                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
137                Self {
138                    ring: self.ring,
139                    element: rhs.ring.$fn_name(self.element, rhs.element),
140                }
141            }
142        }
143
144        impl<'a, R: RingStore> $trait_name<RingElementWrapper<R>> for &'a RingElementWrapper<R> {
145            type Output = RingElementWrapper<R>;
146
147            fn $fn_name(self, rhs: RingElementWrapper<R>) -> Self::Output {
148                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
149                RingElementWrapper {
150                    ring: rhs.ring,
151                    element: self.ring.$fn_name_ref_fst(&self.element, rhs.element),
152                }
153            }
154        }
155
156        impl<'a, R: RingStore> $trait_name<&'a RingElementWrapper<R>> for RingElementWrapper<R> {
157            type Output = RingElementWrapper<R>;
158
159            fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
160                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
161                RingElementWrapper {
162                    ring: self.ring,
163                    element: rhs.ring.$fn_name_ref_snd(self.element, &rhs.element),
164                }
165            }
166        }
167
168        impl<'a, 'b, R: RingStore + Clone> $trait_name<&'a RingElementWrapper<R>> for &'b RingElementWrapper<R> {
169            type Output = RingElementWrapper<R>;
170
171            fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
172                debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
173                RingElementWrapper {
174                    ring: self.ring.clone(),
175                    element: self.ring.$fn_name_ref(&self.element, &rhs.element),
176                }
177            }
178        }
179    };
180}
181
182impl_xassign_trait! { AddAssign, add_assign, add_assign_ref }
183impl_xassign_trait! { MulAssign, mul_assign, mul_assign_ref }
184impl_xassign_trait! { SubAssign, sub_assign, sub_assign_ref }
185impl_trait! { Add, add, add_ref_fst, add_ref_snd, add_ref }
186impl_trait! { Mul, mul, mul_ref_fst, mul_ref_snd, mul_ref }
187impl_trait! { Sub, sub, sub_ref_fst, sub_ref_snd, sub_ref }
188
189impl<R: RingStore> Div<RingElementWrapper<R>> for RingElementWrapper<R>
190where
191    R::Type: Field,
192{
193    type Output = Self;
194
195    fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
196        RingElementWrapper {
197            element: self.ring.div(&self.element, &rhs.element),
198            ring: self.ring,
199        }
200    }
201}
202
203impl<'a, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for &RingElementWrapper<R>
204where
205    R::Type: Field,
206{
207    type Output = RingElementWrapper<R>;
208
209    fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
210        RingElementWrapper {
211            element: self.ring.div(&self.element, &rhs.element),
212            ring: self.ring.clone(),
213        }
214    }
215}
216
217impl<R: RingStore + Clone> Div<RingElementWrapper<R>> for &RingElementWrapper<R>
218where
219    R::Type: Field,
220{
221    type Output = RingElementWrapper<R>;
222
223    fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
224        RingElementWrapper {
225            element: self.ring.div(&self.element, &rhs.element),
226            ring: rhs.ring,
227        }
228    }
229}
230
231impl<'a, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for RingElementWrapper<R>
232where
233    R::Type: Field,
234{
235    type Output = RingElementWrapper<R>;
236
237    fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
238        RingElementWrapper {
239            element: rhs.ring.div(&self.element, &rhs.element),
240            ring: self.ring,
241        }
242    }
243}
244
245macro_rules! impl_xassign_trait_int {
246    ($trait_name:ident, $fn_name:ident) => {
247        impl<R: RingStore> $trait_name<i32> for RingElementWrapper<R> {
248            fn $fn_name(&mut self, rhs: i32) {
249                self.ring
250                    .$fn_name(&mut self.element, self.ring.int_hom().map(rhs));
251            }
252        }
253    };
254}
255
256macro_rules! impl_trait_int {
257    ($trait_name:ident, $fn_name:ident) => {
258        impl<R: RingStore> $trait_name<i32> for RingElementWrapper<R> {
259            type Output = Self;
260
261            fn $fn_name(self, rhs: i32) -> Self::Output {
262                RingElementWrapper {
263                    element: self.ring.$fn_name(self.element, self.ring.int_hom().map(rhs)),
264                    ring: self.ring,
265                }
266            }
267        }
268
269        impl<R: RingStore> $trait_name<RingElementWrapper<R>> for i32 {
270            type Output = RingElementWrapper<R>;
271
272            fn $fn_name(self, rhs: RingElementWrapper<R>) -> Self::Output {
273                RingElementWrapper {
274                    element: rhs.ring.$fn_name(rhs.ring.int_hom().map(self), rhs.element),
275                    ring: rhs.ring,
276                }
277            }
278        }
279
280        impl<'a, R: RingStore + Clone> $trait_name<i32> for &'a RingElementWrapper<R> {
281            type Output = RingElementWrapper<R>;
282
283            fn $fn_name(self, rhs: i32) -> Self::Output {
284                RingElementWrapper {
285                    element: self
286                        .ring
287                        .$fn_name(self.ring.clone_el(&self.element), self.ring.int_hom().map(rhs)),
288                    ring: self.ring.clone(),
289                }
290            }
291        }
292
293        impl<'a, R: RingStore + Clone> $trait_name<&'a RingElementWrapper<R>> for i32 {
294            type Output = RingElementWrapper<R>;
295
296            fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
297                RingElementWrapper {
298                    element: rhs
299                        .ring
300                        .$fn_name(rhs.ring.int_hom().map(self), rhs.ring.clone_el(&rhs.element)),
301                    ring: rhs.ring.clone(),
302                }
303            }
304        }
305    };
306}
307
308impl_xassign_trait_int! { AddAssign, add_assign }
309impl_xassign_trait_int! { MulAssign, mul_assign }
310impl_xassign_trait_int! { SubAssign, sub_assign }
311impl_trait_int! { Add, add }
312impl_trait_int! { Mul, mul }
313impl_trait_int! { Sub, sub }
314
315impl<R: RingStore> Div<i32> for RingElementWrapper<R>
316where
317    R::Type: Field,
318{
319    type Output = Self;
320
321    fn div(self, rhs: i32) -> Self::Output {
322        RingElementWrapper {
323            element: self.ring.div(&self.element, &self.ring.int_hom().map(rhs)),
324            ring: self.ring,
325        }
326    }
327}
328
329impl<R: RingStore> Div<RingElementWrapper<R>> for i32
330where
331    R::Type: Field,
332{
333    type Output = RingElementWrapper<R>;
334
335    fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
336        RingElementWrapper {
337            element: rhs.ring.div(&rhs.ring.int_hom().map(self), &rhs.element),
338            ring: rhs.ring,
339        }
340    }
341}
342
343impl<R: RingStore + Clone> Div<i32> for &RingElementWrapper<R>
344where
345    R::Type: Field,
346{
347    type Output = RingElementWrapper<R>;
348
349    fn div(self, rhs: i32) -> Self::Output {
350        RingElementWrapper {
351            element: self.ring.div(&self.element, &self.ring.int_hom().map(rhs)),
352            ring: self.ring.clone(),
353        }
354    }
355}
356
357impl<'a, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for i32
358where
359    R::Type: Field,
360{
361    type Output = RingElementWrapper<R>;
362
363    fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
364        RingElementWrapper {
365            element: rhs.ring.div(&rhs.ring.int_hom().map(self), &rhs.element),
366            ring: rhs.ring.clone(),
367        }
368    }
369}
370
371impl<R: RingStore + Copy> Copy for RingElementWrapper<R> where El<R>: Copy {}
372
373impl<R: RingStore + Clone> Clone for RingElementWrapper<R> {
374    fn clone(&self) -> Self {
375        Self {
376            ring: self.ring.clone(),
377            element: self.ring.clone_el(&self.element),
378        }
379    }
380}
381
382impl<R: RingStore> PartialEq for RingElementWrapper<R> {
383    fn eq(&self, other: &Self) -> bool {
384        debug_assert!(self.ring.get_ring() == other.ring.get_ring());
385        self.ring.eq_el(&self.element, &other.element)
386    }
387}
388
389impl<R: RingStore> Eq for RingElementWrapper<R> {}
390
391impl<R: RingStore> PartialEq<i32> for RingElementWrapper<R> {
392    fn eq(&self, other: &i32) -> bool {
393        match *other {
394            0 => self.is_zero(),
395            1 => self.is_one(),
396            -1 => self.is_neg_one(),
397            x => self.parent().eq_el(self.unwrap_ref(), &self.parent().int_hom().map(x)),
398        }
399    }
400}
401
402impl<R: RingStore> PartialEq<RingElementWrapper<R>> for i32 {
403    fn eq(&self, other: &RingElementWrapper<R>) -> bool { other == self }
404}
405
406impl<R: RingStore> Hash for RingElementWrapper<R>
407where
408    R::Type: HashableElRing,
409{
410    fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.ring.hash(&self.element, state) }
411}
412
413impl<R: RingStore> Display for RingElementWrapper<R> {
414    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.ring.get_ring().dbg(&self.element, f) }
415}
416
417impl<R: RingStore> Debug for RingElementWrapper<R> {
418    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.ring.get_ring().dbg(&self.element, f) }
419}
420
421impl<R: RingStore> Deref for RingElementWrapper<R> {
422    type Target = El<R>;
423
424    fn deref(&self) -> &Self::Target { &self.element }
425}
426
427#[cfg(test)]
428use crate::rings::finite::FiniteRingStore;
429#[cfg(test)]
430use crate::rings::zn::zn_64;
431
432#[test]
433fn test_arithmetic_expression() {
434    let ring = zn_64::Zn::new(17);
435
436    for x in ring.elements() {
437        for y in ring.elements() {
438            for z in ring.elements() {
439                let expected = ring.add(ring.mul(x, y), ring.mul(ring.add(x, z), ring.sub(y, z)));
440                let x = RingElementWrapper::new(&ring, x);
441                let y = RingElementWrapper::new(&ring, y);
442                let z = RingElementWrapper::new(&ring, z);
443                assert_el_eq!(ring, expected, (x * y + (x + z) * (y - z)).unwrap());
444            }
445        }
446    }
447}
448
449#[test]
450fn test_arithmetic_expression_int() {
451    let ring = zn_64::Zn::new(17);
452
453    for x in ring.elements() {
454        for y in ring.elements() {
455            for z in ring.elements() {
456                let expected = ring.add(
457                    ring.add(
458                        ring.int_hom().mul_map(ring.mul(x, y), 8),
459                        ring.mul(
460                            ring.add(ring.add(ring.one(), x), ring.int_hom().mul_map(z, 2)),
461                            ring.sub(y, ring.int_hom().mul_map(z, 2)),
462                        ),
463                    ),
464                    ring.int_hom().map(5),
465                );
466                let x = RingElementWrapper::new(&ring, x);
467                let y = RingElementWrapper::new(&ring, y);
468                let z = RingElementWrapper::new(&ring, z);
469                assert_el_eq!(ring, expected, (x * 8 * y + (1 + x + 2 * z) * (y - z * 2) + 5).unwrap());
470            }
471        }
472    }
473}
474
475#[test]
476fn test_arithmetic_expression_ref() {
477    let ring = zn_64::Zn::new(17);
478
479    for x in ring.elements() {
480        for y in ring.elements() {
481            for z in ring.elements() {
482                let expected = ring.add(ring.mul(x, y), ring.mul(ring.add(x, z), ring.sub(y, z)));
483                let x = RingElementWrapper::new(&ring, x);
484                let y = RingElementWrapper::new(&ring, y);
485                let z = RingElementWrapper::new(&ring, z);
486                assert_el_eq!(ring, expected, (x * &y + (&x + &z) * (&y - z)).unwrap());
487            }
488        }
489    }
490}
491
492#[test]
493fn test_arithmetic_expression_int_ref() {
494    let ring = zn_64::Zn::new(17);
495
496    for x in ring.elements() {
497        for y in ring.elements() {
498            for z in ring.elements() {
499                let expected = ring.add(
500                    ring.add(
501                        ring.int_hom().mul_map(ring.mul(x, y), 8),
502                        ring.mul(
503                            ring.add(ring.add(ring.one(), x), ring.int_hom().mul_map(z, 2)),
504                            ring.sub(y, ring.int_hom().mul_map(z, 2)),
505                        ),
506                    ),
507                    ring.int_hom().map(5),
508                );
509                let x = RingElementWrapper::new(&ring, x);
510                let y = RingElementWrapper::new(&ring, y);
511                let z = RingElementWrapper::new(&ring, z);
512                assert_el_eq!(
513                    ring,
514                    expected,
515                    (x * 8 * &y + (1 + &x + 2 * &z) * (&y - z * 2) + 5).unwrap()
516                );
517            }
518        }
519    }
520}