contest_algorithms/math/
num.rs

1//! Rational and Complex numbers, safe modular arithmetic, and linear algebra,
2//! implemented minimally for contest use.
3//! If you need more features, you might be interested in crates.io/crates/num
4pub use std::f64::consts::PI;
5use std::ops::{Add, Div, Index, IndexMut, Mul, Neg, Sub};
6
7/// Fast iterative version of Euclid's GCD algorithm
8pub fn fast_gcd(mut a: i64, mut b: i64) -> i64 {
9    while b != 0 {
10        a %= b;
11        std::mem::swap(&mut a, &mut b);
12    }
13    a.abs()
14}
15
16/// Represents a fraction reduced to lowest terms
17#[derive(Clone, Copy, Eq, PartialEq, Debug, Hash)]
18pub struct Rational {
19    pub num: i64,
20    pub den: i64,
21}
22impl Rational {
23    pub fn new(num: i64, den: i64) -> Self {
24        let g = fast_gcd(num, den) * den.signum();
25        Self {
26            num: num / g,
27            den: den / g,
28        }
29    }
30    pub fn abs(self) -> Self {
31        Self {
32            num: self.num.abs(),
33            den: self.den,
34        }
35    }
36    pub fn recip(self) -> Self {
37        let g = self.num.signum();
38        Self {
39            num: self.den / g,
40            den: self.num / g,
41        }
42    }
43}
44impl From<i64> for Rational {
45    fn from(num: i64) -> Self {
46        Self { num, den: 1 }
47    }
48}
49impl Neg for Rational {
50    type Output = Self;
51    fn neg(self) -> Self {
52        Self {
53            num: -self.num,
54            den: self.den,
55        }
56    }
57}
58#[allow(clippy::suspicious_arithmetic_impl)]
59impl Add for Rational {
60    type Output = Self;
61    fn add(self, other: Self) -> Self {
62        Self::new(
63            self.num * other.den + self.den * other.num,
64            self.den * other.den,
65        )
66    }
67}
68#[allow(clippy::suspicious_arithmetic_impl)]
69impl Sub for Rational {
70    type Output = Self;
71    fn sub(self, other: Self) -> Self {
72        Self::new(
73            self.num * other.den - self.den * other.num,
74            self.den * other.den,
75        )
76    }
77}
78impl Mul for Rational {
79    type Output = Self;
80    fn mul(self, other: Self) -> Self {
81        Self::new(self.num * other.num, self.den * other.den)
82    }
83}
84#[allow(clippy::suspicious_arithmetic_impl)]
85impl Div for Rational {
86    type Output = Self;
87    fn div(self, other: Self) -> Self {
88        self * other.recip()
89    }
90}
91impl Ord for Rational {
92    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
93        (self.num * other.den).cmp(&(self.den * other.num))
94    }
95}
96impl PartialOrd for Rational {
97    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
98        Some(self.cmp(other))
99    }
100}
101
102/// Represents a complex number using floating-point arithmetic
103#[derive(Clone, Copy, PartialEq, Debug)]
104pub struct Complex {
105    pub real: f64,
106    pub imag: f64,
107}
108impl Complex {
109    pub fn new(real: f64, imag: f64) -> Self {
110        Self { real, imag }
111    }
112    pub fn from_polar(r: f64, th: f64) -> Self {
113        Self::new(r * th.cos(), r * th.sin())
114    }
115    pub fn abs_square(self) -> f64 {
116        self.real * self.real + self.imag * self.imag
117    }
118    pub fn argument(self) -> f64 {
119        self.imag.atan2(self.real)
120    }
121    pub fn conjugate(self) -> Self {
122        Self::new(self.real, -self.imag)
123    }
124    pub fn recip(self) -> Self {
125        let denom = self.abs_square();
126        Self::new(self.real / denom, -self.imag / denom)
127    }
128}
129impl From<f64> for Complex {
130    fn from(real: f64) -> Self {
131        Self::new(real, 0.0)
132    }
133}
134impl Neg for Complex {
135    type Output = Self;
136    fn neg(self) -> Self {
137        Self::new(-self.real, -self.imag)
138    }
139}
140impl Add for Complex {
141    type Output = Self;
142    fn add(self, other: Self) -> Self {
143        Self::new(self.real + other.real, self.imag + other.imag)
144    }
145}
146impl Sub for Complex {
147    type Output = Self;
148    fn sub(self, other: Self) -> Self {
149        Self::new(self.real - other.real, self.imag - other.imag)
150    }
151}
152impl Mul for Complex {
153    type Output = Self;
154    fn mul(self, other: Self) -> Self {
155        let real = self.real * other.real - self.imag * other.imag;
156        let imag = self.imag * other.real + self.real * other.imag;
157        Self::new(real, imag)
158    }
159}
160#[allow(clippy::suspicious_arithmetic_impl)]
161impl Div for Complex {
162    type Output = Self;
163    fn div(self, other: Self) -> Self {
164        self * other.recip()
165    }
166}
167
168/// Represents an element of the finite (Galois) field of prime order M, where
169/// 1 <= M < 2^31.5. If M is not prime, ring operations are still valid
170/// but recip() and division are not. Note that the latter operations are also
171/// the slowest, so precompute any inverses that you intend to use frequently.
172#[derive(Clone, Copy, Eq, PartialEq, Debug, Hash)]
173pub struct Modulo<const M: i64> {
174    pub val: i64,
175}
176impl<const M: i64> Modulo<M> {
177    /// Computes self^exp in O(log n) time
178    pub fn pow(mut self, mut n: u64) -> Self {
179        let mut result = Self::from_small(1);
180        while n > 0 {
181            if n % 2 == 1 {
182                result = result * self;
183            }
184            self = self * self;
185            n /= 2;
186        }
187        result
188    }
189    /// Computes inverses of 1 to n in O(n) time
190    pub fn vec_of_recips(n: i64) -> Vec<Self> {
191        let mut recips = vec![Self::from(0), Self::from(1)];
192        for i in 2..=n {
193            let (md, dv) = (M % i, M / i);
194            recips.push(recips[md as usize] * Self::from_small(-dv));
195        }
196        recips
197    }
198    /// Computes self^-1 in O(log M) time
199    pub fn recip(self) -> Self {
200        self.pow(M as u64 - 2)
201    }
202    /// Avoids the % operation but requires -M <= x < M
203    fn from_small(s: i64) -> Self {
204        let val = if s < 0 { s + M } else { s };
205        Self { val }
206    }
207}
208impl<const M: i64> From<i64> for Modulo<M> {
209    fn from(val: i64) -> Self {
210        // Self { val: val.rem_euclid(M) }
211        Self::from_small(val % M)
212    }
213}
214impl<const M: i64> Neg for Modulo<M> {
215    type Output = Self;
216    fn neg(self) -> Self {
217        Self::from_small(-self.val)
218    }
219}
220impl<const M: i64> Add for Modulo<M> {
221    type Output = Self;
222    fn add(self, other: Self) -> Self {
223        Self::from_small(self.val + other.val - M)
224    }
225}
226impl<const M: i64> Sub for Modulo<M> {
227    type Output = Self;
228    fn sub(self, other: Self) -> Self {
229        Self::from_small(self.val - other.val)
230    }
231}
232impl<const M: i64> Mul for Modulo<M> {
233    type Output = Self;
234    fn mul(self, other: Self) -> Self {
235        Self::from(self.val * other.val)
236    }
237}
238#[allow(clippy::suspicious_arithmetic_impl)]
239impl<const M: i64> Div for Modulo<M> {
240    type Output = Self;
241    fn div(self, other: Self) -> Self {
242        self * other.recip()
243    }
244}
245
246/// Prime modulus that's commonly used in programming competitions
247pub const COMMON_PRIME: i64 = 998_244_353; // 2^23 * 7 * 17 + 1;
248pub type CommonField = Modulo<COMMON_PRIME>;
249
250#[derive(Clone, PartialEq, Debug)]
251pub struct Matrix {
252    cols: usize,
253    inner: Box<[f64]>,
254}
255impl Matrix {
256    pub fn zero(rows: usize, cols: usize) -> Self {
257        let inner = vec![0.0; rows * cols].into_boxed_slice();
258        Self { cols, inner }
259    }
260    pub fn one(cols: usize) -> Self {
261        let mut matrix = Self::zero(cols, cols);
262        for i in 0..cols {
263            matrix[i][i] = 1.0;
264        }
265        matrix
266    }
267    pub fn vector(vec: &[f64], as_row: bool) -> Self {
268        let cols = if as_row { vec.len() } else { 1 };
269        let inner = vec.to_vec().into_boxed_slice();
270        Self { cols, inner }
271    }
272    pub fn pow(&self, mut exp: u64) -> Self {
273        let mut base = self.clone();
274        let mut result = Self::one(self.cols);
275        while exp > 0 {
276            if exp % 2 == 1 {
277                result = &result * &base;
278            }
279            base = &base * &base;
280            exp /= 2;
281        }
282        result
283    }
284    pub fn rows(&self) -> usize {
285        self.inner.len() / self.cols
286    }
287    pub fn transpose(&self) -> Self {
288        let mut matrix = Matrix::zero(self.cols, self.rows());
289        for i in 0..self.rows() {
290            for j in 0..self.cols {
291                matrix[j][i] = self[i][j];
292            }
293        }
294        matrix
295    }
296    pub fn recip(&self) -> Self {
297        unimplemented!();
298    }
299}
300impl Index<usize> for Matrix {
301    type Output = [f64];
302    fn index(&self, row: usize) -> &Self::Output {
303        let start = self.cols * row;
304        &self.inner[start..start + self.cols]
305    }
306}
307impl IndexMut<usize> for Matrix {
308    fn index_mut(&mut self, row: usize) -> &mut Self::Output {
309        let start = self.cols * row;
310        &mut self.inner[start..start + self.cols]
311    }
312}
313impl Neg for &Matrix {
314    type Output = Matrix;
315    fn neg(self) -> Matrix {
316        let inner = self.inner.iter().map(|&v| -v).collect();
317        Matrix {
318            cols: self.cols,
319            inner,
320        }
321    }
322}
323impl Add for &Matrix {
324    type Output = Matrix;
325    fn add(self, other: Self) -> Matrix {
326        let self_iter = self.inner.iter();
327        let inner = self_iter
328            .zip(other.inner.iter())
329            .map(|(&u, &v)| u + v)
330            .collect();
331        Matrix {
332            cols: self.cols,
333            inner,
334        }
335    }
336}
337impl Sub for &Matrix {
338    type Output = Matrix;
339    fn sub(self, other: Self) -> Matrix {
340        let self_iter = self.inner.iter();
341        let inner = self_iter
342            .zip(other.inner.iter())
343            .map(|(&u, &v)| u - v)
344            .collect();
345        Matrix {
346            cols: self.cols,
347            inner,
348        }
349    }
350}
351impl Mul<f64> for &Matrix {
352    type Output = Matrix;
353    fn mul(self, scalar: f64) -> Matrix {
354        let inner = self.inner.iter().map(|&v| v * scalar).collect();
355        Matrix {
356            cols: self.cols,
357            inner,
358        }
359    }
360}
361impl Mul for &Matrix {
362    type Output = Matrix;
363    fn mul(self, other: Self) -> Matrix {
364        assert_eq!(self.cols, other.rows());
365        let mut matrix = Matrix::zero(self.rows(), other.cols);
366        for i in 0..self.rows() {
367            for k in 0..self.cols {
368                for j in 0..other.cols {
369                    matrix[i][j] += self[i][k] * other[k][j];
370                }
371            }
372        }
373        matrix
374    }
375}
376
377#[cfg(test)]
378mod test {
379    use super::*;
380
381    #[test]
382    fn test_rational() {
383        let three = Rational::from(3);
384        let six = Rational::from(6);
385        let three_and_half = three + three / six;
386
387        assert_eq!(three_and_half.num, 7);
388        assert_eq!(three_and_half.den, 2);
389        assert_eq!(three_and_half, Rational::new(-35, -10));
390        assert!(three_and_half > Rational::from(3));
391        assert!(three_and_half < Rational::from(4));
392
393        let minus_three_and_half = six - three_and_half + three / (-three / six);
394        let zero = three_and_half + minus_three_and_half;
395
396        assert_eq!(minus_three_and_half.num, -7);
397        assert_eq!(minus_three_and_half.den, 2);
398        assert_eq!(three_and_half, -minus_three_and_half);
399        assert_eq!(zero.num, 0);
400        assert_eq!(zero.den, 1);
401    }
402
403    #[test]
404    fn test_complex() {
405        let four = Complex::new(4.0, 0.0);
406        let two_i = Complex::new(0.0, 2.0);
407
408        assert_eq!(four / two_i, -two_i);
409        assert_eq!(two_i * -two_i, four);
410        assert_eq!(two_i - two_i, Complex::from(0.0));
411        assert_eq!(four.abs_square(), 16.0);
412        assert_eq!(two_i.abs_square(), 4.0);
413        assert_eq!((-four).argument(), -PI);
414        assert_eq!((-two_i).argument(), -PI / 2.0);
415        assert_eq!(four.argument(), 0.0);
416        assert_eq!(two_i.argument(), PI / 2.0);
417    }
418
419    #[test]
420    fn test_field() {
421        let base = CommonField::from(1234);
422        let zero = base - base;
423        let one = base.recip() * base;
424        let two = CommonField::from(2 - 5 * COMMON_PRIME);
425
426        assert_eq!(zero.val, 0);
427        assert_eq!(one.val, 1);
428        assert_eq!(one + one, two);
429        assert_eq!(one / base * (base * base) - base / one, zero);
430    }
431
432    #[test]
433    fn test_vec_of_recips() {
434        let recips = CommonField::vec_of_recips(20);
435
436        assert_eq!(recips.len(), 21);
437        for i in 1..recips.len() {
438            assert_eq!(recips[i], CommonField::from(i as i64).recip());
439        }
440    }
441
442    #[test]
443    fn test_linalg() {
444        let zero = Matrix::zero(2, 2);
445        let one = Matrix::one(2);
446        let rotate_90 = Matrix {
447            cols: 2,
448            inner: Box::new([0.0, -1.0, 1.0, 0.0]),
449        };
450        let x_vec = Matrix::vector(&[1.0, 0.0], false);
451        let y_vec = Matrix::vector(&[0.0, 1.0], false);
452        let x_dot_x = &x_vec.transpose() * &x_vec;
453        let x_dot_y = &x_vec.transpose() * &y_vec;
454
455        assert_eq!(x_dot_x, Matrix::one(1));
456        assert_eq!(x_dot_x[0][0], 1.0);
457        assert_eq!(x_dot_y, Matrix::zero(1, 1));
458        assert_eq!(x_dot_y[0][0], 0.0);
459        assert_eq!(&one - &one, zero);
460        assert_eq!(&one * 0.0, zero);
461        assert_eq!(&rotate_90 * &rotate_90, -&one);
462        assert_eq!(&rotate_90 * &x_vec, y_vec);
463        assert_eq!(&rotate_90 * &y_vec, -&x_vec);
464        assert_eq!(&rotate_90 * &(&x_vec + &y_vec), &y_vec - &x_vec);
465    }
466}