jax_rs/ops/
linalg.rs

1//! Linear algebra operations.
2
3use crate::{buffer::Buffer, Array, DType, Device, Shape};
4
5impl Array {
6    /// Transpose the array by reversing its axes.
7    ///
8    /// For 2D arrays, this swaps rows and columns.
9    /// For 1D arrays, returns a copy.
10    /// For higher dimensions, reverses all axes.
11    ///
12    /// # Examples
13    ///
14    /// ```
15    /// # use jax_rs::{Array, Shape};
16    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![2, 3]));
17    /// let b = a.transpose();
18    /// assert_eq!(b.shape().as_slice(), &[3, 2]);
19    /// ```
20    pub fn transpose(&self) -> Array {
21        let shape = self.shape();
22        let dims = shape.as_slice();
23
24        if dims.len() <= 1 {
25            // For scalars and 1D arrays, transpose is identity
26            return self.clone();
27        }
28
29        // Reverse the dimensions
30        let new_dims: Vec<usize> = dims.iter().rev().copied().collect();
31        let new_shape = Shape::new(new_dims);
32
33        // For 2D case, implement explicit transpose
34        if dims.len() == 2 {
35            let (rows, cols) = (dims[0], dims[1]);
36            let data = self.to_vec();
37            let mut transposed = vec![0.0; data.len()];
38
39            for i in 0..rows {
40                for j in 0..cols {
41                    transposed[j * rows + i] = data[i * cols + j];
42                }
43            }
44
45            let buffer = Buffer::from_f32(transposed, Device::Cpu);
46            return Array::from_buffer(buffer, new_shape);
47        }
48
49        // For higher dimensions, use general algorithm
50        transpose_nd(self, new_shape)
51    }
52
53    /// Transpose the array with a specified permutation of axes.
54    ///
55    /// # Arguments
56    ///
57    /// * `axes` - The new order of axes
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// # use jax_rs::{Array, Shape};
63    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![2, 3]));
64    /// let b = a.transpose_axes(&[1, 0]);
65    /// assert_eq!(b.shape().as_slice(), &[3, 2]);
66    /// ```
67    pub fn transpose_axes(&self, axes: &[usize]) -> Array {
68        let shape = self.shape();
69        let dims = shape.as_slice();
70        let ndim = dims.len();
71
72        assert_eq!(axes.len(), ndim, "axes must have same length as dimensions");
73
74        // Verify axes is a valid permutation
75        let mut seen = vec![false; ndim];
76        for &axis in axes {
77            assert!(axis < ndim, "axis {} out of bounds for {} dimensions", axis, ndim);
78            assert!(!seen[axis], "duplicate axis in permutation");
79            seen[axis] = true;
80        }
81
82        // If axes is identity permutation, return self
83        if axes.iter().enumerate().all(|(i, &a)| i == a) {
84            return self.clone();
85        }
86
87        // Compute new shape
88        let new_dims: Vec<usize> = axes.iter().map(|&a| dims[a]).collect();
89        let new_shape = Shape::new(new_dims.clone());
90
91        // Simple 2D case
92        if ndim == 2 && axes == [1, 0] {
93            return self.transpose();
94        }
95
96        // General case: compute transposed data
97        let data = self.to_vec();
98        let size = data.len();
99        let mut result = vec![0.0; size];
100
101        // Compute strides for original array
102        let mut old_strides = vec![1usize; ndim];
103        for i in (0..ndim - 1).rev() {
104            old_strides[i] = old_strides[i + 1] * dims[i + 1];
105        }
106
107        // Compute strides for new array
108        let mut new_strides = vec![1usize; ndim];
109        for i in (0..ndim - 1).rev() {
110            new_strides[i] = new_strides[i + 1] * new_dims[i + 1];
111        }
112
113        // Map strides according to permutation
114        let perm_strides: Vec<usize> = axes.iter().map(|&a| old_strides[a]).collect();
115
116        // Copy data with transposition
117        for new_idx in 0..size {
118            // Convert flat index to multi-index in new array
119            let mut remaining = new_idx;
120            let mut old_idx = 0;
121            for i in 0..ndim {
122                let coord = remaining / new_strides[i];
123                remaining %= new_strides[i];
124                old_idx += coord * perm_strides[i];
125            }
126            result[new_idx] = data[old_idx];
127        }
128
129        let buffer = Buffer::from_f32(result, Device::Cpu);
130        Array::from_buffer(buffer, new_shape)
131    }
132
133    /// Matrix multiplication of two 2D arrays.
134    ///
135    /// # Arguments
136    ///
137    /// * `other` - The right-hand array to multiply with
138    ///
139    /// # Panics
140    ///
141    /// Panics if shapes are incompatible or arrays are not 2D.
142    ///
143    /// # Examples
144    ///
145    /// ```
146    /// # use jax_rs::{Array, Shape};
147    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
148    /// let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0], Shape::new(vec![2, 2]));
149    /// let c = a.matmul(&b);
150    /// // [[1*5 + 2*7, 1*6 + 2*8],
151    /// //  [3*5 + 4*7, 3*6 + 4*8]]
152    /// // = [[19, 22], [43, 50]]
153    /// ```
154    pub fn matmul(&self, other: &Array) -> Array {
155        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
156        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
157
158        let a_shape = self.shape().as_slice();
159        let b_shape = other.shape().as_slice();
160
161        // Handle vector-matrix and matrix-vector cases
162        if a_shape.len() == 1 && b_shape.len() == 2 {
163            // Vector-matrix: (N,) @ (N, M) -> (M,)
164            assert_eq!(
165                a_shape[0], b_shape[0],
166                "Vector-matrix multiplication: incompatible shapes"
167            );
168            return self
169                .reshape(Shape::new(vec![1, a_shape[0]]))
170                .matmul(other)
171                .reshape(Shape::new(vec![b_shape[1]]));
172        }
173
174        if a_shape.len() == 2 && b_shape.len() == 1 {
175            // Matrix-vector: (M, N) @ (N,) -> (M,)
176            assert_eq!(
177                a_shape[1], b_shape[0],
178                "Matrix-vector multiplication: incompatible shapes"
179            );
180            return self
181                .matmul(&other.reshape(Shape::new(vec![b_shape[0], 1])))
182                .reshape(Shape::new(vec![a_shape[0]]));
183        }
184
185        // Matrix-matrix multiplication
186        assert_eq!(a_shape.len(), 2, "Left array must be 2D");
187        assert_eq!(b_shape.len(), 2, "Right array must be 2D");
188        assert_eq!(
189            a_shape[1], b_shape[0],
190            "Incompatible shapes for matmul: {:?} @ {:?}",
191            a_shape, b_shape
192        );
193
194        let (m, k) = (a_shape[0], a_shape[1]);
195        let n = b_shape[1];
196
197        // Dispatch based on device
198        match (self.device(), other.device()) {
199            (Device::WebGpu, Device::WebGpu) => {
200                // GPU path
201                let output_buffer = Buffer::zeros(m * n, DType::Float32, Device::WebGpu);
202
203                crate::backend::ops::gpu_matmul(
204                    self.buffer(),
205                    other.buffer(),
206                    &output_buffer,
207                    m,
208                    n,
209                    k,
210                );
211
212                Array::from_buffer(output_buffer, Shape::new(vec![m, n]))
213            }
214            (Device::Cpu, Device::Cpu) | (Device::Wasm, Device::Wasm) => {
215                // CPU path - naive O(n^3) algorithm
216                let a_data = self.to_vec();
217                let b_data = other.to_vec();
218                let mut result = vec![0.0; m * n];
219
220                for i in 0..m {
221                    for j in 0..n {
222                        let mut sum = 0.0;
223                        for p in 0..k {
224                            sum += a_data[i * k + p] * b_data[p * n + j];
225                        }
226                        result[i * n + j] = sum;
227                    }
228                }
229
230                let buffer = Buffer::from_f32(result, Device::Cpu);
231                Array::from_buffer(buffer, Shape::new(vec![m, n]))
232            }
233            _ => {
234                panic!("Mixed device operations not supported. Both arrays must be on the same device.");
235            }
236        }
237    }
238
239    /// Dot product of two arrays.
240    ///
241    /// For 1D arrays, computes the inner product.
242    /// For 2D arrays, equivalent to matmul.
243    ///
244    /// # Examples
245    ///
246    /// ```
247    /// # use jax_rs::{Array, Shape};
248    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
249    /// let b = Array::from_vec(vec![4.0, 5.0, 6.0], Shape::new(vec![3]));
250    /// let c = a.dot(&b);
251    /// assert_eq!(c.to_vec(), vec![32.0]); // 1*4 + 2*5 + 3*6 = 32
252    /// ```
253    pub fn dot(&self, other: &Array) -> Array {
254        let a_shape = self.shape().as_slice();
255        let b_shape = other.shape().as_slice();
256
257        // 1D dot product (inner product)
258        if a_shape.len() == 1 && b_shape.len() == 1 {
259            assert_eq!(
260                a_shape[0], b_shape[0],
261                "Arrays must have same length for dot product"
262            );
263
264            let a_data = self.to_vec();
265            let b_data = other.to_vec();
266            let result: f32 =
267                a_data.iter().zip(b_data.iter()).map(|(a, b)| a * b).sum();
268
269            let buffer = Buffer::from_f32(vec![result], Device::Cpu);
270            return Array::from_buffer(buffer, Shape::scalar());
271        }
272
273        // For higher dimensions, use matmul
274        self.matmul(other)
275    }
276
277    /// Compute the norm of a vector or matrix.
278    ///
279    /// # Arguments
280    ///
281    /// * `ord` - Order of the norm. Common values:
282    ///   - `1.0`: L1 norm (sum of absolute values)
283    ///   - `2.0`: L2 norm (Euclidean norm)
284    ///   - `f32::INFINITY`: L-infinity norm (maximum absolute value)
285    ///
286    /// # Examples
287    ///
288    /// ```
289    /// # use jax_rs::{Array, Shape};
290    /// let a = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
291    /// let l2_norm = a.norm(2.0);
292    /// assert_eq!(l2_norm, 5.0); // sqrt(3^2 + 4^2) = 5
293    /// ```
294    pub fn norm(&self, ord: f32) -> f32 {
295        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
296        let data = self.to_vec();
297
298        if ord == f32::INFINITY {
299            // L-infinity norm: max absolute value
300            data.iter().map(|x| x.abs()).fold(0.0, f32::max)
301        } else if ord == 1.0 {
302            // L1 norm: sum of absolute values
303            data.iter().map(|x| x.abs()).sum()
304        } else if ord == 2.0 {
305            // L2 norm: Euclidean norm
306            data.iter().map(|x| x * x).sum::<f32>().sqrt()
307        } else {
308            // General Lp norm: (sum |x|^p)^(1/p)
309            data.iter()
310                .map(|x| x.abs().powf(ord))
311                .sum::<f32>()
312                .powf(1.0 / ord)
313        }
314    }
315
316    /// Compute the determinant of a square matrix.
317    ///
318    /// Uses LU decomposition for matrices larger than 3x3.
319    ///
320    /// # Examples
321    ///
322    /// ```
323    /// # use jax_rs::{Array, Shape};
324    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
325    /// let det = a.det();
326    /// assert_eq!(det, -2.0); // 1*4 - 2*3 = -2
327    /// ```
328    pub fn det(&self) -> f32 {
329        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
330        let shape = self.shape().as_slice();
331        assert_eq!(shape.len(), 2, "Determinant requires 2D array");
332        assert_eq!(shape[0], shape[1], "Determinant requires square matrix");
333
334        let n = shape[0];
335        let data = self.to_vec();
336
337        match n {
338            1 => data[0],
339            2 => {
340                // 2x2: ad - bc
341                data[0] * data[3] - data[1] * data[2]
342            }
343            3 => {
344                // 3x3: Sarrus rule
345                let a = data[0];
346                let b = data[1];
347                let c = data[2];
348                let d = data[3];
349                let e = data[4];
350                let f = data[5];
351                let g = data[6];
352                let h = data[7];
353                let i = data[8];
354                a * e * i + b * f * g + c * d * h - c * e * g - b * d * i - a * f * h
355            }
356            _ => {
357                // For larger matrices, use LU decomposition
358                let (_, u, p) = self.lu_decomposition();
359                let u_data = u.to_vec();
360
361                // det(A) = det(P) * det(L) * det(U)
362                // det(L) = 1 (unit diagonal)
363                // det(U) = product of diagonal elements
364                // det(P) = (-1)^(number of swaps)
365                let mut det_u = 1.0;
366                for i in 0..n {
367                    det_u *= u_data[i * n + i];
368                }
369
370                // Count permutation parity
371                let mut swaps = 0;
372                let mut visited = vec![false; n];
373                for i in 0..n {
374                    if !visited[i] {
375                        let mut j = i;
376                        let mut cycle_len = 0;
377                        while !visited[j] {
378                            visited[j] = true;
379                            j = p[j];
380                            cycle_len += 1;
381                        }
382                        if cycle_len > 1 {
383                            swaps += cycle_len - 1;
384                        }
385                    }
386                }
387
388                if swaps % 2 == 0 {
389                    det_u
390                } else {
391                    -det_u
392                }
393            }
394        }
395    }
396
397    /// LU decomposition with partial pivoting.
398    ///
399    /// Returns (L, U, P) where:
400    /// - L is lower triangular with unit diagonal
401    /// - U is upper triangular
402    /// - P is permutation array (row swaps)
403    fn lu_decomposition(&self) -> (Array, Array, Vec<usize>) {
404        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
405        let shape = self.shape().as_slice();
406        assert_eq!(shape.len(), 2, "LU decomposition requires 2D array");
407        assert_eq!(shape[0], shape[1], "LU decomposition requires square matrix");
408
409        let n = shape[0];
410        let data = self.to_vec();
411
412        // Initialize permutation
413        let mut p: Vec<usize> = (0..n).collect();
414        let mut a = data.clone();
415
416        for k in 0..n {
417            // Find pivot
418            let mut pivot_row = k;
419            let mut max_val = a[k * n + k].abs();
420            for i in (k + 1)..n {
421                let val = a[i * n + k].abs();
422                if val > max_val {
423                    max_val = val;
424                    pivot_row = i;
425                }
426            }
427
428            // Swap rows if needed
429            if pivot_row != k {
430                p.swap(k, pivot_row);
431                for j in 0..n {
432                    a.swap(k * n + j, pivot_row * n + j);
433                }
434            }
435
436            // Eliminate column
437            for i in (k + 1)..n {
438                let factor = a[i * n + k] / a[k * n + k];
439                a[i * n + k] = factor; // Store L factor in lower triangle
440                for j in (k + 1)..n {
441                    a[i * n + j] -= factor * a[k * n + j];
442                }
443            }
444        }
445
446        // Extract L and U
447        let mut l_data = vec![0.0; n * n];
448        let mut u_data = vec![0.0; n * n];
449
450        for i in 0..n {
451            for j in 0..n {
452                if i > j {
453                    l_data[i * n + j] = a[i * n + j];
454                } else if i == j {
455                    l_data[i * n + j] = 1.0;
456                    u_data[i * n + j] = a[i * n + j];
457                } else {
458                    u_data[i * n + j] = a[i * n + j];
459                }
460            }
461        }
462
463        let l = Array::from_vec(l_data, Shape::new(vec![n, n]));
464        let u = Array::from_vec(u_data, Shape::new(vec![n, n]));
465
466        (l, u, p)
467    }
468
469    /// Compute the matrix inverse.
470    ///
471    /// Uses Gauss-Jordan elimination.
472    ///
473    /// # Examples
474    ///
475    /// ```
476    /// # use jax_rs::{Array, Shape};
477    /// let a = Array::from_vec(vec![4.0, 7.0, 2.0, 6.0], Shape::new(vec![2, 2]));
478    /// let inv_a = a.inv();
479    /// // Verify A * A^-1 = I
480    /// let identity = a.matmul(&inv_a);
481    /// let expected = vec![1.0, 0.0, 0.0, 1.0];
482    /// for (i, &val) in identity.to_vec().iter().enumerate() {
483    ///     assert!((val - expected[i]).abs() < 1e-5);
484    /// }
485    /// ```
486    pub fn inv(&self) -> Array {
487        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
488        let shape = self.shape().as_slice();
489        assert_eq!(shape.len(), 2, "Matrix inversion requires 2D array");
490        assert_eq!(shape[0], shape[1], "Matrix inversion requires square matrix");
491
492        let n = shape[0];
493        let data = self.to_vec();
494
495        // Create augmented matrix [A | I]
496        let mut aug = vec![0.0; n * 2 * n];
497        for i in 0..n {
498            for j in 0..n {
499                aug[i * 2 * n + j] = data[i * n + j];
500            }
501            aug[i * 2 * n + n + i] = 1.0; // Identity on the right
502        }
503
504        // Gauss-Jordan elimination
505        for k in 0..n {
506            // Find pivot
507            let mut pivot_row = k;
508            let mut max_val = aug[k * 2 * n + k].abs();
509            for i in (k + 1)..n {
510                let val = aug[i * 2 * n + k].abs();
511                if val > max_val {
512                    max_val = val;
513                    pivot_row = i;
514                }
515            }
516
517            assert!(
518                max_val > 1e-10,
519                "Matrix is singular and cannot be inverted"
520            );
521
522            // Swap rows
523            if pivot_row != k {
524                for j in 0..(2 * n) {
525                    aug.swap(k * 2 * n + j, pivot_row * 2 * n + j);
526                }
527            }
528
529            // Scale pivot row
530            let pivot = aug[k * 2 * n + k];
531            for j in 0..(2 * n) {
532                aug[k * 2 * n + j] /= pivot;
533            }
534
535            // Eliminate column
536            for i in 0..n {
537                if i != k {
538                    let factor = aug[i * 2 * n + k];
539                    for j in 0..(2 * n) {
540                        aug[i * 2 * n + j] -= factor * aug[k * 2 * n + j];
541                    }
542                }
543            }
544        }
545
546        // Extract inverse from right half
547        let mut inv_data = vec![0.0; n * n];
548        for i in 0..n {
549            for j in 0..n {
550                inv_data[i * n + j] = aug[i * 2 * n + n + j];
551            }
552        }
553
554        Array::from_vec(inv_data, Shape::new(vec![n, n]))
555    }
556
557    /// Solve a linear system Ax = b.
558    ///
559    /// Uses Gaussian elimination with partial pivoting.
560    ///
561    /// # Examples
562    ///
563    /// ```
564    /// # use jax_rs::{Array, Shape};
565    /// // Solve: 2x + y = 5, x + 3y = 6
566    /// let a = Array::from_vec(vec![2.0, 1.0, 1.0, 3.0], Shape::new(vec![2, 2]));
567    /// let b = Array::from_vec(vec![5.0, 6.0], Shape::new(vec![2]));
568    /// let x = a.solve(&b);
569    /// // Solution: x = [1.8, 1.4]
570    /// assert!((x.to_vec()[0] - 1.8).abs() < 1e-5);
571    /// assert!((x.to_vec()[1] - 1.4).abs() < 1e-5);
572    /// ```
573    pub fn solve(&self, b: &Array) -> Array {
574        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
575        assert_eq!(b.dtype(), DType::Float32, "Only Float32 supported");
576
577        let a_shape = self.shape().as_slice();
578        let b_shape = b.shape().as_slice();
579
580        assert_eq!(a_shape.len(), 2, "A must be 2D");
581        assert_eq!(a_shape[0], a_shape[1], "A must be square");
582        assert_eq!(b_shape.len(), 1, "b must be 1D");
583        assert_eq!(a_shape[0], b_shape[0], "Incompatible dimensions");
584
585        let n = a_shape[0];
586        let a_data = self.to_vec();
587        let b_data = b.to_vec();
588
589        // Create augmented matrix [A | b]
590        let mut aug = vec![0.0; n * (n + 1)];
591        for i in 0..n {
592            for j in 0..n {
593                aug[i * (n + 1) + j] = a_data[i * n + j];
594            }
595            aug[i * (n + 1) + n] = b_data[i];
596        }
597
598        // Forward elimination with partial pivoting
599        for k in 0..n {
600            // Find pivot
601            let mut pivot_row = k;
602            let mut max_val = aug[k * (n + 1) + k].abs();
603            for i in (k + 1)..n {
604                let val = aug[i * (n + 1) + k].abs();
605                if val > max_val {
606                    max_val = val;
607                    pivot_row = i;
608                }
609            }
610
611            assert!(
612                max_val > 1e-10,
613                "Matrix is singular, system has no unique solution"
614            );
615
616            // Swap rows
617            if pivot_row != k {
618                for j in 0..(n + 1) {
619                    aug.swap(k * (n + 1) + j, pivot_row * (n + 1) + j);
620                }
621            }
622
623            // Eliminate below
624            for i in (k + 1)..n {
625                let factor = aug[i * (n + 1) + k] / aug[k * (n + 1) + k];
626                for j in k..(n + 1) {
627                    aug[i * (n + 1) + j] -= factor * aug[k * (n + 1) + j];
628                }
629            }
630        }
631
632        // Back substitution
633        let mut x = vec![0.0; n];
634        for i in (0..n).rev() {
635            let mut sum = aug[i * (n + 1) + n];
636            for j in (i + 1)..n {
637                sum -= aug[i * (n + 1) + j] * x[j];
638            }
639            x[i] = sum / aug[i * (n + 1) + i];
640        }
641
642        Array::from_vec(x, Shape::new(vec![n]))
643    }
644
645    /// Compute the outer product of two 1D arrays.
646    ///
647    /// Given two 1D arrays a and b, returns a 2D array of shape (a.len(), b.len())
648    /// where result[i, j] = a[i] * b[j].
649    ///
650    /// # Examples
651    ///
652    /// ```
653    /// # use jax_rs::{Array, Shape};
654    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
655    /// let b = Array::from_vec(vec![4.0, 5.0], Shape::new(vec![2]));
656    /// let c = a.outer(&b);
657    /// assert_eq!(c.shape().as_slice(), &[3, 2]);
658    /// // [[1*4, 1*5], [2*4, 2*5], [3*4, 3*5]]
659    /// // = [[4, 5], [8, 10], [12, 15]]
660    /// assert_eq!(c.to_vec(), vec![4.0, 5.0, 8.0, 10.0, 12.0, 15.0]);
661    /// ```
662    pub fn outer(&self, other: &Array) -> Array {
663        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
664        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
665
666        let a_shape = self.shape().as_slice();
667        let b_shape = other.shape().as_slice();
668
669        assert_eq!(a_shape.len(), 1, "First array must be 1D");
670        assert_eq!(b_shape.len(), 1, "Second array must be 1D");
671
672        let a_data = self.to_vec();
673        let b_data = other.to_vec();
674        let m = a_shape[0];
675        let n = b_shape[0];
676
677        let mut result = Vec::with_capacity(m * n);
678        for &a_val in a_data.iter() {
679            for &b_val in b_data.iter() {
680                result.push(a_val * b_val);
681            }
682        }
683
684        Array::from_vec(result, Shape::new(vec![m, n]))
685    }
686
687    /// Compute the inner product of two 1D arrays.
688    ///
689    /// For 1D arrays, this is the same as dot product: sum of element-wise products.
690    ///
691    /// # Examples
692    ///
693    /// ```
694    /// # use jax_rs::{Array, Shape};
695    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
696    /// let b = Array::from_vec(vec![4.0, 5.0, 6.0], Shape::new(vec![3]));
697    /// let result = a.inner(&b);
698    /// assert_eq!(result, 32.0); // 1*4 + 2*5 + 3*6 = 32
699    /// ```
700    pub fn inner(&self, other: &Array) -> f32 {
701        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
702        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
703
704        let a_shape = self.shape().as_slice();
705        let b_shape = other.shape().as_slice();
706
707        assert_eq!(a_shape.len(), 1, "First array must be 1D");
708        assert_eq!(b_shape.len(), 1, "Second array must be 1D");
709        assert_eq!(
710            a_shape[0], b_shape[0],
711            "Arrays must have same length for inner product"
712        );
713
714        let a_data = self.to_vec();
715        let b_data = other.to_vec();
716
717        a_data.iter().zip(b_data.iter()).map(|(a, b)| a * b).sum()
718    }
719
720    /// Compute the cross product of two 3D vectors.
721    ///
722    /// Returns a vector perpendicular to both input vectors.
723    /// Formula: a × b = [a1*b2 - a2*b1, a2*b0 - a0*b2, a0*b1 - a1*b0]
724    ///
725    /// # Examples
726    ///
727    /// ```
728    /// # use jax_rs::{Array, Shape};
729    /// let a = Array::from_vec(vec![1.0, 0.0, 0.0], Shape::new(vec![3]));
730    /// let b = Array::from_vec(vec![0.0, 1.0, 0.0], Shape::new(vec![3]));
731    /// let c = a.cross(&b);
732    /// assert_eq!(c.to_vec(), vec![0.0, 0.0, 1.0]); // i × j = k
733    /// ```
734    pub fn cross(&self, other: &Array) -> Array {
735        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
736        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
737
738        let a_shape = self.shape().as_slice();
739        let b_shape = other.shape().as_slice();
740
741        assert_eq!(a_shape.len(), 1, "First array must be 1D");
742        assert_eq!(b_shape.len(), 1, "Second array must be 1D");
743        assert_eq!(a_shape[0], 3, "Cross product requires 3D vectors");
744        assert_eq!(b_shape[0], 3, "Cross product requires 3D vectors");
745
746        let a = self.to_vec();
747        let b = other.to_vec();
748
749        let result = vec![
750            a[1] * b[2] - a[2] * b[1],  // i component
751            a[2] * b[0] - a[0] * b[2],  // j component
752            a[0] * b[1] - a[1] * b[0],  // k component
753        ];
754
755        Array::from_vec(result, Shape::new(vec![3]))
756    }
757
758    /// Compute the trace of a 2D array (sum of diagonal elements).
759    ///
760    /// # Examples
761    ///
762    /// ```
763    /// # use jax_rs::{Array, Shape};
764    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
765    /// let tr = a.trace();
766    /// assert_eq!(tr, 5.0); // 1 + 4 = 5
767    /// ```
768    pub fn trace(&self) -> f32 {
769        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
770        let shape = self.shape().as_slice();
771        assert_eq!(shape.len(), 2, "Trace requires 2D array");
772        assert_eq!(shape[0], shape[1], "Trace requires square matrix");
773
774        let n = shape[0];
775        let data = self.to_vec();
776
777        let mut sum = 0.0;
778        for i in 0..n {
779            sum += data[i * n + i];
780        }
781        sum
782    }
783
784    /// Extract the diagonal of a 2D array.
785    ///
786    /// # Examples
787    ///
788    /// ```
789    /// # use jax_rs::{Array, Shape};
790    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![2, 3]));
791    /// let diag = a.diagonal();
792    /// assert_eq!(diag.to_vec(), vec![1.0, 5.0]); // Elements at (0,0) and (1,1)
793    /// ```
794    pub fn diagonal(&self) -> Array {
795        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
796        let shape = self.shape().as_slice();
797        assert_eq!(shape.len(), 2, "Diagonal requires 2D array");
798
799        let (rows, cols) = (shape[0], shape[1]);
800        let diag_len = rows.min(cols);
801        let data = self.to_vec();
802
803        let mut result = Vec::with_capacity(diag_len);
804        for i in 0..diag_len {
805            result.push(data[i * cols + i]);
806        }
807
808        Array::from_vec(result, Shape::new(vec![diag_len]))
809    }
810
811    /// Generate a Vandermonde matrix.
812    ///
813    /// Creates a matrix where each row is the input vector raised to successive powers.
814    ///
815    /// # Examples
816    ///
817    /// ```
818    /// # use jax_rs::{Array, Shape};
819    /// let x = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
820    /// let v = x.vander(4);
821    /// // [[1, 1, 1, 1], [1, 2, 4, 8], [1, 3, 9, 27]]
822    /// ```
823    pub fn vander(&self, n: usize) -> Array {
824        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
825        let shape = self.shape().as_slice();
826        assert_eq!(shape.len(), 1, "vander() only supports 1D arrays");
827
828        let x = self.to_vec();
829        let m = x.len();
830        let mut result = Vec::with_capacity(m * n);
831
832        for &val in x.iter() {
833            for pow in 0..n {
834                result.push(val.powi(pow as i32));
835            }
836        }
837
838        Array::from_vec(result, Shape::new(vec![m, n]))
839    }
840
841    /// QR decomposition of a matrix.
842    ///
843    /// Decomposes matrix A into Q (orthogonal) and R (upper triangular) such that A = QR.
844    /// Uses the Gram-Schmidt process.
845    ///
846    /// # Examples
847    ///
848    /// ```
849    /// # use jax_rs::{Array, Shape};
850    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
851    /// let (q, r) = a.qr();
852    /// // Q is orthogonal, R is upper triangular
853    /// // Q * R ≈ A
854    /// ```
855    pub fn qr(&self) -> (Array, Array) {
856        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
857        let shape = self.shape().as_slice();
858        assert_eq!(shape.len(), 2, "QR decomposition requires 2D array");
859
860        let (m, n) = (shape[0], shape[1]);
861        let data = self.to_vec();
862
863        // Initialize Q and R
864        let mut q = vec![0.0; m * n];
865        let mut r = vec![0.0; n * n];
866
867        // Modified Gram-Schmidt
868        for j in 0..n {
869            // Copy column j of A into v
870            let mut v: Vec<f32> = (0..m).map(|i| data[i * n + j]).collect();
871
872            // Orthogonalize against previous columns
873            for i in 0..j {
874                // r[i,j] = q[:,i] . v
875                let mut dot = 0.0;
876                for k in 0..m {
877                    dot += q[k * n + i] * v[k];
878                }
879                r[i * n + j] = dot;
880
881                // v = v - r[i,j] * q[:,i]
882                for k in 0..m {
883                    v[k] -= dot * q[k * n + i];
884                }
885            }
886
887            // r[j,j] = ||v||
888            let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
889            r[j * n + j] = norm;
890
891            // q[:,j] = v / norm
892            if norm > 1e-10 {
893                for k in 0..m {
894                    q[k * n + j] = v[k] / norm;
895                }
896            }
897        }
898
899        let q_arr = Array::from_vec(q, Shape::new(vec![m, n]));
900        let r_arr = Array::from_vec(r, Shape::new(vec![n, n]));
901
902        (q_arr, r_arr)
903    }
904
905    /// Cholesky decomposition of a symmetric positive-definite matrix.
906    ///
907    /// Decomposes matrix A into L such that A = L * L^T, where L is lower triangular.
908    ///
909    /// # Panics
910    ///
911    /// Panics if the matrix is not square or not positive-definite.
912    ///
913    /// # Examples
914    ///
915    /// ```
916    /// # use jax_rs::{Array, Shape};
917    /// // Symmetric positive-definite matrix
918    /// let a = Array::from_vec(vec![4.0, 2.0, 2.0, 3.0], Shape::new(vec![2, 2]));
919    /// let l = a.cholesky();
920    /// // L * L^T = A
921    /// ```
922    pub fn cholesky(&self) -> Array {
923        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
924        let shape = self.shape().as_slice();
925        assert_eq!(shape.len(), 2, "Cholesky decomposition requires 2D array");
926        assert_eq!(shape[0], shape[1], "Cholesky decomposition requires square matrix");
927
928        let n = shape[0];
929        let data = self.to_vec();
930        let mut l = vec![0.0; n * n];
931
932        for i in 0..n {
933            for j in 0..=i {
934                let mut sum = 0.0;
935
936                if j == i {
937                    // Diagonal element
938                    for k in 0..j {
939                        sum += l[j * n + k] * l[j * n + k];
940                    }
941                    let val = data[j * n + j] - sum;
942                    assert!(val > 0.0, "Matrix is not positive-definite");
943                    l[j * n + j] = val.sqrt();
944                } else {
945                    // Off-diagonal element
946                    for k in 0..j {
947                        sum += l[i * n + k] * l[j * n + k];
948                    }
949                    l[i * n + j] = (data[i * n + j] - sum) / l[j * n + j];
950                }
951            }
952        }
953
954        Array::from_vec(l, Shape::new(vec![n, n]))
955    }
956
957    /// Compute the rank of a matrix.
958    ///
959    /// Uses SVD-like approach (actually QR with tolerance) to estimate rank.
960    ///
961    /// # Examples
962    ///
963    /// ```
964    /// # use jax_rs::{Array, Shape};
965    /// let a = Array::from_vec(vec![1.0, 2.0, 2.0, 4.0], Shape::new(vec![2, 2]));
966    /// let rank = a.matrix_rank();
967    /// assert_eq!(rank, 1); // Rows are linearly dependent
968    /// ```
969    pub fn matrix_rank(&self) -> usize {
970        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
971        let shape = self.shape().as_slice();
972        assert_eq!(shape.len(), 2, "matrix_rank requires 2D array");
973
974        let (m, n) = (shape[0], shape[1]);
975        let tolerance = 1e-10;
976
977        // Use QR decomposition and count non-zero diagonal elements of R
978        let (_, r) = self.qr();
979        let r_data = r.to_vec();
980        let min_dim = m.min(n);
981
982        let mut rank = 0;
983        for i in 0..min_dim {
984            if r_data[i * n + i].abs() > tolerance {
985                rank += 1;
986            }
987        }
988
989        rank
990    }
991
992    /// Compute eigenvalues of a symmetric matrix using the power method.
993    ///
994    /// Returns approximate eigenvalues for symmetric matrices.
995    /// For non-symmetric matrices, results may not be accurate.
996    ///
997    /// # Examples
998    ///
999    /// ```
1000    /// # use jax_rs::{Array, Shape};
1001    /// let a = Array::from_vec(vec![2.0, 1.0, 1.0, 2.0], Shape::new(vec![2, 2]));
1002    /// let eigvals = a.eigvalsh();
1003    /// // Eigenvalues of this symmetric matrix are 1 and 3
1004    /// ```
1005    pub fn eigvalsh(&self) -> Array {
1006        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1007        let shape = self.shape().as_slice();
1008        assert_eq!(shape.len(), 2, "eigvalsh requires 2D array");
1009        assert_eq!(shape[0], shape[1], "eigvalsh requires square matrix");
1010
1011        let n = shape[0];
1012        let mut eigenvalues = Vec::with_capacity(n);
1013        let mut a = self.clone();
1014
1015        // Use deflation with power iteration
1016        for _ in 0..n {
1017            // Power iteration to find largest eigenvalue
1018            let mut v_data = vec![1.0; n];
1019
1020            for _ in 0..100 {
1021                // v = A * v
1022                let v = Array::from_vec(v_data.clone(), Shape::new(vec![n]));
1023                let av = a.matmul(&v.reshape(Shape::new(vec![n, 1])))
1024                    .reshape(Shape::new(vec![n]));
1025                let av_data = av.to_vec();
1026
1027                // Normalize
1028                let norm: f32 = av_data.iter().map(|x| x * x).sum::<f32>().sqrt();
1029                if norm < 1e-10 {
1030                    break;
1031                }
1032                v_data = av_data.iter().map(|x| x / norm).collect();
1033            }
1034
1035            // Rayleigh quotient: λ = (v^T A v) / (v^T v)
1036            let v = Array::from_vec(v_data.clone(), Shape::new(vec![n]));
1037            let av = a.matmul(&v.reshape(Shape::new(vec![n, 1])))
1038                .reshape(Shape::new(vec![n]));
1039            let eigenvalue = v.inner(&av);
1040            eigenvalues.push(eigenvalue);
1041
1042            // Deflate: A = A - λ * v * v^T
1043            let mut a_data = a.to_vec();
1044            for i in 0..n {
1045                for j in 0..n {
1046                    a_data[i * n + j] -= eigenvalue * v_data[i] * v_data[j];
1047                }
1048            }
1049            a = Array::from_vec(a_data, Shape::new(vec![n, n]));
1050        }
1051
1052        Array::from_vec(eigenvalues, Shape::new(vec![n]))
1053    }
1054
1055    /// Compute the pseudo-inverse of a matrix using the Moore-Penrose algorithm.
1056    ///
1057    /// # Examples
1058    ///
1059    /// ```
1060    /// # use jax_rs::{Array, Shape};
1061    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![2, 3]));
1062    /// let pinv = a.pinv();
1063    /// // pinv has shape [3, 2]
1064    /// ```
1065    pub fn pinv(&self) -> Array {
1066        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1067        let shape = self.shape().as_slice();
1068        assert_eq!(shape.len(), 2, "pinv requires 2D array");
1069
1070        let (m, n) = (shape[0], shape[1]);
1071
1072        if m >= n {
1073            // A^+ = (A^T A)^-1 A^T
1074            let at = self.transpose();
1075            let ata = at.matmul(self);
1076            let ata_inv = ata.inv();
1077            ata_inv.matmul(&at)
1078        } else {
1079            // A^+ = A^T (A A^T)^-1
1080            let at = self.transpose();
1081            let aat = self.matmul(&at);
1082            let aat_inv = aat.inv();
1083            at.matmul(&aat_inv)
1084        }
1085    }
1086
1087    /// Compute the condition number of a matrix.
1088    ///
1089    /// Uses the ratio of the largest to smallest singular value estimate.
1090    ///
1091    /// # Examples
1092    ///
1093    /// ```
1094    /// # use jax_rs::{Array, Shape};
1095    /// let a = Array::from_vec(vec![1.0, 0.0, 0.0, 1.0], Shape::new(vec![2, 2]));
1096    /// let cond = a.cond();
1097    /// // Identity matrix has condition number ~1 (within numerical tolerance)
1098    /// assert!((cond - 1.0).abs() < 0.1);
1099    /// ```
1100    pub fn cond(&self) -> f32 {
1101        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1102        let shape = self.shape().as_slice();
1103        assert_eq!(shape.len(), 2, "cond requires 2D array");
1104
1105        // For small matrices, use direct norm-based computation
1106        // cond(A) = ||A|| * ||A^-1||
1107        let n = shape[0];
1108        if n <= 4 {
1109            // Use Frobenius norm for simplicity
1110            let data = self.to_vec();
1111            let norm_a: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
1112
1113            // Compute inverse
1114            let inv = self.inv();
1115            let inv_data = inv.to_vec();
1116            let norm_inv: f32 = inv_data.iter().map(|x| x * x).sum::<f32>().sqrt();
1117
1118            return norm_a * norm_inv / (n as f32); // Normalize by matrix size
1119        }
1120
1121        // Use A^T A eigenvalues to estimate singular values
1122        let at = self.transpose();
1123        let ata = at.matmul(self);
1124        let eigvals = ata.eigvalsh();
1125        let eigvals_data = eigvals.to_vec();
1126
1127        let max_eigval = eigvals_data.iter().fold(0.0_f32, |a, &b| a.max(b.abs()));
1128        let min_eigval = eigvals_data.iter().fold(f32::INFINITY, |a, &b| {
1129            if b.abs() > 1e-10 { a.min(b.abs()) } else { a }
1130        });
1131
1132        if min_eigval < 1e-10 {
1133            f32::INFINITY
1134        } else {
1135            (max_eigval / min_eigval).sqrt()
1136        }
1137    }
1138
1139    /// Compute singular value decomposition (SVD).
1140    ///
1141    /// Returns (U, S, Vt) where A = U @ diag(S) @ Vt.
1142    /// Uses power iteration to find singular values.
1143    ///
1144    /// # Examples
1145    ///
1146    /// ```
1147    /// # use jax_rs::{Array, Shape};
1148    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1149    /// let (u, s, vt) = a.svd();
1150    /// assert_eq!(s.shape().as_slice(), &[2]);
1151    /// ```
1152    pub fn svd(&self) -> (Array, Array, Array) {
1153        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1154        let shape = self.shape().as_slice();
1155        assert_eq!(shape.len(), 2, "svd requires 2D array");
1156
1157        let m = shape[0];
1158        let n = shape[1];
1159        let k = m.min(n);
1160
1161        // Compute A^T A for right singular vectors
1162        let at = self.transpose();
1163        let ata = at.matmul(self);
1164
1165        // Power iteration to get eigenvectors of A^T A (right singular vectors V)
1166        let mut v_data: Vec<Vec<f32>> = Vec::with_capacity(k);
1167        let mut s_values: Vec<f32> = Vec::with_capacity(k);
1168        let ata_data = ata.to_vec();
1169
1170        for _ in 0..k {
1171            // Initialize random vector
1172            let mut v: Vec<f32> = (0..n).map(|i| ((i as f32 + 1.0) * 0.1).sin()).collect();
1173            let mut norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1174            for x in v.iter_mut() { *x /= norm; }
1175
1176            // Power iteration
1177            for _ in 0..50 {
1178                // Multiply by A^T A matrix
1179                let mut av = vec![0.0; n];
1180                for i in 0..n {
1181                    for j in 0..n {
1182                        av[i] += ata_data[i * n + j] * v[j];
1183                    }
1184                }
1185
1186                // Orthogonalize against previous vectors
1187                for prev in &v_data {
1188                    let dot: f32 = av.iter().zip(prev.iter()).map(|(a, b)| a * b).sum();
1189                    for (a, p) in av.iter_mut().zip(prev.iter()) {
1190                        *a -= dot * p;
1191                    }
1192                }
1193
1194                norm = av.iter().map(|x| x * x).sum::<f32>().sqrt();
1195                if norm < 1e-10 { break; }
1196                for (v_i, av_i) in v.iter_mut().zip(av.iter()) {
1197                    *v_i = av_i / norm;
1198                }
1199            }
1200
1201            // Singular value is sqrt of eigenvalue
1202            let eigenvalue = norm;
1203            s_values.push(eigenvalue.sqrt());
1204            v_data.push(v);
1205        }
1206
1207        // Compute U = A @ V @ S^-1
1208        let mut u_data = vec![0.0; m * k];
1209        let a_data = self.to_vec();
1210        for col in 0..k {
1211            if s_values[col] > 1e-10 {
1212                for row in 0..m {
1213                    let mut sum = 0.0;
1214                    for j in 0..n {
1215                        sum += a_data[row * n + j] * v_data[col][j];
1216                    }
1217                    u_data[row * k + col] = sum / s_values[col];
1218                }
1219            }
1220        }
1221
1222        // Build output arrays
1223        let u = Array::from_vec(u_data, Shape::new(vec![m, k]));
1224        let s = Array::from_vec(s_values, Shape::new(vec![k]));
1225        let mut vt_data = vec![0.0; k * n];
1226        for i in 0..k {
1227            for j in 0..n {
1228                vt_data[i * n + j] = v_data[i][j];
1229            }
1230        }
1231        let vt = Array::from_vec(vt_data, Shape::new(vec![k, n]));
1232
1233        (u, s, vt)
1234    }
1235
1236    /// Solve least squares problem: minimize ||Ax - b||^2.
1237    ///
1238    /// Returns the solution x that minimizes the squared error.
1239    ///
1240    /// # Examples
1241    ///
1242    /// ```
1243    /// # use jax_rs::{Array, Shape};
1244    /// let a = Array::from_vec(vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0], Shape::new(vec![3, 2]));
1245    /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1246    /// let x = a.lstsq(&b);
1247    /// assert_eq!(x.shape().as_slice(), &[2]);
1248    /// ```
1249    pub fn lstsq(&self, b: &Array) -> Array {
1250        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1251        let shape = self.shape().as_slice();
1252        assert_eq!(shape.len(), 2, "lstsq requires 2D matrix A");
1253
1254        // Use normal equations: x = (A^T A)^-1 A^T b
1255        let at = self.transpose();
1256        let ata = at.matmul(self);
1257        let atb = at.matmul(b);
1258        ata.solve(&atb)
1259    }
1260
1261    /// Compute eigenvalues and eigenvectors of a symmetric matrix.
1262    ///
1263    /// Returns (eigenvalues, eigenvectors) where each column of eigenvectors
1264    /// is an eigenvector corresponding to the eigenvalue at the same index.
1265    ///
1266    /// # Examples
1267    ///
1268    /// ```
1269    /// # use jax_rs::{Array, Shape};
1270    /// let a = Array::from_vec(vec![2.0, 1.0, 1.0, 2.0], Shape::new(vec![2, 2]));
1271    /// let (vals, vecs) = a.eigh();
1272    /// assert_eq!(vals.shape().as_slice(), &[2]);
1273    /// assert_eq!(vecs.shape().as_slice(), &[2, 2]);
1274    /// ```
1275    pub fn eigh(&self) -> (Array, Array) {
1276        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1277        let shape = self.shape().as_slice();
1278        assert_eq!(shape.len(), 2, "eigh requires 2D array");
1279        assert_eq!(shape[0], shape[1], "eigh requires square matrix");
1280
1281        let n = shape[0];
1282        let mut a_data = self.to_vec();
1283        let mut eigenvectors = vec![0.0; n * n];
1284
1285        // Initialize eigenvectors as identity
1286        for i in 0..n {
1287            eigenvectors[i * n + i] = 1.0;
1288        }
1289
1290        // Jacobi eigenvalue algorithm
1291        for _ in 0..100 {
1292            // Find largest off-diagonal element
1293            let mut max_val = 0.0_f32;
1294            let mut p = 0;
1295            let mut q = 1;
1296            for i in 0..n {
1297                for j in (i + 1)..n {
1298                    if a_data[i * n + j].abs() > max_val {
1299                        max_val = a_data[i * n + j].abs();
1300                        p = i;
1301                        q = j;
1302                    }
1303                }
1304            }
1305
1306            if max_val < 1e-10 { break; }
1307
1308            // Compute rotation angle
1309            let diff = a_data[q * n + q] - a_data[p * n + p];
1310            let t = if diff.abs() < 1e-10 {
1311                1.0
1312            } else {
1313                let phi = diff / (2.0 * a_data[p * n + q]);
1314                1.0 / (phi.abs() + (phi * phi + 1.0).sqrt()) * phi.signum()
1315            };
1316            let c = 1.0 / (1.0 + t * t).sqrt();
1317            let s = t * c;
1318
1319            // Apply rotation to A
1320            let app = a_data[p * n + p];
1321            let aqq = a_data[q * n + q];
1322            let apq = a_data[p * n + q];
1323
1324            a_data[p * n + p] = c * c * app - 2.0 * s * c * apq + s * s * aqq;
1325            a_data[q * n + q] = s * s * app + 2.0 * s * c * apq + c * c * aqq;
1326            a_data[p * n + q] = 0.0;
1327            a_data[q * n + p] = 0.0;
1328
1329            for i in 0..n {
1330                if i != p && i != q {
1331                    let aip = a_data[i * n + p];
1332                    let aiq = a_data[i * n + q];
1333                    a_data[i * n + p] = c * aip - s * aiq;
1334                    a_data[p * n + i] = a_data[i * n + p];
1335                    a_data[i * n + q] = s * aip + c * aiq;
1336                    a_data[q * n + i] = a_data[i * n + q];
1337                }
1338            }
1339
1340            // Update eigenvectors
1341            for i in 0..n {
1342                let vip = eigenvectors[i * n + p];
1343                let viq = eigenvectors[i * n + q];
1344                eigenvectors[i * n + p] = c * vip - s * viq;
1345                eigenvectors[i * n + q] = s * vip + c * viq;
1346            }
1347        }
1348
1349        // Extract eigenvalues from diagonal
1350        let eigenvalues: Vec<f32> = (0..n).map(|i| a_data[i * n + i]).collect();
1351
1352        (
1353            Array::from_vec(eigenvalues, Shape::new(vec![n])),
1354            Array::from_vec(eigenvectors, Shape::new(vec![n, n])),
1355        )
1356    }
1357
1358    /// Compute eigenvalues of a general (non-symmetric) matrix.
1359    ///
1360    /// Uses QR iteration to find eigenvalues.
1361    ///
1362    /// # Examples
1363    ///
1364    /// ```
1365    /// # use jax_rs::{Array, Shape};
1366    /// let a = Array::from_vec(vec![1.0, 2.0, 0.0, 3.0], Shape::new(vec![2, 2]));
1367    /// let eigvals = a.eig();
1368    /// assert_eq!(eigvals.shape().as_slice(), &[2]);
1369    /// ```
1370    pub fn eig(&self) -> Array {
1371        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1372        let shape = self.shape().as_slice();
1373        assert_eq!(shape.len(), 2, "eig requires 2D array");
1374        assert_eq!(shape[0], shape[1], "eig requires square matrix");
1375
1376        let n = shape[0];
1377        let mut a = self.clone();
1378
1379        // QR iteration
1380        for _ in 0..100 {
1381            let (q, r) = a.qr();
1382            a = r.matmul(&q);
1383        }
1384
1385        // Extract eigenvalues from diagonal
1386        let a_data = a.to_vec();
1387        let eigenvalues: Vec<f32> = (0..n).map(|i| a_data[i * n + i]).collect();
1388
1389        Array::from_vec(eigenvalues, Shape::new(vec![n]))
1390    }
1391
1392    /// Compute the tensor dot product along specified axes.
1393    ///
1394    /// # Examples
1395    ///
1396    /// ```
1397    /// # use jax_rs::{Array, Shape};
1398    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1399    /// let b = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1400    /// let c = a.tensordot(&b, 1);
1401    /// assert_eq!(c.shape().as_slice(), &[2, 2]);
1402    /// ```
1403    pub fn tensordot(&self, other: &Array, axes: usize) -> Array {
1404        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1405        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
1406
1407        let a_shape = self.shape().as_slice();
1408        let b_shape = other.shape().as_slice();
1409
1410        // Contract last `axes` dimensions of self with first `axes` of other
1411        assert!(axes <= a_shape.len() && axes <= b_shape.len());
1412
1413        // Reshape to 2D and use matmul
1414        let a_outer: usize = a_shape[..a_shape.len() - axes].iter().product();
1415        let a_inner: usize = a_shape[a_shape.len() - axes..].iter().product();
1416        let b_inner: usize = b_shape[..axes].iter().product();
1417        let b_outer: usize = b_shape[axes..].iter().product();
1418
1419        assert_eq!(a_inner, b_inner, "Contracted dimensions must match");
1420
1421        let a_2d = self.reshape(Shape::new(vec![a_outer, a_inner]));
1422        let b_2d = other.reshape(Shape::new(vec![b_inner, b_outer]));
1423
1424        let result = a_2d.matmul(&b_2d);
1425
1426        // Build output shape
1427        let mut out_shape = a_shape[..a_shape.len() - axes].to_vec();
1428        out_shape.extend_from_slice(&b_shape[axes..]);
1429        if out_shape.is_empty() {
1430            out_shape.push(1);
1431        }
1432
1433        result.reshape(Shape::new(out_shape))
1434    }
1435
1436    /// Compute the Kronecker product of two arrays.
1437    ///
1438    /// # Examples
1439    ///
1440    /// ```
1441    /// # use jax_rs::{Array, Shape};
1442    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1443    /// let b = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1444    /// let c = a.kron(&b);
1445    /// assert_eq!(c.shape().as_slice(), &[4, 4]);
1446    /// ```
1447    pub fn kron(&self, other: &Array) -> Array {
1448        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1449        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
1450
1451        let a_shape = self.shape().as_slice();
1452        let b_shape = other.shape().as_slice();
1453
1454        // For 2D arrays
1455        if a_shape.len() == 2 && b_shape.len() == 2 {
1456            let (m, n) = (a_shape[0], a_shape[1]);
1457            let (p, q) = (b_shape[0], b_shape[1]);
1458            let a_data = self.to_vec();
1459            let b_data = other.to_vec();
1460
1461            let mut result = vec![0.0; m * p * n * q];
1462            for i in 0..m {
1463                for j in 0..n {
1464                    for k in 0..p {
1465                        for l in 0..q {
1466                            let out_row = i * p + k;
1467                            let out_col = j * q + l;
1468                            result[out_row * (n * q) + out_col] =
1469                                a_data[i * n + j] * b_data[k * q + l];
1470                        }
1471                    }
1472                }
1473            }
1474
1475            Array::from_vec(result, Shape::new(vec![m * p, n * q]))
1476        } else {
1477            // 1D case
1478            let a_data = self.to_vec();
1479            let b_data = other.to_vec();
1480            let mut result = Vec::with_capacity(a_data.len() * b_data.len());
1481            for &a in &a_data {
1482                for &b in &b_data {
1483                    result.push(a * b);
1484                }
1485            }
1486            Array::from_vec(result, Shape::new(vec![a_data.len() * b_data.len()]))
1487        }
1488    }
1489}
1490
1491/// Helper function for n-dimensional transpose.
1492fn transpose_nd(array: &Array, new_shape: Shape) -> Array {
1493    let old_dims = array.shape().as_slice();
1494    let data = array.to_vec();
1495
1496    let size = array.size();
1497    let mut result = vec![0.0; size];
1498
1499    // Compute strides for old and new layouts
1500    let old_strides = array.shape().default_strides();
1501    let new_strides = new_shape.default_strides();
1502
1503    for flat_idx in 0..size {
1504        // Convert flat index to multi-dimensional for old layout
1505        let mut old_multi = vec![0; old_dims.len()];
1506        let mut idx = flat_idx;
1507        for (i, &stride) in old_strides.iter().enumerate() {
1508            old_multi[i] = idx / stride;
1509            idx %= stride;
1510        }
1511
1512        // Reverse to get new multi-dimensional indices
1513        let new_multi: Vec<usize> = old_multi.iter().rev().copied().collect();
1514
1515        // Convert new multi-dimensional to flat index
1516        let new_flat: usize = new_multi
1517            .iter()
1518            .zip(new_strides.iter())
1519            .map(|(idx, stride)| idx * stride)
1520            .sum();
1521
1522        result[new_flat] = data[flat_idx];
1523    }
1524
1525    let buffer = Buffer::from_f32(result, Device::Cpu);
1526    Array::from_buffer(buffer, new_shape)
1527}
1528
1529#[cfg(test)]
1530mod tests {
1531    use super::*;
1532
1533    #[test]
1534    fn test_transpose_2d() {
1535        let a = Array::from_vec(
1536            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1537            Shape::new(vec![2, 3]),
1538        );
1539        let b = a.transpose();
1540        assert_eq!(b.shape().as_slice(), &[3, 2]);
1541        assert_eq!(b.to_vec(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1542    }
1543
1544    #[test]
1545    fn test_transpose_1d() {
1546        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1547        let b = a.transpose();
1548        assert_eq!(b.shape().as_slice(), &[3]);
1549        assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
1550    }
1551
1552    #[test]
1553    fn test_matmul_2d() {
1554        let a =
1555            Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1556        let b =
1557            Array::from_vec(vec![5.0, 6.0, 7.0, 8.0], Shape::new(vec![2, 2]));
1558        let c = a.matmul(&b);
1559        assert_eq!(c.shape().as_slice(), &[2, 2]);
1560        // [[1, 2], [3, 4]] @ [[5, 6], [7, 8]]
1561        // = [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
1562        // = [[19, 22], [43, 50]]
1563        assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1564    }
1565
1566    #[test]
1567    fn test_matmul_non_square() {
1568        let a = Array::from_vec(
1569            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1570            Shape::new(vec![2, 3]),
1571        );
1572        let b = Array::from_vec(
1573            vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
1574            Shape::new(vec![3, 2]),
1575        );
1576        let c = a.matmul(&b);
1577        assert_eq!(c.shape().as_slice(), &[2, 2]);
1578        // [[1, 2, 3], [4, 5, 6]] @ [[7, 8], [9, 10], [11, 12]]
1579        // = [[1*7+2*9+3*11, 1*8+2*10+3*12], [4*7+5*9+6*11, 4*8+5*10+6*12]]
1580        // = [[58, 64], [139, 154]]
1581        assert_eq!(c.to_vec(), vec![58.0, 64.0, 139.0, 154.0]);
1582    }
1583
1584    #[test]
1585    fn test_matmul_vector() {
1586        // Matrix-vector multiplication
1587        let a =
1588            Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1589        let v = Array::from_vec(vec![5.0, 6.0], Shape::new(vec![2]));
1590        let c = a.matmul(&v);
1591        assert_eq!(c.shape().as_slice(), &[2]);
1592        // [[1, 2], [3, 4]] @ [5, 6] = [1*5+2*6, 3*5+4*6] = [17, 39]
1593        assert_eq!(c.to_vec(), vec![17.0, 39.0]);
1594    }
1595
1596    #[test]
1597    fn test_dot_1d() {
1598        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1599        let b = Array::from_vec(vec![4.0, 5.0, 6.0], Shape::new(vec![3]));
1600        let c = a.dot(&b);
1601        assert!(c.is_scalar());
1602        assert_eq!(c.to_vec(), vec![32.0]); // 1*4 + 2*5 + 3*6 = 32
1603    }
1604
1605    #[test]
1606    fn test_dot_2d() {
1607        let a =
1608            Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1609        let b =
1610            Array::from_vec(vec![5.0, 6.0, 7.0, 8.0], Shape::new(vec![2, 2]));
1611        let c = a.dot(&b);
1612        // For 2D, dot is same as matmul
1613        assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1614    }
1615
1616    #[test]
1617    #[should_panic(expected = "Incompatible shapes")]
1618    fn test_matmul_incompatible() {
1619        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![1, 3]));
1620        let b = Array::from_vec(vec![4.0, 5.0], Shape::new(vec![2, 1]));
1621        let _c = a.matmul(&b);
1622    }
1623
1624    #[test]
1625    fn test_outer() {
1626        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1627        let b = Array::from_vec(vec![4.0, 5.0], Shape::new(vec![2]));
1628        let c = a.outer(&b);
1629        assert_eq!(c.shape().as_slice(), &[3, 2]);
1630        // [[1*4, 1*5], [2*4, 2*5], [3*4, 3*5]]
1631        assert_eq!(c.to_vec(), vec![4.0, 5.0, 8.0, 10.0, 12.0, 15.0]);
1632    }
1633
1634    #[test]
1635    fn test_outer_square() {
1636        let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
1637        let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
1638        let c = a.outer(&b);
1639        assert_eq!(c.shape().as_slice(), &[2, 2]);
1640        // [[1*3, 1*4], [2*3, 2*4]] = [[3, 4], [6, 8]]
1641        assert_eq!(c.to_vec(), vec![3.0, 4.0, 6.0, 8.0]);
1642    }
1643
1644    #[test]
1645    fn test_inner() {
1646        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1647        let b = Array::from_vec(vec![4.0, 5.0, 6.0], Shape::new(vec![3]));
1648        let result = a.inner(&b);
1649        assert_eq!(result, 32.0); // 1*4 + 2*5 + 3*6 = 32
1650    }
1651
1652    #[test]
1653    fn test_inner_zeros() {
1654        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1655        let b = Array::from_vec(vec![0.0, 0.0, 0.0], Shape::new(vec![3]));
1656        let result = a.inner(&b);
1657        assert_eq!(result, 0.0);
1658    }
1659
1660    #[test]
1661    fn test_cross_basic() {
1662        // i × j = k
1663        let i = Array::from_vec(vec![1.0, 0.0, 0.0], Shape::new(vec![3]));
1664        let j = Array::from_vec(vec![0.0, 1.0, 0.0], Shape::new(vec![3]));
1665        let k = i.cross(&j);
1666        assert_eq!(k.to_vec(), vec![0.0, 0.0, 1.0]);
1667    }
1668
1669    #[test]
1670    fn test_cross_general() {
1671        let a = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
1672        let b = Array::from_vec(vec![5.0, 6.0, 7.0], Shape::new(vec![3]));
1673        let c = a.cross(&b);
1674        // [3*7 - 4*6, 4*5 - 2*7, 2*6 - 3*5]
1675        // = [21 - 24, 20 - 14, 12 - 15]
1676        // = [-3, 6, -3]
1677        assert_eq!(c.to_vec(), vec![-3.0, 6.0, -3.0]);
1678    }
1679
1680    #[test]
1681    fn test_cross_anticommutative() {
1682        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1683        let b = Array::from_vec(vec![4.0, 5.0, 6.0], Shape::new(vec![3]));
1684        let c1 = a.cross(&b);
1685        let c2 = b.cross(&a);
1686        // a × b = -(b × a)
1687        let c2_neg = c2.neg();
1688        assert_eq!(c1.to_vec(), c2_neg.to_vec());
1689    }
1690
1691    #[test]
1692    fn test_trace() {
1693        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1694        let tr = a.trace();
1695        assert_eq!(tr, 5.0); // 1 + 4
1696
1697        let b = Array::from_vec(
1698            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1699            Shape::new(vec![3, 3]),
1700        );
1701        let tr_b = b.trace();
1702        assert_eq!(tr_b, 15.0); // 1 + 5 + 9
1703    }
1704
1705    #[test]
1706    fn test_diagonal_square() {
1707        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1708        let diag = a.diagonal();
1709        assert_eq!(diag.to_vec(), vec![1.0, 4.0]);
1710    }
1711
1712    #[test]
1713    fn test_diagonal_rectangular() {
1714        // 2x3 matrix
1715        let a = Array::from_vec(
1716            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1717            Shape::new(vec![2, 3]),
1718        );
1719        let diag = a.diagonal();
1720        assert_eq!(diag.to_vec(), vec![1.0, 5.0]); // min(2, 3) = 2 elements
1721
1722        // 3x2 matrix
1723        let b = Array::from_vec(
1724            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1725            Shape::new(vec![3, 2]),
1726        );
1727        let diag_b = b.diagonal();
1728        assert_eq!(diag_b.to_vec(), vec![1.0, 4.0]); // min(3, 2) = 2 elements
1729    }
1730}