1pub use std::f64::consts::PI;
5use std::ops::{Add, Div, Index, IndexMut, Mul, Neg, Sub};
6
7pub fn fast_gcd(mut a: i64, mut b: i64) -> i64 {
9 while b != 0 {
10 a %= b;
11 std::mem::swap(&mut a, &mut b);
12 }
13 a.abs()
14}
15
16#[derive(Clone, Copy, Eq, PartialEq, Debug, Hash)]
18pub struct Rational {
19 pub num: i64,
20 pub den: i64,
21}
22impl Rational {
23 pub fn new(num: i64, den: i64) -> Self {
24 let g = fast_gcd(num, den) * den.signum();
25 Self {
26 num: num / g,
27 den: den / g,
28 }
29 }
30 pub fn abs(self) -> Self {
31 Self {
32 num: self.num.abs(),
33 den: self.den,
34 }
35 }
36 pub fn recip(self) -> Self {
37 let g = self.num.signum();
38 Self {
39 num: self.den / g,
40 den: self.num / g,
41 }
42 }
43}
44impl From<i64> for Rational {
45 fn from(num: i64) -> Self {
46 Self { num, den: 1 }
47 }
48}
49impl Neg for Rational {
50 type Output = Self;
51 fn neg(self) -> Self {
52 Self {
53 num: -self.num,
54 den: self.den,
55 }
56 }
57}
58#[allow(clippy::suspicious_arithmetic_impl)]
59impl Add for Rational {
60 type Output = Self;
61 fn add(self, other: Self) -> Self {
62 Self::new(
63 self.num * other.den + self.den * other.num,
64 self.den * other.den,
65 )
66 }
67}
68#[allow(clippy::suspicious_arithmetic_impl)]
69impl Sub for Rational {
70 type Output = Self;
71 fn sub(self, other: Self) -> Self {
72 Self::new(
73 self.num * other.den - self.den * other.num,
74 self.den * other.den,
75 )
76 }
77}
78impl Mul for Rational {
79 type Output = Self;
80 fn mul(self, other: Self) -> Self {
81 Self::new(self.num * other.num, self.den * other.den)
82 }
83}
84#[allow(clippy::suspicious_arithmetic_impl)]
85impl Div for Rational {
86 type Output = Self;
87 fn div(self, other: Self) -> Self {
88 self * other.recip()
89 }
90}
91impl Ord for Rational {
92 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
93 (self.num * other.den).cmp(&(self.den * other.num))
94 }
95}
96impl PartialOrd for Rational {
97 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
98 Some(self.cmp(other))
99 }
100}
101
102#[derive(Clone, Copy, PartialEq, Debug)]
104pub struct Complex {
105 pub real: f64,
106 pub imag: f64,
107}
108impl Complex {
109 pub fn new(real: f64, imag: f64) -> Self {
110 Self { real, imag }
111 }
112 pub fn from_polar(r: f64, th: f64) -> Self {
113 Self::new(r * th.cos(), r * th.sin())
114 }
115 pub fn abs_square(self) -> f64 {
116 self.real * self.real + self.imag * self.imag
117 }
118 pub fn argument(self) -> f64 {
119 self.imag.atan2(self.real)
120 }
121 pub fn conjugate(self) -> Self {
122 Self::new(self.real, -self.imag)
123 }
124 pub fn recip(self) -> Self {
125 let denom = self.abs_square();
126 Self::new(self.real / denom, -self.imag / denom)
127 }
128}
129impl From<f64> for Complex {
130 fn from(real: f64) -> Self {
131 Self::new(real, 0.0)
132 }
133}
134impl Neg for Complex {
135 type Output = Self;
136 fn neg(self) -> Self {
137 Self::new(-self.real, -self.imag)
138 }
139}
140impl Add for Complex {
141 type Output = Self;
142 fn add(self, other: Self) -> Self {
143 Self::new(self.real + other.real, self.imag + other.imag)
144 }
145}
146impl Sub for Complex {
147 type Output = Self;
148 fn sub(self, other: Self) -> Self {
149 Self::new(self.real - other.real, self.imag - other.imag)
150 }
151}
152impl Mul for Complex {
153 type Output = Self;
154 fn mul(self, other: Self) -> Self {
155 let real = self.real * other.real - self.imag * other.imag;
156 let imag = self.imag * other.real + self.real * other.imag;
157 Self::new(real, imag)
158 }
159}
160#[allow(clippy::suspicious_arithmetic_impl)]
161impl Div for Complex {
162 type Output = Self;
163 fn div(self, other: Self) -> Self {
164 self * other.recip()
165 }
166}
167
168#[derive(Clone, Copy, Eq, PartialEq, Debug, Hash)]
173pub struct Modulo<const M: i64> {
174 pub val: i64,
175}
176impl<const M: i64> Modulo<M> {
177 pub fn pow(mut self, mut n: u64) -> Self {
179 let mut result = Self::from_small(1);
180 while n > 0 {
181 if n % 2 == 1 {
182 result = result * self;
183 }
184 self = self * self;
185 n /= 2;
186 }
187 result
188 }
189 pub fn vec_of_recips(n: i64) -> Vec<Self> {
191 let mut recips = vec![Self::from(0), Self::from(1)];
192 for i in 2..=n {
193 let (md, dv) = (M % i, M / i);
194 recips.push(recips[md as usize] * Self::from_small(-dv));
195 }
196 recips
197 }
198 pub fn recip(self) -> Self {
200 self.pow(M as u64 - 2)
201 }
202 fn from_small(s: i64) -> Self {
204 let val = if s < 0 { s + M } else { s };
205 Self { val }
206 }
207}
208impl<const M: i64> From<i64> for Modulo<M> {
209 fn from(val: i64) -> Self {
210 Self::from_small(val % M)
212 }
213}
214impl<const M: i64> Neg for Modulo<M> {
215 type Output = Self;
216 fn neg(self) -> Self {
217 Self::from_small(-self.val)
218 }
219}
220impl<const M: i64> Add for Modulo<M> {
221 type Output = Self;
222 fn add(self, other: Self) -> Self {
223 Self::from_small(self.val + other.val - M)
224 }
225}
226impl<const M: i64> Sub for Modulo<M> {
227 type Output = Self;
228 fn sub(self, other: Self) -> Self {
229 Self::from_small(self.val - other.val)
230 }
231}
232impl<const M: i64> Mul for Modulo<M> {
233 type Output = Self;
234 fn mul(self, other: Self) -> Self {
235 Self::from(self.val * other.val)
236 }
237}
238#[allow(clippy::suspicious_arithmetic_impl)]
239impl<const M: i64> Div for Modulo<M> {
240 type Output = Self;
241 fn div(self, other: Self) -> Self {
242 self * other.recip()
243 }
244}
245
246pub const COMMON_PRIME: i64 = 998_244_353; pub type CommonField = Modulo<COMMON_PRIME>;
249
250#[derive(Clone, PartialEq, Debug)]
251pub struct Matrix {
252 cols: usize,
253 inner: Box<[f64]>,
254}
255impl Matrix {
256 pub fn zero(rows: usize, cols: usize) -> Self {
257 let inner = vec![0.0; rows * cols].into_boxed_slice();
258 Self { cols, inner }
259 }
260 pub fn one(cols: usize) -> Self {
261 let mut matrix = Self::zero(cols, cols);
262 for i in 0..cols {
263 matrix[i][i] = 1.0;
264 }
265 matrix
266 }
267 pub fn vector(vec: &[f64], as_row: bool) -> Self {
268 let cols = if as_row { vec.len() } else { 1 };
269 let inner = vec.to_vec().into_boxed_slice();
270 Self { cols, inner }
271 }
272 pub fn pow(&self, mut exp: u64) -> Self {
273 let mut base = self.clone();
274 let mut result = Self::one(self.cols);
275 while exp > 0 {
276 if exp % 2 == 1 {
277 result = &result * &base;
278 }
279 base = &base * &base;
280 exp /= 2;
281 }
282 result
283 }
284 pub fn rows(&self) -> usize {
285 self.inner.len() / self.cols
286 }
287 pub fn transpose(&self) -> Self {
288 let mut matrix = Matrix::zero(self.cols, self.rows());
289 for i in 0..self.rows() {
290 for j in 0..self.cols {
291 matrix[j][i] = self[i][j];
292 }
293 }
294 matrix
295 }
296 pub fn recip(&self) -> Self {
297 unimplemented!();
298 }
299}
300impl Index<usize> for Matrix {
301 type Output = [f64];
302 fn index(&self, row: usize) -> &Self::Output {
303 let start = self.cols * row;
304 &self.inner[start..start + self.cols]
305 }
306}
307impl IndexMut<usize> for Matrix {
308 fn index_mut(&mut self, row: usize) -> &mut Self::Output {
309 let start = self.cols * row;
310 &mut self.inner[start..start + self.cols]
311 }
312}
313impl Neg for &Matrix {
314 type Output = Matrix;
315 fn neg(self) -> Matrix {
316 let inner = self.inner.iter().map(|&v| -v).collect();
317 Matrix {
318 cols: self.cols,
319 inner,
320 }
321 }
322}
323impl Add for &Matrix {
324 type Output = Matrix;
325 fn add(self, other: Self) -> Matrix {
326 let self_iter = self.inner.iter();
327 let inner = self_iter
328 .zip(other.inner.iter())
329 .map(|(&u, &v)| u + v)
330 .collect();
331 Matrix {
332 cols: self.cols,
333 inner,
334 }
335 }
336}
337impl Sub for &Matrix {
338 type Output = Matrix;
339 fn sub(self, other: Self) -> Matrix {
340 let self_iter = self.inner.iter();
341 let inner = self_iter
342 .zip(other.inner.iter())
343 .map(|(&u, &v)| u - v)
344 .collect();
345 Matrix {
346 cols: self.cols,
347 inner,
348 }
349 }
350}
351impl Mul<f64> for &Matrix {
352 type Output = Matrix;
353 fn mul(self, scalar: f64) -> Matrix {
354 let inner = self.inner.iter().map(|&v| v * scalar).collect();
355 Matrix {
356 cols: self.cols,
357 inner,
358 }
359 }
360}
361impl Mul for &Matrix {
362 type Output = Matrix;
363 fn mul(self, other: Self) -> Matrix {
364 assert_eq!(self.cols, other.rows());
365 let mut matrix = Matrix::zero(self.rows(), other.cols);
366 for i in 0..self.rows() {
367 for k in 0..self.cols {
368 for j in 0..other.cols {
369 matrix[i][j] += self[i][k] * other[k][j];
370 }
371 }
372 }
373 matrix
374 }
375}
376
377#[cfg(test)]
378mod test {
379 use super::*;
380
381 #[test]
382 fn test_rational() {
383 let three = Rational::from(3);
384 let six = Rational::from(6);
385 let three_and_half = three + three / six;
386
387 assert_eq!(three_and_half.num, 7);
388 assert_eq!(three_and_half.den, 2);
389 assert_eq!(three_and_half, Rational::new(-35, -10));
390 assert!(three_and_half > Rational::from(3));
391 assert!(three_and_half < Rational::from(4));
392
393 let minus_three_and_half = six - three_and_half + three / (-three / six);
394 let zero = three_and_half + minus_three_and_half;
395
396 assert_eq!(minus_three_and_half.num, -7);
397 assert_eq!(minus_three_and_half.den, 2);
398 assert_eq!(three_and_half, -minus_three_and_half);
399 assert_eq!(zero.num, 0);
400 assert_eq!(zero.den, 1);
401 }
402
403 #[test]
404 fn test_complex() {
405 let four = Complex::new(4.0, 0.0);
406 let two_i = Complex::new(0.0, 2.0);
407
408 assert_eq!(four / two_i, -two_i);
409 assert_eq!(two_i * -two_i, four);
410 assert_eq!(two_i - two_i, Complex::from(0.0));
411 assert_eq!(four.abs_square(), 16.0);
412 assert_eq!(two_i.abs_square(), 4.0);
413 assert_eq!((-four).argument(), -PI);
414 assert_eq!((-two_i).argument(), -PI / 2.0);
415 assert_eq!(four.argument(), 0.0);
416 assert_eq!(two_i.argument(), PI / 2.0);
417 }
418
419 #[test]
420 fn test_field() {
421 let base = CommonField::from(1234);
422 let zero = base - base;
423 let one = base.recip() * base;
424 let two = CommonField::from(2 - 5 * COMMON_PRIME);
425
426 assert_eq!(zero.val, 0);
427 assert_eq!(one.val, 1);
428 assert_eq!(one + one, two);
429 assert_eq!(one / base * (base * base) - base / one, zero);
430 }
431
432 #[test]
433 fn test_vec_of_recips() {
434 let recips = CommonField::vec_of_recips(20);
435
436 assert_eq!(recips.len(), 21);
437 for i in 1..recips.len() {
438 assert_eq!(recips[i], CommonField::from(i as i64).recip());
439 }
440 }
441
442 #[test]
443 fn test_linalg() {
444 let zero = Matrix::zero(2, 2);
445 let one = Matrix::one(2);
446 let rotate_90 = Matrix {
447 cols: 2,
448 inner: Box::new([0.0, -1.0, 1.0, 0.0]),
449 };
450 let x_vec = Matrix::vector(&[1.0, 0.0], false);
451 let y_vec = Matrix::vector(&[0.0, 1.0], false);
452 let x_dot_x = &x_vec.transpose() * &x_vec;
453 let x_dot_y = &x_vec.transpose() * &y_vec;
454
455 assert_eq!(x_dot_x, Matrix::one(1));
456 assert_eq!(x_dot_x[0][0], 1.0);
457 assert_eq!(x_dot_y, Matrix::zero(1, 1));
458 assert_eq!(x_dot_y[0][0], 0.0);
459 assert_eq!(&one - &one, zero);
460 assert_eq!(&one * 0.0, zero);
461 assert_eq!(&rotate_90 * &rotate_90, -&one);
462 assert_eq!(&rotate_90 * &x_vec, y_vec);
463 assert_eq!(&rotate_90 * &y_vec, -&x_vec);
464 assert_eq!(&rotate_90 * &(&x_vec + &y_vec), &y_vec - &x_vec);
465 }
466}