Skip to main content

bare_metal_modulo/
lib.rs

1#![cfg_attr(not(test), no_std)]
2//! # Overview
3//! The bare_metal_modulo crate includes the following structs:
4//! - `ModNum` is a highly ergonomic modular arithmetic struct intended for `no_std` use.
5//! - `ModNumC` is similar to `ModNum`, but uses
6//!   [const generics](https://rust-lang.github.io/rfcs/2000-const-generics.html) to specify the
7//!   modulo.
8//! - `ModNumIterator` is a double-ended iterator that starts with the `ModNum` or `ModNumC` upon
9//!   which it is invoked, making a complete traversal of the elements in that object's ring.
10//! - `WrapCountNum` is similar to `ModNum`, but additionally tracks the amount of "wraparound"
11//!   that produced the modulo value. `WrapCountNumC` corresponds to `ModNumC`.
12//! - `OffsetNum` is similar to `ModNum`, but instead of a zero-based range,
13//!   it can start at any integer value. `OffsetNumC` corresponds to `ModNumC`.
14//!
15//! `ModNum` objects represent a value modulo **m**. The value and modulo can be of any
16//! primitive integer type.  Arithmetic operators include `+`, `-` (both unary and binary),
17//! `*`, `/`, `pow()`, `==`, `<`, `>`, `<=`, `>=`, and `!=`. Additional capabilities include
18//! computing multiplicative inverses and solving modular equations. Division, multiplicative
19//! inverse, and solving modular equations are only defined for signed types.
20//!
21//! `ModNumC` objects likewise represent a value modulo **M**, where **M** is a generic constant of
22//! the `usize` type. `ModNumC` objects support the same arithmetic operators as `ModNum` objects.
23//!
24//! Modular numbers represent the remainder of an integer when divided by the modulo. If we also
25//! store the quotient in addition to the remainder, we have a count of the number of times a
26//! value had to "wrap around" during the calculation.
27//!
28//! For example, if we start with **8 (mod 17)** and add **42**, the result is **16 (mod 17)** with
29//! a wraparound of **2**.
30//!
31//! `WrapCountNum`/`WrapCountNumC` objects store this wraparound value and make it available. They
32//! only support subtraction and iteration with `Signed` values such as `isize` and `i64`.
33//!
34//! This library was originally developed to facilitate bidirectional navigation through fixed-size
35//! arrays at arbitrary starting points. This is facilitated by a double-ended iterator that
36//! traverses the entire ring starting at any desired value. The iterator supports
37//! `ModNum`, `ModNumC`, `OffsetNum`, and `OffsetNumC`. It also supports `WrapCountNum` and
38//! `WrapCountNumC` for signed values only.
39//!
40//! Note that `ModNum`, `ModNumC`, `WrapCountNum`, `WrapCountNumC`, `OffsetNum`, and `OffsetNumC`
41//! are not designed to work with arbitrary-length integers, as they require their integer type
42//! to implement the `Copy` trait.
43//!
44//! For the [2020 Advent of Code](https://adventofcode.com/2020)
45//! ([Day 13](https://adventofcode.com/2020/day/13) part 2),
46//! I extended `ModNum` to solve systems of modular equations, provided that each modular equation
47//! is represented using signed integers. My implementation is based on this
48//! [lucid](https://byorgey.wordpress.com/2020/02/15/competitive-programming-in-haskell-modular-arithmetic-part-1/)
49//! [explanation](https://byorgey.wordpress.com/2020/03/03/competitive-programming-in-haskell-modular-arithmetic-part-2/)
50//! by [Brent Yorgey](http://ozark.hendrix.edu/~yorgey/).
51//!
52//! # Accessing Values
53//! Each `ModNum`/`ModNumC`/`WrapCountNum`/`WrapCountNumC` represents an integer **a (mod m)**. To
54//! access these values, use the corresponding **a()** and **m()** methods. Note that **a()** will
55//! always return a fully reduced value, regardless of how it was initialized.
56//!
57//! Each `WrapCountNum`/`WrapCountNumC` tracks accumulated wraparounds. Use the **.wraps()** method
58//! to access this tracked count.
59//!
60//!```
61//! use bare_metal_modulo::*;
62//!
63//! let m = ModNum::new(7, 10);
64//! assert_eq!(m.a(), 7);
65//! assert_eq!(m.m(), 10);
66//!
67//! let n = ModNum::new(23, 17);
68//! assert_eq!(n.a(), 6);
69//! assert_eq!(n.m(), 17);
70//!
71//! let p = ModNum::new(-4, 3);
72//! assert_eq!(p.a(), 2);
73//! assert_eq!(p.m(), 3);
74//!
75//! let f = format!("{}", p);
76//! assert_eq!(f, "2 (mod 3)");
77//!
78//! // ModNumC variables indicate the modulo using a type annotation.
79//! let q: ModNumC<i32, 17> = ModNumC::new(23);
80//! assert_eq!(q, 6);
81//!
82//! // Create a new ModNum with the same modulo as an existing one.
83//! let r = p.with(8);
84//! assert_eq!(r.a(), 2);
85//! assert_eq!(r.m(), 3);
86//!
87//! let s = q.with(35);
88//! assert_eq!(s.a(), 1);
89//! assert_eq!(s.m(), 17);
90//!
91//! // Replace the value of the ModNum with a new value while keeping the modulo.
92//! let mut t = ModNum::new(3, 5);
93//! t.replace(12);
94//! assert_eq!(t.a(), 2);
95//!
96//! let mut v: ModNumC<usize,5> = ModNumC::new(3);
97//! v.replace(12);
98//! assert_eq!(v.a(), 2);
99//! ```
100//!
101//! # Arithmetic
102//! Addition, subtraction, and multiplication are all fully supported for both
103//! signed and unsigned integer types. The right-hand side may either be an integer of the
104//! corresponding type or another `ModNum`/`ModNumC` object.
105//!
106//! Unary negation and subtraction are supported for both signed and unsigned integers.
107//!
108//! ```
109//! use bare_metal_modulo::*;
110//!
111//! let mut m = ModNum::new(2, 5);
112//! m += 2;
113//! assert_eq!(m, ModNum::new(4, 5));
114//! m += 2;
115//! assert_eq!(m, ModNum::new(1, 5));
116//! m -= 3;
117//! assert_eq!(m, ModNum::new(3, 5));
118//! m *= 2;
119//! assert_eq!(m, ModNum::new(1, 5));
120//! m *= 2;
121//! assert_eq!(m, ModNum::new(2, 5));
122//! m *= 2;
123//! assert_eq!(m, ModNum::new(4, 5));
124//! m *= 2;
125//! assert_eq!(m, ModNum::new(3, 5));
126//! m = -m;
127//! assert_eq!(m, ModNum::new(2, 5));
128//!
129//! assert_eq!(m + ModNum::new(4, 5), ModNum::new(1, 5));
130//! m += ModNum::new(4, 5);
131//! assert_eq!(m, ModNum::new(1, 5));
132//!
133//! assert_eq!(m - ModNum::new(4, 5), ModNum::new(2, 5));
134//! m -= ModNum::new(4, 5);
135//! assert_eq!(m, ModNum::new(2, 5));
136//!
137//! assert_eq!(m * ModNum::new(3, 5), ModNum::new(1, 5));
138//! m *= ModNum::new(3, 5);
139//! assert_eq!(m, ModNum::new(1, 5));
140//!
141//! let mut m: ModNumC<isize,5> = ModNumC::new(2);
142//! m *= 3;
143//! assert_eq!(m, ModNumC::new(1));
144//!
145//! m += 1;
146//! assert_eq!(m, ModNumC::new(2));
147//!
148//! m += 3;
149//! assert_eq!(m, ModNumC::new(0));
150//!
151//! let m: ModNumC<isize,5> = ModNumC::default();
152//! assert_eq!(m, ModNumC::new(0));
153//! ```
154//!
155//! Saturating addition and subtraction are often useful relative to the modulus, so the
156//! [`num::traits::SaturatingAdd`](https://docs.rs/num-traits/0.2.14/num_traits/ops/saturating/trait.SaturatingAdd.html)
157//! and
158//! [`num::traits::SaturatingSub`](https://docs.rs/num-traits/0.2.14/num_traits/ops/saturating/trait.SaturatingSub.html)
159//! traits are implemented as well for `ModNum` and `ModNumC`.
160//!
161//! ```
162//! use bare_metal_modulo::*;
163//! use num::traits::SaturatingAdd;
164//! use num::traits::SaturatingSub;
165//!
166//! let m = ModNum::new(2, 5);
167//! assert_eq!(m.saturating_add(&ModNum::new(1, 5)), ModNum::new(3, 5));
168//! assert_eq!(m.saturating_add(&ModNum::new(2, 5)), ModNum::new(4, 5));
169//! assert_eq!(m.saturating_add(&ModNum::new(3, 5)), ModNum::new(4, 5));
170//! assert_eq!(m.saturating_sub(&ModNum::new(1, 5)), ModNum::new(1, 5));
171//! assert_eq!(m.saturating_sub(&ModNum::new(2, 5)), ModNum::new(0, 5));
172//! assert_eq!(m.saturating_sub(&ModNum::new(3, 5)), ModNum::new(0, 5));
173//!
174//! let m: ModNumC<i32, 5> = ModNumC::new(2);
175//! assert_eq!(m.saturating_add(&ModNumC::new(1)), ModNumC::new(3));
176//! assert_eq!(m.saturating_add(&ModNumC::new(2)), ModNumC::new(4));
177//! assert_eq!(m.saturating_add(&ModNumC::new(3)), ModNumC::new(4));
178//! assert_eq!(m.saturating_sub(&ModNumC::new(1)), ModNumC::new(1));
179//! assert_eq!(m.saturating_sub(&ModNumC::new(2)), ModNumC::new(0));
180//! assert_eq!(m.saturating_sub(&ModNumC::new(3)), ModNumC::new(0));
181//! ```
182//!
183//! Multiplicative inverse (using the **.inverse()** method) is supported for signed integers only.
184//! As inverses are only defined when **a** and **m** are relatively prime, **.inverse()** will
185//! return **None** when it is not possible to calculate.
186//!
187//! Division is defined in terms of the multiplicative inverse, so it is likewise only supported
188//! for signed integers, and will return **None** when the quotient does not exist. Assigned
189//! division (/=) will **panic** if the quotient does not exist.
190//!
191//! The **.pow()** method is fully supported for unsigned integer types. It also works for signed
192//! integer types, but it will **panic** if given a negative exponent. If negative exponents are
193//! possible, use **.pow_signed()**, which will return **None** if the result does not exist.
194//!
195//! ```
196//! use bare_metal_modulo::*;
197//! use num::traits::Pow;
198//!
199//! let m = ModNum::new(2, 5);
200//! assert_eq!(m.pow(2), ModNum::new(4, 5));
201//! assert_eq!(m.pow(3), ModNum::new(3, 5));
202//! assert_eq!(m.pow(4), ModNum::new(1, 5));
203//! assert_eq!(m.pow(5), ModNum::new(2, 5));
204//! assert_eq!(m.pow(6), ModNum::new(4, 5));
205//!
206//! assert_eq!(m.pow(ModNum::new(4, 5)), ModNum::new(1, 5));
207//!
208//! // Note: Very different result from m.pow(6)!
209//! assert_eq!(m.pow(ModNum::new(6, 5)), ModNum::new(2, 5));
210//!
211//! let i = m.inverse().unwrap();
212//! assert_eq!(m * i.a(), 1);
213//!
214//! assert_eq!(m.pow_signed(-2).unwrap(), ModNum::new(4, 5));
215//! assert_eq!(m.pow_signed(-3).unwrap(), ModNum::new(2, 5));
216//! assert_eq!(m.pow_signed(-4).unwrap(), ModNum::new(1, 5));
217//! assert_eq!(m.pow_signed(-5).unwrap(), ModNum::new(3, 5));
218//! assert_eq!(m.pow_signed(-6).unwrap(), ModNum::new(4, 5));
219//!
220//! let m = ModNum::new(6, 11);
221//! assert_eq!((m / 2).unwrap().a(), 3);
222//! assert_eq!((m / 4).unwrap().a(), 7);
223//! assert_eq!((m / 5).unwrap().a(), 10);
224//! assert_eq!((m / 6).unwrap().a(), 1);
225//! assert_eq!((m / 8).unwrap().a(), 9);
226//! assert_eq!(m / 0, None);
227//!
228//! assert_eq!((m / ModNum::new(2, 11)).unwrap(), ModNum::new(3, 11));
229//! assert_eq!((m / ModNum::new(4, 11)).unwrap(), ModNum::new(7, 11));
230//!
231//! let m: ModNumC<i32,5> = ModNumC::new(2);
232//!
233//! let i = m.inverse().unwrap();
234//! assert_eq!(m * i.a(), 1);
235//!
236//! assert_eq!(m.pow(2), ModNumC::new(4));
237//! assert_eq!(m.pow(3), ModNumC::new(3));
238//! assert_eq!(m.pow_signed(-2).unwrap(), ModNumC::new(4));
239//! assert_eq!(m.pow_signed(-3).unwrap(), ModNumC::new(2));
240//!
241//! let m: ModNumC<i32, 11> = ModNumC::new(6);
242//! assert_eq!((m / 2).unwrap().a(), 3);
243//! assert_eq!((m / 4).unwrap().a(), 7);
244//! assert_eq!(m / 0, None);
245//! ```
246//!
247//! The **==** operator can be used to compare two `ModNum`s, two `ModNumC`s or a `ModNum`/`ModNumC`
248//! and an integer of the corresponding type. In both cases, it represents congruence rather than
249//! strict equality.
250//!
251//! Inequalities are also defined over those same sets.
252//!
253//!```
254//! use bare_metal_modulo::*;
255//!
256//! let m = ModNum::new(2, 5);
257//! assert!(m == 2);
258//! assert!(m == 7);
259//! assert!(m == -3);
260//! assert!(m != 3);
261//!
262//! assert_eq!(m, ModNum::new(2, 5));
263//! assert_eq!(m, ModNum::new(7, 5));
264//! assert_eq!(m, ModNum::new(-3, 5));
265//!
266//! assert!(m < 4);
267//! assert!(m < 8);
268//! assert!(m > 1);
269//!
270//! let n = ModNum::new(6, 5);
271//! assert!(m > n);
272//! assert!(n < 2);
273//! assert!(n > 0);
274//!
275//! let m: ModNumC<i32,5> = ModNumC::new(2);
276//! assert!(m == 2);
277//! assert!(m == 7);
278//! assert!(m == -3);
279//! assert!(m != 3);
280//!
281//! assert!(m < 4);
282//! assert!(m < 8);
283//! assert!(m > 1);
284//!
285//! let n: ModNumC<i32,5> = ModNumC::new(6);
286//! assert!(m > n);
287//! assert!(n < 2);
288//! assert!(n > 0);
289//! ```
290//!
291//! # Iteration
292//! I originally created `ModNum` to facilitate cyclic iteration through a fixed-size array from an
293//! arbitrary starting point in a `no_std` environment. Its double-ended iterator facilitates
294//! both forward and backward iteration.
295//!
296//! ```
297//! use bare_metal_modulo::*;
298//!
299//! let forward: Vec<usize> = ModNum::new(2, 5).iter().map(|mn| mn.a()).collect();
300//! assert_eq!(forward, vec![2, 3, 4, 0, 1]);
301//!
302//! let reverse: Vec<usize> = ModNum::new(2, 5).iter().rev().map(|mn| mn.a()).collect();
303//! assert_eq!(reverse, vec![1, 0, 4, 3, 2]);
304//!
305//! let m: ModNumC<usize,5> = ModNumC::new(2);
306//! let forward: Vec<usize> = m.iter().map(|mn| mn.a()).collect();
307//! assert_eq!(forward, vec![2, 3, 4, 0, 1]);
308//!
309//! let m: ModNumC<usize,5> = ModNumC::new(2);
310//! let reverse: Vec<usize> = m.iter().rev().map(|mn| mn.a()).collect();
311//! assert_eq!(reverse, vec![1, 0, 4, 3, 2]);
312//! ```
313//!
314//! # Counting Wraps
315//!
316//! Modular numbers represent the remainder of an integer when divided by the modulo. If we also
317//! store the quotient in addition to the remainder, we have a count of the number of times a
318//! value had to "wrap around" during the calculation.
319//!
320//! For example, if we start with **8 (mod 17)** and add **42**, the result is **16 (mod 17)** with
321//! a wraparound of **2**.
322//!
323//! `WrapCountNum` objects store this wraparound value and make it available. It is tracked through
324//! both `+=` and `*=` for all supported numeric types.
325//!
326//! ```
327//! use bare_metal_modulo::*;
328//!
329//! let mut value = WrapCountNum::new(8, 17);
330//! value += 42;
331//! assert_eq!(value, 16);
332//! assert_eq!(value.wraps(), 2);
333//!
334//! value += 18;
335//! assert_eq!(value, 0);
336//! assert_eq!(value.wraps(), 4);
337//!
338//! value += 11;
339//! assert_eq!(value, 11);
340//! assert_eq!(value.wraps(), 4);
341//!
342//! value *= 5;
343//! assert_eq!(value, 4);
344//! assert_eq!(value.wraps(), 7);
345//! ```
346//!
347//! Because regular operations produce a new `WordCountNum` value every time, `value = value + x`
348//! only tracks wraps for a single operation, unlike `value += x`.
349//!
350//! ```
351//! use bare_metal_modulo::*;
352//! use num::traits::Pow;
353//!
354//! let mut value = WrapCountNum::new(8, 17);
355//! value = value + 42;
356//! assert_eq!(value, 16);
357//! assert_eq!(value.wraps(), 2);
358//!
359//! value = value + 18;
360//! assert_eq!(value, 0);
361//! assert_eq!(value.wraps(), 2);
362//!
363//! value = value + 11;
364//! assert_eq!(value, 11);
365//! assert_eq!(value.wraps(), 0);
366//!
367//! value = value * 5;
368//! assert_eq!(value, 4);
369//! assert_eq!(value.wraps(), 3);
370//!
371//! value = value.pow(3);
372//! assert_eq!(value, 13);
373//! assert_eq!(value.wraps(), 3);
374//! ```
375//!
376//! The `.pow_assign()` method enables wrap tracking when exponentiating.
377//!
378//! ```
379//! use bare_metal_modulo::*;
380//!
381//! let mut value = WrapCountNum::new(4, 17);
382//! value.pow_assign(3);
383//! assert_eq!(value, 13);
384//! assert_eq!(value.wraps(), 3);
385//!
386//! value += 6;
387//! assert_eq!(value, 2);
388//! assert_eq!(value.wraps(), 4);
389//!
390//! value.pow_assign(5);
391//! assert_eq!(value, 15);
392//! assert_eq!(value.wraps(), 5);
393//! ```
394//!
395//! Subtraction causes the wrap count to decrease. To simplify the implementation, `WrapCountNum`
396//! only defines subtraction on `Signed` numerical types.
397//!
398//! ```
399//! use bare_metal_modulo::*;
400//!
401//! let mut value = WrapCountNum::new(2, 5);
402//! value -= 8;
403//! assert_eq!(value, 4);
404//! assert_eq!(value.wraps(), -2);
405//!
406//! value -= 3;
407//! assert_eq!(value, 1);
408//! assert_eq!(value.wraps(), -2);
409//!
410//! value -= 13;
411//! assert_eq!(value, 3);
412//! assert_eq!(value.wraps(), -5);
413//!
414//! value += 8;
415//! assert_eq!(value, 1);
416//! assert_eq!(value.wraps(), -3);
417//! ```
418//!
419//! There is a `const generic` version as well, `WrapCountNumC`:
420//! ```
421//! use bare_metal_modulo::*;
422//!
423//! let mut value = WrapCountNumC::<isize,17>::new(8);
424//! value += 42;
425//! assert_eq!(value, 16);
426//! assert_eq!(value.wraps(), 2);
427//!
428//! value += 18;
429//! assert_eq!(value, 0);
430//! assert_eq!(value.wraps(), 4);
431//!
432//! value += 11;
433//! assert_eq!(value, 11);
434//! assert_eq!(value.wraps(), 4);
435//!
436//! value *= 5;
437//! assert_eq!(value, 4);
438//! assert_eq!(value.wraps(), 7);
439//!
440//! let mut value = WrapCountNumC::<isize, 8>::default();
441//! value += 15;
442//! assert_eq!(value, 7);
443//! assert_eq!(value.wraps(), 1);
444//! ```
445//!
446//! # Hash/BTree keys
447//! `ModNum` and `ModNumC` objects implement the `Ord` and `Hash` traits, and therefore can
448//! be included in `HashSet` and `BTreeSet` collections and serve
449//! as keys for `HashMap` and `BTreeMap` dictionaries.
450//!
451//! ```
452//! use bare_metal_modulo::*;
453//! use std::collections::HashSet;
454//!
455//! let m1: ModNumC<usize, 3> = ModNumC::new(2);
456//! let m2: ModNumC<usize, 3> = ModNumC::new(4);
457//! let m3: ModNumC<usize, 3> = ModNumC::new(5);
458//! assert_eq!(m1, m3);
459//! assert_eq!(m1 + 2, m2);
460//!
461//! let mut set = HashSet::new();
462//! set.insert(m1);
463//! set.insert(m2);
464//! assert_eq!(set.len(), 2);
465//! for m in [m1, m2, m3].iter() {
466//!     assert!(set.contains(m));
467//! }
468//! ```
469//!
470//! # Modular ranges rooted elsewhere than zero
471//!
472//! Occasionally of use is the ability to represent a modular range of values starting elsewhere than zero. To address
473//! this situation, `OffsetNum` and `OffsetNumC` are provided. Both support addition, subtraction, and iteration in a
474//! manner similar to the other types.
475//!
476//! `OffsetNum` objects can be created directly or from a `Range` or `RangeInclusive`:
477//!
478//! ```
479//! use bare_metal_modulo::*;
480//!
481//! let mut off = OffsetNum::<usize>::from(1..=10);
482//! assert_eq!(off.a(), 1);
483//! assert_eq!(off, 1);
484//! assert_eq!(off, 11); // Congruence equality with basic integer type
485//! assert_eq!(off.min_max(), (1, 10));
486//!
487//! for i in 1..=10 {
488//!     assert_eq!(off.a(), i);    
489//!     off += 1;
490//! }
491//! assert_eq!(off.a(), 1);
492//!
493//! for (i, n) in off.iter().enumerate() {
494//!     assert_eq!(n.a(), i + 1);
495//! }
496//!
497//! off -= 1;
498//! for i in (1..=10).rev() {
499//!     assert_eq!(off.a(), i);
500//!     off -= 1;
501//! }
502//! assert_eq!(off.a(), 10);
503//!
504//! let off_inclusive = OffsetNum::<usize>::from(1..=5);
505//! let off_exclusive = OffsetNum::<usize>::from(1..6);
506//! assert_eq!(off_inclusive, off_exclusive);
507//!
508//! ```
509//!
510//! Negative offsets are fine:
511//!
512//! ```
513//! use bare_metal_modulo::*;
514//!
515//! let mut off = OffsetNum::<isize>::new(-4, 3, -6);
516//! assert_eq!(off.a(), -4);
517//! off += 1;
518//! assert_eq!(off.a(), -6);
519//! ```
520//!
521//! Subtraction is subtle for `OffsetNum`. The subtrahend is normalized
522//! to the size of the `OffsetNum`'s range, but zero-based. It is then
523//! subtracted from the modulus and added to the minuend.
524//!
525//! ```
526//! use bare_metal_modulo::*;
527//!
528//! let mut off = OffsetNum::<usize>::from(3..=6);
529//! assert_eq!((off - 1).a(), 6);
530//! assert_eq!((off - 2).a(), 5);
531//! assert_eq!((off - 3).a(), 4);
532//! assert_eq!((off - 4).a(), 3);
533//!
534//! off += 3;
535//! assert_eq!((off - 1).a(), 5);
536//! assert_eq!((off - 2).a(), 4);
537//! assert_eq!((off - 3).a(), 3);
538//! assert_eq!((off - 4).a(), 6);
539//! ```
540//!
541//! `OffsetNumC` has three generic parameters:
542//! * Underlying integer type
543//! * Number of values in the range
544//! * Starting offset for the range
545//!
546//! ```
547//! use bare_metal_modulo::*;
548//!
549//! let mut off = OffsetNumC::<i16, 7, 5>::new(5);
550//! assert_eq!(off.a(), 5);
551//! assert_eq!(off.min_max(), (5, 11));
552//!
553//! for i in 0..7 {
554//!     assert_eq!(off.a(), 5 + i);
555//!     off += 1;
556//! }
557//! assert_eq!(off.a(), 5);
558//!
559//! let off_at_start = OffsetNumC::<i16, 7, 5>::default();
560//! assert_eq!(off_at_start, off);
561//!
562//! for (i, n) in off.iter().enumerate() {
563//!     assert_eq!(i + 5, n.a() as usize);
564//! }
565//! ```
566//!
567//! # Solving Modular Equations with the Chinese Remainder Theorem
568//! For the [2020 Advent of Code](https://adventofcode.com/2020)
569//! ([Day 13](https://adventofcode.com/2020/day/13) part 2),
570//! I extended `ModNum` to solve systems of modular equations, provided that each modular equation
571//! is represented using signed integers. My implementation is based on this
572//! [lucid](https://byorgey.wordpress.com/2020/02/15/competitive-programming-in-haskell-modular-arithmetic-part-1/)
573//! [explanation](https://byorgey.wordpress.com/2020/03/03/competitive-programming-in-haskell-modular-arithmetic-part-2/)
574//! by [Brent Yorgey](http://ozark.hendrix.edu/~yorgey/).
575//!
576//! The solver works directly with an iterator containing the `ModNum` objects corresponding to the
577//! modular equations to be solved, facilitating space-efficient solutions of a sequence coming
578//! from a stream. The examples below show two variants of the same system. The first example uses
579//! an iterator, and the second example retrieves the system from a `Vec`.
580//!
581//! Note that the solution value can rapidly become large, as shown in the third example. I
582//! recommend using **i128**, so as to maximize the set of solvable equations given this crate's
583//! constraint of using **Copy** integers only. While the solution to the third example is
584//! representable as an **i64**, some of the intermediate multiplications will overflow unless
585//! **i128** is used.
586//!
587//! ```
588//! use bare_metal_modulo::*;
589//!
590//! let mut values = (2..).zip((5..).step_by(2)).map(|(a, m)| ModNum::new(a, m)).take(3);
591//! let solution = ModNum::<i128>::chinese_remainder_system(values);
592//! assert_eq!(solution.unwrap().a(), 157);
593//!
594//! let values = vec![ModNum::new(2, 5), ModNum::new(3, 7), ModNum::new(4, 9)];
595//! let solution = ModNum::<i128>::chinese_remainder_system(values.iter().copied());
596//! assert_eq!(solution.unwrap().a(), 157);
597//!
598//!let mut values = [(0, 23), (28, 41), (20, 37), (398, 421), (11, 17), (15, 19), (6, 29),
599//!    (433, 487), (11, 13), (5, 137), (19, 49)]
600//!    .iter().copied().map(|(a, m)| ModNum::new(a, m));
601//! let solution = ModNum::<i128>::chinese_remainder_system(values);
602//! assert_eq!(solution.unwrap().a(), 762009420388013796);
603//! ```
604
605use core::cmp::Ordering;
606use core::fmt::{Debug, Display, Formatter};
607use core::mem;
608use core::ops::{
609    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Range, RangeInclusive, Sub, SubAssign,
610};
611use num::traits::{Pow, SaturatingAdd, SaturatingSub};
612use num::{FromPrimitive, Integer, NumCast, One, Signed, Zero};
613
614pub trait NumType:
615    Default
616    + Copy
617    + Clone
618    + Integer
619    + Display
620    + Debug
621    + NumCast
622    + FromPrimitive
623    + AddAssign
624    + SubAssign
625    + MulAssign
626    + DivAssign
627{
628}
629impl<
630        T: Default
631            + Copy
632            + Clone
633            + Integer
634            + Display
635            + Debug
636            + NumCast
637            + FromPrimitive
638            + AddAssign
639            + SubAssign
640            + MulAssign
641            + DivAssign,
642    > NumType for T
643{
644}
645
646pub trait MNum: Copy + Eq + PartialEq {
647    type Num: NumType;
648
649    fn a(&self) -> Self::Num;
650
651    fn m(&self) -> Self::Num;
652
653    fn with(&self, new_a: Self::Num) -> Self;
654
655    fn replace(&mut self, new_a: Self::Num) {
656        *self = self.with(new_a);
657    }
658
659    /// [Extended Euclidean Algorithm for Greatest Common Divisor](https://byorgey.wordpress.com/2020/02/15/competitive-programming-in-haskell-modular-arithmetic-part-1/) for GCD.
660    ///
661    /// This is my translation into Rust of [Brent Yorgey's Haskell implementation](https://byorgey.wordpress.com/2020/02/15/competitive-programming-in-haskell-modular-arithmetic-part-1/).
662    ///
663    /// Given two integers **a** and **b**, it returns three integer values:
664    /// - Greatest Common Divisor (**g**) of **a** and **b**
665    /// - Two additional values **x** and **y**, where **ax + by = g**
666    fn egcd(a: Self::Num, b: Self::Num) -> (Self::Num, Self::Num, Self::Num)
667    where
668        Self::Num: Signed,
669    {
670        if b == Self::Num::zero() {
671            (a.signum() * a, a.signum(), Self::Num::zero())
672        } else {
673            let (g, x, y) = Self::egcd(b, a.mod_floor(&b));
674            (g, y, x - (a / b) * y)
675        }
676    }
677
678    /// Returns the modular inverse, if it exists. Returns **None** if it does not exist.
679    ///
680    /// This is my translation into Rust of [Brent Yorgey's Haskell implementation](https://byorgey.wordpress.com/2020/02/15/competitive-programming-in-haskell-modular-arithmetic-part-1/).
681    ///
682    /// Let **m = ModNum::new(a, m)**, where **a** and **m** are relatively prime.
683    /// Then **m * m.inverse().unwrap().a()** is congruent to **1 (mod m)**.
684    ///
685    /// Returns None if **a** and **m** are not relatively prime.
686    fn inverse(&self) -> Option<Self>
687    where
688        Self::Num: Signed,
689    {
690        let (g, _, inv) = Self::egcd(self.m(), self.a());
691        if g == Self::Num::one() {
692            Some(self.with(inv))
693        } else {
694            None
695        }
696    }
697}
698
699/// Represents an integer **a (mod m)**
700#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
701pub struct ModNum<N> {
702    num: N,
703    modulo: N,
704}
705
706impl<N: NumType> MNum for ModNum<N> {
707    type Num = N;
708
709    fn a(&self) -> N {
710        self.num
711    }
712
713    fn m(&self) -> N {
714        self.modulo
715    }
716
717    fn with(&self, new_a: Self::Num) -> Self {
718        Self::new(new_a, self.m())
719    }
720}
721
722impl<N: NumType> ModNum<N> {
723    /// Creates a new integer **a (mod m)**
724    pub fn new(a: N, m: N) -> Self {
725        ModNum {
726            num: a.mod_floor(&m),
727            modulo: m,
728        }
729    }
730
731    /// Returns an iterator starting at **a (mod m)** and ending at **a - 1 (mod m)**
732    pub fn iter(&self) -> ModNumIterator<N, Self> {
733        ModNumIterator::new(*self)
734    }
735}
736
737impl<N: NumType + Signed> ModNum<N> {
738    /// Solves a pair of modular equations using the [Chinese Remainder Theorem](https://byorgey.wordpress.com/2020/03/03/competitive-programming-in-haskell-modular-arithmetic-part-2/).
739    ///
740    /// This is my translation into Rust of [Brent Yorgey's Haskell implementation](https://byorgey.wordpress.com/2020/03/03/competitive-programming-in-haskell-modular-arithmetic-part-2/).
741    ///
742    /// - `self` represents the modular equation **x = a (mod m)**
743    /// - `other` represents the modular equation **x = b (mod n)**
744    /// - It returns a `ModNum` corresponding to the equation **x = c (mod mn)** where
745    ///   **c** is congruent both to **a (mod m)** and **b (mod n)**
746    pub fn chinese_remainder(&self, other: ModNum<N>) -> ModNum<N> {
747        let (g, u, v) = ModNum::egcd(self.m(), other.m());
748        let c = (self.a() * other.m() * v + other.a() * self.m() * u).div_floor(&g);
749        ModNum::new(c, self.m() * other.m())
750    }
751
752    /// Solves a system of modular equations using `ModMum::chinese_remainder()`.
753    ///
754    /// Each equation in the system is an element of the **modnums** iterator parameter.
755    /// - Returns **None** if the iterator is empty.
756    /// - Returns **Some(element)** if the iterator has only one element.
757    /// - Returns **Some(solution)** if the iterator has two or more elements, where the solution is
758    ///   found by repeatedly calling **ModNum::chinese_remainder()**.
759    pub fn chinese_remainder_system<I: Iterator<Item = ModNum<N>>>(
760        mut modnums: I,
761    ) -> Option<ModNum<N>> {
762        modnums
763            .next()
764            .map(|start_num| modnums.fold(start_num, |a, b| a.chinese_remainder(b)))
765    }
766}
767
768macro_rules! derive_assign {
769    ($name:ty, $implname:ty, $rhs_type:ty, $methodname:ident {$symbol:tt} {$($generic:tt)*} {$($num_type_suffix:ident)?} {$($unwrap:tt)*}) => {
770        impl <N: NumType + $($num_type_suffix)?,$($generic)*> $implname for $name {
771            fn $methodname(&mut self, rhs: $rhs_type) {
772                *self = (*self $symbol rhs)$($unwrap)*;
773            }
774        }
775    }
776}
777
778macro_rules! derive_basic_modulo_arithmetic {
779    ($name:ty {$($generic:tt)*}) => {
780        /// Returns **true** if **other** is congruent to **self.a() (mod self.m())**
781        impl <N:NumType,$($generic)*> PartialEq<N> for $name {
782            fn eq(&self, other: &N) -> bool {
783                self.a() == self.with(*other).a()
784            }
785        }
786
787        impl <N:NumType,$($generic)*> PartialOrd<N> for $name {
788            fn partial_cmp(&self, other: &N) -> Option<Ordering> {
789                self.a().partial_cmp(other)
790            }
791        }
792
793        impl <N: NumType,$($generic)*> Add<N> for $name {
794            type Output = Self;
795
796            fn add(self, rhs: N) -> Self::Output {
797                self.with(self.a() + rhs)
798            }
799        }
800
801        impl <N: NumType,$($generic)*> Add<$name> for $name {
802            type Output = Self;
803
804            fn add(self, rhs: Self) -> Self::Output {
805                self + rhs.a()
806            }
807        }
808    }
809}
810
811macro_rules! derive_core_modulo_arithmetic {
812    ($name:ty {$($generic:tt)*}) => {
813
814        derive_basic_modulo_arithmetic! {
815            $name
816            {$($generic)*}
817        }
818
819        impl <N: NumType,$($generic)*> Mul<N> for $name {
820            type Output = Self;
821
822            fn mul(self, rhs: N) -> Self::Output {
823                self.with(self.a() * rhs)
824            }
825        }
826
827        impl <N: NumType,$($generic)*> Mul<$name> for $name {
828            type Output = Self;
829
830            fn mul(self, rhs: Self) -> Self::Output {
831                assert_eq!(self.m(), rhs.m());
832                self * rhs.a()
833            }
834        }
835
836        impl <N: NumType + Signed,$($generic)*> Div<N> for $name {
837            type Output = Option<Self>;
838
839            fn div(self, rhs: N) -> Self::Output {
840                self.with(rhs).inverse().map(|inv| self * inv.a())
841            }
842        }
843
844        impl <N: NumType + Signed,$($generic)*> Div<$name> for $name {
845            type Output = Option<Self>;
846
847            fn div(self, rhs: Self) -> Self::Output {
848                self / rhs.a()
849            }
850        }
851
852        impl <N: NumType,$($generic)*> Pow<N> for $name {
853            type Output = Self;
854
855            /// Returns a^rhs (mod m), for rhs >= 0.
856            /// Implements efficient modular exponentiation by [repeated squaring](https://byorgey.wordpress.com/2020/02/15/competitive-programming-in-haskell-modular-arithmetic-part-1/).
857            ///
858            /// Panics if rhs < 0. If negative exponents are possible, use .pow_signed()
859            fn pow(self, rhs: N) -> Self::Output {
860                if rhs < N::zero() {
861                    panic!("Negative exponentiation undefined for ModNum.pow(). Try .pow_signed() instead.")
862                } else if rhs == N::zero() {
863                    self.with(N::one())
864                } else {
865                    let mut r = self.pow(rhs.div_floor(&(N::one() + N::one())));
866                    r *= r;
867                    if rhs.is_odd() {
868                        r *= self;
869                    }
870                    r
871                }
872            }
873        }
874
875        impl <N: NumType,$($generic)*> Pow<$name> for $name {
876            type Output = Self;
877
878            fn pow(self, rhs: Self) -> Self::Output {
879                self.pow(rhs.a())
880            }
881        }
882
883        /// Exponentiates safely with negative exponents - if the inverse is undefined, it returns
884        /// `None`, otherwise it returns `Some(self.pow(rhs))`.
885        impl <N: NumType + Signed,$($generic)*> $name {
886            pub fn pow_signed(&self, rhs: N) -> Option<Self> {
887                if rhs < N::zero() {
888                    self.pow(-rhs).inverse()
889                } else {
890                    Some(self.pow(rhs))
891                }
892            }
893        }
894    }
895}
896
897macro_rules! derive_add_assign_sub {
898    ($name:ty {$($generic:tt)*}) => {
899        derive_assign! {
900            $name, AddAssign<N>, N, add_assign {+} {$($generic)*} {} {}
901        }
902
903        derive_assign! {
904            $name, AddAssign<$name>, $name, add_assign {+} {$($generic)*} {} {}
905        }
906
907        impl <N: NumType,$($generic)*> Neg for $name {
908            type Output = Self;
909
910            fn neg(self) -> Self::Output {
911                self.with(self.m() - self.num)
912            }
913        }
914
915        impl <N: NumType,$($generic)*> Sub<N> for $name {
916            type Output = Self;
917
918            fn sub(self, rhs: N) -> Self::Output {
919                let offset = rhs.mod_floor(&self.m());
920                let negated_offset = self.m() - offset;
921                self + negated_offset
922            }
923        }
924
925        impl <N: NumType,$($generic)*> Sub<$name> for $name {
926            type Output = Self;
927
928            fn sub(self, rhs: Self) -> Self::Output {
929                self - rhs.a()
930            }
931        }
932
933        derive_assign! {
934            $name, SubAssign<N>, N, sub_assign {-} {$($generic)*} {} {}
935        }
936
937        derive_assign! {
938            $name, SubAssign<$name>, $name, sub_assign {-} {$($generic)*} {} {}
939        }
940    }
941}
942
943macro_rules! derive_modulo_arithmetic {
944    ($name:ty {$($generic:tt)*}) => {
945
946        derive_core_modulo_arithmetic! {
947            $name
948            {$($generic)*}
949        }
950
951        impl <N:NumType,$($generic)*> Display for $name {
952            fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
953                write!(f, "{} (mod {})", self.a(), self.m())
954            }
955        }
956
957        derive_add_assign_sub! {
958            $name
959            {$($generic)*}
960        }
961
962        derive_assign! {
963            $name, MulAssign<N>, N, mul_assign {*} {$($generic)*} {} {}
964        }
965
966        derive_assign! {
967            $name, MulAssign<$name>, $name, mul_assign {*} {$($generic)*} {} {}
968        }
969
970        derive_assign! {
971            $name, DivAssign<N>, N, div_assign {/} {$($generic)*} {Signed} {.unwrap()}
972        }
973
974        derive_assign! {
975            $name, DivAssign<$name>, $name, div_assign {/} {$($generic)*} {Signed} {.unwrap()}
976        }
977
978        impl <N: NumType, $($generic)*> SaturatingAdd for $name {
979            fn saturating_add(&self, v: &Self) -> Self {
980                if self.a() + v.a() >= self.m() {
981                    self.with(self.m() - N::one())
982                } else {
983                    *self + *v
984                }
985            }
986        }
987
988        impl <N: NumType, $($generic)*> SaturatingSub for $name {
989            fn saturating_sub(&self, v: &Self) -> Self {
990                if self.a() < v.a() {
991                    self.with(N::zero())
992                } else {
993                    *self - *v
994                }
995            }
996        }
997    }
998}
999
1000derive_modulo_arithmetic! {
1001    ModNum<N> {}
1002}
1003
1004/// A double-ended iterator that starts with the ModNum upon which it is invoked,
1005/// making a complete traversal of the elements in that ModNum's ring.
1006#[derive(Debug)]
1007pub struct ModNumIterator<N: NumType, M: MNum<Num = N> + Add<N, Output = M> + Sub<N, Output = M>> {
1008    next: M,
1009    next_back: M,
1010    finished: bool,
1011}
1012
1013impl<N: NumType, M: MNum<Num = N> + Add<N, Output = M> + Sub<N, Output = M>> ModNumIterator<N, M> {
1014    pub fn new(mn: M) -> Self {
1015        ModNumIterator {
1016            next: mn,
1017            next_back: mn - N::one(),
1018            finished: false,
1019        }
1020    }
1021}
1022
1023fn update<
1024    N: NumType,
1025    M: MNum<Num = N> + Add<N, Output = M> + Sub<N, Output = M>,
1026    F: Fn(&M, N) -> M,
1027>(
1028    finished: &mut bool,
1029    update: &mut M,
1030    updater: F,
1031    target: M,
1032) -> Option<<ModNumIterator<N, M> as Iterator>::Item> {
1033    if *finished {
1034        None
1035    } else {
1036        let mut future = updater(update, N::one());
1037        if future == updater(&target, N::one()) {
1038            *finished = true;
1039        }
1040        mem::swap(&mut future, update);
1041        Some(future)
1042    }
1043}
1044
1045impl<N: NumType, M: MNum<Num = N> + Add<N, Output = M> + Sub<N, Output = M>> Iterator
1046    for ModNumIterator<N, M>
1047{
1048    type Item = M;
1049
1050    fn next(&mut self) -> Option<Self::Item> {
1051        update(
1052            &mut self.finished,
1053            &mut self.next,
1054            |m, u| *m + u,
1055            self.next_back,
1056        )
1057    }
1058}
1059
1060impl<N: NumType, M: MNum<Num = N> + Add<N, Output = M> + Sub<N, Output = M>> DoubleEndedIterator
1061    for ModNumIterator<N, M>
1062{
1063    fn next_back(&mut self) -> Option<Self::Item> {
1064        update(
1065            &mut self.finished,
1066            &mut self.next_back,
1067            |m, u| *m - u,
1068            self.next,
1069        )
1070    }
1071}
1072
1073/// Represents an integer **a (mod M)**
1074#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
1075pub struct ModNumC<N: FromPrimitive, const M: usize> {
1076    num: N,
1077}
1078
1079impl<N: NumType, const M: usize> MNum for ModNumC<N, M> {
1080    type Num = N;
1081
1082    fn a(&self) -> Self::Num {
1083        self.num
1084    }
1085
1086    fn m(&self) -> Self::Num {
1087        N::from_usize(M).unwrap()
1088    }
1089
1090    fn with(&self, new_a: Self::Num) -> Self {
1091        Self::new(new_a)
1092    }
1093}
1094
1095impl<N: NumType, const M: usize> ModNumC<N, M> {
1096    pub fn new(num: N) -> Self {
1097        let mut result = ModNumC { num };
1098        result.num = result.num.mod_floor(&result.m());
1099        result
1100    }
1101
1102    /// Returns an iterator starting at **a (mod m)** and ending at **a - 1 (mod m)**
1103    pub fn iter(&self) -> ModNumIterator<N, Self> {
1104        ModNumIterator::new(*self)
1105    }
1106}
1107
1108derive_modulo_arithmetic! {
1109    ModNumC<N,M> {const M: usize}
1110}
1111
1112/// Represents an integer **a (mod m)**, storing the number of wraparounds of **a** as well.
1113#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
1114pub struct WrapCountNum<N: NumType> {
1115    num: N,
1116    modulo: N,
1117    wraps: N,
1118}
1119
1120impl<N: NumType> MNum for WrapCountNum<N> {
1121    type Num = N;
1122
1123    fn a(&self) -> Self::Num {
1124        self.num
1125    }
1126
1127    fn m(&self) -> Self::Num {
1128        self.modulo
1129    }
1130
1131    fn with(&self, new_a: Self::Num) -> Self {
1132        Self::new(new_a, self.m())
1133    }
1134}
1135
1136impl<N: NumType> WrapCountNum<N> {
1137    /// Creates a new integer **a (mod m)**, storing the number of wraparounds
1138    /// of **a** as well.
1139    pub fn new(a: N, modulo: N) -> Self {
1140        let (wraps, num) = a.div_mod_floor(&modulo);
1141        WrapCountNum { num, modulo, wraps }
1142    }
1143
1144    pub fn with_wraps(&self, a: N, wraps: N) -> Self {
1145        WrapCountNum {
1146            num: a,
1147            modulo: self.modulo,
1148            wraps,
1149        }
1150    }
1151}
1152
1153macro_rules! derive_wrap_assign {
1154    ($name:ty, $implname:ty, $rhs_type:ty, $methodname:ident {$symbol:tt} {$($generic:tt)*} {$($num_type_suffix:ident)?} {$($unwrap:tt)*}) => {
1155        impl <N: NumType + $($num_type_suffix)?,$($generic)*> $implname for $name {
1156            fn $methodname(&mut self, rhs: $rhs_type) {
1157                let result = (*self $symbol rhs)$($unwrap)*;
1158                self.num = result.num;
1159                self.wraps += result.wraps;
1160            }
1161        }
1162    }
1163}
1164
1165macro_rules! derive_wrap_modulo_arithmetic {
1166    ($name:ty {$($generic:tt)*}) => {
1167        derive_core_modulo_arithmetic! {$name {$($generic)*}}
1168
1169        impl <N: NumType,$($generic)*> $name {
1170            /// Returns the total number of wraparounds counted when calculating this value.
1171            pub fn wraps(&self) -> N {
1172                self.wraps
1173            }
1174
1175            pub fn pow_assign(&mut self, rhs: N) {
1176                let result = self.pow(rhs);
1177                self.num = result.num;
1178                self.wraps += result.wraps;
1179            }
1180        }
1181
1182        impl <N: NumType + Signed,$($generic)*> $name {
1183            /// Returns an iterator starting at **a (mod m)** and ending at **a - 1 (mod m)**
1184            pub fn iter(&self) -> ModNumIterator<N,Self> {
1185                ModNumIterator::new(*self)
1186            }
1187
1188            /// Computes and assigns to `self` the value `self.pow_signed(rhs)`. If the result
1189            /// is undefined due to `rhs` not having a defined inverse, `self` will be unchanged.
1190            pub fn pow_assign_signed(&mut self, rhs: N) {
1191                let result = self.pow_signed(rhs);
1192                if let Some(result) = result {
1193                    self.num = result.num;
1194                    self.wraps += result.wraps;
1195                }
1196            }
1197        }
1198
1199        impl <N: NumType,$($generic)*> Display for $name {
1200            fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
1201                write!(f, "{} (mod {}) (wrap {})", self.a(), self.m(), self.wraps)
1202            }
1203        }
1204
1205        impl <N: NumType + Signed,$($generic)*> Neg for $name {
1206            type Output = Self;
1207
1208            fn neg(self) -> Self::Output {
1209                self.with_wraps(-self.num, -self.wraps)
1210            }
1211        }
1212
1213        impl <N: NumType + Signed,$($generic)*> Sub<N> for $name {
1214            type Output = Self;
1215
1216            fn sub(self, rhs: N) -> Self::Output {
1217                self.with(self.num - rhs)
1218            }
1219        }
1220
1221        impl <N: NumType + Signed,$($generic)*> Sub<$name> for $name {
1222            type Output = Self;
1223
1224            fn sub(self, rhs: $name) -> Self::Output {
1225                self - rhs.a()
1226            }
1227        }
1228
1229        derive_wrap_assign! {
1230            $name, AddAssign<N>, N, add_assign {+} {$($generic)*} {} {}
1231        }
1232
1233        derive_wrap_assign! {
1234            $name, AddAssign<$name>, $name, add_assign {+} {$($generic)*} {} {}
1235        }
1236
1237        derive_wrap_assign! {
1238            $name, SubAssign<N>, N, sub_assign {-} {$($generic)*} {Signed} {}
1239        }
1240
1241        derive_wrap_assign! {
1242            $name, SubAssign<$name>, $name, sub_assign {-} {$($generic)*} {Signed} {}
1243        }
1244
1245        derive_wrap_assign! {
1246            $name, MulAssign<N>, N, mul_assign {*} {$($generic)*} {} {}
1247        }
1248
1249        derive_wrap_assign! {
1250            $name, MulAssign<$name>, $name, mul_assign {*} {$($generic)*} {} {}
1251        }
1252
1253        derive_wrap_assign! {
1254            $name, DivAssign<N>, N, div_assign {/} {$($generic)*} {Signed} {.unwrap()}
1255        }
1256
1257        derive_wrap_assign! {
1258            $name, DivAssign<$name>, $name, div_assign {/} {$($generic)*} {Signed} {.unwrap()}
1259        }
1260    }
1261}
1262
1263derive_wrap_modulo_arithmetic! {
1264    WrapCountNum<N> {}
1265}
1266
1267/// Represents an integer **a (mod M)**, storing the number of wraparounds of **a** as well.
1268#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
1269pub struct WrapCountNumC<N: FromPrimitive, const M: usize> {
1270    num: N,
1271    wraps: N,
1272}
1273
1274impl<N: NumType, const M: usize> MNum for WrapCountNumC<N, M> {
1275    type Num = N;
1276
1277    fn a(&self) -> Self::Num {
1278        self.num
1279    }
1280
1281    fn m(&self) -> Self::Num {
1282        N::from_usize(M).unwrap()
1283    }
1284
1285    fn with(&self, new_a: Self::Num) -> Self {
1286        Self::new(new_a)
1287    }
1288}
1289
1290impl<N: NumType, const M: usize> WrapCountNumC<N, M> {
1291    /// Creates a new integer **a (mod m)**, storing the number of wraparounds
1292    /// of **a** as well.
1293    pub fn new(a: N) -> Self {
1294        let mut result = WrapCountNumC {
1295            num: a,
1296            wraps: N::zero(),
1297        };
1298        let (wraps, num) = a.div_mod_floor(&result.m());
1299        result.num = num;
1300        result.wraps = wraps;
1301        result
1302    }
1303
1304    pub fn with_wraps(&self, a: N, wraps: N) -> Self {
1305        WrapCountNumC { num: a, wraps }
1306    }
1307}
1308
1309derive_wrap_modulo_arithmetic! {
1310    WrapCountNumC<N,M> {const M: usize}
1311}
1312
1313/// Represents an integer bounded between two values.
1314#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
1315pub struct OffsetNum<N: FromPrimitive> {
1316    num: N,
1317    modulo: N,
1318    offset: N,
1319}
1320
1321fn offset_init<N: NumType>(num: N, modulo: N, offset: N) -> N {
1322    let mut num = num;
1323    while num < offset {
1324        num += modulo;
1325    }
1326    num - offset
1327}
1328
1329impl<N: NumType> OffsetNum<N> {
1330    pub fn new(num: N, modulo: N, offset: N) -> Self {
1331        let num = offset_init(num, modulo, offset);
1332        let mut result = OffsetNum {
1333            num,
1334            modulo,
1335            offset,
1336        };
1337        result.num = result.num.mod_floor(&result.m());
1338        result
1339    }
1340
1341    /// Returns an iterator starting at **a (mod m)** and ending at **a - 1 (mod m)**
1342    pub fn iter(&self) -> ModNumIterator<N, Self> {
1343        ModNumIterator::new(*self)
1344    }
1345
1346    /// Returns the minimum and maximum possible values.
1347    pub fn min_max(&self) -> (N, N) {
1348        (self.offset, self.offset + self.modulo - N::one())
1349    }
1350}
1351
1352impl<N: NumType> From<RangeInclusive<N>> for OffsetNum<N> {
1353    fn from(r: RangeInclusive<N>) -> Self {
1354        Self::new(*r.start(), *r.end() - *r.start() + N::one(), *r.start())
1355    }
1356}
1357
1358impl<N: NumType> From<Range<N>> for OffsetNum<N> {
1359    fn from(r: Range<N>) -> Self {
1360        Self::new(r.start, r.end - r.start, r.start)
1361    }
1362}
1363
1364impl<N: NumType> MNum for OffsetNum<N> {
1365    type Num = N;
1366
1367    fn a(&self) -> N {
1368        self.num + self.offset
1369    }
1370
1371    fn m(&self) -> N {
1372        self.modulo
1373    }
1374
1375    fn with(&self, new_a: Self::Num) -> Self {
1376        Self::new(new_a, self.m(), self.offset)
1377    }
1378}
1379
1380derive_basic_modulo_arithmetic! {
1381    OffsetNum<N> {}
1382}
1383
1384derive_add_assign_sub! {
1385    OffsetNum<N> {}
1386}
1387
1388/// Represents an integer bounded between two values.
1389#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
1390pub struct OffsetNumC<N: FromPrimitive, const M: usize, const O: isize> {
1391    num: N,
1392}
1393
1394impl<N: NumType, const M: usize, const O: isize> MNum for OffsetNumC<N, M, O> {
1395    type Num = N;
1396
1397    fn a(&self) -> Self::Num {
1398        self.num + N::from_isize(O).unwrap()
1399    }
1400
1401    fn m(&self) -> Self::Num {
1402        N::from_usize(M).unwrap()
1403    }
1404
1405    fn with(&self, new_a: Self::Num) -> Self {
1406        Self::new(new_a)
1407    }
1408}
1409
1410impl<N: NumType, const M: usize, const O: isize> OffsetNumC<N, M, O> {
1411    pub fn new(num: N) -> Self {
1412        let num = offset_init(num, N::from_usize(M).unwrap(), N::from_isize(O).unwrap());
1413        let mut result = OffsetNumC { num };
1414        result.num = result.num.mod_floor(&result.m());
1415        result
1416    }
1417
1418    /// Returns an iterator starting at **a (mod m)** and ending at **a - 1 (mod m)**
1419    pub fn iter(&self) -> ModNumIterator<N, Self> {
1420        ModNumIterator::new(*self)
1421    }
1422
1423    /// Returns the minimum and maximum possible values.
1424    pub fn min_max(&self) -> (N, N) {
1425        (
1426            N::from_isize(O).unwrap(),
1427            N::from_isize(O + isize::from_usize(M).unwrap() - 1).unwrap(),
1428        )
1429    }
1430}
1431
1432derive_basic_modulo_arithmetic! {
1433    OffsetNumC<N,M,O> {const M: usize, const O: isize}
1434}
1435
1436derive_add_assign_sub! {
1437    OffsetNumC<N,M,O> {const M: usize, const O: isize}
1438}
1439
1440#[cfg(test)]
1441mod tests {
1442    extern crate alloc;
1443    use super::*;
1444    use alloc::vec;
1445    use alloc::vec::Vec;
1446
1447    #[test]
1448    fn test_neg() {
1449        let m = ModNum::new(-2, 5);
1450        assert_eq!(m, ModNum::new(3, 5));
1451    }
1452
1453    #[test]
1454    fn test_negation() {
1455        for n in 0..5 {
1456            let m = ModNum::new(n, 5);
1457            let n = -m;
1458            assert_eq!(m + n.a(), 0);
1459        }
1460    }
1461
1462    #[test]
1463    fn test_sub() {
1464        for (n, m, sub, target) in vec![(1, 5, 2, 4)] {
1465            assert_eq!(ModNum::new(n, m) - sub, ModNum::new(target, m));
1466        }
1467    }
1468
1469    #[test]
1470    fn test_neg_c() {
1471        let m: ModNumC<i32, 5> = ModNumC::new(-2);
1472        assert_eq!(m, ModNumC::new(3));
1473    }
1474
1475    #[test]
1476    fn test_negation_c() {
1477        let s: ModNumC<i32, 5> = ModNumC::new(0);
1478        for m in s.iter() {
1479            let n = -m;
1480            assert_eq!(m + n, 0);
1481        }
1482    }
1483
1484    #[test]
1485    fn test_sub_c() {
1486        for (n, sub, target) in vec![(1, 2, 4), (2, 1, 1), (4, 1, 3), (4, 4, 0), (2, 5, 2)] {
1487            let n: ModNumC<i32, 5> = ModNumC::new(n);
1488            assert_eq!(n - sub, target);
1489        }
1490    }
1491
1492    #[test]
1493    fn test_congruence_c() {
1494        let m: ModNumC<i32, 5> = ModNumC::new(2);
1495        for c in (-13..13).step_by(5) {
1496            assert_eq!(m, c);
1497            for i in -2..=2 {
1498                if i == 0 {
1499                    assert_eq!(m, c);
1500                } else {
1501                    assert_ne!(m, c + i);
1502                }
1503            }
1504        }
1505    }
1506
1507    #[test]
1508    fn test_iter_up() {
1509        assert_eq!(
1510            vec![2, 3, 4, 0, 1],
1511            ModNum::new(2, 5)
1512                .iter()
1513                .map(|m: ModNum<usize>| m.a())
1514                .collect::<Vec<usize>>()
1515        )
1516    }
1517
1518    #[test]
1519    fn test_iter_down() {
1520        assert_eq!(
1521            vec![1, 0, 4, 3, 2],
1522            ModNum::new(2, 5)
1523                .iter()
1524                .rev()
1525                .map(|m: ModNum<usize>| m.a())
1526                .collect::<Vec<usize>>()
1527        )
1528    }
1529
1530    #[test]
1531    fn test_iter_up_w() {
1532        assert_eq!(
1533            vec![2, 3, 4, 0, 1],
1534            WrapCountNumC::<isize, 5>::new(2)
1535                .iter()
1536                .map(|m: WrapCountNumC<isize, 5>| m.a())
1537                .collect::<Vec<isize>>()
1538        )
1539    }
1540
1541    #[test]
1542    fn test_iter_down_w() {
1543        assert_eq!(
1544            vec![1, 0, 4, 3, 2],
1545            WrapCountNumC::<isize, 5>::new(2)
1546                .iter()
1547                .rev()
1548                .map(|m: WrapCountNumC<isize, 5>| m.a())
1549                .collect::<Vec<isize>>()
1550        )
1551    }
1552
1553    #[test]
1554    fn test_inverse() {
1555        for a in 0..13 {
1556            let m = ModNum::new(a, 13);
1557            let inv = m.inverse();
1558            if a == 0 {
1559                assert!(inv.is_none());
1560            } else {
1561                assert_eq!(m * inv.unwrap().a(), 1);
1562            }
1563        }
1564    }
1565
1566    #[test]
1567    fn test_assign() {
1568        let mut m = ModNum::new(2, 5);
1569        m += 2;
1570        assert_eq!(m, ModNum::new(4, 5));
1571        m += 2;
1572        assert_eq!(m, ModNum::new(1, 5));
1573        m -= 3;
1574        assert_eq!(m, ModNum::new(3, 5));
1575        m *= 2;
1576        assert_eq!(m, ModNum::new(1, 5));
1577        m *= 2;
1578        assert_eq!(m, ModNum::new(2, 5));
1579        m *= 2;
1580        assert_eq!(m, ModNum::new(4, 5));
1581        m *= 2;
1582        assert_eq!(m, ModNum::new(3, 5));
1583    }
1584
1585    #[test]
1586    fn test_assign_2() {
1587        let mut m = ModNum::new(2, 5);
1588        m += ModNum::new(2, 5);
1589        assert_eq!(m, ModNum::new(4, 5));
1590        m += ModNum::new(2, 5);
1591        assert_eq!(m, ModNum::new(1, 5));
1592        m -= ModNum::new(3, 5);
1593        assert_eq!(m, ModNum::new(3, 5));
1594        m *= ModNum::new(2, 5);
1595        assert_eq!(m, ModNum::new(1, 5));
1596        m *= ModNum::new(2, 5);
1597        assert_eq!(m, ModNum::new(2, 5));
1598        m *= ModNum::new(2, 5);
1599        assert_eq!(m, ModNum::new(4, 5));
1600        m *= ModNum::new(2, 5);
1601        assert_eq!(m, ModNum::new(3, 5));
1602    }
1603
1604    #[test]
1605    fn test_chinese_remainder() {
1606        let x = ModNum::new(2, 5);
1607        let y = ModNum::new(3, 7);
1608        assert_eq!(x.chinese_remainder(y), ModNum::new(17, 35));
1609    }
1610
1611    #[test]
1612    fn test_chinese_systems() {
1613        // Examples from 2020 Advent of Code, Day 13 Puzzle 2.
1614        let systems = vec![
1615            (vec![(2, 5), (3, 7), (4, 9)], 157),
1616            (vec![(0, 17), (-2, 13), (-3, 19)], 3417),
1617            (vec![(0, 67), (-1, 7), (-2, 59), (-3, 61)], 754018),
1618            (vec![(0, 67), (-2, 7), (-3, 59), (-4, 61)], 779210),
1619            (vec![(0, 67), (-1, 7), (-3, 59), (-4, 61)], 1261476),
1620            (vec![(0, 1789), (-1, 37), (-2, 47), (-3, 1889)], 1202161486),
1621        ];
1622        for (system, goal) in systems {
1623            let mut equations = system
1624                .iter()
1625                .copied()
1626                .map(|(a, m)| ModNum::<i128>::new(a, m));
1627            assert_eq!(
1628                ModNum::chinese_remainder_system(&mut equations)
1629                    .unwrap()
1630                    .a(),
1631                goal
1632            );
1633        }
1634    }
1635
1636    #[test]
1637    fn test_congruence() {
1638        let m = ModNum::new(2, 5);
1639        for c in (-13..13).step_by(5) {
1640            assert_eq!(m, c);
1641            for i in -2..=2 {
1642                if i == 0 {
1643                    assert_eq!(m, c);
1644                } else {
1645                    assert_ne!(m, c + i);
1646                }
1647            }
1648        }
1649    }
1650
1651    #[test]
1652    fn test_division() {
1653        let m = ModNum::new(6, 11);
1654        for undefined in [0, 11].iter() {
1655            assert_eq!(m / *undefined, None);
1656        }
1657        for (divisor, quotient) in [(1, 6), (2, 3), (4, 7), (5, 10), (8, 9)].iter() {
1658            for (d, q) in [(divisor, quotient), (quotient, divisor)].iter() {
1659                let result = (m / **d).unwrap();
1660                assert_eq!(result * **d, m);
1661                assert_eq!(result.a(), **q);
1662            }
1663        }
1664    }
1665
1666    #[test]
1667    fn test_pow() {
1668        let m = ModNum::new(2, 5);
1669        for (exp, result) in (2..).zip([4, 3, 1, 2].iter().cycle()).take(20) {
1670            assert_eq!(m.pow(exp).a(), *result);
1671        }
1672    }
1673
1674    #[test]
1675    fn test_big() {
1676        let mut values = [
1677            (0, 23),
1678            (28, 41),
1679            (20, 37),
1680            (398, 421),
1681            (11, 17),
1682            (15, 19),
1683            (6, 29),
1684            (433, 487),
1685            (11, 13),
1686            (5, 137),
1687            (19, 49),
1688        ]
1689        .iter()
1690        .copied()
1691        .map(|(a, m)| ModNum::new(a, m));
1692        let solution = ModNum::<i128>::chinese_remainder_system(&mut values)
1693            .unwrap()
1694            .a();
1695        assert_eq!(solution, 762009420388013796);
1696    }
1697
1698    #[test]
1699    fn test_negative_exp() {
1700        let m = ModNum::new(2, 5);
1701        for (exp, result) in (2..).map(|n| -n).zip([4, 2, 1, 3].iter().cycle()).take(20) {
1702            assert_eq!(m.pow_signed(exp).unwrap().a(), *result);
1703        }
1704    }
1705
1706    #[test]
1707    fn test_wrapping() {
1708        let mut w: WrapCountNumC<usize, 5> = WrapCountNumC::new(4);
1709        w *= 4;
1710        assert_eq!(w, 1);
1711        assert_eq!(w.wraps(), 3);
1712        w += 9;
1713        assert_eq!(w, 0);
1714        assert_eq!(w.wraps(), 5);
1715    }
1716
1717    #[test]
1718    fn test_offset() {
1719        let mut off = OffsetNumC::<i16, 7, 5>::new(3);
1720        assert_eq!(off.a(), 10);
1721        off += 1;
1722        assert_eq!(off.a(), 11);
1723        off += 1;
1724        assert_eq!(off.a(), 5);
1725        off += 1;
1726        assert_eq!(off.a(), 6);
1727    }
1728
1729    #[test]
1730    fn test_offset_2() {
1731        let off = OffsetNumC::<usize, 5, 2>::new(1);
1732        assert_eq!(off.a(), 6);
1733        for test in [507, 512, 502, 497, 22] {
1734            let bigoff = OffsetNumC::<usize, 5, 502>::new(test);
1735            assert_eq!(bigoff.a(), 502);
1736        }
1737    }
1738
1739    #[test]
1740    fn test_offset_3() {
1741        let mut off = OffsetNum::<usize>::from(1..=10);
1742        for i in 1..=10 {
1743            assert_eq!(off.a(), i);
1744            off += 1;
1745            println!("{off:?}");
1746        }
1747        assert_eq!(off.a(), 1);
1748    }
1749
1750    #[test]
1751    fn test_offset_4() {
1752        let mut off = OffsetNumC::<usize, 10, 1>::new(1);
1753        for i in 1..=10 {
1754            assert_eq!(off.a(), i);
1755            off += 1;
1756            println!("{off:?}");
1757        }
1758        assert_eq!(off.a(), 1);
1759    }
1760
1761    #[test]
1762    fn test_offset_5() {
1763        let mut off = OffsetNum::<usize>::from(1..=10);
1764        assert_eq!(off.a(), 1);
1765        assert_eq!(off, 1);
1766        assert_eq!(off, 11); // Congruence equality with basic integer type
1767        assert_eq!(off.min_max(), (1, 10));
1768
1769        for i in 1..=10 {
1770            assert_eq!(off.a(), i);
1771            off += 1;
1772        }
1773        assert_eq!(off.a(), 1);
1774
1775        for (i, n) in off.iter().enumerate() {
1776            assert_eq!(n.a(), i + 1);
1777        }
1778
1779        off -= 1;
1780        for i in (1..=10).rev() {
1781            assert_eq!(off.a(), i);
1782            off -= 1;
1783        }
1784        assert_eq!(off.a(), 10);
1785    }
1786
1787    #[test]
1788    fn test_offset_6() {
1789        // From https://github.com/gjf2a/bare_metal_modulo/issues/10
1790        let x = OffsetNum::new(5, 9, 1);
1791        let one = OffsetNum::new(1, 9, 1);
1792
1793        let y = x + one;
1794        let z = x - one;
1795        assert_eq!(x.a() + 1, y.a());
1796        assert_eq!(x.a() - 1, z.a());
1797    }
1798}