1use std::{ops::*, hash::Hash, fmt::{Display, Debug}};
2
3use crate::homomorphism::*;
4use crate::ring::*;
5use crate::field::*;
6
7pub struct RingElementWrapper<R>
41 where R: RingStore
42{
43 ring: R,
44 element: El<R>
45}
46
47impl<R: RingStore> RingElementWrapper<R> {
48
49 pub const fn new(ring: R, element: El<R>) -> Self {
53 Self { ring, element }
54 }
55
56 pub fn pow(self, power: usize) -> Self {
63 Self {
64 element: self.ring.pow(self.element, power),
65 ring: self.ring
66 }
67 }
68
69 pub fn pow_ref(&self, power: usize) -> Self
73 where R: Clone
74 {
75 Self {
76 element: self.ring.pow(self.ring.clone_el(&self.element), power),
77 ring: self.ring.clone()
78 }
79 }
80
81 pub fn unwrap(self) -> El<R> {
85 self.element
86 }
87
88 pub fn unwrap_ref(&self) -> &El<R> {
92 &self.element
93 }
94
95 pub fn parent(&self) -> &R {
99 &self.ring
100 }
101
102 pub fn is_zero(&self) -> bool {
108 self.parent().is_zero(self.unwrap_ref())
109 }
110
111 pub fn is_one(&self) -> bool {
117 self.parent().is_one(self.unwrap_ref())
118 }
119
120 pub fn is_neg_one(&self) -> bool {
126 self.parent().is_neg_one(self.unwrap_ref())
127 }
128}
129
130macro_rules! impl_xassign_trait {
131 ($trait_name:ident, $fn_name:ident, $fn_ref_name:ident) => {
132
133 impl<R: RingStore> $trait_name for RingElementWrapper<R> {
134
135 fn $fn_name(&mut self, rhs: Self) {
136 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
137 self.ring.$fn_name(&mut self.element, rhs.element);
138 }
139 }
140
141 impl<'a, R: RingStore> $trait_name<&'a Self> for RingElementWrapper<R> {
142
143 fn $fn_name(&mut self, rhs: &'a Self) {
144 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
145 self.ring.$fn_ref_name(&mut self.element, &rhs.element);
146 }
147 }
148 };
149}
150
151macro_rules! impl_trait {
152 ($trait_name:ident, $fn_name:ident, $fn_name_ref_fst:ident, $fn_name_ref_snd:ident, $fn_name_ref:ident) => {
153
154 impl<R: RingStore> $trait_name for RingElementWrapper<R> {
155 type Output = Self;
156
157 fn $fn_name(self, rhs: Self) -> Self::Output {
158 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
159 Self { ring: self.ring, element: rhs.ring.$fn_name(self.element, rhs.element) }
160 }
161 }
162
163 impl<'a, R: RingStore> $trait_name<RingElementWrapper<R>> for &'a RingElementWrapper<R> {
164 type Output = RingElementWrapper<R>;
165
166 fn $fn_name(self, rhs: RingElementWrapper<R>) -> Self::Output {
167 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
168 RingElementWrapper { ring: rhs.ring, element: self.ring.$fn_name_ref_fst(&self.element, rhs.element) }
169 }
170 }
171
172 impl<'a, R: RingStore> $trait_name<&'a RingElementWrapper<R>> for RingElementWrapper<R> {
173 type Output = RingElementWrapper<R>;
174
175 fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
176 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
177 RingElementWrapper { ring: self.ring, element: rhs.ring.$fn_name_ref_snd(self.element, &rhs.element) }
178 }
179 }
180
181 impl<'a, 'b, R: RingStore + Clone> $trait_name<&'a RingElementWrapper<R>> for &'b RingElementWrapper<R> {
182 type Output = RingElementWrapper<R>;
183
184 fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
185 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
186 RingElementWrapper { ring: self.ring.clone(), element: self.ring.$fn_name_ref(&self.element, &rhs.element) }
187 }
188 }
189 };
190}
191
192impl_xassign_trait!{ AddAssign, add_assign, add_assign_ref }
193impl_xassign_trait!{ MulAssign, mul_assign, mul_assign_ref }
194impl_xassign_trait!{ SubAssign, sub_assign, sub_assign_ref }
195impl_trait!{ Add, add, add_ref_fst, add_ref_snd, add_ref }
196impl_trait!{ Mul, mul, mul_ref_fst, mul_ref_snd, mul_ref }
197impl_trait!{ Sub, sub, sub_ref_fst, sub_ref_snd, sub_ref }
198
199impl<R: RingStore> Div<RingElementWrapper<R>> for RingElementWrapper<R>
200 where R::Type: Field
201{
202 type Output = Self;
203
204 fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
205 RingElementWrapper { element: self.ring.div(&self.element, &rhs.element), ring: self.ring }
206 }
207}
208
209impl<'a, 'b, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for &'b RingElementWrapper<R>
210 where R::Type: Field
211{
212 type Output = RingElementWrapper<R>;
213
214 fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
215 RingElementWrapper { element: self.ring.div(&self.element, &rhs.element), ring: self.ring.clone() }
216 }
217}
218
219impl<'a, R: RingStore + Clone> Div<RingElementWrapper<R>> for &'a RingElementWrapper<R>
220 where R::Type: Field
221{
222 type Output = RingElementWrapper<R>;
223
224 fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
225 RingElementWrapper { element: self.ring.div(&self.element, &rhs.element), ring: rhs.ring }
226 }
227}
228
229impl<'a, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for RingElementWrapper<R>
230 where R::Type: Field
231{
232 type Output = RingElementWrapper<R>;
233
234 fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
235 RingElementWrapper { element: rhs.ring.div(&self.element, &rhs.element), ring: self.ring }
236 }
237}
238
239macro_rules! impl_xassign_trait_int {
240 ($trait_name:ident, $fn_name:ident) => {
241
242 impl<R: RingStore> $trait_name<i32> for RingElementWrapper<R> {
243
244 fn $fn_name(&mut self, rhs: i32) {
245 self.ring.$fn_name(&mut self.element, self.ring.int_hom().map(rhs));
246 }
247 }
248 };
249}
250
251macro_rules! impl_trait_int {
252 ($trait_name:ident, $fn_name:ident) => {
253
254 impl<R: RingStore> $trait_name<i32> for RingElementWrapper<R> {
255 type Output = Self;
256
257 fn $fn_name(self, rhs: i32) -> Self::Output {
258 RingElementWrapper { element: self.ring.$fn_name(self.element, self.ring.int_hom().map(rhs)), ring: self.ring }
259 }
260 }
261
262 impl<R: RingStore> $trait_name<RingElementWrapper<R>> for i32 {
263 type Output = RingElementWrapper<R>;
264
265 fn $fn_name(self, rhs: RingElementWrapper<R>) -> Self::Output {
266 RingElementWrapper { element: rhs.ring.$fn_name(rhs.ring.int_hom().map(self), rhs.element), ring: rhs.ring }
267 }
268 }
269
270 impl<'a, R: RingStore + Clone> $trait_name<i32> for &'a RingElementWrapper<R> {
271 type Output = RingElementWrapper<R>;
272
273 fn $fn_name(self, rhs: i32) -> Self::Output {
274 RingElementWrapper { element: self.ring.$fn_name(self.ring.clone_el(&self.element), self.ring.int_hom().map(rhs)), ring: self.ring.clone() }
275 }
276 }
277
278 impl<'a, R: RingStore + Clone> $trait_name<&'a RingElementWrapper<R>> for i32 {
279 type Output = RingElementWrapper<R>;
280
281 fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
282 RingElementWrapper { element: rhs.ring.$fn_name(rhs.ring.int_hom().map(self), rhs.ring.clone_el(&rhs.element)), ring: rhs.ring.clone() }
283 }
284 }
285 };
286}
287
288impl_xassign_trait_int!{ AddAssign, add_assign }
289impl_xassign_trait_int!{ MulAssign, mul_assign }
290impl_xassign_trait_int!{ SubAssign, sub_assign }
291impl_trait_int!{ Add, add }
292impl_trait_int!{ Mul, mul }
293impl_trait_int!{ Sub, sub }
294
295impl<R: RingStore> Div<i32> for RingElementWrapper<R>
296 where R::Type: Field
297{
298 type Output = Self;
299
300 fn div(self, rhs: i32) -> Self::Output {
301 RingElementWrapper { element: self.ring.div(&self.element, &self.ring.int_hom().map(rhs)), ring: self.ring }
302 }
303}
304
305impl<R: RingStore> Div<RingElementWrapper<R>> for i32
306 where R::Type: Field
307{
308 type Output = RingElementWrapper<R>;
309
310 fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
311 RingElementWrapper { element: rhs.ring.div(&rhs.ring.int_hom().map(self), &rhs.element), ring: rhs.ring }
312 }
313}
314
315impl<'a, R: RingStore + Clone> Div<i32> for &'a RingElementWrapper<R>
316 where R::Type: Field
317{
318 type Output = RingElementWrapper<R>;
319
320 fn div(self, rhs: i32) -> Self::Output {
321 RingElementWrapper { element: self.ring.div(&self.element, &self.ring.int_hom().map(rhs)), ring: self.ring.clone() }
322 }
323}
324
325impl<'a, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for i32
326 where R::Type: Field
327{
328 type Output = RingElementWrapper<R>;
329
330 fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
331 RingElementWrapper { element: rhs.ring.div(&rhs.ring.int_hom().map(self), &rhs.element), ring: rhs.ring.clone() }
332 }
333}
334
335impl<R: RingStore + Copy> Copy for RingElementWrapper<R>
336 where El<R>: Copy
337{}
338
339impl<R: RingStore + Clone> Clone for RingElementWrapper<R> {
340
341 fn clone(&self) -> Self {
342 Self { ring: self.ring.clone(), element: self.ring.clone_el(&self.element) }
343 }
344}
345
346impl<R: RingStore> PartialEq for RingElementWrapper<R> {
347
348 fn eq(&self, other: &Self) -> bool {
349 debug_assert!(self.ring.get_ring() == other.ring.get_ring());
350 self.ring.eq_el(&self.element, &other.element)
351 }
352}
353
354impl<R: RingStore> Eq for RingElementWrapper<R> {}
355
356impl<R: RingStore> PartialEq<i32> for RingElementWrapper<R> {
357
358 fn eq(&self, other: &i32) -> bool {
359 match *other {
360 0 => self.is_zero(),
361 1 => self.is_one(),
362 -1 => self.is_neg_one(),
363 x => self.parent().eq_el(self.unwrap_ref(), &self.parent().int_hom().map(x))
364 }
365 }
366}
367
368impl<R: RingStore> PartialEq<RingElementWrapper<R>> for i32 {
369
370 fn eq(&self, other: &RingElementWrapper<R>) -> bool {
371 other == self
372 }
373}
374
375impl<R: RingStore> Hash for RingElementWrapper<R>
376 where R::Type: HashableElRing
377{
378 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
379 self.ring.hash(&self.element, state)
380 }
381}
382
383impl<R: RingStore> Display for RingElementWrapper<R> {
384
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 self.ring.get_ring().dbg(&self.element, f)
387 }
388}
389
390impl<R: RingStore> Debug for RingElementWrapper<R> {
391
392 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393 self.ring.get_ring().dbg(&self.element, f)
394 }
395}
396
397impl<R: RingStore> Deref for RingElementWrapper<R> {
398 type Target = El<R>;
399
400 fn deref(&self) -> &Self::Target {
401 &self.element
402 }
403}
404
405#[cfg(test)]
406use crate::rings::finite::FiniteRingStore;
407#[cfg(test)]
408use crate::rings::zn::zn_64;
409
410#[test]
411fn test_arithmetic_expression() {
412 let ring = zn_64::Zn::new(17);
413
414 for x in ring.elements() {
415 for y in ring.elements() {
416 for z in ring.elements() {
417 let expected = ring.add(ring.mul(x, y), ring.mul(ring.add(x, z), ring.sub(y, z)));
418 let x = RingElementWrapper::new(&ring, x);
419 let y = RingElementWrapper::new(&ring, y);
420 let z = RingElementWrapper::new(&ring, z);
421 assert_el_eq!(ring, expected, (x * y + (x + z) * (y - z)).unwrap());
422 }
423 }
424 }
425}
426
427#[test]
428fn test_arithmetic_expression_int() {
429 let ring = zn_64::Zn::new(17);
430
431 for x in ring.elements() {
432 for y in ring.elements() {
433 for z in ring.elements() {
434 let expected = ring.add(ring.add(ring.int_hom().mul_map(ring.mul(x, y), 8), ring.mul(ring.add(ring.add(ring.one(), x), ring.int_hom().mul_map(z, 2)), ring.sub(y, ring.int_hom().mul_map(z, 2)))), ring.int_hom().map(5));
435 let x = RingElementWrapper::new(&ring, x);
436 let y = RingElementWrapper::new(&ring, y);
437 let z = RingElementWrapper::new(&ring, z);
438 assert_el_eq!(ring, expected, (x * 8 * y + (1 + x + 2 * z) * (y - z * 2) + 5).unwrap());
439 }
440 }
441 }
442}
443
444#[test]
445fn test_arithmetic_expression_ref() {
446 let ring = zn_64::Zn::new(17);
447
448 for x in ring.elements() {
449 for y in ring.elements() {
450 for z in ring.elements() {
451 let expected = ring.add(ring.mul(x, y), ring.mul(ring.add(x, z), ring.sub(y, z)));
452 let x = RingElementWrapper::new(&ring, x);
453 let y = RingElementWrapper::new(&ring, y);
454 let z = RingElementWrapper::new(&ring, z);
455 assert_el_eq!(ring, expected, (x * &y + (&x + &z) * (&y - z)).unwrap());
456 }
457 }
458 }
459}
460
461#[test]
462fn test_arithmetic_expression_int_ref() {
463 let ring = zn_64::Zn::new(17);
464
465 for x in ring.elements() {
466 for y in ring.elements() {
467 for z in ring.elements() {
468 let expected = ring.add(ring.add(ring.int_hom().mul_map(ring.mul(x, y), 8), ring.mul(ring.add(ring.add(ring.one(), x), ring.int_hom().mul_map(z, 2)), ring.sub(y, ring.int_hom().mul_map(z, 2)))), ring.int_hom().map(5));
469 let x = RingElementWrapper::new(&ring, x);
470 let y = RingElementWrapper::new(&ring, y);
471 let z = RingElementWrapper::new(&ring, z);
472 assert_el_eq!(ring, expected, (x * 8 * &y + (1 + &x + 2 * &z) * (&y - z * 2) + 5).unwrap());
473 }
474 }
475 }
476}