1use std::fmt::{Debug, Display};
2use std::hash::Hash;
3use std::ops::*;
4
5use crate::field::*;
6use crate::homomorphism::*;
7use crate::ring::*;
8
9pub struct RingElementWrapper<R>
54where
55 R: RingStore,
56{
57 ring: R,
58 element: El<R>,
59}
60
61impl<R: RingStore> RingElementWrapper<R> {
62 pub const fn new(ring: R, element: El<R>) -> Self { Self { ring, element } }
64
65 pub fn pow(self, power: usize) -> Self {
70 Self {
71 element: self.ring.pow(self.element, power),
72 ring: self.ring,
73 }
74 }
75
76 pub fn pow_ref(&self, power: usize) -> Self
78 where
79 R: Clone,
80 {
81 Self {
82 element: self.ring.pow(self.ring.clone_el(&self.element), power),
83 ring: self.ring.clone(),
84 }
85 }
86
87 pub fn unwrap(self) -> El<R> { self.element }
89
90 pub fn unwrap_ref(&self) -> &El<R> { &self.element }
92
93 pub fn parent(&self) -> &R { &self.ring }
95
96 pub fn is_zero(&self) -> bool { self.parent().is_zero(self.unwrap_ref()) }
100
101 pub fn is_one(&self) -> bool { self.parent().is_one(self.unwrap_ref()) }
105
106 pub fn is_neg_one(&self) -> bool { self.parent().is_neg_one(self.unwrap_ref()) }
110}
111
112macro_rules! impl_xassign_trait {
113 ($trait_name:ident, $fn_name:ident, $fn_ref_name:ident) => {
114 impl<R: RingStore> $trait_name for RingElementWrapper<R> {
115 fn $fn_name(&mut self, rhs: Self) {
116 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
117 self.ring.$fn_name(&mut self.element, rhs.element);
118 }
119 }
120
121 impl<'a, R: RingStore> $trait_name<&'a Self> for RingElementWrapper<R> {
122 fn $fn_name(&mut self, rhs: &'a Self) {
123 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
124 self.ring.$fn_ref_name(&mut self.element, &rhs.element);
125 }
126 }
127 };
128}
129
130macro_rules! impl_trait {
131 ($trait_name:ident, $fn_name:ident, $fn_name_ref_fst:ident, $fn_name_ref_snd:ident, $fn_name_ref:ident) => {
132 impl<R: RingStore> $trait_name for RingElementWrapper<R> {
133 type Output = Self;
134
135 fn $fn_name(self, rhs: Self) -> Self::Output {
136 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
137 Self {
138 ring: self.ring,
139 element: rhs.ring.$fn_name(self.element, rhs.element),
140 }
141 }
142 }
143
144 impl<'a, R: RingStore> $trait_name<RingElementWrapper<R>> for &'a RingElementWrapper<R> {
145 type Output = RingElementWrapper<R>;
146
147 fn $fn_name(self, rhs: RingElementWrapper<R>) -> Self::Output {
148 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
149 RingElementWrapper {
150 ring: rhs.ring,
151 element: self.ring.$fn_name_ref_fst(&self.element, rhs.element),
152 }
153 }
154 }
155
156 impl<'a, R: RingStore> $trait_name<&'a RingElementWrapper<R>> for RingElementWrapper<R> {
157 type Output = RingElementWrapper<R>;
158
159 fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
160 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
161 RingElementWrapper {
162 ring: self.ring,
163 element: rhs.ring.$fn_name_ref_snd(self.element, &rhs.element),
164 }
165 }
166 }
167
168 impl<'a, 'b, R: RingStore + Clone> $trait_name<&'a RingElementWrapper<R>> for &'b RingElementWrapper<R> {
169 type Output = RingElementWrapper<R>;
170
171 fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
172 debug_assert!(self.ring.get_ring() == rhs.ring.get_ring());
173 RingElementWrapper {
174 ring: self.ring.clone(),
175 element: self.ring.$fn_name_ref(&self.element, &rhs.element),
176 }
177 }
178 }
179 };
180}
181
182impl_xassign_trait! { AddAssign, add_assign, add_assign_ref }
183impl_xassign_trait! { MulAssign, mul_assign, mul_assign_ref }
184impl_xassign_trait! { SubAssign, sub_assign, sub_assign_ref }
185impl_trait! { Add, add, add_ref_fst, add_ref_snd, add_ref }
186impl_trait! { Mul, mul, mul_ref_fst, mul_ref_snd, mul_ref }
187impl_trait! { Sub, sub, sub_ref_fst, sub_ref_snd, sub_ref }
188
189impl<R: RingStore> Div<RingElementWrapper<R>> for RingElementWrapper<R>
190where
191 R::Type: Field,
192{
193 type Output = Self;
194
195 fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
196 RingElementWrapper {
197 element: self.ring.div(&self.element, &rhs.element),
198 ring: self.ring,
199 }
200 }
201}
202
203impl<'a, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for &RingElementWrapper<R>
204where
205 R::Type: Field,
206{
207 type Output = RingElementWrapper<R>;
208
209 fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
210 RingElementWrapper {
211 element: self.ring.div(&self.element, &rhs.element),
212 ring: self.ring.clone(),
213 }
214 }
215}
216
217impl<R: RingStore + Clone> Div<RingElementWrapper<R>> for &RingElementWrapper<R>
218where
219 R::Type: Field,
220{
221 type Output = RingElementWrapper<R>;
222
223 fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
224 RingElementWrapper {
225 element: self.ring.div(&self.element, &rhs.element),
226 ring: rhs.ring,
227 }
228 }
229}
230
231impl<'a, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for RingElementWrapper<R>
232where
233 R::Type: Field,
234{
235 type Output = RingElementWrapper<R>;
236
237 fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
238 RingElementWrapper {
239 element: rhs.ring.div(&self.element, &rhs.element),
240 ring: self.ring,
241 }
242 }
243}
244
245macro_rules! impl_xassign_trait_int {
246 ($trait_name:ident, $fn_name:ident) => {
247 impl<R: RingStore> $trait_name<i32> for RingElementWrapper<R> {
248 fn $fn_name(&mut self, rhs: i32) {
249 self.ring
250 .$fn_name(&mut self.element, self.ring.int_hom().map(rhs));
251 }
252 }
253 };
254}
255
256macro_rules! impl_trait_int {
257 ($trait_name:ident, $fn_name:ident) => {
258 impl<R: RingStore> $trait_name<i32> for RingElementWrapper<R> {
259 type Output = Self;
260
261 fn $fn_name(self, rhs: i32) -> Self::Output {
262 RingElementWrapper {
263 element: self.ring.$fn_name(self.element, self.ring.int_hom().map(rhs)),
264 ring: self.ring,
265 }
266 }
267 }
268
269 impl<R: RingStore> $trait_name<RingElementWrapper<R>> for i32 {
270 type Output = RingElementWrapper<R>;
271
272 fn $fn_name(self, rhs: RingElementWrapper<R>) -> Self::Output {
273 RingElementWrapper {
274 element: rhs.ring.$fn_name(rhs.ring.int_hom().map(self), rhs.element),
275 ring: rhs.ring,
276 }
277 }
278 }
279
280 impl<'a, R: RingStore + Clone> $trait_name<i32> for &'a RingElementWrapper<R> {
281 type Output = RingElementWrapper<R>;
282
283 fn $fn_name(self, rhs: i32) -> Self::Output {
284 RingElementWrapper {
285 element: self
286 .ring
287 .$fn_name(self.ring.clone_el(&self.element), self.ring.int_hom().map(rhs)),
288 ring: self.ring.clone(),
289 }
290 }
291 }
292
293 impl<'a, R: RingStore + Clone> $trait_name<&'a RingElementWrapper<R>> for i32 {
294 type Output = RingElementWrapper<R>;
295
296 fn $fn_name(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
297 RingElementWrapper {
298 element: rhs
299 .ring
300 .$fn_name(rhs.ring.int_hom().map(self), rhs.ring.clone_el(&rhs.element)),
301 ring: rhs.ring.clone(),
302 }
303 }
304 }
305 };
306}
307
308impl_xassign_trait_int! { AddAssign, add_assign }
309impl_xassign_trait_int! { MulAssign, mul_assign }
310impl_xassign_trait_int! { SubAssign, sub_assign }
311impl_trait_int! { Add, add }
312impl_trait_int! { Mul, mul }
313impl_trait_int! { Sub, sub }
314
315impl<R: RingStore> Div<i32> for RingElementWrapper<R>
316where
317 R::Type: Field,
318{
319 type Output = Self;
320
321 fn div(self, rhs: i32) -> Self::Output {
322 RingElementWrapper {
323 element: self.ring.div(&self.element, &self.ring.int_hom().map(rhs)),
324 ring: self.ring,
325 }
326 }
327}
328
329impl<R: RingStore> Div<RingElementWrapper<R>> for i32
330where
331 R::Type: Field,
332{
333 type Output = RingElementWrapper<R>;
334
335 fn div(self, rhs: RingElementWrapper<R>) -> Self::Output {
336 RingElementWrapper {
337 element: rhs.ring.div(&rhs.ring.int_hom().map(self), &rhs.element),
338 ring: rhs.ring,
339 }
340 }
341}
342
343impl<R: RingStore + Clone> Div<i32> for &RingElementWrapper<R>
344where
345 R::Type: Field,
346{
347 type Output = RingElementWrapper<R>;
348
349 fn div(self, rhs: i32) -> Self::Output {
350 RingElementWrapper {
351 element: self.ring.div(&self.element, &self.ring.int_hom().map(rhs)),
352 ring: self.ring.clone(),
353 }
354 }
355}
356
357impl<'a, R: RingStore + Clone> Div<&'a RingElementWrapper<R>> for i32
358where
359 R::Type: Field,
360{
361 type Output = RingElementWrapper<R>;
362
363 fn div(self, rhs: &'a RingElementWrapper<R>) -> Self::Output {
364 RingElementWrapper {
365 element: rhs.ring.div(&rhs.ring.int_hom().map(self), &rhs.element),
366 ring: rhs.ring.clone(),
367 }
368 }
369}
370
371impl<R: RingStore + Copy> Copy for RingElementWrapper<R> where El<R>: Copy {}
372
373impl<R: RingStore + Clone> Clone for RingElementWrapper<R> {
374 fn clone(&self) -> Self {
375 Self {
376 ring: self.ring.clone(),
377 element: self.ring.clone_el(&self.element),
378 }
379 }
380}
381
382impl<R: RingStore> PartialEq for RingElementWrapper<R> {
383 fn eq(&self, other: &Self) -> bool {
384 debug_assert!(self.ring.get_ring() == other.ring.get_ring());
385 self.ring.eq_el(&self.element, &other.element)
386 }
387}
388
389impl<R: RingStore> Eq for RingElementWrapper<R> {}
390
391impl<R: RingStore> PartialEq<i32> for RingElementWrapper<R> {
392 fn eq(&self, other: &i32) -> bool {
393 match *other {
394 0 => self.is_zero(),
395 1 => self.is_one(),
396 -1 => self.is_neg_one(),
397 x => self.parent().eq_el(self.unwrap_ref(), &self.parent().int_hom().map(x)),
398 }
399 }
400}
401
402impl<R: RingStore> PartialEq<RingElementWrapper<R>> for i32 {
403 fn eq(&self, other: &RingElementWrapper<R>) -> bool { other == self }
404}
405
406impl<R: RingStore> Hash for RingElementWrapper<R>
407where
408 R::Type: HashableElRing,
409{
410 fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.ring.hash(&self.element, state) }
411}
412
413impl<R: RingStore> Display for RingElementWrapper<R> {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.ring.get_ring().dbg(&self.element, f) }
415}
416
417impl<R: RingStore> Debug for RingElementWrapper<R> {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.ring.get_ring().dbg(&self.element, f) }
419}
420
421impl<R: RingStore> Deref for RingElementWrapper<R> {
422 type Target = El<R>;
423
424 fn deref(&self) -> &Self::Target { &self.element }
425}
426
427#[cfg(test)]
428use crate::rings::finite::FiniteRingStore;
429#[cfg(test)]
430use crate::rings::zn::zn_64;
431
432#[test]
433fn test_arithmetic_expression() {
434 let ring = zn_64::Zn::new(17);
435
436 for x in ring.elements() {
437 for y in ring.elements() {
438 for z in ring.elements() {
439 let expected = ring.add(ring.mul(x, y), ring.mul(ring.add(x, z), ring.sub(y, z)));
440 let x = RingElementWrapper::new(&ring, x);
441 let y = RingElementWrapper::new(&ring, y);
442 let z = RingElementWrapper::new(&ring, z);
443 assert_el_eq!(ring, expected, (x * y + (x + z) * (y - z)).unwrap());
444 }
445 }
446 }
447}
448
449#[test]
450fn test_arithmetic_expression_int() {
451 let ring = zn_64::Zn::new(17);
452
453 for x in ring.elements() {
454 for y in ring.elements() {
455 for z in ring.elements() {
456 let expected = ring.add(
457 ring.add(
458 ring.int_hom().mul_map(ring.mul(x, y), 8),
459 ring.mul(
460 ring.add(ring.add(ring.one(), x), ring.int_hom().mul_map(z, 2)),
461 ring.sub(y, ring.int_hom().mul_map(z, 2)),
462 ),
463 ),
464 ring.int_hom().map(5),
465 );
466 let x = RingElementWrapper::new(&ring, x);
467 let y = RingElementWrapper::new(&ring, y);
468 let z = RingElementWrapper::new(&ring, z);
469 assert_el_eq!(ring, expected, (x * 8 * y + (1 + x + 2 * z) * (y - z * 2) + 5).unwrap());
470 }
471 }
472 }
473}
474
475#[test]
476fn test_arithmetic_expression_ref() {
477 let ring = zn_64::Zn::new(17);
478
479 for x in ring.elements() {
480 for y in ring.elements() {
481 for z in ring.elements() {
482 let expected = ring.add(ring.mul(x, y), ring.mul(ring.add(x, z), ring.sub(y, z)));
483 let x = RingElementWrapper::new(&ring, x);
484 let y = RingElementWrapper::new(&ring, y);
485 let z = RingElementWrapper::new(&ring, z);
486 assert_el_eq!(ring, expected, (x * &y + (&x + &z) * (&y - z)).unwrap());
487 }
488 }
489 }
490}
491
492#[test]
493fn test_arithmetic_expression_int_ref() {
494 let ring = zn_64::Zn::new(17);
495
496 for x in ring.elements() {
497 for y in ring.elements() {
498 for z in ring.elements() {
499 let expected = ring.add(
500 ring.add(
501 ring.int_hom().mul_map(ring.mul(x, y), 8),
502 ring.mul(
503 ring.add(ring.add(ring.one(), x), ring.int_hom().mul_map(z, 2)),
504 ring.sub(y, ring.int_hom().mul_map(z, 2)),
505 ),
506 ),
507 ring.int_hom().map(5),
508 );
509 let x = RingElementWrapper::new(&ring, x);
510 let y = RingElementWrapper::new(&ring, y);
511 let z = RingElementWrapper::new(&ring, z);
512 assert_el_eq!(
513 ring,
514 expected,
515 (x * 8 * &y + (1 + &x + 2 * &z) * (&y - z * 2) + 5).unwrap()
516 );
517 }
518 }
519 }
520}