atlas_embeddings/arithmetic/
matrix.rs

1//! Exact rational matrices and vectors for Weyl group elements
2//!
3//! This module provides matrix and vector representation for Weyl group elements,
4//! using exact rational arithmetic to preserve mathematical correctness.
5
6use super::Rational;
7use num_traits::{One, Zero};
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11/// N-dimensional vector with exact rational coordinates
12///
13/// Used for simple roots and Weyl group operations in rank-N space.
14/// All coordinates are rational numbers for exact arithmetic.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct RationalVector<const N: usize> {
17    /// Vector coordinates
18    coords: [Rational; N],
19}
20
21impl<const N: usize> RationalVector<N> {
22    /// Create vector from array of rationals
23    #[must_use]
24    pub const fn new(coords: [Rational; N]) -> Self {
25        Self { coords }
26    }
27
28    /// Create zero vector
29    #[must_use]
30    pub fn zero() -> Self {
31        Self { coords: [Rational::zero(); N] }
32    }
33
34    /// Get coordinate at index
35    #[must_use]
36    pub const fn get(&self, i: usize) -> Rational {
37        self.coords[i]
38    }
39
40    /// Get all coordinates
41    #[must_use]
42    pub const fn coords(&self) -> &[Rational; N] {
43        &self.coords
44    }
45
46    /// Inner product (exact rational arithmetic)
47    #[must_use]
48    pub fn dot(&self, other: &Self) -> Rational {
49        let mut sum = Rational::zero();
50        for i in 0..N {
51            sum += self.coords[i] * other.coords[i];
52        }
53        sum
54    }
55
56    /// Norm squared: ⟨v, v⟩
57    #[must_use]
58    pub fn norm_squared(&self) -> Rational {
59        self.dot(self)
60    }
61
62    /// Vector subtraction
63    #[must_use]
64    pub fn sub(&self, other: &Self) -> Self {
65        let mut result = [Rational::zero(); N];
66        for (i, item) in result.iter_mut().enumerate().take(N) {
67            *item = self.coords[i] - other.coords[i];
68        }
69        Self { coords: result }
70    }
71
72    /// Scalar multiplication
73    #[must_use]
74    pub fn scale(&self, scalar: Rational) -> Self {
75        let mut result = [Rational::zero(); N];
76        for (i, item) in result.iter_mut().enumerate().take(N) {
77            *item = self.coords[i] * scalar;
78        }
79        Self { coords: result }
80    }
81}
82
83impl<const N: usize> Hash for RationalVector<N> {
84    fn hash<H: Hasher>(&self, state: &mut H) {
85        for coord in &self.coords {
86            coord.numer().hash(state);
87            coord.denom().hash(state);
88        }
89    }
90}
91
92/// Matrix with exact rational entries
93///
94/// Used to represent Weyl group elements as matrices. All operations
95/// use exact rational arithmetic (no floating point).
96///
97/// From certified Python implementation: `ExactMatrix` class
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct RationalMatrix<const N: usize> {
100    /// Matrix data: N×N array of rational numbers
101    data: [[Rational; N]; N],
102}
103
104impl<const N: usize> RationalMatrix<N> {
105    /// Create matrix from 2D array
106    #[must_use]
107    pub const fn new(data: [[Rational; N]; N]) -> Self {
108        Self { data }
109    }
110
111    /// Create identity matrix
112    ///
113    /// # Examples
114    ///
115    /// ```
116    /// use atlas_embeddings::arithmetic::{RationalMatrix, Rational};
117    ///
118    /// let id = RationalMatrix::<2>::identity();
119    /// assert_eq!(id.get(0, 0), Rational::new(1, 1));
120    /// assert_eq!(id.get(0, 1), Rational::new(0, 1));
121    /// ```
122    #[must_use]
123    pub fn identity() -> Self {
124        let mut data = [[Rational::zero(); N]; N];
125        for (i, row) in data.iter_mut().enumerate().take(N) {
126            row[i] = Rational::one();
127        }
128        Self { data }
129    }
130
131    /// Create reflection matrix from root vector
132    ///
133    /// Implements: `R_α = I - 2(α ⊗ α)/⟨α,α⟩`
134    ///
135    /// This is the matrix representation of the reflection through the
136    /// hyperplane perpendicular to α. Uses exact rational arithmetic.
137    ///
138    /// From certified Python implementation: `simple_reflection()` method
139    ///
140    /// # Panics
141    ///
142    /// Panics if root has norm² = 0
143    #[must_use]
144    pub fn reflection(root: &RationalVector<N>) -> Self {
145        let root_norm_sq = root.norm_squared();
146        assert!(!root_norm_sq.is_zero(), "Cannot create reflection from zero root");
147
148        let mut data = [[Rational::zero(); N]; N];
149
150        // Compute I - 2(α ⊗ α)/⟨α,α⟩
151        for (i, row) in data.iter_mut().enumerate().take(N) {
152            #[allow(clippy::needless_range_loop)]
153            for j in 0..N {
154                // Identity matrix entry
155                let delta = if i == j {
156                    Rational::one()
157                } else {
158                    Rational::zero()
159                };
160
161                // Outer product entry: α_i * α_j
162                let outer_product = root.get(i) * root.get(j);
163
164                // Matrix entry: δ_ij - 2 * α_i * α_j / ⟨α,α⟩
165                row[j] = delta - Rational::new(2, 1) * outer_product / root_norm_sq;
166            }
167        }
168
169        Self { data }
170    }
171
172    /// Get entry at (i, j)
173    #[must_use]
174    pub const fn get(&self, i: usize, j: usize) -> Rational {
175        self.data[i][j]
176    }
177
178    /// Get reference to entry at (i, j)
179    #[must_use]
180    pub const fn get_ref(&self, i: usize, j: usize) -> &Rational {
181        &self.data[i][j]
182    }
183
184    /// Get all data as reference
185    #[must_use]
186    pub const fn data(&self) -> &[[Rational; N]; N] {
187        &self.data
188    }
189
190    /// Matrix multiplication (exact rational arithmetic)
191    ///
192    /// Computes C = A × B where all operations are exact.
193    /// This is the composition operation for Weyl group elements.
194    #[must_use]
195    pub fn multiply(&self, other: &Self) -> Self {
196        let mut result = [[Rational::zero(); N]; N];
197
198        for (i, row) in result.iter_mut().enumerate().take(N) {
199            #[allow(clippy::needless_range_loop)]
200            for j in 0..N {
201                let mut sum = Rational::zero();
202                for k in 0..N {
203                    sum += self.data[i][k] * other.data[k][j];
204                }
205                row[j] = sum;
206            }
207        }
208
209        Self { data: result }
210    }
211
212    /// Compute trace (sum of diagonal elements)
213    #[must_use]
214    pub fn trace(&self) -> Rational {
215        let mut sum = Rational::zero();
216        for i in 0..N {
217            sum += self.data[i][i];
218        }
219        sum
220    }
221
222    /// Check if this is the identity matrix
223    #[must_use]
224    pub fn is_identity(&self) -> bool {
225        for i in 0..N {
226            for j in 0..N {
227                let expected = if i == j {
228                    Rational::one()
229                } else {
230                    Rational::zero()
231                };
232                if self.data[i][j] != expected {
233                    return false;
234                }
235            }
236        }
237        true
238    }
239}
240
241impl<const N: usize> Hash for RationalMatrix<N> {
242    fn hash<H: Hasher>(&self, state: &mut H) {
243        // Hash each entry (numerator and denominator)
244        for row in &self.data {
245            for entry in row {
246                entry.numer().hash(state);
247                entry.denom().hash(state);
248            }
249        }
250    }
251}
252
253impl<const N: usize> fmt::Display for RationalMatrix<N> {
254    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255        writeln!(f, "[")?;
256        for row in &self.data {
257            write!(f, "  [")?;
258            for (j, entry) in row.iter().enumerate() {
259                if j > 0 {
260                    write!(f, ", ")?;
261                }
262                write!(f, "{}/{}", entry.numer(), entry.denom())?;
263            }
264            writeln!(f, "]")?;
265        }
266        write!(f, "]")
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_identity_matrix() {
276        let id = RationalMatrix::<3>::identity();
277        assert!(id.is_identity());
278        assert_eq!(id.trace(), Rational::new(3, 1));
279    }
280
281    #[test]
282    fn test_matrix_multiply_identity() {
283        let id = RationalMatrix::<2>::identity();
284        let a = RationalMatrix::new([
285            [Rational::new(1, 2), Rational::new(3, 4)],
286            [Rational::new(5, 6), Rational::new(7, 8)],
287        ]);
288
289        let result = a.multiply(&id);
290        assert_eq!(result, a);
291
292        let result2 = id.multiply(&a);
293        assert_eq!(result2, a);
294    }
295
296    #[test]
297    fn test_matrix_multiply_exact() {
298        // Simple 2x2 multiplication
299        let a = RationalMatrix::new([
300            [Rational::new(1, 2), Rational::new(1, 3)],
301            [Rational::new(1, 4), Rational::new(1, 5)],
302        ]);
303
304        let b = RationalMatrix::new([
305            [Rational::new(2, 1), Rational::new(0, 1)],
306            [Rational::new(0, 1), Rational::new(3, 1)],
307        ]);
308
309        let result = a.multiply(&b);
310
311        // Expected: [1/2*2 + 1/3*0,  1/2*0 + 1/3*3] = [1, 1]
312        //           [1/4*2 + 1/5*0,  1/4*0 + 1/5*3] = [1/2, 3/5]
313        assert_eq!(result.get(0, 0), Rational::new(1, 1));
314        assert_eq!(result.get(0, 1), Rational::new(1, 1));
315        assert_eq!(result.get(1, 0), Rational::new(1, 2));
316        assert_eq!(result.get(1, 1), Rational::new(3, 5));
317    }
318
319    #[test]
320    fn test_matrix_equality() {
321        let a = RationalMatrix::<2>::identity();
322        let b = RationalMatrix::<2>::identity();
323        assert_eq!(a, b);
324
325        let c = RationalMatrix::new([
326            [Rational::new(1, 1), Rational::new(1, 1)],
327            [Rational::new(0, 1), Rational::new(1, 1)],
328        ]);
329        assert_ne!(a, c);
330    }
331
332    #[test]
333    fn test_matrix_trace() {
334        let m = RationalMatrix::new([
335            [Rational::new(1, 2), Rational::new(3, 4)],
336            [Rational::new(5, 6), Rational::new(7, 8)],
337        ]);
338
339        // Trace = 1/2 + 7/8 = 4/8 + 7/8 = 11/8
340        assert_eq!(m.trace(), Rational::new(11, 8));
341    }
342}