1use crate::{Matrix, Vector};
2use core::ops::{Index, IndexMut, Neg};
3use num_traits::{Num, One, Zero, Inv};
4
5impl<T, const N: usize> Matrix<T, N, N>
6where
7 T: Zero,
8{
9 pub fn diagonal(diag: Vector<T, N>) -> Self {
11 let mut iter = diag.into_iter();
12 Matrix::indices().map(|(i, j)| {
13 if i == j {
14 iter.next().unwrap()
15 } else {
16 T::zero()
17 }
18 })
19 }
20}
21impl<T, const N: usize> One for Matrix<T, N, N>
22where
23 T: One + Zero,
24{
25 fn one() -> Self {
27 Matrix::indices().map(|(i, j)| if i == j { T::one() } else { T::zero() })
28 }
29}
30
31struct IndexMask<const N: usize> {
48 data: [bool; N],
49}
50
51impl<const N: usize> IndexMask<N> {
52 pub fn new() -> Self {
53 Self { data: [true; N] }
54 }
55 pub fn find(&self, mut i: usize) -> usize {
56 loop {
57 if self.data[i] {
58 break i;
59 }
60 i += 1;
61 }
62 }
63}
64
65impl<const N: usize> Index<usize> for IndexMask<N> {
66 type Output = bool;
67 fn index(&self, i: usize) -> &bool {
68 &self.data[i]
69 }
70}
71impl<const N: usize> IndexMut<usize> for IndexMask<N> {
72 fn index_mut(&mut self, i: usize) -> &mut bool {
73 &mut self.data[i]
74 }
75}
76
77struct SubmatrixMask<const N: usize> {
78 pub col: IndexMask<N>,
79 pub row: IndexMask<N>,
80 pub deg: usize,
81}
82
83impl<const N: usize> SubmatrixMask<N> {
84 fn new() -> Self {
85 Self {
86 col: IndexMask::new(),
87 row: IndexMask::new(),
88 deg: N,
89 }
90 }
91 fn exclude(&mut self, i: usize, j: usize) {
92 self.col[i] = false;
93 self.row[j] = false;
94 self.deg -= 1;
95 }
96 fn include(&mut self, i: usize, j: usize) {
97 self.col[i] = true;
98 self.row[j] = true;
99 self.deg += 1;
100 }
101}
102
103struct Determinator<'a, T, const N: usize> {
104 matrix: &'a Matrix<T, N, N>,
105 mask: SubmatrixMask<N>,
106}
107
108impl<'a, T, const N: usize> Determinator<'a, T, N>
109where
110 T: Neg<Output = T> + Num + Copy,
111{
112 fn new(matrix: &'a Matrix<T, N, N>) -> Self {
113 Self {
114 matrix,
115 mask: SubmatrixMask::new(),
116 }
117 }
118 fn cofactor(&mut self, (i, ri): (usize, usize), (j, rj): (usize, usize)) -> T {
119 self.mask.exclude(i, j);
120 let mut a = self.det();
121 if (ri + rj) % 2 != 0 {
122 a = -a;
123 }
124 self.mask.include(i, j);
125 a
126 }
127 fn det(&mut self) -> T {
128 if self.mask.deg == 0 {
129 T::one()
130 } else {
131 let i = self.mask.col.find(0);
132 let mut j = 0;
133 let mut a = T::zero();
134 for rj in 0..self.mask.deg {
135 j = self.mask.row.find(j);
136 a = a + self.matrix[(i, j)] * self.cofactor((i, 0), (j, rj));
137 j += 1;
138 }
139 a
140 }
141 }
142}
143
144impl<T, const N: usize> Matrix<T, N, N>
145where
146 T: Neg<Output = T> + Num + Copy,
147{
148 pub fn cofactor(&self, i: usize, j: usize) -> T {
150 assert!(i < N && j < N);
151 Determinator::new(self).cofactor((i, i), (j, j))
152 }
153
154 pub fn det(&self) -> T {
156 Determinator::new(self).det()
157 }
158
159 pub fn adj(&self) -> Self {
161 Matrix::indices().map(|(i, j)| self.cofactor(j, i))
162 }
163
164 pub fn inv(&self) -> Self {
166 self.adj() / self.det()
167 }
168}
169
170impl<T, const N: usize> Inv for Matrix<T, N, N>
171where
172 T: Neg<Output = T> + Num + Copy,
173{
174 type Output = Self;
175
176 fn inv(self) -> Self::Output {
177 self.adj() / self.det()
178 }
179}