atlas_embeddings/arithmetic/
matrix.rs1use super::Rational;
7use num_traits::{One, Zero};
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct RationalVector<const N: usize> {
17 coords: [Rational; N],
19}
20
21impl<const N: usize> RationalVector<N> {
22 #[must_use]
24 pub const fn new(coords: [Rational; N]) -> Self {
25 Self { coords }
26 }
27
28 #[must_use]
30 pub fn zero() -> Self {
31 Self { coords: [Rational::zero(); N] }
32 }
33
34 #[must_use]
36 pub const fn get(&self, i: usize) -> Rational {
37 self.coords[i]
38 }
39
40 #[must_use]
42 pub const fn coords(&self) -> &[Rational; N] {
43 &self.coords
44 }
45
46 #[must_use]
48 pub fn dot(&self, other: &Self) -> Rational {
49 let mut sum = Rational::zero();
50 for i in 0..N {
51 sum += self.coords[i] * other.coords[i];
52 }
53 sum
54 }
55
56 #[must_use]
58 pub fn norm_squared(&self) -> Rational {
59 self.dot(self)
60 }
61
62 #[must_use]
64 pub fn sub(&self, other: &Self) -> Self {
65 let mut result = [Rational::zero(); N];
66 for (i, item) in result.iter_mut().enumerate().take(N) {
67 *item = self.coords[i] - other.coords[i];
68 }
69 Self { coords: result }
70 }
71
72 #[must_use]
74 pub fn scale(&self, scalar: Rational) -> Self {
75 let mut result = [Rational::zero(); N];
76 for (i, item) in result.iter_mut().enumerate().take(N) {
77 *item = self.coords[i] * scalar;
78 }
79 Self { coords: result }
80 }
81}
82
83impl<const N: usize> Hash for RationalVector<N> {
84 fn hash<H: Hasher>(&self, state: &mut H) {
85 for coord in &self.coords {
86 coord.numer().hash(state);
87 coord.denom().hash(state);
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct RationalMatrix<const N: usize> {
100 data: [[Rational; N]; N],
102}
103
104impl<const N: usize> RationalMatrix<N> {
105 #[must_use]
107 pub const fn new(data: [[Rational; N]; N]) -> Self {
108 Self { data }
109 }
110
111 #[must_use]
123 pub fn identity() -> Self {
124 let mut data = [[Rational::zero(); N]; N];
125 for (i, row) in data.iter_mut().enumerate().take(N) {
126 row[i] = Rational::one();
127 }
128 Self { data }
129 }
130
131 #[must_use]
144 pub fn reflection(root: &RationalVector<N>) -> Self {
145 let root_norm_sq = root.norm_squared();
146 assert!(!root_norm_sq.is_zero(), "Cannot create reflection from zero root");
147
148 let mut data = [[Rational::zero(); N]; N];
149
150 for (i, row) in data.iter_mut().enumerate().take(N) {
152 #[allow(clippy::needless_range_loop)]
153 for j in 0..N {
154 let delta = if i == j {
156 Rational::one()
157 } else {
158 Rational::zero()
159 };
160
161 let outer_product = root.get(i) * root.get(j);
163
164 row[j] = delta - Rational::new(2, 1) * outer_product / root_norm_sq;
166 }
167 }
168
169 Self { data }
170 }
171
172 #[must_use]
174 pub const fn get(&self, i: usize, j: usize) -> Rational {
175 self.data[i][j]
176 }
177
178 #[must_use]
180 pub const fn get_ref(&self, i: usize, j: usize) -> &Rational {
181 &self.data[i][j]
182 }
183
184 #[must_use]
186 pub const fn data(&self) -> &[[Rational; N]; N] {
187 &self.data
188 }
189
190 #[must_use]
195 pub fn multiply(&self, other: &Self) -> Self {
196 let mut result = [[Rational::zero(); N]; N];
197
198 for (i, row) in result.iter_mut().enumerate().take(N) {
199 #[allow(clippy::needless_range_loop)]
200 for j in 0..N {
201 let mut sum = Rational::zero();
202 for k in 0..N {
203 sum += self.data[i][k] * other.data[k][j];
204 }
205 row[j] = sum;
206 }
207 }
208
209 Self { data: result }
210 }
211
212 #[must_use]
214 pub fn trace(&self) -> Rational {
215 let mut sum = Rational::zero();
216 for i in 0..N {
217 sum += self.data[i][i];
218 }
219 sum
220 }
221
222 #[must_use]
224 pub fn is_identity(&self) -> bool {
225 for i in 0..N {
226 for j in 0..N {
227 let expected = if i == j {
228 Rational::one()
229 } else {
230 Rational::zero()
231 };
232 if self.data[i][j] != expected {
233 return false;
234 }
235 }
236 }
237 true
238 }
239}
240
241impl<const N: usize> Hash for RationalMatrix<N> {
242 fn hash<H: Hasher>(&self, state: &mut H) {
243 for row in &self.data {
245 for entry in row {
246 entry.numer().hash(state);
247 entry.denom().hash(state);
248 }
249 }
250 }
251}
252
253impl<const N: usize> fmt::Display for RationalMatrix<N> {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 writeln!(f, "[")?;
256 for row in &self.data {
257 write!(f, " [")?;
258 for (j, entry) in row.iter().enumerate() {
259 if j > 0 {
260 write!(f, ", ")?;
261 }
262 write!(f, "{}/{}", entry.numer(), entry.denom())?;
263 }
264 writeln!(f, "]")?;
265 }
266 write!(f, "]")
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_identity_matrix() {
276 let id = RationalMatrix::<3>::identity();
277 assert!(id.is_identity());
278 assert_eq!(id.trace(), Rational::new(3, 1));
279 }
280
281 #[test]
282 fn test_matrix_multiply_identity() {
283 let id = RationalMatrix::<2>::identity();
284 let a = RationalMatrix::new([
285 [Rational::new(1, 2), Rational::new(3, 4)],
286 [Rational::new(5, 6), Rational::new(7, 8)],
287 ]);
288
289 let result = a.multiply(&id);
290 assert_eq!(result, a);
291
292 let result2 = id.multiply(&a);
293 assert_eq!(result2, a);
294 }
295
296 #[test]
297 fn test_matrix_multiply_exact() {
298 let a = RationalMatrix::new([
300 [Rational::new(1, 2), Rational::new(1, 3)],
301 [Rational::new(1, 4), Rational::new(1, 5)],
302 ]);
303
304 let b = RationalMatrix::new([
305 [Rational::new(2, 1), Rational::new(0, 1)],
306 [Rational::new(0, 1), Rational::new(3, 1)],
307 ]);
308
309 let result = a.multiply(&b);
310
311 assert_eq!(result.get(0, 0), Rational::new(1, 1));
314 assert_eq!(result.get(0, 1), Rational::new(1, 1));
315 assert_eq!(result.get(1, 0), Rational::new(1, 2));
316 assert_eq!(result.get(1, 1), Rational::new(3, 5));
317 }
318
319 #[test]
320 fn test_matrix_equality() {
321 let a = RationalMatrix::<2>::identity();
322 let b = RationalMatrix::<2>::identity();
323 assert_eq!(a, b);
324
325 let c = RationalMatrix::new([
326 [Rational::new(1, 1), Rational::new(1, 1)],
327 [Rational::new(0, 1), Rational::new(1, 1)],
328 ]);
329 assert_ne!(a, c);
330 }
331
332 #[test]
333 fn test_matrix_trace() {
334 let m = RationalMatrix::new([
335 [Rational::new(1, 2), Rational::new(3, 4)],
336 [Rational::new(5, 6), Rational::new(7, 8)],
337 ]);
338
339 assert_eq!(m.trace(), Rational::new(11, 8));
341 }
342}