1use crate::{buffer::Buffer, Array, DType, Device, Shape};
4
5impl Array {
6 pub fn transpose(&self) -> Array {
21 let shape = self.shape();
22 let dims = shape.as_slice();
23
24 if dims.len() <= 1 {
25 return self.clone();
27 }
28
29 let new_dims: Vec<usize> = dims.iter().rev().copied().collect();
31 let new_shape = Shape::new(new_dims);
32
33 if dims.len() == 2 {
35 let (rows, cols) = (dims[0], dims[1]);
36 let data = self.to_vec();
37 let mut transposed = vec![0.0; data.len()];
38
39 for i in 0..rows {
40 for j in 0..cols {
41 transposed[j * rows + i] = data[i * cols + j];
42 }
43 }
44
45 let buffer = Buffer::from_f32(transposed, Device::Cpu);
46 return Array::from_buffer(buffer, new_shape);
47 }
48
49 transpose_nd(self, new_shape)
51 }
52
53 pub fn transpose_axes(&self, axes: &[usize]) -> Array {
68 let shape = self.shape();
69 let dims = shape.as_slice();
70 let ndim = dims.len();
71
72 assert_eq!(axes.len(), ndim, "axes must have same length as dimensions");
73
74 let mut seen = vec![false; ndim];
76 for &axis in axes {
77 assert!(axis < ndim, "axis {} out of bounds for {} dimensions", axis, ndim);
78 assert!(!seen[axis], "duplicate axis in permutation");
79 seen[axis] = true;
80 }
81
82 if axes.iter().enumerate().all(|(i, &a)| i == a) {
84 return self.clone();
85 }
86
87 let new_dims: Vec<usize> = axes.iter().map(|&a| dims[a]).collect();
89 let new_shape = Shape::new(new_dims.clone());
90
91 if ndim == 2 && axes == [1, 0] {
93 return self.transpose();
94 }
95
96 let data = self.to_vec();
98 let size = data.len();
99 let mut result = vec![0.0; size];
100
101 let mut old_strides = vec![1usize; ndim];
103 for i in (0..ndim - 1).rev() {
104 old_strides[i] = old_strides[i + 1] * dims[i + 1];
105 }
106
107 let mut new_strides = vec![1usize; ndim];
109 for i in (0..ndim - 1).rev() {
110 new_strides[i] = new_strides[i + 1] * new_dims[i + 1];
111 }
112
113 let perm_strides: Vec<usize> = axes.iter().map(|&a| old_strides[a]).collect();
115
116 for new_idx in 0..size {
118 let mut remaining = new_idx;
120 let mut old_idx = 0;
121 for i in 0..ndim {
122 let coord = remaining / new_strides[i];
123 remaining %= new_strides[i];
124 old_idx += coord * perm_strides[i];
125 }
126 result[new_idx] = data[old_idx];
127 }
128
129 let buffer = Buffer::from_f32(result, Device::Cpu);
130 Array::from_buffer(buffer, new_shape)
131 }
132
133 pub fn matmul(&self, other: &Array) -> Array {
155 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
156 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
157
158 let a_shape = self.shape().as_slice();
159 let b_shape = other.shape().as_slice();
160
161 if a_shape.len() == 1 && b_shape.len() == 2 {
163 assert_eq!(
165 a_shape[0], b_shape[0],
166 "Vector-matrix multiplication: incompatible shapes"
167 );
168 return self
169 .reshape(Shape::new(vec![1, a_shape[0]]))
170 .matmul(other)
171 .reshape(Shape::new(vec![b_shape[1]]));
172 }
173
174 if a_shape.len() == 2 && b_shape.len() == 1 {
175 assert_eq!(
177 a_shape[1], b_shape[0],
178 "Matrix-vector multiplication: incompatible shapes"
179 );
180 return self
181 .matmul(&other.reshape(Shape::new(vec![b_shape[0], 1])))
182 .reshape(Shape::new(vec![a_shape[0]]));
183 }
184
185 assert_eq!(a_shape.len(), 2, "Left array must be 2D");
187 assert_eq!(b_shape.len(), 2, "Right array must be 2D");
188 assert_eq!(
189 a_shape[1], b_shape[0],
190 "Incompatible shapes for matmul: {:?} @ {:?}",
191 a_shape, b_shape
192 );
193
194 let (m, k) = (a_shape[0], a_shape[1]);
195 let n = b_shape[1];
196
197 match (self.device(), other.device()) {
199 (Device::WebGpu, Device::WebGpu) => {
200 let output_buffer = Buffer::zeros(m * n, DType::Float32, Device::WebGpu);
202
203 crate::backend::ops::gpu_matmul(
204 self.buffer(),
205 other.buffer(),
206 &output_buffer,
207 m,
208 n,
209 k,
210 );
211
212 Array::from_buffer(output_buffer, Shape::new(vec![m, n]))
213 }
214 (Device::Cpu, Device::Cpu) | (Device::Wasm, Device::Wasm) => {
215 let a_data = self.to_vec();
217 let b_data = other.to_vec();
218 let mut result = vec![0.0; m * n];
219
220 for i in 0..m {
221 for j in 0..n {
222 let mut sum = 0.0;
223 for p in 0..k {
224 sum += a_data[i * k + p] * b_data[p * n + j];
225 }
226 result[i * n + j] = sum;
227 }
228 }
229
230 let buffer = Buffer::from_f32(result, Device::Cpu);
231 Array::from_buffer(buffer, Shape::new(vec![m, n]))
232 }
233 _ => {
234 panic!("Mixed device operations not supported. Both arrays must be on the same device.");
235 }
236 }
237 }
238
239 pub fn dot(&self, other: &Array) -> Array {
254 let a_shape = self.shape().as_slice();
255 let b_shape = other.shape().as_slice();
256
257 if a_shape.len() == 1 && b_shape.len() == 1 {
259 assert_eq!(
260 a_shape[0], b_shape[0],
261 "Arrays must have same length for dot product"
262 );
263
264 let a_data = self.to_vec();
265 let b_data = other.to_vec();
266 let result: f32 =
267 a_data.iter().zip(b_data.iter()).map(|(a, b)| a * b).sum();
268
269 let buffer = Buffer::from_f32(vec![result], Device::Cpu);
270 return Array::from_buffer(buffer, Shape::scalar());
271 }
272
273 self.matmul(other)
275 }
276
277 pub fn norm(&self, ord: f32) -> f32 {
295 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
296 let data = self.to_vec();
297
298 if ord == f32::INFINITY {
299 data.iter().map(|x| x.abs()).fold(0.0, f32::max)
301 } else if ord == 1.0 {
302 data.iter().map(|x| x.abs()).sum()
304 } else if ord == 2.0 {
305 data.iter().map(|x| x * x).sum::<f32>().sqrt()
307 } else {
308 data.iter()
310 .map(|x| x.abs().powf(ord))
311 .sum::<f32>()
312 .powf(1.0 / ord)
313 }
314 }
315
316 pub fn det(&self) -> f32 {
329 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
330 let shape = self.shape().as_slice();
331 assert_eq!(shape.len(), 2, "Determinant requires 2D array");
332 assert_eq!(shape[0], shape[1], "Determinant requires square matrix");
333
334 let n = shape[0];
335 let data = self.to_vec();
336
337 match n {
338 1 => data[0],
339 2 => {
340 data[0] * data[3] - data[1] * data[2]
342 }
343 3 => {
344 let a = data[0];
346 let b = data[1];
347 let c = data[2];
348 let d = data[3];
349 let e = data[4];
350 let f = data[5];
351 let g = data[6];
352 let h = data[7];
353 let i = data[8];
354 a * e * i + b * f * g + c * d * h - c * e * g - b * d * i - a * f * h
355 }
356 _ => {
357 let (_, u, p) = self.lu_decomposition();
359 let u_data = u.to_vec();
360
361 let mut det_u = 1.0;
366 for i in 0..n {
367 det_u *= u_data[i * n + i];
368 }
369
370 let mut swaps = 0;
372 let mut visited = vec![false; n];
373 for i in 0..n {
374 if !visited[i] {
375 let mut j = i;
376 let mut cycle_len = 0;
377 while !visited[j] {
378 visited[j] = true;
379 j = p[j];
380 cycle_len += 1;
381 }
382 if cycle_len > 1 {
383 swaps += cycle_len - 1;
384 }
385 }
386 }
387
388 if swaps % 2 == 0 {
389 det_u
390 } else {
391 -det_u
392 }
393 }
394 }
395 }
396
397 fn lu_decomposition(&self) -> (Array, Array, Vec<usize>) {
404 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
405 let shape = self.shape().as_slice();
406 assert_eq!(shape.len(), 2, "LU decomposition requires 2D array");
407 assert_eq!(shape[0], shape[1], "LU decomposition requires square matrix");
408
409 let n = shape[0];
410 let data = self.to_vec();
411
412 let mut p: Vec<usize> = (0..n).collect();
414 let mut a = data.clone();
415
416 for k in 0..n {
417 let mut pivot_row = k;
419 let mut max_val = a[k * n + k].abs();
420 for i in (k + 1)..n {
421 let val = a[i * n + k].abs();
422 if val > max_val {
423 max_val = val;
424 pivot_row = i;
425 }
426 }
427
428 if pivot_row != k {
430 p.swap(k, pivot_row);
431 for j in 0..n {
432 a.swap(k * n + j, pivot_row * n + j);
433 }
434 }
435
436 for i in (k + 1)..n {
438 let factor = a[i * n + k] / a[k * n + k];
439 a[i * n + k] = factor; for j in (k + 1)..n {
441 a[i * n + j] -= factor * a[k * n + j];
442 }
443 }
444 }
445
446 let mut l_data = vec![0.0; n * n];
448 let mut u_data = vec![0.0; n * n];
449
450 for i in 0..n {
451 for j in 0..n {
452 if i > j {
453 l_data[i * n + j] = a[i * n + j];
454 } else if i == j {
455 l_data[i * n + j] = 1.0;
456 u_data[i * n + j] = a[i * n + j];
457 } else {
458 u_data[i * n + j] = a[i * n + j];
459 }
460 }
461 }
462
463 let l = Array::from_vec(l_data, Shape::new(vec![n, n]));
464 let u = Array::from_vec(u_data, Shape::new(vec![n, n]));
465
466 (l, u, p)
467 }
468
469 pub fn inv(&self) -> Array {
487 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
488 let shape = self.shape().as_slice();
489 assert_eq!(shape.len(), 2, "Matrix inversion requires 2D array");
490 assert_eq!(shape[0], shape[1], "Matrix inversion requires square matrix");
491
492 let n = shape[0];
493 let data = self.to_vec();
494
495 let mut aug = vec![0.0; n * 2 * n];
497 for i in 0..n {
498 for j in 0..n {
499 aug[i * 2 * n + j] = data[i * n + j];
500 }
501 aug[i * 2 * n + n + i] = 1.0; }
503
504 for k in 0..n {
506 let mut pivot_row = k;
508 let mut max_val = aug[k * 2 * n + k].abs();
509 for i in (k + 1)..n {
510 let val = aug[i * 2 * n + k].abs();
511 if val > max_val {
512 max_val = val;
513 pivot_row = i;
514 }
515 }
516
517 assert!(
518 max_val > 1e-10,
519 "Matrix is singular and cannot be inverted"
520 );
521
522 if pivot_row != k {
524 for j in 0..(2 * n) {
525 aug.swap(k * 2 * n + j, pivot_row * 2 * n + j);
526 }
527 }
528
529 let pivot = aug[k * 2 * n + k];
531 for j in 0..(2 * n) {
532 aug[k * 2 * n + j] /= pivot;
533 }
534
535 for i in 0..n {
537 if i != k {
538 let factor = aug[i * 2 * n + k];
539 for j in 0..(2 * n) {
540 aug[i * 2 * n + j] -= factor * aug[k * 2 * n + j];
541 }
542 }
543 }
544 }
545
546 let mut inv_data = vec![0.0; n * n];
548 for i in 0..n {
549 for j in 0..n {
550 inv_data[i * n + j] = aug[i * 2 * n + n + j];
551 }
552 }
553
554 Array::from_vec(inv_data, Shape::new(vec![n, n]))
555 }
556
557 pub fn solve(&self, b: &Array) -> Array {
574 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
575 assert_eq!(b.dtype(), DType::Float32, "Only Float32 supported");
576
577 let a_shape = self.shape().as_slice();
578 let b_shape = b.shape().as_slice();
579
580 assert_eq!(a_shape.len(), 2, "A must be 2D");
581 assert_eq!(a_shape[0], a_shape[1], "A must be square");
582 assert_eq!(b_shape.len(), 1, "b must be 1D");
583 assert_eq!(a_shape[0], b_shape[0], "Incompatible dimensions");
584
585 let n = a_shape[0];
586 let a_data = self.to_vec();
587 let b_data = b.to_vec();
588
589 let mut aug = vec![0.0; n * (n + 1)];
591 for i in 0..n {
592 for j in 0..n {
593 aug[i * (n + 1) + j] = a_data[i * n + j];
594 }
595 aug[i * (n + 1) + n] = b_data[i];
596 }
597
598 for k in 0..n {
600 let mut pivot_row = k;
602 let mut max_val = aug[k * (n + 1) + k].abs();
603 for i in (k + 1)..n {
604 let val = aug[i * (n + 1) + k].abs();
605 if val > max_val {
606 max_val = val;
607 pivot_row = i;
608 }
609 }
610
611 assert!(
612 max_val > 1e-10,
613 "Matrix is singular, system has no unique solution"
614 );
615
616 if pivot_row != k {
618 for j in 0..(n + 1) {
619 aug.swap(k * (n + 1) + j, pivot_row * (n + 1) + j);
620 }
621 }
622
623 for i in (k + 1)..n {
625 let factor = aug[i * (n + 1) + k] / aug[k * (n + 1) + k];
626 for j in k..(n + 1) {
627 aug[i * (n + 1) + j] -= factor * aug[k * (n + 1) + j];
628 }
629 }
630 }
631
632 let mut x = vec![0.0; n];
634 for i in (0..n).rev() {
635 let mut sum = aug[i * (n + 1) + n];
636 for j in (i + 1)..n {
637 sum -= aug[i * (n + 1) + j] * x[j];
638 }
639 x[i] = sum / aug[i * (n + 1) + i];
640 }
641
642 Array::from_vec(x, Shape::new(vec![n]))
643 }
644
645 pub fn outer(&self, other: &Array) -> Array {
663 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
664 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
665
666 let a_shape = self.shape().as_slice();
667 let b_shape = other.shape().as_slice();
668
669 assert_eq!(a_shape.len(), 1, "First array must be 1D");
670 assert_eq!(b_shape.len(), 1, "Second array must be 1D");
671
672 let a_data = self.to_vec();
673 let b_data = other.to_vec();
674 let m = a_shape[0];
675 let n = b_shape[0];
676
677 let mut result = Vec::with_capacity(m * n);
678 for &a_val in a_data.iter() {
679 for &b_val in b_data.iter() {
680 result.push(a_val * b_val);
681 }
682 }
683
684 Array::from_vec(result, Shape::new(vec![m, n]))
685 }
686
687 pub fn inner(&self, other: &Array) -> f32 {
701 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
702 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
703
704 let a_shape = self.shape().as_slice();
705 let b_shape = other.shape().as_slice();
706
707 assert_eq!(a_shape.len(), 1, "First array must be 1D");
708 assert_eq!(b_shape.len(), 1, "Second array must be 1D");
709 assert_eq!(
710 a_shape[0], b_shape[0],
711 "Arrays must have same length for inner product"
712 );
713
714 let a_data = self.to_vec();
715 let b_data = other.to_vec();
716
717 a_data.iter().zip(b_data.iter()).map(|(a, b)| a * b).sum()
718 }
719
720 pub fn cross(&self, other: &Array) -> Array {
735 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
736 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
737
738 let a_shape = self.shape().as_slice();
739 let b_shape = other.shape().as_slice();
740
741 assert_eq!(a_shape.len(), 1, "First array must be 1D");
742 assert_eq!(b_shape.len(), 1, "Second array must be 1D");
743 assert_eq!(a_shape[0], 3, "Cross product requires 3D vectors");
744 assert_eq!(b_shape[0], 3, "Cross product requires 3D vectors");
745
746 let a = self.to_vec();
747 let b = other.to_vec();
748
749 let result = vec![
750 a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0], ];
754
755 Array::from_vec(result, Shape::new(vec![3]))
756 }
757
758 pub fn trace(&self) -> f32 {
769 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
770 let shape = self.shape().as_slice();
771 assert_eq!(shape.len(), 2, "Trace requires 2D array");
772 assert_eq!(shape[0], shape[1], "Trace requires square matrix");
773
774 let n = shape[0];
775 let data = self.to_vec();
776
777 let mut sum = 0.0;
778 for i in 0..n {
779 sum += data[i * n + i];
780 }
781 sum
782 }
783
784 pub fn diagonal(&self) -> Array {
795 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
796 let shape = self.shape().as_slice();
797 assert_eq!(shape.len(), 2, "Diagonal requires 2D array");
798
799 let (rows, cols) = (shape[0], shape[1]);
800 let diag_len = rows.min(cols);
801 let data = self.to_vec();
802
803 let mut result = Vec::with_capacity(diag_len);
804 for i in 0..diag_len {
805 result.push(data[i * cols + i]);
806 }
807
808 Array::from_vec(result, Shape::new(vec![diag_len]))
809 }
810
811 pub fn vander(&self, n: usize) -> Array {
824 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
825 let shape = self.shape().as_slice();
826 assert_eq!(shape.len(), 1, "vander() only supports 1D arrays");
827
828 let x = self.to_vec();
829 let m = x.len();
830 let mut result = Vec::with_capacity(m * n);
831
832 for &val in x.iter() {
833 for pow in 0..n {
834 result.push(val.powi(pow as i32));
835 }
836 }
837
838 Array::from_vec(result, Shape::new(vec![m, n]))
839 }
840
841 pub fn qr(&self) -> (Array, Array) {
856 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
857 let shape = self.shape().as_slice();
858 assert_eq!(shape.len(), 2, "QR decomposition requires 2D array");
859
860 let (m, n) = (shape[0], shape[1]);
861 let data = self.to_vec();
862
863 let mut q = vec![0.0; m * n];
865 let mut r = vec![0.0; n * n];
866
867 for j in 0..n {
869 let mut v: Vec<f32> = (0..m).map(|i| data[i * n + j]).collect();
871
872 for i in 0..j {
874 let mut dot = 0.0;
876 for k in 0..m {
877 dot += q[k * n + i] * v[k];
878 }
879 r[i * n + j] = dot;
880
881 for k in 0..m {
883 v[k] -= dot * q[k * n + i];
884 }
885 }
886
887 let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
889 r[j * n + j] = norm;
890
891 if norm > 1e-10 {
893 for k in 0..m {
894 q[k * n + j] = v[k] / norm;
895 }
896 }
897 }
898
899 let q_arr = Array::from_vec(q, Shape::new(vec![m, n]));
900 let r_arr = Array::from_vec(r, Shape::new(vec![n, n]));
901
902 (q_arr, r_arr)
903 }
904
905 pub fn cholesky(&self) -> Array {
923 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
924 let shape = self.shape().as_slice();
925 assert_eq!(shape.len(), 2, "Cholesky decomposition requires 2D array");
926 assert_eq!(shape[0], shape[1], "Cholesky decomposition requires square matrix");
927
928 let n = shape[0];
929 let data = self.to_vec();
930 let mut l = vec![0.0; n * n];
931
932 for i in 0..n {
933 for j in 0..=i {
934 let mut sum = 0.0;
935
936 if j == i {
937 for k in 0..j {
939 sum += l[j * n + k] * l[j * n + k];
940 }
941 let val = data[j * n + j] - sum;
942 assert!(val > 0.0, "Matrix is not positive-definite");
943 l[j * n + j] = val.sqrt();
944 } else {
945 for k in 0..j {
947 sum += l[i * n + k] * l[j * n + k];
948 }
949 l[i * n + j] = (data[i * n + j] - sum) / l[j * n + j];
950 }
951 }
952 }
953
954 Array::from_vec(l, Shape::new(vec![n, n]))
955 }
956
957 pub fn matrix_rank(&self) -> usize {
970 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
971 let shape = self.shape().as_slice();
972 assert_eq!(shape.len(), 2, "matrix_rank requires 2D array");
973
974 let (m, n) = (shape[0], shape[1]);
975 let tolerance = 1e-10;
976
977 let (_, r) = self.qr();
979 let r_data = r.to_vec();
980 let min_dim = m.min(n);
981
982 let mut rank = 0;
983 for i in 0..min_dim {
984 if r_data[i * n + i].abs() > tolerance {
985 rank += 1;
986 }
987 }
988
989 rank
990 }
991
992 pub fn eigvalsh(&self) -> Array {
1006 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1007 let shape = self.shape().as_slice();
1008 assert_eq!(shape.len(), 2, "eigvalsh requires 2D array");
1009 assert_eq!(shape[0], shape[1], "eigvalsh requires square matrix");
1010
1011 let n = shape[0];
1012 let mut eigenvalues = Vec::with_capacity(n);
1013 let mut a = self.clone();
1014
1015 for _ in 0..n {
1017 let mut v_data = vec![1.0; n];
1019
1020 for _ in 0..100 {
1021 let v = Array::from_vec(v_data.clone(), Shape::new(vec![n]));
1023 let av = a.matmul(&v.reshape(Shape::new(vec![n, 1])))
1024 .reshape(Shape::new(vec![n]));
1025 let av_data = av.to_vec();
1026
1027 let norm: f32 = av_data.iter().map(|x| x * x).sum::<f32>().sqrt();
1029 if norm < 1e-10 {
1030 break;
1031 }
1032 v_data = av_data.iter().map(|x| x / norm).collect();
1033 }
1034
1035 let v = Array::from_vec(v_data.clone(), Shape::new(vec![n]));
1037 let av = a.matmul(&v.reshape(Shape::new(vec![n, 1])))
1038 .reshape(Shape::new(vec![n]));
1039 let eigenvalue = v.inner(&av);
1040 eigenvalues.push(eigenvalue);
1041
1042 let mut a_data = a.to_vec();
1044 for i in 0..n {
1045 for j in 0..n {
1046 a_data[i * n + j] -= eigenvalue * v_data[i] * v_data[j];
1047 }
1048 }
1049 a = Array::from_vec(a_data, Shape::new(vec![n, n]));
1050 }
1051
1052 Array::from_vec(eigenvalues, Shape::new(vec![n]))
1053 }
1054
1055 pub fn pinv(&self) -> Array {
1066 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1067 let shape = self.shape().as_slice();
1068 assert_eq!(shape.len(), 2, "pinv requires 2D array");
1069
1070 let (m, n) = (shape[0], shape[1]);
1071
1072 if m >= n {
1073 let at = self.transpose();
1075 let ata = at.matmul(self);
1076 let ata_inv = ata.inv();
1077 ata_inv.matmul(&at)
1078 } else {
1079 let at = self.transpose();
1081 let aat = self.matmul(&at);
1082 let aat_inv = aat.inv();
1083 at.matmul(&aat_inv)
1084 }
1085 }
1086
1087 pub fn cond(&self) -> f32 {
1101 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1102 let shape = self.shape().as_slice();
1103 assert_eq!(shape.len(), 2, "cond requires 2D array");
1104
1105 let n = shape[0];
1108 if n <= 4 {
1109 let data = self.to_vec();
1111 let norm_a: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
1112
1113 let inv = self.inv();
1115 let inv_data = inv.to_vec();
1116 let norm_inv: f32 = inv_data.iter().map(|x| x * x).sum::<f32>().sqrt();
1117
1118 return norm_a * norm_inv / (n as f32); }
1120
1121 let at = self.transpose();
1123 let ata = at.matmul(self);
1124 let eigvals = ata.eigvalsh();
1125 let eigvals_data = eigvals.to_vec();
1126
1127 let max_eigval = eigvals_data.iter().fold(0.0_f32, |a, &b| a.max(b.abs()));
1128 let min_eigval = eigvals_data.iter().fold(f32::INFINITY, |a, &b| {
1129 if b.abs() > 1e-10 { a.min(b.abs()) } else { a }
1130 });
1131
1132 if min_eigval < 1e-10 {
1133 f32::INFINITY
1134 } else {
1135 (max_eigval / min_eigval).sqrt()
1136 }
1137 }
1138
1139 pub fn svd(&self) -> (Array, Array, Array) {
1153 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1154 let shape = self.shape().as_slice();
1155 assert_eq!(shape.len(), 2, "svd requires 2D array");
1156
1157 let m = shape[0];
1158 let n = shape[1];
1159 let k = m.min(n);
1160
1161 let at = self.transpose();
1163 let ata = at.matmul(self);
1164
1165 let mut v_data: Vec<Vec<f32>> = Vec::with_capacity(k);
1167 let mut s_values: Vec<f32> = Vec::with_capacity(k);
1168 let ata_data = ata.to_vec();
1169
1170 for _ in 0..k {
1171 let mut v: Vec<f32> = (0..n).map(|i| ((i as f32 + 1.0) * 0.1).sin()).collect();
1173 let mut norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1174 for x in v.iter_mut() { *x /= norm; }
1175
1176 for _ in 0..50 {
1178 let mut av = vec![0.0; n];
1180 for i in 0..n {
1181 for j in 0..n {
1182 av[i] += ata_data[i * n + j] * v[j];
1183 }
1184 }
1185
1186 for prev in &v_data {
1188 let dot: f32 = av.iter().zip(prev.iter()).map(|(a, b)| a * b).sum();
1189 for (a, p) in av.iter_mut().zip(prev.iter()) {
1190 *a -= dot * p;
1191 }
1192 }
1193
1194 norm = av.iter().map(|x| x * x).sum::<f32>().sqrt();
1195 if norm < 1e-10 { break; }
1196 for (v_i, av_i) in v.iter_mut().zip(av.iter()) {
1197 *v_i = av_i / norm;
1198 }
1199 }
1200
1201 let eigenvalue = norm;
1203 s_values.push(eigenvalue.sqrt());
1204 v_data.push(v);
1205 }
1206
1207 let mut u_data = vec![0.0; m * k];
1209 let a_data = self.to_vec();
1210 for col in 0..k {
1211 if s_values[col] > 1e-10 {
1212 for row in 0..m {
1213 let mut sum = 0.0;
1214 for j in 0..n {
1215 sum += a_data[row * n + j] * v_data[col][j];
1216 }
1217 u_data[row * k + col] = sum / s_values[col];
1218 }
1219 }
1220 }
1221
1222 let u = Array::from_vec(u_data, Shape::new(vec![m, k]));
1224 let s = Array::from_vec(s_values, Shape::new(vec![k]));
1225 let mut vt_data = vec![0.0; k * n];
1226 for i in 0..k {
1227 for j in 0..n {
1228 vt_data[i * n + j] = v_data[i][j];
1229 }
1230 }
1231 let vt = Array::from_vec(vt_data, Shape::new(vec![k, n]));
1232
1233 (u, s, vt)
1234 }
1235
1236 pub fn lstsq(&self, b: &Array) -> Array {
1250 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1251 let shape = self.shape().as_slice();
1252 assert_eq!(shape.len(), 2, "lstsq requires 2D matrix A");
1253
1254 let at = self.transpose();
1256 let ata = at.matmul(self);
1257 let atb = at.matmul(b);
1258 ata.solve(&atb)
1259 }
1260
1261 pub fn eigh(&self) -> (Array, Array) {
1276 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1277 let shape = self.shape().as_slice();
1278 assert_eq!(shape.len(), 2, "eigh requires 2D array");
1279 assert_eq!(shape[0], shape[1], "eigh requires square matrix");
1280
1281 let n = shape[0];
1282 let mut a_data = self.to_vec();
1283 let mut eigenvectors = vec![0.0; n * n];
1284
1285 for i in 0..n {
1287 eigenvectors[i * n + i] = 1.0;
1288 }
1289
1290 for _ in 0..100 {
1292 let mut max_val = 0.0_f32;
1294 let mut p = 0;
1295 let mut q = 1;
1296 for i in 0..n {
1297 for j in (i + 1)..n {
1298 if a_data[i * n + j].abs() > max_val {
1299 max_val = a_data[i * n + j].abs();
1300 p = i;
1301 q = j;
1302 }
1303 }
1304 }
1305
1306 if max_val < 1e-10 { break; }
1307
1308 let diff = a_data[q * n + q] - a_data[p * n + p];
1310 let t = if diff.abs() < 1e-10 {
1311 1.0
1312 } else {
1313 let phi = diff / (2.0 * a_data[p * n + q]);
1314 1.0 / (phi.abs() + (phi * phi + 1.0).sqrt()) * phi.signum()
1315 };
1316 let c = 1.0 / (1.0 + t * t).sqrt();
1317 let s = t * c;
1318
1319 let app = a_data[p * n + p];
1321 let aqq = a_data[q * n + q];
1322 let apq = a_data[p * n + q];
1323
1324 a_data[p * n + p] = c * c * app - 2.0 * s * c * apq + s * s * aqq;
1325 a_data[q * n + q] = s * s * app + 2.0 * s * c * apq + c * c * aqq;
1326 a_data[p * n + q] = 0.0;
1327 a_data[q * n + p] = 0.0;
1328
1329 for i in 0..n {
1330 if i != p && i != q {
1331 let aip = a_data[i * n + p];
1332 let aiq = a_data[i * n + q];
1333 a_data[i * n + p] = c * aip - s * aiq;
1334 a_data[p * n + i] = a_data[i * n + p];
1335 a_data[i * n + q] = s * aip + c * aiq;
1336 a_data[q * n + i] = a_data[i * n + q];
1337 }
1338 }
1339
1340 for i in 0..n {
1342 let vip = eigenvectors[i * n + p];
1343 let viq = eigenvectors[i * n + q];
1344 eigenvectors[i * n + p] = c * vip - s * viq;
1345 eigenvectors[i * n + q] = s * vip + c * viq;
1346 }
1347 }
1348
1349 let eigenvalues: Vec<f32> = (0..n).map(|i| a_data[i * n + i]).collect();
1351
1352 (
1353 Array::from_vec(eigenvalues, Shape::new(vec![n])),
1354 Array::from_vec(eigenvectors, Shape::new(vec![n, n])),
1355 )
1356 }
1357
1358 pub fn eig(&self) -> Array {
1371 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1372 let shape = self.shape().as_slice();
1373 assert_eq!(shape.len(), 2, "eig requires 2D array");
1374 assert_eq!(shape[0], shape[1], "eig requires square matrix");
1375
1376 let n = shape[0];
1377 let mut a = self.clone();
1378
1379 for _ in 0..100 {
1381 let (q, r) = a.qr();
1382 a = r.matmul(&q);
1383 }
1384
1385 let a_data = a.to_vec();
1387 let eigenvalues: Vec<f32> = (0..n).map(|i| a_data[i * n + i]).collect();
1388
1389 Array::from_vec(eigenvalues, Shape::new(vec![n]))
1390 }
1391
1392 pub fn tensordot(&self, other: &Array, axes: usize) -> Array {
1404 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1405 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
1406
1407 let a_shape = self.shape().as_slice();
1408 let b_shape = other.shape().as_slice();
1409
1410 assert!(axes <= a_shape.len() && axes <= b_shape.len());
1412
1413 let a_outer: usize = a_shape[..a_shape.len() - axes].iter().product();
1415 let a_inner: usize = a_shape[a_shape.len() - axes..].iter().product();
1416 let b_inner: usize = b_shape[..axes].iter().product();
1417 let b_outer: usize = b_shape[axes..].iter().product();
1418
1419 assert_eq!(a_inner, b_inner, "Contracted dimensions must match");
1420
1421 let a_2d = self.reshape(Shape::new(vec![a_outer, a_inner]));
1422 let b_2d = other.reshape(Shape::new(vec![b_inner, b_outer]));
1423
1424 let result = a_2d.matmul(&b_2d);
1425
1426 let mut out_shape = a_shape[..a_shape.len() - axes].to_vec();
1428 out_shape.extend_from_slice(&b_shape[axes..]);
1429 if out_shape.is_empty() {
1430 out_shape.push(1);
1431 }
1432
1433 result.reshape(Shape::new(out_shape))
1434 }
1435
1436 pub fn kron(&self, other: &Array) -> Array {
1448 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1449 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
1450
1451 let a_shape = self.shape().as_slice();
1452 let b_shape = other.shape().as_slice();
1453
1454 if a_shape.len() == 2 && b_shape.len() == 2 {
1456 let (m, n) = (a_shape[0], a_shape[1]);
1457 let (p, q) = (b_shape[0], b_shape[1]);
1458 let a_data = self.to_vec();
1459 let b_data = other.to_vec();
1460
1461 let mut result = vec![0.0; m * p * n * q];
1462 for i in 0..m {
1463 for j in 0..n {
1464 for k in 0..p {
1465 for l in 0..q {
1466 let out_row = i * p + k;
1467 let out_col = j * q + l;
1468 result[out_row * (n * q) + out_col] =
1469 a_data[i * n + j] * b_data[k * q + l];
1470 }
1471 }
1472 }
1473 }
1474
1475 Array::from_vec(result, Shape::new(vec![m * p, n * q]))
1476 } else {
1477 let a_data = self.to_vec();
1479 let b_data = other.to_vec();
1480 let mut result = Vec::with_capacity(a_data.len() * b_data.len());
1481 for &a in &a_data {
1482 for &b in &b_data {
1483 result.push(a * b);
1484 }
1485 }
1486 Array::from_vec(result, Shape::new(vec![a_data.len() * b_data.len()]))
1487 }
1488 }
1489}
1490
1491fn transpose_nd(array: &Array, new_shape: Shape) -> Array {
1493 let old_dims = array.shape().as_slice();
1494 let data = array.to_vec();
1495
1496 let size = array.size();
1497 let mut result = vec![0.0; size];
1498
1499 let old_strides = array.shape().default_strides();
1501 let new_strides = new_shape.default_strides();
1502
1503 for flat_idx in 0..size {
1504 let mut old_multi = vec![0; old_dims.len()];
1506 let mut idx = flat_idx;
1507 for (i, &stride) in old_strides.iter().enumerate() {
1508 old_multi[i] = idx / stride;
1509 idx %= stride;
1510 }
1511
1512 let new_multi: Vec<usize> = old_multi.iter().rev().copied().collect();
1514
1515 let new_flat: usize = new_multi
1517 .iter()
1518 .zip(new_strides.iter())
1519 .map(|(idx, stride)| idx * stride)
1520 .sum();
1521
1522 result[new_flat] = data[flat_idx];
1523 }
1524
1525 let buffer = Buffer::from_f32(result, Device::Cpu);
1526 Array::from_buffer(buffer, new_shape)
1527}
1528
1529#[cfg(test)]
1530mod tests {
1531 use super::*;
1532
1533 #[test]
1534 fn test_transpose_2d() {
1535 let a = Array::from_vec(
1536 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1537 Shape::new(vec![2, 3]),
1538 );
1539 let b = a.transpose();
1540 assert_eq!(b.shape().as_slice(), &[3, 2]);
1541 assert_eq!(b.to_vec(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1542 }
1543
1544 #[test]
1545 fn test_transpose_1d() {
1546 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1547 let b = a.transpose();
1548 assert_eq!(b.shape().as_slice(), &[3]);
1549 assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
1550 }
1551
1552 #[test]
1553 fn test_matmul_2d() {
1554 let a =
1555 Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1556 let b =
1557 Array::from_vec(vec![5.0, 6.0, 7.0, 8.0], Shape::new(vec![2, 2]));
1558 let c = a.matmul(&b);
1559 assert_eq!(c.shape().as_slice(), &[2, 2]);
1560 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1564 }
1565
1566 #[test]
1567 fn test_matmul_non_square() {
1568 let a = Array::from_vec(
1569 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1570 Shape::new(vec![2, 3]),
1571 );
1572 let b = Array::from_vec(
1573 vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
1574 Shape::new(vec![3, 2]),
1575 );
1576 let c = a.matmul(&b);
1577 assert_eq!(c.shape().as_slice(), &[2, 2]);
1578 assert_eq!(c.to_vec(), vec![58.0, 64.0, 139.0, 154.0]);
1582 }
1583
1584 #[test]
1585 fn test_matmul_vector() {
1586 let a =
1588 Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1589 let v = Array::from_vec(vec![5.0, 6.0], Shape::new(vec![2]));
1590 let c = a.matmul(&v);
1591 assert_eq!(c.shape().as_slice(), &[2]);
1592 assert_eq!(c.to_vec(), vec![17.0, 39.0]);
1594 }
1595
1596 #[test]
1597 fn test_dot_1d() {
1598 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1599 let b = Array::from_vec(vec![4.0, 5.0, 6.0], Shape::new(vec![3]));
1600 let c = a.dot(&b);
1601 assert!(c.is_scalar());
1602 assert_eq!(c.to_vec(), vec![32.0]); }
1604
1605 #[test]
1606 fn test_dot_2d() {
1607 let a =
1608 Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1609 let b =
1610 Array::from_vec(vec![5.0, 6.0, 7.0, 8.0], Shape::new(vec![2, 2]));
1611 let c = a.dot(&b);
1612 assert_eq!(c.to_vec(), vec![19.0, 22.0, 43.0, 50.0]);
1614 }
1615
1616 #[test]
1617 #[should_panic(expected = "Incompatible shapes")]
1618 fn test_matmul_incompatible() {
1619 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![1, 3]));
1620 let b = Array::from_vec(vec![4.0, 5.0], Shape::new(vec![2, 1]));
1621 let _c = a.matmul(&b);
1622 }
1623
1624 #[test]
1625 fn test_outer() {
1626 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1627 let b = Array::from_vec(vec![4.0, 5.0], Shape::new(vec![2]));
1628 let c = a.outer(&b);
1629 assert_eq!(c.shape().as_slice(), &[3, 2]);
1630 assert_eq!(c.to_vec(), vec![4.0, 5.0, 8.0, 10.0, 12.0, 15.0]);
1632 }
1633
1634 #[test]
1635 fn test_outer_square() {
1636 let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
1637 let b = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
1638 let c = a.outer(&b);
1639 assert_eq!(c.shape().as_slice(), &[2, 2]);
1640 assert_eq!(c.to_vec(), vec![3.0, 4.0, 6.0, 8.0]);
1642 }
1643
1644 #[test]
1645 fn test_inner() {
1646 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1647 let b = Array::from_vec(vec![4.0, 5.0, 6.0], Shape::new(vec![3]));
1648 let result = a.inner(&b);
1649 assert_eq!(result, 32.0); }
1651
1652 #[test]
1653 fn test_inner_zeros() {
1654 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1655 let b = Array::from_vec(vec![0.0, 0.0, 0.0], Shape::new(vec![3]));
1656 let result = a.inner(&b);
1657 assert_eq!(result, 0.0);
1658 }
1659
1660 #[test]
1661 fn test_cross_basic() {
1662 let i = Array::from_vec(vec![1.0, 0.0, 0.0], Shape::new(vec![3]));
1664 let j = Array::from_vec(vec![0.0, 1.0, 0.0], Shape::new(vec![3]));
1665 let k = i.cross(&j);
1666 assert_eq!(k.to_vec(), vec![0.0, 0.0, 1.0]);
1667 }
1668
1669 #[test]
1670 fn test_cross_general() {
1671 let a = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
1672 let b = Array::from_vec(vec![5.0, 6.0, 7.0], Shape::new(vec![3]));
1673 let c = a.cross(&b);
1674 assert_eq!(c.to_vec(), vec![-3.0, 6.0, -3.0]);
1678 }
1679
1680 #[test]
1681 fn test_cross_anticommutative() {
1682 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1683 let b = Array::from_vec(vec![4.0, 5.0, 6.0], Shape::new(vec![3]));
1684 let c1 = a.cross(&b);
1685 let c2 = b.cross(&a);
1686 let c2_neg = c2.neg();
1688 assert_eq!(c1.to_vec(), c2_neg.to_vec());
1689 }
1690
1691 #[test]
1692 fn test_trace() {
1693 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1694 let tr = a.trace();
1695 assert_eq!(tr, 5.0); let b = Array::from_vec(
1698 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1699 Shape::new(vec![3, 3]),
1700 );
1701 let tr_b = b.trace();
1702 assert_eq!(tr_b, 15.0); }
1704
1705 #[test]
1706 fn test_diagonal_square() {
1707 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
1708 let diag = a.diagonal();
1709 assert_eq!(diag.to_vec(), vec![1.0, 4.0]);
1710 }
1711
1712 #[test]
1713 fn test_diagonal_rectangular() {
1714 let a = Array::from_vec(
1716 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1717 Shape::new(vec![2, 3]),
1718 );
1719 let diag = a.diagonal();
1720 assert_eq!(diag.to_vec(), vec![1.0, 5.0]); let b = Array::from_vec(
1724 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1725 Shape::new(vec![3, 2]),
1726 );
1727 let diag_b = b.diagonal();
1728 assert_eq!(diag_b.to_vec(), vec![1.0, 4.0]); }
1730}