1use crate::{buffer::Buffer, Array, DType, Shape};
4
5impl Array {
6 pub fn arange(start: f32, stop: f32, step: f32, dtype: DType) -> Self {
24 assert_ne!(step, 0.0, "Step must be non-zero");
25 assert_eq!(dtype, DType::Float32, "Only Float32 supported for now");
26
27 let size = ((stop - start) / step).ceil().max(0.0) as usize;
28 if size == 0 {
29 return Array::zeros(Shape::new(vec![0]), dtype);
30 }
31
32 let data: Vec<f32> =
33 (0..size).map(|i| start + (i as f32) * step).collect();
34
35 let device = crate::default_device();
36 let buffer = Buffer::from_f32(data, device);
37 Array::from_buffer(buffer, Shape::new(vec![size]))
38 }
39
40 pub fn linspace(
59 start: f32,
60 stop: f32,
61 num: usize,
62 endpoint: bool,
63 dtype: DType,
64 ) -> Self {
65 assert_eq!(dtype, DType::Float32, "Only Float32 supported for now");
66
67 if num == 0 {
68 return Array::zeros(Shape::new(vec![0]), dtype);
69 }
70
71 if num == 1 {
72 return Array::full(start, Shape::new(vec![1]), dtype);
73 }
74
75 if start == stop {
76 return Array::full(start, Shape::new(vec![num]), dtype);
77 }
78
79 let delta = stop - start;
80 let denom = if endpoint { num - 1 } else { num } as f32;
81
82 let data: Vec<f32> =
83 (0..num).map(|i| start + (i as f32) * delta / denom).collect();
84
85 let device = crate::default_device();
86 let buffer = Buffer::from_f32(data, device);
87 Array::from_buffer(buffer, Shape::new(vec![num]))
88 }
89
90 pub fn eye(n: usize, m: Option<usize>, dtype: DType) -> Self {
107 assert_eq!(dtype, DType::Float32, "Only Float32 supported for now");
108
109 let m = m.unwrap_or(n);
110 let size = n * m;
111 let mut data = vec![0.0; size];
112
113 for i in 0..n.min(m) {
115 data[i * m + i] = 1.0;
116 }
117
118 let device = crate::default_device();
119 let buffer = Buffer::from_f32(data, device);
120 Array::from_buffer(buffer, Shape::new(vec![n, m]))
121 }
122
123 pub fn identity(n: usize, dtype: DType) -> Self {
132 Self::eye(n, None, dtype)
133 }
134
135 pub fn diag(v: &Self, k: i32) -> Self {
151 assert_eq!(v.dtype(), DType::Float32, "Only Float32 supported");
152
153 if v.ndim() == 1 {
154 let n = v.size();
156 let offset = k.unsigned_abs() as usize;
157 let matrix_size = n + offset;
158
159 let mut data = vec![0.0; matrix_size * matrix_size];
160 let v_data = v.to_vec();
161
162 for (i, &val) in v_data.iter().enumerate() {
163 let (row, col) =
164 if k >= 0 { (i, i + offset) } else { (i + offset, i) };
165 data[row * matrix_size + col] = val;
166 }
167
168 Self::from_vec(data, Shape::new(vec![matrix_size, matrix_size]))
169 } else if v.ndim() == 2 {
170 let shape = v.shape().as_slice();
172 let (rows, cols) = (shape[0], shape[1]);
173 let data = v.to_vec();
174
175 let diag_len = if k >= 0 {
176 (cols as i32 - k).min(rows as i32).max(0) as usize
177 } else {
178 (rows as i32 + k).min(cols as i32).max(0) as usize
179 };
180
181 let mut diag_data = Vec::with_capacity(diag_len);
182
183 for i in 0..diag_len {
184 let (row, col) = if k >= 0 {
185 (i, i + k as usize)
186 } else {
187 (i + (-k) as usize, i)
188 };
189 diag_data.push(data[row * cols + col]);
190 }
191
192 Self::from_vec(diag_data, Shape::new(vec![diag_len]))
193 } else {
194 panic!("diag only supports 1-D and 2-D arrays");
195 }
196 }
197
198 pub fn tril(&self, k: i32) -> Self {
212 assert_eq!(self.ndim(), 2, "tril only supports 2-D arrays");
213 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
214
215 let shape = self.shape().as_slice();
216 let (rows, cols) = (shape[0], shape[1]);
217 let data = self.to_vec();
218
219 let mut result = Vec::with_capacity(data.len());
220
221 for i in 0..rows {
222 for j in 0..cols {
223 let val = if (j as i32) <= (i as i32 + k) {
224 data[i * cols + j]
225 } else {
226 0.0
227 };
228 result.push(val);
229 }
230 }
231
232 Self::from_vec(result, self.shape().clone())
233 }
234
235 pub fn triu(&self, k: i32) -> Self {
249 assert_eq!(self.ndim(), 2, "triu only supports 2-D arrays");
250 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
251
252 let shape = self.shape().as_slice();
253 let (rows, cols) = (shape[0], shape[1]);
254 let data = self.to_vec();
255
256 let mut result = Vec::with_capacity(data.len());
257
258 for i in 0..rows {
259 for j in 0..cols {
260 let val = if (j as i32) >= (i as i32 + k) {
261 data[i * cols + j]
262 } else {
263 0.0
264 };
265 result.push(val);
266 }
267 }
268
269 Self::from_vec(result, self.shape().clone())
270 }
271
272 pub fn tri(n: usize, m: Option<usize>, k: i32, dtype: DType) -> Self {
282 assert_eq!(dtype, DType::Float32, "Only Float32 supported");
283
284 let cols = m.unwrap_or(n);
285 let mut data = Vec::with_capacity(n * cols);
286
287 for i in 0..n {
288 for j in 0..cols {
289 let val = if (j as i32) <= (i as i32 + k) { 1.0 } else { 0.0 };
290 data.push(val);
291 }
292 }
293
294 Self::from_vec(data, Shape::new(vec![n, cols]))
295 }
296
297 pub fn zeros_like(other: &Array) -> Array {
308 Array::zeros(other.shape().clone(), other.dtype())
309 }
310
311 pub fn ones_like(other: &Array) -> Array {
322 Array::ones(other.shape().clone(), other.dtype())
323 }
324
325 pub fn full_like(other: &Array, value: f32) -> Array {
336 Array::full(value, other.shape().clone(), other.dtype())
337 }
338
339 pub fn repeat(&self, repeats: usize, axis: usize) -> Array {
350 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
351 assert!(axis < self.ndim(), "Axis out of bounds");
352
353 let shape = self.shape().as_slice();
354 let data = self.to_vec();
355
356 if self.ndim() == 1 {
358 let mut result = Vec::with_capacity(data.len() * repeats);
359 for &val in data.iter() {
360 for _ in 0..repeats {
361 result.push(val);
362 }
363 }
364 return Array::from_vec(result, Shape::new(vec![shape[0] * repeats]));
365 }
366
367 assert_eq!(axis, 0, "repeat only supports axis=0 for multi-dimensional arrays");
369
370 let slice_size = data.len() / shape[0];
371 let mut result = Vec::with_capacity(data.len() * repeats);
372
373 for i in 0..shape[0] {
374 let start = i * slice_size;
375 let end = start + slice_size;
376 for _ in 0..repeats {
377 result.extend_from_slice(&data[start..end]);
378 }
379 }
380
381 let mut result_shape = shape.to_vec();
382 result_shape[axis] *= repeats;
383 Array::from_vec(result, Shape::new(result_shape))
384 }
385
386 pub fn tile(&self, reps: usize) -> Array {
397 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
398
399 let data = self.to_vec();
400 let mut result = Vec::with_capacity(data.len() * reps);
401
402 for _ in 0..reps {
403 result.extend_from_slice(&data);
404 }
405
406 let shape = self.shape().as_slice();
407 let mut result_shape = shape.to_vec();
408 result_shape[0] *= reps;
409
410 Array::from_vec(result, Shape::new(result_shape))
411 }
412
413 pub fn meshgrid(x: &Array, y: &Array) -> (Array, Array) {
426 assert_eq!(x.ndim(), 1, "meshgrid requires 1D arrays");
427 assert_eq!(y.ndim(), 1, "meshgrid requires 1D arrays");
428
429 let x_data = x.to_vec();
430 let y_data = y.to_vec();
431 let nx = x_data.len();
432 let ny = y_data.len();
433
434 let mut xx_data = Vec::with_capacity(nx * ny);
436 for _ in 0..ny {
437 xx_data.extend_from_slice(&x_data);
438 }
439
440 let mut yy_data = Vec::with_capacity(nx * ny);
442 for &y_val in y_data.iter() {
443 for _ in 0..nx {
444 yy_data.push(y_val);
445 }
446 }
447
448 let xx = Array::from_vec(xx_data, Shape::new(vec![ny, nx]));
449 let yy = Array::from_vec(yy_data, Shape::new(vec![ny, nx]));
450
451 (xx, yy)
452 }
453
454 pub fn indices(dimensions: &[usize]) -> Vec<Array> {
467 let total_size: usize = dimensions.iter().product();
468 let mut result = Vec::with_capacity(dimensions.len());
469
470 for (dim_idx, &dim_size) in dimensions.iter().enumerate() {
471 let mut data = Vec::with_capacity(total_size);
472
473 let stride: usize = dimensions.iter().skip(dim_idx + 1).product();
475
476 for i in 0..total_size {
477 let idx = (i / stride) % dim_size;
478 data.push(idx as f32);
479 }
480
481 result.push(Array::from_vec(data, Shape::new(dimensions.to_vec())));
482 }
483
484 result
485 }
486
487 pub fn unravel_index(index: usize, shape: &Shape) -> Vec<usize> {
498 let dims = shape.as_slice();
499 let mut coords = vec![0; dims.len()];
500 let mut idx = index;
501
502 for i in (0..dims.len()).rev() {
503 coords[i] = idx % dims[i];
504 idx /= dims[i];
505 }
506
507 coords
508 }
509
510 pub fn ravel_multi_index(multi_index: &[usize], shape: &Shape) -> usize {
521 let dims = shape.as_slice();
522 assert_eq!(
523 multi_index.len(),
524 dims.len(),
525 "Index dimensions must match shape"
526 );
527
528 let mut index = 0;
529 let mut stride = 1;
530
531 for i in (0..dims.len()).rev() {
532 assert!(
533 multi_index[i] < dims[i],
534 "Index out of bounds at dimension {}", i
535 );
536 index += multi_index[i] * stride;
537 stride *= dims[i];
538 }
539
540 index
541 }
542
543 pub fn diag_indices(n: usize) -> (Vec<usize>, Vec<usize>) {
554 let indices: Vec<usize> = (0..n).collect();
555 (indices.clone(), indices)
556 }
557
558 pub fn tril_indices(n: usize, k: isize) -> (Vec<usize>, Vec<usize>) {
574 let mut rows = Vec::new();
575 let mut cols = Vec::new();
576
577 for i in 0..n {
578 for j in 0..n {
579 if (j as isize) <= (i as isize + k) {
580 rows.push(i);
581 cols.push(j);
582 }
583 }
584 }
585
586 (rows, cols)
587 }
588
589 pub fn triu_indices(n: usize, k: isize) -> (Vec<usize>, Vec<usize>) {
605 let mut rows = Vec::new();
606 let mut cols = Vec::new();
607
608 for i in 0..n {
609 for j in 0..n {
610 if (j as isize) >= (i as isize + k) {
611 rows.push(i);
612 cols.push(j);
613 }
614 }
615 }
616
617 (rows, cols)
618 }
619
620 pub fn geomspace(start: f32, stop: f32, num: usize, dtype: DType) -> Self {
630 assert_eq!(dtype, DType::Float32, "Only Float32 supported for now");
631 assert!(num > 0, "Number of samples must be positive");
632 assert!(start > 0.0 && stop > 0.0, "Start and stop must be positive for geomspace");
633
634 if num == 1 {
635 return Array::from_vec(vec![start], Shape::new(vec![1]));
636 }
637
638 let log_start = start.ln();
639 let log_stop = stop.ln();
640 let step = (log_stop - log_start) / (num - 1) as f32;
641
642 let mut data = Vec::with_capacity(num);
643 for i in 0..num {
644 data.push((log_start + step * i as f32).exp());
645 }
646
647 let device = crate::default_device();
648 let buffer = Buffer::from_f32(data, device);
649 Array::from_buffer(buffer, Shape::new(vec![num]))
650 }
651
652 pub fn logspace(start: f32, stop: f32, num: usize, dtype: DType) -> Self {
662 assert_eq!(dtype, DType::Float32, "Only Float32 supported for now");
663 assert!(num > 0, "Number of samples must be positive");
664
665 if num == 1 {
666 return Array::from_vec(vec![10_f32.powf(start)], Shape::new(vec![1]));
667 }
668
669 let step = (stop - start) / (num - 1) as f32;
670 let mut data = Vec::with_capacity(num);
671 for i in 0..num {
672 data.push(10_f32.powf(start + step * i as f32));
673 }
674
675 let device = crate::default_device();
676 let buffer = Buffer::from_f32(data, device);
677 Array::from_buffer(buffer, Shape::new(vec![num]))
678 }
679
680 pub fn empty_like(&self) -> Array {
692 Array::zeros(self.shape().clone(), self.dtype())
694 }
695
696 pub fn is_contiguous(&self) -> bool {
706 true
708 }
709
710 pub fn is_fortran_contiguous(&self) -> bool {
713 self.ndim() <= 1
715 }
716
717 pub fn ascontiguousarray(&self) -> Array {
728 self.clone()
730 }
731
732 pub fn hamming(n: usize) -> Array {
742 let data: Vec<f32> = (0..n)
743 .map(|i| {
744 0.54 - 0.46 * (2.0 * std::f32::consts::PI * i as f32 / (n - 1) as f32).cos()
745 })
746 .collect();
747 let buffer = Buffer::from_f32(data, crate::default_device());
748 Array::from_buffer(buffer, Shape::new(vec![n]))
749 }
750
751 pub fn hanning(n: usize) -> Array {
761 let data: Vec<f32> = (0..n)
762 .map(|i| {
763 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / (n - 1) as f32).cos())
764 })
765 .collect();
766 let buffer = Buffer::from_f32(data, crate::default_device());
767 Array::from_buffer(buffer, Shape::new(vec![n]))
768 }
769
770 pub fn blackman(n: usize) -> Array {
780 let data: Vec<f32> = (0..n)
781 .map(|i| {
782 let x = i as f32 / (n - 1) as f32;
783 0.42 - 0.5 * (2.0 * std::f32::consts::PI * x).cos()
784 + 0.08 * (4.0 * std::f32::consts::PI * x).cos()
785 })
786 .collect();
787 let buffer = Buffer::from_f32(data, crate::default_device());
788 Array::from_buffer(buffer, Shape::new(vec![n]))
789 }
790
791 pub fn kaiser(n: usize, beta: f32) -> Array {
801 fn i0(x: f32) -> f32 {
803 let ax = x.abs();
804 if ax < 3.75 {
805 let y = (x / 3.75).powi(2);
806 1.0 + y * (3.5156229 + y * (3.0899424 + y * (1.2067492
807 + y * (0.2659732 + y * (0.0360768 + y * 0.0045813)))))
808 } else {
809 let y = 3.75 / ax;
810 (ax.exp() / ax.sqrt()) * (0.398_942_3 + y * (0.01328592
811 + y * (0.00225319 + y * (-0.00157565 + y * (0.00916281
812 + y * (-0.02057706 + y * (0.02635537 + y * (-0.01647633
813 + y * 0.00392377))))))))
814 }
815 }
816
817 let data: Vec<f32> = (0..n)
818 .map(|i| {
819 let x = 2.0 * i as f32 / (n - 1) as f32 - 1.0;
820 i0(beta * (1.0 - x * x).sqrt()) / i0(beta)
821 })
822 .collect();
823 let buffer = Buffer::from_f32(data, crate::default_device());
824 Array::from_buffer(buffer, Shape::new(vec![n]))
825 }
826
827 pub fn bartlett(n: usize) -> Array {
837 let data: Vec<f32> = (0..n)
838 .map(|i| {
839 let x = i as f32;
840 let half = (n - 1) as f32 / 2.0;
841 1.0 - ((x - half) / half).abs()
842 })
843 .collect();
844 let buffer = Buffer::from_f32(data, crate::default_device());
845 Array::from_buffer(buffer, Shape::new(vec![n]))
846 }
847
848 pub fn flattop(n: usize) -> Array {
858 let a0 = 0.21557895;
859 let a1 = 0.41663158;
860 let a2 = 0.277_263_16;
861 let a3 = 0.083578947;
862 let a4 = 0.006947368;
863
864 let data: Vec<f32> = (0..n)
865 .map(|i| {
866 let x = 2.0 * std::f32::consts::PI * i as f32 / (n - 1) as f32;
867 a0 - a1 * x.cos() + a2 * (2.0 * x).cos()
868 - a3 * (3.0 * x).cos() + a4 * (4.0 * x).cos()
869 })
870 .collect();
871 let buffer = Buffer::from_f32(data, crate::default_device());
872 Array::from_buffer(buffer, Shape::new(vec![n]))
873 }
874
875 pub fn triang(n: usize) -> Array {
886 let data: Vec<f32> = (0..n)
887 .map(|i| {
888 let half = (n as f32 + 1.0) / 2.0;
889 if i as f32 + 1.0 <= half {
890 2.0 * (i as f32 + 1.0) / (n as f32 + 1.0)
891 } else {
892 2.0 - 2.0 * (i as f32 + 1.0) / (n as f32 + 1.0)
893 }
894 })
895 .collect();
896 let buffer = Buffer::from_f32(data, crate::default_device());
897 Array::from_buffer(buffer, Shape::new(vec![n]))
898 }
899}
900
901#[cfg(test)]
902mod tests {
903 use super::*;
904 use approx::assert_abs_diff_eq;
905
906 #[test]
907 fn test_arange() {
908 let a = Array::arange(0.0, 10.0, 2.0, DType::Float32);
909 assert_eq!(a.to_vec(), vec![0.0, 2.0, 4.0, 6.0, 8.0]);
910
911 let b = Array::arange(0.0, 5.0, 1.0, DType::Float32);
912 assert_eq!(b.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
913
914 let c = Array::arange(1.0, 2.0, 0.25, DType::Float32);
915 assert_eq!(c.to_vec(), vec![1.0, 1.25, 1.5, 1.75]);
916 }
917
918 #[test]
919 fn test_arange_negative_step() {
920 let a = Array::arange(10.0, 0.0, -2.0, DType::Float32);
921 assert_eq!(a.to_vec(), vec![10.0, 8.0, 6.0, 4.0, 2.0]);
922 }
923
924 #[test]
925 fn test_arange_empty() {
926 let a = Array::arange(0.0, 0.0, 1.0, DType::Float32);
927 assert_eq!(a.size(), 0);
928 }
929
930 #[test]
931 #[should_panic(expected = "Step must be non-zero")]
932 fn test_arange_zero_step() {
933 let _a = Array::arange(0.0, 10.0, 0.0, DType::Float32);
934 }
935
936 #[test]
937 fn test_linspace() {
938 let a = Array::linspace(0.0, 1.0, 5, true, DType::Float32);
939 let expected = vec![0.0, 0.25, 0.5, 0.75, 1.0];
940 for (i, &val) in a.to_vec().iter().enumerate() {
941 assert_abs_diff_eq!(val, expected[i], epsilon = 1e-6);
942 }
943 }
944
945 #[test]
946 fn test_linspace_no_endpoint() {
947 let a = Array::linspace(0.0, 1.0, 5, false, DType::Float32);
948 let expected = vec![0.0, 0.2, 0.4, 0.6, 0.8];
949 for (i, &val) in a.to_vec().iter().enumerate() {
950 assert_abs_diff_eq!(val, expected[i], epsilon = 1e-6);
951 }
952 }
953
954 #[test]
955 fn test_linspace_single() {
956 let a = Array::linspace(5.0, 10.0, 1, true, DType::Float32);
957 assert_eq!(a.to_vec(), vec![5.0]);
958 }
959
960 #[test]
961 fn test_linspace_same_start_stop() {
962 let a = Array::linspace(5.0, 5.0, 10, true, DType::Float32);
963 assert!(a.to_vec().iter().all(|&x| x == 5.0));
964 }
965
966 #[test]
967 fn test_eye() {
968 let i = Array::eye(3, None, DType::Float32);
969 assert_eq!(i.shape().as_slice(), &[3, 3]);
970 assert_eq!(
971 i.to_vec(),
972 vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
973 );
974 }
975
976 #[test]
977 fn test_eye_rectangular() {
978 let i = Array::eye(2, Some(4), DType::Float32);
979 assert_eq!(i.shape().as_slice(), &[2, 4]);
980 assert_eq!(i.to_vec(), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
981 }
982
983 #[test]
984 fn test_identity() {
985 let i = Array::identity(4, DType::Float32);
986 assert_eq!(i.shape().as_slice(), &[4, 4]);
987 let data = i.to_vec();
989 for idx in 0..4 {
990 assert_eq!(data[idx * 4 + idx], 1.0);
991 }
992 assert_eq!(data[1], 0.0);
994 assert_eq!(data[2], 0.0);
995 assert_eq!(data[4], 0.0);
996 }
997
998 #[test]
999 fn test_indices() {
1000 let indices = Array::indices(&[2, 3]);
1001 assert_eq!(indices.len(), 2);
1002 assert_eq!(indices[0].shape().as_slice(), &[2, 3]);
1003 assert_eq!(indices[1].shape().as_slice(), &[2, 3]);
1004 assert_eq!(indices[0].to_vec(), vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1006 assert_eq!(indices[1].to_vec(), vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0]);
1008 }
1009
1010 #[test]
1011 fn test_unravel_index() {
1012 let shape = Shape::new(vec![3, 4]);
1013 assert_eq!(Array::unravel_index(0, &shape), vec![0, 0]);
1014 assert_eq!(Array::unravel_index(5, &shape), vec![1, 1]);
1015 assert_eq!(Array::unravel_index(11, &shape), vec![2, 3]);
1016 }
1017
1018 #[test]
1019 fn test_ravel_multi_index() {
1020 let shape = Shape::new(vec![3, 4]);
1021 assert_eq!(Array::ravel_multi_index(&[0, 0], &shape), 0);
1022 assert_eq!(Array::ravel_multi_index(&[1, 2], &shape), 6);
1023 assert_eq!(Array::ravel_multi_index(&[2, 3], &shape), 11);
1024 }
1025
1026 #[test]
1027 fn test_diag_indices() {
1028 let (rows, cols) = Array::diag_indices(3);
1029 assert_eq!(rows, vec![0, 1, 2]);
1030 assert_eq!(cols, vec![0, 1, 2]);
1031 }
1032
1033 #[test]
1034 fn test_tril_indices() {
1035 let (rows, cols) = Array::tril_indices(3, 0);
1036 assert_eq!(rows, vec![0, 1, 1, 2, 2, 2]);
1037 assert_eq!(cols, vec![0, 0, 1, 0, 1, 2]);
1038
1039 let (rows2, cols2) = Array::tril_indices(3, 1);
1041 assert_eq!(rows2, vec![0, 0, 1, 1, 1, 2, 2, 2]);
1042 assert_eq!(cols2, vec![0, 1, 0, 1, 2, 0, 1, 2]);
1043 }
1044
1045 #[test]
1046 fn test_triu_indices() {
1047 let (rows, cols) = Array::triu_indices(3, 0);
1048 assert_eq!(rows, vec![0, 0, 0, 1, 1, 2]);
1049 assert_eq!(cols, vec![0, 1, 2, 1, 2, 2]);
1050
1051 let (rows2, cols2) = Array::triu_indices(3, -1);
1053 assert_eq!(rows2, vec![0, 0, 0, 1, 1, 1, 2, 2]);
1054 assert_eq!(cols2, vec![0, 1, 2, 0, 1, 2, 1, 2]);
1055 }
1056}