rustframes/array/
linalg.rs

1use super::Array;
2
3impl Array<f64> {
4    /// Matrix multiplication (dot product)
5    pub fn dot(&self, other: &Array<f64>) -> Array<f64> {
6        match (self.ndim(), other.ndim()) {
7            (2, 2) => self.matmul_2d(other),
8            (1, 1) => self.dot_1d(other),
9            (2, 1) => self.matvec(other),
10            (1, 2) => self.vecmat(other),
11            _ => panic!("Unsupported dimensions for dot product"),
12        }
13    }
14
15    /// 1D vector dot product
16    fn dot_1d(&self, other: &Array<f64>) -> Array<f64> {
17        assert_eq!(self.shape[0], other.shape[0], "Vector lengths must match");
18        let result = self
19            .data
20            .iter()
21            .zip(other.data.iter())
22            .map(|(&a, &b)| a * b)
23            .sum::<f64>();
24        Array::from_vec(vec![result], vec![1])
25    }
26
27    /// 2D matrix multiplication
28    fn matmul_2d(&self, other: &Array<f64>) -> Array<f64> {
29        assert_eq!(
30            self.shape[1], other.shape[0],
31            "Matrix dimensions incompatible for multiplication"
32        );
33
34        let (m, k) = (self.shape[0], self.shape[1]);
35        let n = other.shape[1];
36        let mut result = Array::zeros(vec![m, n]);
37
38        for i in 0..m {
39            for j in 0..n {
40                let mut sum = 0.0;
41                for l in 0..k {
42                    sum += self[(i, l)] * other[(l, j)];
43                }
44                result[(i, j)] = sum;
45            }
46        }
47        result
48    }
49
50    /// Matrix-vector multiplication
51    fn matvec(&self, other: &Array<f64>) -> Array<f64> {
52        assert_eq!(
53            self.shape[1], other.shape[0],
54            "Matrix-vector dimensions incompatible"
55        );
56
57        let m = self.shape[0];
58        let mut result = Array::zeros(vec![m]);
59
60        for i in 0..m {
61            let mut sum = 0.0;
62            for j in 0..self.shape[1] {
63                sum += self[(i, j)] * other.data[j];
64            }
65            result.data[i] = sum;
66        }
67        result
68    }
69
70    /// Vector-matrix multiplication
71    fn vecmat(&self, other: &Array<f64>) -> Array<f64> {
72        assert_eq!(
73            self.shape[0], other.shape[0],
74            "Vector-matrix dimensions incompatible"
75        );
76
77        let n = other.shape[1];
78        let mut result = Array::zeros(vec![n]);
79
80        for j in 0..n {
81            let mut sum = 0.0;
82            for i in 0..other.shape[0] {
83                sum += self[&[{ i }][..]] * other[&[{ i }, { j }][..]];
84            }
85            result[&[{ j }][..]] = sum;
86        }
87        result
88    }
89
90    /// Matrix determinant (2x2 and 3x3 only)
91    pub fn det(&self) -> f64 {
92        assert_eq!(self.ndim(), 2, "Determinant requires 2D matrix");
93        assert_eq!(self.shape[0], self.shape[1], "Matrix must be square");
94
95        match self.shape[0] {
96            1 => self.data[0],
97            2 => self[(0, 0)] * self[(1, 1)] - self[(0, 1)] * self[(1, 0)],
98            3 => {
99                let a = self[(0, 0)];
100                let b = self[(0, 1)];
101                let c = self[(0, 2)];
102                let d = self[(1, 0)];
103                let e = self[(1, 1)];
104                let f = self[(1, 2)];
105                let g = self[(2, 0)];
106                let h = self[(2, 1)];
107                let i = self[(2, 2)];
108
109                a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g)
110            }
111            _ => panic!("Determinant only implemented for 1x1, 2x2, and 3x3 matrices"),
112        }
113    }
114
115    /// Matrix trace (sum of diagonal elements)
116    pub fn trace(&self) -> f64 {
117        assert_eq!(self.ndim(), 2, "Trace requires 2D matrix");
118        let min_dim = self.shape[0].min(self.shape[1]);
119
120        (0..min_dim).map(|i| self[(i, i)]).sum()
121    }
122
123    /// Matrix inverse (2x2 only for now)
124    pub fn inv(&self) -> Option<Array<f64>> {
125        assert_eq!(self.ndim(), 2, "Inverse requires 2D matrix");
126        assert_eq!(self.shape[0], self.shape[1], "Matrix must be square");
127
128        match self.shape[0] {
129            1 => {
130                let elem = self[(0, 0)];
131                if elem.abs() < 1e-10 {
132                    None
133                } else {
134                    Some(Array::from_vec(vec![1.0 / elem], vec![1, 1]))
135                }
136            }
137            2 => {
138                let det = self.det();
139                if det.abs() < 1e-10 {
140                    return None; // Matrix is singular
141                }
142
143                let a = self[(0, 0)];
144                let b = self[(0, 1)];
145                let c = self[(1, 0)];
146                let d = self[(1, 1)];
147
148                let inv_det = 1.0 / det;
149                let data = vec![d * inv_det, -b * inv_det, -c * inv_det, a * inv_det];
150
151                Some(Array::from_vec(data, vec![2, 2]))
152            }
153            _ => {
154                panic!("Matrix inverse only implemented for 1x1 and 2x2 matrices");
155            }
156        }
157    }
158
159    /// QR decomposition (Gram-Schmidt process)
160    pub fn qr(&self) -> (Array<f64>, Array<f64>) {
161        assert_eq!(self.ndim(), 2, "QR decomposition requires 2D matrix");
162        let (m, n) = (self.shape[0], self.shape[1]);
163
164        let mut q = Array::zeros(vec![m, n]);
165        let mut r = Array::zeros(vec![n, n]);
166
167        // Gram-Schmidt process
168        for j in 0..n {
169            // Get j-th column of A
170            let mut v = Vec::new();
171            for i in 0..m {
172                v.push(self[(i, j)]);
173            }
174
175            // Subtract projections of previous columns
176            for k in 0..j {
177                let mut dot_product = 0.0;
178                for i in 0..m {
179                    dot_product += v[i] * q[(i, k)];
180                }
181                r[(k, j)] = dot_product;
182
183                for i in 0..m {
184                    v[i] -= dot_product * q[(i, k)];
185                }
186            }
187
188            // Normalize
189            let norm = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
190            r[(j, j)] = norm;
191
192            if norm > 1e-10 {
193                for i in 0..m {
194                    q[(i, j)] = v[i] / norm;
195                }
196            }
197        }
198
199        (q, r)
200    }
201
202    /// Solve linear system Ax = b using QR decomposition
203    pub fn solve(&self, b: &Array<f64>) -> Option<Array<f64>> {
204        assert_eq!(self.ndim(), 2, "A must be 2D matrix");
205        assert_eq!(b.ndim(), 1, "b must be 1D vector");
206        assert_eq!(self.shape[0], b.shape[0], "Dimensions incompatible");
207        assert_eq!(self.shape[0], self.shape[1], "A must be square");
208
209        let n = self.shape[0];
210        let (q, r) = self.qr();
211
212        // Solve Rx = Q^T * b
213        let mut qtb = Array::zeros(vec![n]);
214        for i in 0..n {
215            let mut sum = 0.0;
216            for j in 0..n {
217                sum += q[(j, i)] * b[&[j][..]];
218            }
219            qtb[&[i][..]] = sum;
220        }
221
222        // Back substitution
223        let mut x = Array::zeros(vec![n]);
224        for i in (0..n).rev() {
225            let mut sum = qtb[&[i][..]];
226            for j in (i + 1)..n {
227                sum -= r[(i, j)] * x[&[j][..]];
228            }
229
230            if r[(i, i)].abs() < 1e-10 {
231                return None; // Singular matrix
232            }
233
234            x[&[i][..]] = sum / r[(i, i)];
235        }
236
237        Some(x)
238    }
239
240    /// Compute eigenvalues and eigenvectors (2x2 only)
241    pub fn eig(&self) -> Option<(Array<f64>, Array<f64>)> {
242        assert_eq!(self.ndim(), 2, "Eigendecomposition requires 2D matrix");
243        assert_eq!(self.shape[0], self.shape[1], "Matrix must be square");
244
245        if self.shape[0] == 2 {
246            let a = self[(0, 0)];
247            let b = self[(0, 1)];
248            let c = self[(1, 0)];
249            let d = self[(1, 1)];
250
251            // Characteristic polynomial: λ² - (a+d)λ + (ad-bc) = 0
252            let trace = a + d;
253            let det = a * d - b * c;
254            let discriminant = trace * trace - 4.0 * det;
255
256            if discriminant < 0.0 {
257                return None; // Complex eigenvalues
258            }
259
260            let sqrt_disc = discriminant.sqrt();
261            let lambda1 = (trace + sqrt_disc) / 2.0;
262            let lambda2 = (trace - sqrt_disc) / 2.0;
263
264            let eigenvalues = Array::from_vec(vec![lambda1, lambda2], vec![2]);
265
266            // Compute eigenvectors
267            let mut eigenvectors = Array::zeros(vec![2, 2]);
268
269            // Eigenvector for λ₁
270            if b.abs() > 1e-10 {
271                eigenvectors[(0, 0)] = 1.0;
272                eigenvectors[(1, 0)] = (lambda1 - a) / b;
273            } else if c.abs() > 1e-10 {
274                eigenvectors[(0, 0)] = (lambda1 - d) / c;
275                eigenvectors[(1, 0)] = 1.0;
276            } else {
277                eigenvectors[(0, 0)] = 1.0;
278                eigenvectors[(1, 0)] = 0.0;
279            }
280
281            // Eigenvector for λ₂
282            if b.abs() > 1e-10 {
283                eigenvectors[(0, 1)] = 1.0;
284                eigenvectors[(1, 1)] = (lambda2 - a) / b;
285            } else if c.abs() > 1e-10 {
286                eigenvectors[(0, 1)] = (lambda2 - d) / c;
287                eigenvectors[(1, 1)] = 1.0;
288            } else {
289                eigenvectors[(0, 1)] = 0.0;
290                eigenvectors[(1, 1)] = 1.0;
291            }
292
293            // Normalize eigenvectors
294            for j in 0..2 {
295                let norm = (eigenvectors[(0, j)].powi(2) + eigenvectors[(1, j)].powi(2)).sqrt();
296                if norm > 1e-10 {
297                    eigenvectors[(0, j)] /= norm;
298                    eigenvectors[(1, j)] /= norm;
299                }
300            }
301
302            Some((eigenvalues, eigenvectors))
303        } else {
304            panic!("Eigendecomposition only implemented for 2x2 matrices");
305        }
306    }
307
308    /// Singular Value Decomposition (SVD) - simplified version for 2x2 matrices
309    pub fn svd(&self) -> (Array<f64>, Array<f64>, Array<f64>) {
310        assert_eq!(self.ndim(), 2, "SVD requires 2D matrix");
311
312        if self.shape[0] == 2 && self.shape[1] == 2 {
313            // For 2x2 matrix A, compute A^T * A and A * A^T
314            let at = self.transpose();
315            let ata = at.dot(self);
316            let aat = self.dot(&at);
317
318            // Get eigenvalues and eigenvectors
319            let (s_squared, v) = ata.eig().expect("Failed to compute eigenvalues");
320            let (_, u) = aat.eig().expect("Failed to compute eigenvalues");
321
322            // Singular values are square roots of eigenvalues
323            let s1 = s_squared[&[0][..]].max(0.0).sqrt();
324            let s2 = s_squared[&[1][..]].max(0.0).sqrt();
325            let s = Array::from_vec(vec![s1, s2], vec![2]);
326
327            (u, s, v.transpose())
328        } else {
329            panic!("SVD only implemented for 2x2 matrices");
330        }
331    }
332
333    /// Compute matrix norm (Frobenius norm)
334    pub fn norm(&self) -> f64 {
335        self.data.iter().map(|&x| x * x).sum::<f64>().sqrt()
336    }
337
338    /// Compute matrix rank (approximate, using SVD)
339    pub fn rank(&self, tolerance: Option<f64>) -> usize {
340        let tol = tolerance.unwrap_or(1e-10);
341
342        if self.shape[0] == 2 && self.shape[1] == 2 {
343            let (_, s, _) = self.svd();
344            s.data.iter().filter(|&&x| x > tol).count()
345        } else {
346            // For non-square matrices, use determinant-based approach
347            let min_dim = self.shape[0].min(self.shape[1]);
348            if min_dim == 1 {
349                if self.data.iter().any(|&x| x.abs() > tol) {
350                    1
351                } else {
352                    0
353                }
354            } else {
355                // Simplified: just check if determinant is non-zero for square submatrices
356                if self.shape[0] == self.shape[1] {
357                    if self.det().abs() > tol {
358                        self.shape[0]
359                    } else {
360                        self.shape[0] - 1
361                    }
362                } else {
363                    min_dim // Conservative estimate
364                }
365            }
366        }
367    }
368
369    /// Check if matrix is symmetric
370    pub fn is_symmetric(&self, tolerance: Option<f64>) -> bool {
371        assert_eq!(self.ndim(), 2, "Symmetry check requires 2D matrix");
372        assert_eq!(self.shape[0], self.shape[1], "Matrix must be square");
373
374        let tol = tolerance.unwrap_or(1e-10);
375        let n = self.shape[0];
376
377        for i in 0..n {
378            for j in 0..n {
379                if (self[(i, j)] - self[(j, i)]).abs() > tol {
380                    return false;
381                }
382            }
383        }
384        true
385    }
386
387    /// Check if matrix is orthogonal
388    pub fn is_orthogonal(&self, tolerance: Option<f64>) -> bool {
389        assert_eq!(self.ndim(), 2, "Orthogonality check requires 2D matrix");
390        assert_eq!(self.shape[0], self.shape[1], "Matrix must be square");
391
392        let tol = tolerance.unwrap_or(1e-10);
393        let at = self.transpose();
394        let should_be_identity = self.dot(&at);
395        let n = self.shape[0];
396
397        // Check if A * A^T = I
398        for i in 0..n {
399            for j in 0..n {
400                let expected = if i == j { 1.0 } else { 0.0 };
401                if (should_be_identity[(i, j)] - expected).abs() > tol {
402                    return false;
403                }
404            }
405        }
406        true
407    }
408}