proof_engine/number_theory/
modular.rs1use glam::{Vec2, Vec3, Vec4};
4use std::collections::HashMap;
5
6pub fn mod_pow(mut base: u64, mut exp: u64, modulus: u64) -> u64 {
8 if modulus == 1 {
9 return 0;
10 }
11 let mut result = 1u128;
12 let m = modulus as u128;
13 base %= modulus;
14 let mut b = base as u128;
15 while exp > 0 {
16 if exp & 1 == 1 {
17 result = result * b % m;
18 }
19 exp >>= 1;
20 b = b * b % m;
21 }
22 result as u64
23}
24
25fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
27 if a == 0 {
28 return (b, 0, 1);
29 }
30 let (g, x1, y1) = extended_gcd(b % a, a);
31 (g, y1 - (b / a) * x1, x1)
32}
33
34pub fn mod_inverse(a: u64, m: u64) -> Option<u64> {
36 let (g, x, _) = extended_gcd(a as i64, m as i64);
37 if g != 1 {
38 return None;
39 }
40 Some(((x % m as i64 + m as i64) % m as i64) as u64)
41}
42
43pub fn chinese_remainder_theorem(residues: &[u64], moduli: &[u64]) -> Option<u64> {
47 if residues.len() != moduli.len() || residues.is_empty() {
48 return None;
49 }
50 let mut r = residues[0] as i64;
51 let mut m = moduli[0] as i64;
52
53 for i in 1..residues.len() {
54 let r2 = residues[i] as i64;
55 let m2 = moduli[i] as i64;
56 let (g, p, _) = extended_gcd(m, m2);
57 if (r2 - r) % g != 0 {
58 return None; }
60 let lcm = m / g * m2;
61 let diff = r2 - r;
62 let adjust = diff / g * p % (m2 / g);
63 r = ((r as i128 + m as i128 * adjust as i128) % lcm as i128 + lcm as i128) as i64
64 % lcm as i64;
65 m = lcm;
66 }
67 Some(((r % m + m) % m) as u64)
68}
69
70fn euler_totient(mut n: u64) -> u64 {
72 let mut result = n;
73 let mut p = 2u64;
74 while p * p <= n {
75 if n % p == 0 {
76 while n % p == 0 {
77 n /= p;
78 }
79 result -= result / p;
80 }
81 p += 1;
82 }
83 if n > 1 {
84 result -= result / n;
85 }
86 result
87}
88
89pub fn primitive_roots(n: u64) -> Vec<u64> {
92 if n <= 1 {
93 return vec![];
94 }
95 if n == 2 {
96 return vec![1];
97 }
98 let phi = euler_totient(n);
99 let phi_factors = factor_small(phi);
101
102 let mut roots = Vec::new();
103 for g in 2..n {
104 if gcd(g, n) != 1 {
105 continue;
106 }
107 let mut is_root = true;
108 for &(p, _) in &phi_factors {
109 if mod_pow(g, phi / p, n) == 1 {
110 is_root = false;
111 break;
112 }
113 }
114 if is_root {
115 roots.push(g);
116 }
117 }
118 roots
119}
120
121fn gcd(a: u64, b: u64) -> u64 {
122 if b == 0 { a } else { gcd(b, a % b) }
123}
124
125fn factor_small(mut n: u64) -> Vec<(u64, u32)> {
126 let mut factors = Vec::new();
127 let mut d = 2u64;
128 while d * d <= n {
129 let mut count = 0u32;
130 while n % d == 0 {
131 n /= d;
132 count += 1;
133 }
134 if count > 0 {
135 factors.push((d, count));
136 }
137 d += 1;
138 }
139 if n > 1 {
140 factors.push((n, 1));
141 }
142 factors
143}
144
145pub fn discrete_log(base: u64, target: u64, modulus: u64) -> Option<u64> {
148 if modulus <= 1 {
149 return None;
150 }
151 let m = (modulus as f64).sqrt().ceil() as u64;
152
153 let mut table: HashMap<u64, u64> = HashMap::new();
155 let mut power = 1u64;
156 for j in 0..m {
157 table.insert(power, j);
158 power = (power as u128 * base as u128 % modulus as u128) as u64;
159 }
160
161 let base_inv_m = match mod_inverse(mod_pow(base, m, modulus), modulus) {
163 Some(v) => v,
164 None => return None,
165 };
166
167 let mut gamma = target;
168 for i in 0..m {
169 if let Some(&j) = table.get(&gamma) {
170 return Some(i * m + j);
171 }
172 gamma = (gamma as u128 * base_inv_m as u128 % modulus as u128) as u64;
173 }
174 None
175}
176
177pub struct ClockVisualization {
181 pub modulus: u64,
182 pub center: Vec2,
183 pub radius: f32,
184}
185
186pub struct ClockGlyph {
187 pub residue: u64,
188 pub position: Vec2,
189 pub color: Vec4,
190 pub character: char,
191}
192
193impl ClockVisualization {
194 pub fn new(modulus: u64, center: Vec2, radius: f32) -> Self {
195 Self { modulus, center, radius }
196 }
197
198 pub fn generate(&self) -> Vec<ClockGlyph> {
200 let n = self.modulus;
201 (0..n)
202 .map(|r| {
203 let angle =
204 std::f32::consts::FRAC_PI_2 - (r as f32 / n as f32) * std::f32::consts::TAU;
205 let pos = self.center + Vec2::new(angle.cos(), angle.sin()) * self.radius;
206 let hue = r as f32 / n as f32;
207 ClockGlyph {
208 residue: r,
209 position: pos,
210 color: Vec4::new(hue, 1.0 - hue, 0.5, 1.0),
211 character: std::char::from_digit(r as u32 % 36, 36).unwrap_or('?'),
212 }
213 })
214 .collect()
215 }
216
217 pub fn residue_class(&self, value: u64) -> Vec<Vec2> {
219 let r = value % self.modulus;
220 let angle = std::f32::consts::FRAC_PI_2
221 - (r as f32 / self.modulus as f32) * std::f32::consts::TAU;
222 vec![
223 self.center,
224 self.center + Vec2::new(angle.cos(), angle.sin()) * self.radius,
225 ]
226 }
227}
228
229pub struct ResiduePattern {
231 pub modulus: u64,
232 pub cell_size: f32,
233}
234
235pub struct PatternCell {
236 pub row: u64,
237 pub col: u64,
238 pub value: u64,
239 pub position: Vec2,
240 pub color: Vec4,
241}
242
243impl ResiduePattern {
244 pub fn new(modulus: u64, cell_size: f32) -> Self {
245 Self { modulus, cell_size }
246 }
247
248 pub fn multiplication_table(&self) -> Vec<PatternCell> {
250 let n = self.modulus;
251 let mut cells = Vec::with_capacity((n * n) as usize);
252 for r in 0..n {
253 for c in 0..n {
254 let val = (r * c) % n;
255 let brightness = val as f32 / n as f32;
256 cells.push(PatternCell {
257 row: r,
258 col: c,
259 value: val,
260 position: Vec2::new(c as f32 * self.cell_size, r as f32 * self.cell_size),
261 color: Vec4::new(brightness, 0.2, 1.0 - brightness, 1.0),
262 });
263 }
264 }
265 cells
266 }
267
268 pub fn addition_table(&self) -> Vec<PatternCell> {
270 let n = self.modulus;
271 let mut cells = Vec::with_capacity((n * n) as usize);
272 for r in 0..n {
273 for c in 0..n {
274 let val = (r + c) % n;
275 let brightness = val as f32 / n as f32;
276 cells.push(PatternCell {
277 row: r,
278 col: c,
279 value: val,
280 position: Vec2::new(c as f32 * self.cell_size, r as f32 * self.cell_size),
281 color: Vec4::new(0.2, brightness, 1.0 - brightness, 1.0),
282 });
283 }
284 }
285 cells
286 }
287}
288
289#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_mod_pow() {
297 assert_eq!(mod_pow(2, 10, 1000), 24);
298 assert_eq!(mod_pow(3, 0, 7), 1);
299 assert_eq!(mod_pow(2, 10, 1024), 0);
300 assert_eq!(mod_pow(5, 3, 13), 8); }
302
303 #[test]
304 fn test_mod_inverse() {
305 assert_eq!(mod_inverse(3, 7), Some(5)); assert_eq!(mod_inverse(2, 4), None); assert_eq!(mod_inverse(1, 5), Some(1));
308 let inv = mod_inverse(17, 43).unwrap();
309 assert_eq!((17 * inv) % 43, 1);
310 }
311
312 #[test]
313 fn test_crt() {
314 let result = chinese_remainder_theorem(&[2, 3, 2], &[3, 5, 7]).unwrap();
316 assert_eq!(result % 3, 2);
317 assert_eq!(result % 5, 3);
318 assert_eq!(result % 7, 2);
319 assert_eq!(result, 23);
320 }
321
322 #[test]
323 fn test_crt_inconsistent() {
324 assert!(chinese_remainder_theorem(&[1, 0], &[2, 2]).is_none());
326 }
327
328 #[test]
329 fn test_primitive_roots() {
330 let roots = primitive_roots(7);
331 assert!(roots.contains(&3));
332 assert!(roots.contains(&5));
333 for &g in &roots {
335 let mut seen = std::collections::HashSet::new();
336 let mut val = 1u64;
337 for _ in 0..6 {
338 seen.insert(val);
339 val = val * g % 7;
340 }
341 assert_eq!(seen.len(), 6);
342 }
343 }
344
345 #[test]
346 fn test_primitive_roots_11() {
347 let roots = primitive_roots(11);
348 assert!(roots.contains(&2));
349 assert_eq!(roots.len(), 4);
351 }
352
353 #[test]
354 fn test_discrete_log() {
355 assert_eq!(discrete_log(2, 8, 13), Some(3));
357 let x = discrete_log(3, 1, 7).unwrap();
359 assert_eq!(mod_pow(3, x, 7), 1);
360 }
361
362 #[test]
363 fn test_discrete_log_larger() {
364 if let Some(x) = discrete_log(5, 12, 23) {
366 assert_eq!(mod_pow(5, x, 23), 12);
367 }
368 }
369
370 #[test]
371 fn test_clock_visualization() {
372 let clock = ClockVisualization::new(12, Vec2::ZERO, 5.0);
373 let glyphs = clock.generate();
374 assert_eq!(glyphs.len(), 12);
375 let top = &glyphs[0];
377 assert!((top.position.y - 5.0).abs() < 0.01);
378 }
379
380 #[test]
381 fn test_residue_pattern() {
382 let pat = ResiduePattern::new(5, 1.0);
383 let table = pat.multiplication_table();
384 assert_eq!(table.len(), 25);
385 let cell = &table[3 * 5 + 4]; assert_eq!(cell.value, 2);
388 }
389
390 #[test]
391 fn test_mod_pow_large() {
392 let result = mod_pow(2, 64, 1_000_000_007);
394 assert_eq!(result, 582344008);
396 }
397}