Skip to main content

alkahest_cas/flint/
integer.rs

1use super::ffi;
2use rug::Complete;
3use std::ffi::CString;
4use std::fmt;
5use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
6
7/// Safe wrapper over FLINT's `fmpz_t` — arbitrary-precision integer.
8///
9/// Memory is managed by FLINT's allocator. `Drop` calls `fmpz_clear`.
10/// All raw pointers are confined to this file; callers see only safe Rust.
11pub struct FlintInteger {
12    /// The `fmpz` storage. Either an inline i64 or a tagged pointer to GMP
13    /// memory managed by FLINT. Must never be aliased across two `FlintInteger`
14    /// values.
15    inner: ffi::fmpz,
16}
17
18// SAFETY: fmpz is either an i64 or owns its GMP memory. No shared state.
19unsafe impl Send for FlintInteger {}
20unsafe impl Sync for FlintInteger {}
21
22impl FlintInteger {
23    pub fn new() -> Self {
24        let mut inner: ffi::fmpz = 0;
25        unsafe { ffi::fmpz_init(&mut inner) };
26        FlintInteger { inner }
27    }
28
29    pub fn from_i64(val: i64) -> Self {
30        let mut f = Self::new();
31        unsafe { ffi::fmpz_set_si(&mut f.inner, val) };
32        f
33    }
34
35    /// Return as `i64`. For values that overflow i64 this wraps/truncates —
36    /// use `to_string()` for a lossless decimal representation.
37    pub fn to_i64(&self) -> i64 {
38        unsafe { ffi::fmpz_get_si(&self.inner) }
39    }
40
41    pub fn gcd(&self, other: &Self) -> Self {
42        let mut res = Self::new();
43        unsafe { ffi::fmpz_gcd(&mut res.inner, &self.inner, &other.inner) };
44        res
45    }
46
47    pub fn pow(&self, exp: u64) -> Self {
48        let mut res = Self::new();
49        unsafe { ffi::fmpz_pow_ui(&mut res.inner, &self.inner, exp) };
50        res
51    }
52
53    /// Construct from a `rug::Integer` via decimal string round-trip.
54    pub fn from_rug(n: &rug::Integer) -> Self {
55        let s = n.to_string();
56        let cstr = CString::new(s.as_str()).unwrap();
57        let mut f = Self::new();
58        unsafe { ffi::fmpz_set_str(&mut f.inner, cstr.as_ptr(), 10) };
59        f
60    }
61
62    /// Expose the raw inner `fmpz` for use by `FlintPoly` coefficient accessors.
63    pub(crate) fn inner_ptr(&self) -> *const ffi::fmpz {
64        &self.inner
65    }
66
67    pub(crate) fn inner_mut_ptr(&mut self) -> *mut ffi::fmpz {
68        &mut self.inner
69    }
70
71    /// Convert to a `rug::Integer` for cross-validation in tests.
72    pub fn to_rug(&self) -> rug::Integer {
73        rug::Integer::parse_radix(self.to_string().as_bytes(), 10)
74            .unwrap()
75            .complete()
76    }
77}
78
79impl Default for FlintInteger {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl Drop for FlintInteger {
86    fn drop(&mut self) {
87        unsafe { ffi::fmpz_clear(&mut self.inner) };
88    }
89}
90
91impl Clone for FlintInteger {
92    fn clone(&self) -> Self {
93        let mut new = Self::new();
94        unsafe { ffi::fmpz_set(&mut new.inner, &self.inner) };
95        new
96    }
97}
98
99impl PartialEq for FlintInteger {
100    fn eq(&self, other: &Self) -> bool {
101        unsafe { ffi::fmpz_equal(&self.inner, &other.inner) != 0 }
102    }
103}
104impl Eq for FlintInteger {}
105
106// ---------------------------------------------------------------------------
107// Arithmetic — owned and reference variants
108// ---------------------------------------------------------------------------
109
110impl Add for FlintInteger {
111    type Output = Self;
112    fn add(self, rhs: Self) -> Self {
113        &self + &rhs
114    }
115}
116impl<'b> Add<&'b FlintInteger> for &FlintInteger {
117    type Output = FlintInteger;
118    fn add(self, rhs: &'b FlintInteger) -> FlintInteger {
119        let mut res = FlintInteger::new();
120        unsafe { ffi::fmpz_add(&mut res.inner, &self.inner, &rhs.inner) };
121        res
122    }
123}
124
125impl Sub for FlintInteger {
126    type Output = Self;
127    fn sub(self, rhs: Self) -> Self {
128        &self - &rhs
129    }
130}
131impl<'b> Sub<&'b FlintInteger> for &FlintInteger {
132    type Output = FlintInteger;
133    fn sub(self, rhs: &'b FlintInteger) -> FlintInteger {
134        let mut res = FlintInteger::new();
135        unsafe { ffi::fmpz_sub(&mut res.inner, &self.inner, &rhs.inner) };
136        res
137    }
138}
139
140impl Mul for FlintInteger {
141    type Output = Self;
142    fn mul(self, rhs: Self) -> Self {
143        &self * &rhs
144    }
145}
146impl<'b> Mul<&'b FlintInteger> for &FlintInteger {
147    type Output = FlintInteger;
148    fn mul(self, rhs: &'b FlintInteger) -> FlintInteger {
149        let mut res = FlintInteger::new();
150        unsafe { ffi::fmpz_mul(&mut res.inner, &self.inner, &rhs.inner) };
151        res
152    }
153}
154
155/// Truncated (toward-zero) division, matching Rust's built-in integer `/`.
156impl Div for FlintInteger {
157    type Output = Self;
158    fn div(self, rhs: Self) -> Self {
159        &self / &rhs
160    }
161}
162impl<'b> Div<&'b FlintInteger> for &FlintInteger {
163    type Output = FlintInteger;
164    fn div(self, rhs: &'b FlintInteger) -> FlintInteger {
165        let mut res = FlintInteger::new();
166        unsafe { ffi::fmpz_tdiv_q(&mut res.inner, &self.inner, &rhs.inner) };
167        res
168    }
169}
170
171/// Remainder after truncated division, matching Rust's built-in `%`.
172impl Rem for FlintInteger {
173    type Output = Self;
174    fn rem(self, rhs: Self) -> Self {
175        &self % &rhs
176    }
177}
178impl<'b> Rem<&'b FlintInteger> for &FlintInteger {
179    type Output = FlintInteger;
180    fn rem(self, rhs: &'b FlintInteger) -> FlintInteger {
181        let mut q = FlintInteger::new();
182        let mut r = FlintInteger::new();
183        unsafe { ffi::fmpz_tdiv_qr(&mut q.inner, &mut r.inner, &self.inner, &rhs.inner) };
184        r
185    }
186}
187
188impl Neg for FlintInteger {
189    type Output = Self;
190    fn neg(self) -> Self {
191        -&self
192    }
193}
194impl Neg for &FlintInteger {
195    type Output = FlintInteger;
196    fn neg(self) -> FlintInteger {
197        let mut res = FlintInteger::new();
198        unsafe { ffi::fmpz_neg(&mut res.inner, &self.inner) };
199        res
200    }
201}
202
203// ---------------------------------------------------------------------------
204// Display / Debug
205// ---------------------------------------------------------------------------
206
207impl fmt::Display for FlintInteger {
208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        // fmpz_get_str(NULL, base, f) allocates a new C string; caller frees
210        // with flint_free.
211        unsafe {
212            let ptr = ffi::fmpz_get_str(std::ptr::null_mut(), 10, &self.inner);
213            if ptr.is_null() {
214                return write!(f, "<err>");
215            }
216            let s = std::ffi::CStr::from_ptr(ptr)
217                .to_str()
218                .unwrap_or("<utf8-err>")
219                .to_owned();
220            ffi::flint_free(ptr as *mut std::ffi::c_void);
221            write!(f, "{}", s)
222        }
223    }
224}
225
226impl fmt::Debug for FlintInteger {
227    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228        write!(f, "FlintInteger({})", self)
229    }
230}
231
232// ---------------------------------------------------------------------------
233// Unit tests
234// ---------------------------------------------------------------------------
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    // --- construction and equality ---
241
242    #[test]
243    fn zero() {
244        let z = FlintInteger::new();
245        assert_eq!(z, FlintInteger::from_i64(0));
246    }
247
248    #[test]
249    fn from_i64_roundtrip() {
250        for v in [-1000i64, -1, 0, 1, 1000, i64::MAX / 2] {
251            let f = FlintInteger::from_i64(v);
252            assert_eq!(f.to_i64(), v);
253        }
254    }
255
256    #[test]
257    fn clone_is_independent() {
258        let a = FlintInteger::from_i64(42);
259        let b = a.clone();
260        assert_eq!(a, b);
261        // modifying b via arithmetic should not affect a
262        let c = &b + &FlintInteger::from_i64(1);
263        assert_eq!(a, FlintInteger::from_i64(42));
264        assert_eq!(c, FlintInteger::from_i64(43));
265    }
266
267    // --- arithmetic ---
268
269    #[test]
270    fn add() {
271        let a = FlintInteger::from_i64(7);
272        let b = FlintInteger::from_i64(5);
273        assert_eq!((&a + &b).to_i64(), 12);
274    }
275
276    #[test]
277    fn sub() {
278        let a = FlintInteger::from_i64(7);
279        let b = FlintInteger::from_i64(5);
280        assert_eq!((&a - &b).to_i64(), 2);
281    }
282
283    #[test]
284    fn mul() {
285        let a = FlintInteger::from_i64(7);
286        let b = FlintInteger::from_i64(5);
287        assert_eq!((&a * &b).to_i64(), 35);
288    }
289
290    #[test]
291    fn div_truncated() {
292        let a = FlintInteger::from_i64(7);
293        let b = FlintInteger::from_i64(3);
294        assert_eq!((&a / &b).to_i64(), 2); // truncated toward zero
295        let c = FlintInteger::from_i64(-7);
296        assert_eq!((&c / &b).to_i64(), -2); // negative: truncates toward zero
297    }
298
299    #[test]
300    fn rem() {
301        let a = FlintInteger::from_i64(7);
302        let b = FlintInteger::from_i64(3);
303        assert_eq!((&a % &b).to_i64(), 1);
304    }
305
306    #[test]
307    fn neg() {
308        let a = FlintInteger::from_i64(5);
309        assert_eq!((-&a).to_i64(), -5);
310        assert_eq!((-FlintInteger::from_i64(-3)).to_i64(), 3);
311    }
312
313    #[test]
314    fn gcd() {
315        let a = FlintInteger::from_i64(12);
316        let b = FlintInteger::from_i64(8);
317        assert_eq!(a.gcd(&b).to_i64(), 4);
318        let p = FlintInteger::from_i64(17);
319        let q = FlintInteger::from_i64(5);
320        assert_eq!(p.gcd(&q).to_i64(), 1); // coprime
321    }
322
323    #[test]
324    fn pow() {
325        let a = FlintInteger::from_i64(2);
326        assert_eq!(a.pow(10).to_i64(), 1024);
327        assert_eq!(a.pow(0).to_i64(), 1);
328    }
329
330    // --- display ---
331
332    #[test]
333    fn display() {
334        assert_eq!(FlintInteger::from_i64(0).to_string(), "0");
335        assert_eq!(FlintInteger::from_i64(-42).to_string(), "-42");
336        assert_eq!(FlintInteger::from_i64(1_000_000).to_string(), "1000000");
337    }
338
339    // --- cross-validation against rug ---
340
341    #[test]
342    fn roundtrip_vs_rug_small() {
343        for v in [-999i64, -1, 0, 1, 999] {
344            let flint = FlintInteger::from_i64(v);
345            let rug_val = rug::Integer::from(v);
346            assert_eq!(flint.to_string(), rug_val.to_string(), "mismatch for v={v}");
347        }
348    }
349
350    #[test]
351    fn arithmetic_vs_rug() {
352        use rug::ops::DivRounding;
353        let pairs: &[(i64, i64)] = &[(0, 0), (7, 5), (-12, 4), (100, 7), (1000, 999)];
354        for &(a, b) in pairs {
355            let fa = FlintInteger::from_i64(a);
356            let fb = FlintInteger::from_i64(b);
357            let ra = rug::Integer::from(a);
358            let rb = rug::Integer::from(b);
359            assert_eq!(
360                (&fa + &fb).to_string(),
361                rug::Integer::from(&ra + &rb).to_string(),
362                "add {a}+{b}"
363            );
364            assert_eq!(
365                (&fa - &fb).to_string(),
366                rug::Integer::from(&ra - &rb).to_string(),
367                "sub {a}-{b}"
368            );
369            assert_eq!(
370                (&fa * &fb).to_string(),
371                rug::Integer::from(&ra * &rb).to_string(),
372                "mul {a}*{b}"
373            );
374            if b != 0 {
375                let rug_div = ra.clone().div_trunc(rb.clone());
376                assert_eq!((&fa / &fb).to_string(), rug_div.to_string(), "div {a}/{b}");
377            }
378        }
379    }
380
381    #[test]
382    fn large_integer_vs_rug() {
383        use rug::ops::Pow;
384        // 2^100 — larger than i64, exercises GMP allocation path in fmpz
385        let two = FlintInteger::from_i64(2);
386        let big = two.pow(100);
387        let rug_big = rug::Integer::from(2i64).pow(100u32);
388        assert_eq!(big.to_string(), rug_big.to_string());
389    }
390}