Skip to main content

proof_engine/number_theory/
modular.rs

1//! Modular arithmetic, CRT, discrete logarithm, and visualization.
2
3use glam::{Vec2, Vec3, Vec4};
4use std::collections::HashMap;
5
6/// Modular exponentiation: (base^exp) mod modulus.
7pub fn mod_pow(mut base: u64, mut exp: u64, modulus: u64) -> u64 {
8    if modulus == 1 {
9        return 0;
10    }
11    let mut result = 1u128;
12    let m = modulus as u128;
13    base %= modulus;
14    let mut b = base as u128;
15    while exp > 0 {
16        if exp & 1 == 1 {
17            result = result * b % m;
18        }
19        exp >>= 1;
20        b = b * b % m;
21    }
22    result as u64
23}
24
25/// Extended Euclidean algorithm returning (gcd, x, y) such that a*x + b*y = gcd.
26fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
27    if a == 0 {
28        return (b, 0, 1);
29    }
30    let (g, x1, y1) = extended_gcd(b % a, a);
31    (g, y1 - (b / a) * x1, x1)
32}
33
34/// Modular inverse of a mod m, if it exists (gcd(a, m) == 1).
35pub fn mod_inverse(a: u64, m: u64) -> Option<u64> {
36    let (g, x, _) = extended_gcd(a as i64, m as i64);
37    if g != 1 {
38        return None;
39    }
40    Some(((x % m as i64 + m as i64) % m as i64) as u64)
41}
42
43/// Chinese Remainder Theorem: given residues[i] and moduli[i],
44/// find x such that x ≡ residues[i] (mod moduli[i]) for all i.
45/// Returns None if the system is inconsistent.
46pub fn chinese_remainder_theorem(residues: &[u64], moduli: &[u64]) -> Option<u64> {
47    if residues.len() != moduli.len() || residues.is_empty() {
48        return None;
49    }
50    let mut r = residues[0] as i64;
51    let mut m = moduli[0] as i64;
52
53    for i in 1..residues.len() {
54        let r2 = residues[i] as i64;
55        let m2 = moduli[i] as i64;
56        let (g, p, _) = extended_gcd(m, m2);
57        if (r2 - r) % g != 0 {
58            return None; // inconsistent
59        }
60        let lcm = m / g * m2;
61        let diff = r2 - r;
62        let adjust = diff / g * p % (m2 / g);
63        r = ((r as i128 + m as i128 * adjust as i128) % lcm as i128 + lcm as i128) as i64
64            % lcm as i64;
65        m = lcm;
66    }
67    Some(((r % m + m) % m) as u64)
68}
69
70/// Euler's totient (needed for primitive roots).
71fn euler_totient(mut n: u64) -> u64 {
72    let mut result = n;
73    let mut p = 2u64;
74    while p * p <= n {
75        if n % p == 0 {
76            while n % p == 0 {
77                n /= p;
78            }
79            result -= result / p;
80        }
81        p += 1;
82    }
83    if n > 1 {
84        result -= result / n;
85    }
86    result
87}
88
89/// Find all primitive roots modulo n (if they exist).
90/// Primitive roots exist iff n is 1, 2, 4, p^k, or 2*p^k for odd prime p.
91pub fn primitive_roots(n: u64) -> Vec<u64> {
92    if n <= 1 {
93        return vec![];
94    }
95    if n == 2 {
96        return vec![1];
97    }
98    let phi = euler_totient(n);
99    // Factor phi to check orders
100    let phi_factors = factor_small(phi);
101
102    let mut roots = Vec::new();
103    for g in 2..n {
104        if gcd(g, n) != 1 {
105            continue;
106        }
107        let mut is_root = true;
108        for &(p, _) in &phi_factors {
109            if mod_pow(g, phi / p, n) == 1 {
110                is_root = false;
111                break;
112            }
113        }
114        if is_root {
115            roots.push(g);
116        }
117    }
118    roots
119}
120
121fn gcd(a: u64, b: u64) -> u64 {
122    if b == 0 { a } else { gcd(b, a % b) }
123}
124
125fn factor_small(mut n: u64) -> Vec<(u64, u32)> {
126    let mut factors = Vec::new();
127    let mut d = 2u64;
128    while d * d <= n {
129        let mut count = 0u32;
130        while n % d == 0 {
131            n /= d;
132            count += 1;
133        }
134        if count > 0 {
135            factors.push((d, count));
136        }
137        d += 1;
138    }
139    if n > 1 {
140        factors.push((n, 1));
141    }
142    factors
143}
144
145/// Baby-step giant-step discrete logarithm:
146/// find x such that base^x ≡ target (mod modulus).
147pub fn discrete_log(base: u64, target: u64, modulus: u64) -> Option<u64> {
148    if modulus <= 1 {
149        return None;
150    }
151    let m = (modulus as f64).sqrt().ceil() as u64;
152
153    // Baby step: base^j for j = 0..m
154    let mut table: HashMap<u64, u64> = HashMap::new();
155    let mut power = 1u64;
156    for j in 0..m {
157        table.insert(power, j);
158        power = (power as u128 * base as u128 % modulus as u128) as u64;
159    }
160
161    // Giant step factor: base^{-m} mod modulus
162    let base_inv_m = match mod_inverse(mod_pow(base, m, modulus), modulus) {
163        Some(v) => v,
164        None => return None,
165    };
166
167    let mut gamma = target;
168    for i in 0..m {
169        if let Some(&j) = table.get(&gamma) {
170            return Some(i * m + j);
171        }
172        gamma = (gamma as u128 * base_inv_m as u128 % modulus as u128) as u64;
173    }
174    None
175}
176
177// ─── Visualization ──────────────────────────────────────────────────────────
178
179/// Renders residue classes as positions on a circle (clock face).
180pub struct ClockVisualization {
181    pub modulus: u64,
182    pub center: Vec2,
183    pub radius: f32,
184}
185
186pub struct ClockGlyph {
187    pub residue: u64,
188    pub position: Vec2,
189    pub color: Vec4,
190    pub character: char,
191}
192
193impl ClockVisualization {
194    pub fn new(modulus: u64, center: Vec2, radius: f32) -> Self {
195        Self { modulus, center, radius }
196    }
197
198    /// Place residues 0..modulus around a circle.
199    pub fn generate(&self) -> Vec<ClockGlyph> {
200        let n = self.modulus;
201        (0..n)
202            .map(|r| {
203                let angle =
204                    std::f32::consts::FRAC_PI_2 - (r as f32 / n as f32) * std::f32::consts::TAU;
205                let pos = self.center + Vec2::new(angle.cos(), angle.sin()) * self.radius;
206                let hue = r as f32 / n as f32;
207                ClockGlyph {
208                    residue: r,
209                    position: pos,
210                    color: Vec4::new(hue, 1.0 - hue, 0.5, 1.0),
211                    character: std::char::from_digit(r as u32 % 36, 36).unwrap_or('?'),
212                }
213            })
214            .collect()
215    }
216
217    /// Highlight a specific residue class with connections.
218    pub fn residue_class(&self, value: u64) -> Vec<Vec2> {
219        let r = value % self.modulus;
220        let angle = std::f32::consts::FRAC_PI_2
221            - (r as f32 / self.modulus as f32) * std::f32::consts::TAU;
222        vec![
223            self.center,
224            self.center + Vec2::new(angle.cos(), angle.sin()) * self.radius,
225        ]
226    }
227}
228
229/// Render multiplication tables mod n as glyph grids.
230pub struct ResiduePattern {
231    pub modulus: u64,
232    pub cell_size: f32,
233}
234
235pub struct PatternCell {
236    pub row: u64,
237    pub col: u64,
238    pub value: u64,
239    pub position: Vec2,
240    pub color: Vec4,
241}
242
243impl ResiduePattern {
244    pub fn new(modulus: u64, cell_size: f32) -> Self {
245        Self { modulus, cell_size }
246    }
247
248    /// Generate the full multiplication table mod n.
249    pub fn multiplication_table(&self) -> Vec<PatternCell> {
250        let n = self.modulus;
251        let mut cells = Vec::with_capacity((n * n) as usize);
252        for r in 0..n {
253            for c in 0..n {
254                let val = (r * c) % n;
255                let brightness = val as f32 / n as f32;
256                cells.push(PatternCell {
257                    row: r,
258                    col: c,
259                    value: val,
260                    position: Vec2::new(c as f32 * self.cell_size, r as f32 * self.cell_size),
261                    color: Vec4::new(brightness, 0.2, 1.0 - brightness, 1.0),
262                });
263            }
264        }
265        cells
266    }
267
268    /// Generate the addition table mod n.
269    pub fn addition_table(&self) -> Vec<PatternCell> {
270        let n = self.modulus;
271        let mut cells = Vec::with_capacity((n * n) as usize);
272        for r in 0..n {
273            for c in 0..n {
274                let val = (r + c) % n;
275                let brightness = val as f32 / n as f32;
276                cells.push(PatternCell {
277                    row: r,
278                    col: c,
279                    value: val,
280                    position: Vec2::new(c as f32 * self.cell_size, r as f32 * self.cell_size),
281                    color: Vec4::new(0.2, brightness, 1.0 - brightness, 1.0),
282                });
283            }
284        }
285        cells
286    }
287}
288
289// ─── Tests ──────────────────────────────────────────────────────────────────
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_mod_pow() {
297        assert_eq!(mod_pow(2, 10, 1000), 24);
298        assert_eq!(mod_pow(3, 0, 7), 1);
299        assert_eq!(mod_pow(2, 10, 1024), 0);
300        assert_eq!(mod_pow(5, 3, 13), 8); // 125 % 13 = 8
301    }
302
303    #[test]
304    fn test_mod_inverse() {
305        assert_eq!(mod_inverse(3, 7), Some(5)); // 3*5 = 15 ≡ 1 (mod 7)
306        assert_eq!(mod_inverse(2, 4), None); // gcd(2,4) = 2
307        assert_eq!(mod_inverse(1, 5), Some(1));
308        let inv = mod_inverse(17, 43).unwrap();
309        assert_eq!((17 * inv) % 43, 1);
310    }
311
312    #[test]
313    fn test_crt() {
314        // x ≡ 2 (mod 3), x ≡ 3 (mod 5), x ≡ 2 (mod 7)
315        let result = chinese_remainder_theorem(&[2, 3, 2], &[3, 5, 7]).unwrap();
316        assert_eq!(result % 3, 2);
317        assert_eq!(result % 5, 3);
318        assert_eq!(result % 7, 2);
319        assert_eq!(result, 23);
320    }
321
322    #[test]
323    fn test_crt_inconsistent() {
324        // x ≡ 1 (mod 2), x ≡ 0 (mod 2) — impossible
325        assert!(chinese_remainder_theorem(&[1, 0], &[2, 2]).is_none());
326    }
327
328    #[test]
329    fn test_primitive_roots() {
330        let roots = primitive_roots(7);
331        assert!(roots.contains(&3));
332        assert!(roots.contains(&5));
333        // Verify they generate all residues
334        for &g in &roots {
335            let mut seen = std::collections::HashSet::new();
336            let mut val = 1u64;
337            for _ in 0..6 {
338                seen.insert(val);
339                val = val * g % 7;
340            }
341            assert_eq!(seen.len(), 6);
342        }
343    }
344
345    #[test]
346    fn test_primitive_roots_11() {
347        let roots = primitive_roots(11);
348        assert!(roots.contains(&2));
349        // phi(11) = 10, number of primitive roots = phi(phi(11)) = phi(10) = 4
350        assert_eq!(roots.len(), 4);
351    }
352
353    #[test]
354    fn test_discrete_log() {
355        // 2^x ≡ 8 (mod 13) => x = 3
356        assert_eq!(discrete_log(2, 8, 13), Some(3));
357        // 3^x ≡ 1 (mod 7) => x = 0 (or 6)
358        let x = discrete_log(3, 1, 7).unwrap();
359        assert_eq!(mod_pow(3, x, 7), 1);
360    }
361
362    #[test]
363    fn test_discrete_log_larger() {
364        // 5^x ≡ 12 (mod 23)
365        if let Some(x) = discrete_log(5, 12, 23) {
366            assert_eq!(mod_pow(5, x, 23), 12);
367        }
368    }
369
370    #[test]
371    fn test_clock_visualization() {
372        let clock = ClockVisualization::new(12, Vec2::ZERO, 5.0);
373        let glyphs = clock.generate();
374        assert_eq!(glyphs.len(), 12);
375        // residue 0 should be at top (angle = pi/2)
376        let top = &glyphs[0];
377        assert!((top.position.y - 5.0).abs() < 0.01);
378    }
379
380    #[test]
381    fn test_residue_pattern() {
382        let pat = ResiduePattern::new(5, 1.0);
383        let table = pat.multiplication_table();
384        assert_eq!(table.len(), 25);
385        // 3 * 4 mod 5 = 2
386        let cell = &table[3 * 5 + 4]; // row 3, col 4
387        assert_eq!(cell.value, 2);
388    }
389
390    #[test]
391    fn test_mod_pow_large() {
392        // 2^64 mod 1000000007
393        let result = mod_pow(2, 64, 1_000_000_007);
394        // 2^64 = 18446744073709551616, mod 10^9+7 = 582344008
395        assert_eq!(result, 582344008);
396    }
397}