alkahest_cas/number_theory/
mod.rs1use crate::errors::AlkahestError;
8use crate::flint::ffi::{self as ffi, FmpzFactorStruct};
9use crate::flint::FlintInteger;
10use rug::Complete;
11use rug::Integer;
12use std::cmp::Ordering;
13use std::fmt;
14use std::str::FromStr;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum NumberTheoryError {
23 InvalidInput { msg: &'static str },
24 Domain { msg: &'static str },
25 NoSolution,
26 CompositeModulus,
27 UnsupportedNthRoot,
28}
29
30impl fmt::Display for NumberTheoryError {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 match self {
33 NumberTheoryError::InvalidInput { msg } => write!(f, "{msg}"),
34 NumberTheoryError::Domain { msg } => write!(f, "{msg}"),
35 NumberTheoryError::NoSolution => {
36 write!(f, "no discrete logarithm or modular root exists")
37 }
38 NumberTheoryError::CompositeModulus => write!(f, "operation requires a prime modulus"),
39 NumberTheoryError::UnsupportedNthRoot => {
40 write!(f, "nth root modulo p requires k=2 or gcd(k,p−1)=1")
41 }
42 }
43 }
44}
45
46impl std::error::Error for NumberTheoryError {}
47
48impl AlkahestError for NumberTheoryError {
49 fn code(&self) -> &'static str {
50 match self {
51 NumberTheoryError::InvalidInput { .. } => "E-NT-001",
52 NumberTheoryError::Domain { .. } => "E-NT-002",
53 NumberTheoryError::NoSolution => "E-NT-003",
54 NumberTheoryError::CompositeModulus => "E-NT-004",
55 NumberTheoryError::UnsupportedNthRoot => "E-NT-005",
56 }
57 }
58
59 fn remediation(&self) -> Option<&'static str> {
60 match self {
61 NumberTheoryError::InvalidInput { .. } => {
62 Some("pass arbitrary-precision integers as decimal strings without spaces")
63 }
64 NumberTheoryError::Domain { .. } => {
65 Some("check parity, positivity, and defined ranges")
66 }
67 NumberTheoryError::NoSolution => {
68 Some("verify solvability: residue in ⟨base⟩, or quadratic residue for k=2")
69 }
70 NumberTheoryError::CompositeModulus => {
71 Some("use a prime field modulus where the FLINT primitives apply")
72 }
73 NumberTheoryError::UnsupportedNthRoot => Some(
74 "use sqrt (k=2) or primes with gcd(k,p−1)=1; Tonelli–Shanks chains are deferred",
75 ),
76 }
77 }
78}
79
80fn parse_int(s: &str) -> Result<Integer, NumberTheoryError> {
81 Integer::from_str(s.trim()).map_err(|_| NumberTheoryError::InvalidInput {
82 msg: "invalid decimal integer string",
83 })
84}
85
86fn parse_nonnegative(s: &str) -> Result<Integer, NumberTheoryError> {
87 let z = parse_int(s)?;
88 if z.cmp0() == Ordering::Less {
89 Err(NumberTheoryError::Domain {
90 msg: "expected a non-negative integer",
91 })
92 } else {
93 Ok(z)
94 }
95}
96
97fn mod_inverse(mut a: Integer, m: &Integer) -> Option<Integer> {
99 if m.cmp0() != Ordering::Greater {
100 return None;
101 }
102 if m == &Integer::from(1) {
103 return Some(Integer::from(0));
104 }
105 a %= m;
106 let (g, s, _) = a.extended_gcd(m.clone(), Integer::new());
107 if g != 1 && g != -1 {
108 return None;
109 }
110 let mut inv = if g == -1 { -s } else { s };
111 inv %= m;
112 if inv.cmp0() == Ordering::Less {
113 inv += m;
114 }
115 Some(inv)
116}
117
118fn integer_is_odd(n: &Integer) -> bool {
119 (n.clone() % Integer::from(2_u32)).cmp0() != Ordering::Equal
120}
121
122fn parse_positive(s: &str) -> Result<Integer, NumberTheoryError> {
124 let z = parse_nonnegative(s)?;
125 if z.is_zero() {
126 Err(NumberTheoryError::Domain {
127 msg: "expected a positive integer",
128 })
129 } else {
130 Ok(z)
131 }
132}
133
134pub fn isprime(n: &str) -> Result<bool, NumberTheoryError> {
136 let z = parse_int(n)?;
137 if z.cmp0() != Ordering::Greater || z < 2 {
138 return Ok(false);
139 }
140 let fz = FlintInteger::from_rug(&z);
141 let r = unsafe { ffi::fmpz_is_prime(fz.inner_ptr()) };
142 Ok(r != 0)
143}
144
145pub fn factorint(n: &str) -> Result<(i32, Vec<(String, u64)>), NumberTheoryError> {
147 let z = parse_int(n)?;
148 let fz = FlintInteger::from_rug(&z);
149 unsafe {
150 let mut fac = std::mem::MaybeUninit::<FmpzFactorStruct>::uninit();
151 ffi::fmpz_factor_init(fac.as_mut_ptr());
152 let mut fac = fac.assume_init();
153 ffi::fmpz_factor(&mut fac, fz.inner_ptr());
154 let mut out = Vec::with_capacity(fac.num.max(0) as usize);
155 for i in 0..fac.num {
156 let mut base = FlintInteger::new();
157 ffi::fmpz_set(base.inner_mut_ptr(), fac.p.add(i as usize));
158 let exp = *fac.exp.add(i as usize);
159 out.push((base.to_string(), exp));
160 }
161 let sign = fac.sign;
162 ffi::fmpz_factor_clear(&mut fac);
163 Ok((sign, out))
164 }
165}
166
167pub fn nextprime(n: &str, proved: bool) -> Result<String, NumberTheoryError> {
169 let z = parse_int(n)?;
170 let fz = FlintInteger::from_rug(&z);
171 let mut res = FlintInteger::new();
172 unsafe {
173 ffi::fmpz_nextprime(
174 res.inner_mut_ptr(),
175 fz.inner_ptr(),
176 if proved { 1 } else { 0 },
177 );
178 }
179 Ok(res.to_string())
180}
181
182pub fn totient(n: &str) -> Result<String, NumberTheoryError> {
184 let z = parse_positive(n)?;
185 let fz = FlintInteger::from_rug(&z);
186 let mut out = FlintInteger::new();
187 unsafe {
188 ffi::fmpz_euler_phi(out.inner_mut_ptr(), fz.inner_ptr());
189 }
190 Ok(out.to_string())
191}
192
193pub fn jacobi_symbol(a: &str, n: &str) -> Result<i32, NumberTheoryError> {
195 let na = parse_int(a)?;
196 let nn = parse_positive(n)?;
197 if nn <= 1 || !integer_is_odd(&nn) {
198 return Err(NumberTheoryError::Domain {
199 msg: "Jacobi denominator must be odd and greater than 1",
200 });
201 }
202 let fa = FlintInteger::from_rug(&na);
203 let fn_ = FlintInteger::from_rug(&nn);
204 let j = unsafe { ffi::fmpz_jacobi(fa.inner_ptr(), fn_.inner_ptr()) };
205 Ok(j as i32)
206}
207
208pub fn nthroot_mod(a: &str, k: u64, p: &str) -> Result<String, NumberTheoryError> {
212 if k == 0 {
213 return Err(NumberTheoryError::InvalidInput {
214 msg: "root degree must be ≥ 1",
215 });
216 }
217 let pm = parse_positive(p)?;
218 let fp = FlintInteger::from_rug(&pm);
219 if unsafe { ffi::fmpz_is_prime(fp.inner_ptr()) } == 0 {
220 return Err(NumberTheoryError::CompositeModulus);
221 }
222
223 let mut ared = parse_int(a)?;
224 ared %= ±
225
226 let mut out = FlintInteger::new();
227
228 if k == 2 {
229 let fa = FlintInteger::from_rug(&ared);
230 let ok = unsafe { ffi::fmpz_sqrtmod(out.inner_mut_ptr(), fa.inner_ptr(), fp.inner_ptr()) };
231 if ok == 0 {
232 return Err(NumberTheoryError::NoSolution);
233 }
234 return Ok(out.to_string());
235 }
236
237 let ord = pm.clone() - 1;
238 let kk = Integer::from(k);
239 if kk.clone().gcd(&ord) != 1 {
240 return Err(NumberTheoryError::UnsupportedNthRoot);
241 }
242 let mut inv_e = mod_inverse(kk.clone(), &ord).ok_or(NumberTheoryError::UnsupportedNthRoot)?;
243 inv_e %= ⩝
244 let fa = FlintInteger::from_rug(&ared);
245 let fe = FlintInteger::from_rug(&inv_e);
246 unsafe {
247 ffi::fmpz_powm(
248 out.inner_mut_ptr(),
249 fa.inner_ptr(),
250 fe.inner_ptr(),
251 fp.inner_ptr(),
252 );
253 }
254 Ok(out.to_string())
255}
256
257pub fn discrete_log(residue: &str, base: &str, p: &str) -> Result<String, NumberTheoryError> {
262 let pm = parse_positive(p)?;
263 if pm < 2 {
264 return Err(NumberTheoryError::Domain {
265 msg: "modulus must be at least 2",
266 });
267 }
268 let fp = FlintInteger::from_rug(&pm);
269 if unsafe { ffi::fmpz_is_prime(fp.inner_ptr()) } == 0 {
270 return Err(NumberTheoryError::CompositeModulus);
271 }
272
273 let ord = pm.clone() - Integer::from(1);
274 let mut b = parse_int(base)?;
275 let mut r = parse_int(residue)?;
276 r %= ±
277 b %= ±
278
279 if b.is_zero() {
280 return if r.is_zero() {
281 Ok("1".into())
282 } else {
283 Err(NumberTheoryError::NoSolution)
284 };
285 }
286
287 let mut cur = Integer::from(1);
288 let mut exp = Integer::from(0);
289 while exp < ord {
290 if cur == r {
291 return Ok(exp.to_string());
292 }
293 cur = (&cur * &b).complete();
294 cur %= ±
295 exp += 1;
296 }
297 Err(NumberTheoryError::NoSolution)
298}
299
300#[derive(Clone, Debug)]
302pub struct QuadraticDirichlet {
303 modulus: Integer,
304}
305
306impl QuadraticDirichlet {
307 pub fn new(conductor: &str) -> Result<Self, NumberTheoryError> {
308 let q = parse_positive(conductor)?;
309 if q <= 2 || !integer_is_odd(&q) {
310 return Err(NumberTheoryError::Domain {
311 msg: "quadratic Dirichlet conductor must be odd and ≥ 3",
312 });
313 }
314 let (_sign, fac) = factorint(conductor)?;
315 for (_, e) in &fac {
316 if *e != 1 {
317 return Err(NumberTheoryError::Domain {
318 msg: "conductor must be square-free",
319 });
320 }
321 }
322 Ok(QuadraticDirichlet { modulus: q })
323 }
324
325 pub fn conductor(&self) -> String {
326 self.modulus.to_string()
327 }
328
329 pub fn eval(&self, n: &str) -> Result<i32, NumberTheoryError> {
331 jacobi_symbol(n, &self.modulus.to_string())
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use rug::ops::Pow;
339 use std::collections::HashMap;
340
341 #[test]
342 fn mersenne_m127_prime() {
343 let m = Integer::from(2u32).pow(127_u32) - 1_u32;
344 assert!(isprime(&m.to_string()).unwrap());
345 }
346
347 #[test]
348 fn factorint_f5() {
349 let n = &(1u128 << 32) - 1;
350 let (sign, pairs) = factorint(&n.to_string()).unwrap();
351 assert_eq!(sign, 1);
352 let m: HashMap<_, _> = pairs.into_iter().collect();
353 assert_eq!(m.get("65537").copied(), Some(1));
354 }
355
356 #[test]
357 fn nextprime_gap() {
358 assert_eq!(nextprime("13", true).unwrap(), "17");
359 }
360
361 #[test]
362 fn totient_twelve() {
363 assert_eq!(totient("12").unwrap(), "4");
364 }
365
366 #[test]
367 fn jacobi_two_fifteen() {
368 assert_eq!(jacobi_symbol("2", "15").unwrap(), 1);
369 }
370
371 #[test]
372 fn sqrt_mod_prime() {
373 let x_str = nthroot_mod("144", 2, "401").unwrap();
374 let x: u64 = x_str.parse().unwrap();
375 assert_eq!((x * x) % 401, 144);
376 }
377
378 #[test]
379 fn nth_root_via_coprime_exponent() {
380 let pm = Integer::from(10007);
381 let a = Integer::from(42);
382 let k = 5u64;
383 let kk = Integer::from(k);
384 let ord = pm.clone() - Integer::from(1);
385 assert_eq!(kk.clone().gcd(&ord), Integer::from(1));
386
387 let x_str = nthroot_mod(&a.to_string(), k, &pm.to_string()).unwrap();
388 let x = Integer::from_str(&x_str).unwrap();
389 let chk = x.clone().pow(k as u32) % ±
390 assert_eq!(chk, a % &pm);
391 }
392
393 #[test]
394 fn discrete_log_three_mod_seventeen() {
395 assert_eq!(discrete_log("13", "3", "17").unwrap(), "4",);
396 }
397
398 #[test]
399 fn dirichlet_phi_fifteen() {
400 let chi = QuadraticDirichlet::new("15").unwrap();
401 assert_eq!(chi.eval("14").unwrap(), -1);
402 assert_eq!(chi.eval("3").unwrap(), 0);
403 }
404}