1use axonml_core::dtype::{Float, Numeric, Scalar};
64use axonml_core::error::Result;
65
66use crate::tensor::Tensor;
67
68pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
74 if a.shape() != b.shape() {
75 return Err(axonml_core::error::Error::shape_mismatch(
76 a.shape(),
77 b.shape(),
78 ));
79 }
80
81 let a_data = a.to_vec();
82 let b_data = b.to_vec();
83
84 Ok(a_data
85 .iter()
86 .zip(b_data.iter())
87 .map(|(x, y)| x == y)
88 .collect())
89}
90
91pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
93 if a.shape() != b.shape() {
94 return Err(axonml_core::error::Error::shape_mismatch(
95 a.shape(),
96 b.shape(),
97 ));
98 }
99
100 let a_data = a.to_vec();
101 let b_data = b.to_vec();
102
103 Ok(a_data
104 .iter()
105 .zip(b_data.iter())
106 .map(|(x, y)| x < y)
107 .collect())
108}
109
110pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
112 if a.shape() != b.shape() {
113 return Err(axonml_core::error::Error::shape_mismatch(
114 a.shape(),
115 b.shape(),
116 ));
117 }
118
119 let a_data = a.to_vec();
120 let b_data = b.to_vec();
121
122 Ok(a_data
123 .iter()
124 .zip(b_data.iter())
125 .map(|(x, y)| x > y)
126 .collect())
127}
128
129pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
135 let data = x.to_vec();
137 let shape = x.shape();
138
139 if shape.is_empty() {
140 return Ok(Tensor::scalar(T::one()));
141 }
142
143 let max_val = data
145 .iter()
146 .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
147
148 let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
150
151 let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
153
154 let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
156
157 Tensor::from_vec(result, shape)
158}
159
160pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
162 let sm = softmax(x, dim)?;
163 Ok(sm.ln())
164}
165
166#[must_use]
168pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
169 let data = x.to_vec();
170 let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
171 let coeff = T::from(0.044715).unwrap();
172
173 let result: Vec<T> = data
174 .iter()
175 .map(|&v| {
176 let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
177 v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
178 })
179 .collect();
180
181 Tensor::from_vec(result, x.shape()).unwrap()
182}
183
184pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
186 let data = x.to_vec();
187 let result: Vec<T> = data
188 .iter()
189 .map(|&v| if v > T::zero() { v } else { negative_slope * v })
190 .collect();
191
192 Tensor::from_vec(result, x.shape()).unwrap()
193}
194
195pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
197 let data = x.to_vec();
198 let result: Vec<T> = data
199 .iter()
200 .map(|&v| {
201 if v > T::zero() {
202 v
203 } else {
204 alpha * (v.exp_value() - T::one())
205 }
206 })
207 .collect();
208
209 Tensor::from_vec(result, x.shape()).unwrap()
210}
211
212#[must_use]
214pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
215 let sig = x.sigmoid();
216 x.mul(&sig).unwrap()
217}
218
219pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
225 let data = x.to_vec();
226 let result: Vec<T> = data
227 .iter()
228 .map(|&v| {
229 if v < min {
230 min
231 } else if v > max {
232 max
233 } else {
234 v
235 }
236 })
237 .collect();
238
239 Tensor::from_vec(result, x.shape()).unwrap()
240}
241
242pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
244 let data = x.to_vec();
245 let result: Vec<T> = data
246 .iter()
247 .map(|&v| if v < min { min } else { v })
248 .collect();
249
250 Tensor::from_vec(result, x.shape()).unwrap()
251}
252
253pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
255 let data = x.to_vec();
256 let result: Vec<T> = data
257 .iter()
258 .map(|&v| if v > max { max } else { v })
259 .collect();
260
261 Tensor::from_vec(result, x.shape()).unwrap()
262}
263
264pub fn where_cond<T: Scalar>(
270 condition: &[bool],
271 x: &Tensor<T>,
272 y: &Tensor<T>,
273) -> Result<Tensor<T>> {
274 if x.shape() != y.shape() {
275 return Err(axonml_core::error::Error::shape_mismatch(
276 x.shape(),
277 y.shape(),
278 ));
279 }
280
281 if condition.len() != x.numel() {
282 return Err(axonml_core::error::Error::shape_mismatch(
283 &[condition.len()],
284 &[x.numel()],
285 ));
286 }
287
288 let x_data = x.to_vec();
289 let y_data = y.to_vec();
290
291 let result: Vec<T> = condition
292 .iter()
293 .zip(x_data.iter().zip(y_data.iter()))
294 .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
295 .collect();
296
297 Tensor::from_vec(result, x.shape())
298}
299
300#[derive(Clone)]
306pub struct TopKResult<T: Scalar> {
307 pub values: Tensor<T>,
309 pub indices: Tensor<i64>,
311}
312
313impl<T: Scalar> std::fmt::Debug for TopKResult<T> {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("TopKResult")
316 .field("values_shape", &self.values.shape())
317 .field("indices_shape", &self.indices.shape())
318 .finish()
319 }
320}
321
322pub fn topk<T: Numeric>(
334 x: &Tensor<T>,
335 k: usize,
336 dim: i64,
337 largest: bool,
338 sorted: bool,
339) -> Result<TopKResult<T>> {
340 let shape = x.shape();
341 if shape.is_empty() {
342 return Err(axonml_core::error::Error::invalid_operation(
343 "Cannot apply topk to scalar tensor".to_string(),
344 ));
345 }
346
347 let dim = if dim < 0 {
348 (shape.len() as i64 + dim) as usize
349 } else {
350 dim as usize
351 };
352
353 if dim >= shape.len() {
354 return Err(axonml_core::error::Error::invalid_operation(format!(
355 "Dimension {} out of range for tensor with {} dimensions",
356 dim,
357 shape.len()
358 )));
359 }
360
361 let dim_size = shape[dim];
362 if k > dim_size {
363 return Err(axonml_core::error::Error::invalid_operation(format!(
364 "k ({}) is larger than dimension size ({})",
365 k, dim_size
366 )));
367 }
368
369 let data = x.to_vec();
370
371 if shape.len() == 1 {
373 let mut indexed: Vec<(usize, T)> = data.into_iter().enumerate().collect();
374 if largest {
375 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
376 } else {
377 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
378 }
379
380 if !sorted {
381 indexed[..k].sort_by_key(|x| x.0);
382 }
383
384 let values: Vec<T> = indexed[..k].iter().map(|(_, v)| *v).collect();
385 let indices: Vec<i64> = indexed[..k].iter().map(|(i, _)| *i as i64).collect();
386
387 return Ok(TopKResult {
388 values: Tensor::from_vec(values, &[k])?,
389 indices: Tensor::from_vec(indices, &[k])?,
390 });
391 }
392
393 let outer_size: usize = shape[..dim].iter().product();
395 let inner_size: usize = shape[dim + 1..].iter().product();
396
397 let mut values_data = Vec::with_capacity(outer_size * k * inner_size);
398 let mut indices_data = Vec::with_capacity(outer_size * k * inner_size);
399
400 for outer in 0..outer_size {
401 for inner in 0..inner_size {
402 let mut slice: Vec<(usize, T)> = (0..dim_size)
403 .map(|d| {
404 let idx = outer * dim_size * inner_size + d * inner_size + inner;
405 (d, data[idx])
406 })
407 .collect();
408
409 if largest {
410 slice.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
411 } else {
412 slice.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
413 }
414
415 if !sorted {
416 slice[..k].sort_by_key(|x| x.0);
417 }
418
419 for (orig_idx, val) in slice.into_iter().take(k) {
420 values_data.push(val);
421 indices_data.push(orig_idx as i64);
422 }
423 }
424 }
425
426 let mut output_shape = shape.to_vec();
427 output_shape[dim] = k;
428
429 Ok(TopKResult {
430 values: Tensor::from_vec(values_data, &output_shape)?,
431 indices: Tensor::from_vec(indices_data, &output_shape)?,
432 })
433}
434
435#[derive(Clone)]
437pub struct SortResult<T: Scalar> {
438 pub values: Tensor<T>,
440 pub indices: Tensor<i64>,
442}
443
444impl<T: Scalar> std::fmt::Debug for SortResult<T> {
445 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
446 f.debug_struct("SortResult")
447 .field("values_shape", &self.values.shape())
448 .field("indices_shape", &self.indices.shape())
449 .finish()
450 }
451}
452
453pub fn sort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<SortResult<T>> {
463 let shape = x.shape();
464 if shape.is_empty() {
465 return Ok(SortResult {
466 values: x.clone(),
467 indices: Tensor::scalar(0i64),
468 });
469 }
470
471 let dim = if dim < 0 {
472 (shape.len() as i64 + dim) as usize
473 } else {
474 dim as usize
475 };
476
477 let dim_size = shape[dim];
478 topk(x, dim_size, dim as i64, descending, true).map(|tk| SortResult {
479 values: tk.values,
480 indices: tk.indices,
481 })
482}
483
484pub fn argsort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<Tensor<i64>> {
491 sort(x, dim, descending).map(|r| r.indices)
492}
493
494pub fn scatter<T: Scalar>(
508 dst: &Tensor<T>,
509 dim: usize,
510 index: &Tensor<i64>,
511 src: &Tensor<T>,
512) -> Result<Tensor<T>> {
513 let dst_shape = dst.shape();
514 let idx_shape = index.shape();
515 let src_shape = src.shape();
516
517 if idx_shape != src_shape {
518 return Err(axonml_core::error::Error::shape_mismatch(
519 idx_shape, src_shape,
520 ));
521 }
522
523 if dim >= dst_shape.len() {
524 return Err(axonml_core::error::Error::invalid_operation(format!(
525 "Dimension {} out of range",
526 dim
527 )));
528 }
529
530 let mut result = dst.to_vec();
531 let idx_data = index.to_vec();
532 let src_data = src.to_vec();
533
534 let mut dst_strides = vec![1usize; dst_shape.len()];
536 for i in (0..dst_shape.len() - 1).rev() {
537 dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
538 }
539
540 let mut idx_strides = vec![1usize; idx_shape.len()];
542 for i in (0..idx_shape.len() - 1).rev() {
543 idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
544 }
545
546 let total = index.numel();
548 for linear_idx in 0..total {
549 let mut nd_idx = vec![0usize; idx_shape.len()];
551 let mut remaining = linear_idx;
552 for d in 0..idx_shape.len() {
553 nd_idx[d] = remaining / idx_strides[d];
554 remaining %= idx_strides[d];
555 }
556
557 let scatter_idx = idx_data[linear_idx] as usize;
559
560 let mut dst_nd_idx = nd_idx.clone();
562 dst_nd_idx[dim] = scatter_idx;
563
564 let mut dst_linear = 0;
566 for d in 0..dst_shape.len() {
567 dst_linear += dst_nd_idx[d] * dst_strides[d];
568 }
569
570 result[dst_linear] = src_data[linear_idx];
571 }
572
573 Tensor::from_vec(result, dst_shape)
574}
575
576pub fn nonzero<T: Numeric>(x: &Tensor<T>) -> Tensor<i64> {
588 let data = x.to_vec();
589 let shape = x.shape();
590 let ndim = shape.len();
591
592 let mut indices: Vec<Vec<i64>> = Vec::new();
594
595 let mut strides = vec![1usize; ndim.max(1)];
597 for i in (0..ndim.saturating_sub(1)).rev() {
598 strides[i] = strides[i + 1] * shape[i + 1];
599 }
600
601 for (linear_idx, &val) in data.iter().enumerate() {
602 if val != T::zero() {
603 let mut nd_idx = vec![0i64; ndim.max(1)];
604 let mut remaining = linear_idx;
605 for d in 0..ndim {
606 nd_idx[d] = (remaining / strides[d]) as i64;
607 remaining %= strides[d];
608 }
609 indices.push(nd_idx);
610 }
611 }
612
613 let num_nonzero = indices.len();
614 if num_nonzero == 0 {
615 return Tensor::from_vec(vec![], &[0, ndim.max(1)]).unwrap();
616 }
617
618 let flat: Vec<i64> = indices.into_iter().flatten().collect();
619 Tensor::from_vec(flat, &[num_nonzero, ndim.max(1)]).unwrap()
620}
621
622#[derive(Clone)]
628pub struct UniqueResult<T: Scalar> {
629 pub values: Tensor<T>,
631 pub inverse_indices: Option<Tensor<i64>>,
633 pub counts: Option<Tensor<i64>>,
635}
636
637impl<T: Scalar> std::fmt::Debug for UniqueResult<T> {
638 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639 f.debug_struct("UniqueResult")
640 .field("values_shape", &self.values.shape())
641 .field("has_inverse", &self.inverse_indices.is_some())
642 .field("has_counts", &self.counts.is_some())
643 .finish()
644 }
645}
646
647pub fn unique<T: Numeric>(
655 x: &Tensor<T>,
656 sorted: bool,
657 return_inverse: bool,
658 return_counts: bool,
659) -> UniqueResult<T> {
660 let data = x.to_vec();
661
662 let mut seen: Vec<T> = Vec::new();
664 let mut counts_map: Vec<i64> = Vec::new();
665 let mut inverse: Vec<i64> = Vec::with_capacity(data.len());
666
667 for &val in &data {
668 if let Some(pos) = seen.iter().position(|&v| v == val) {
669 inverse.push(pos as i64);
670 counts_map[pos] += 1;
671 } else {
672 inverse.push(seen.len() as i64);
673 seen.push(val);
674 counts_map.push(1);
675 }
676 }
677
678 let (unique_vals, final_inverse, final_counts) = if sorted {
679 let mut indexed: Vec<(usize, T)> = seen.into_iter().enumerate().collect();
681 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
682
683 let mut old_to_new = vec![0i64; indexed.len()];
685 for (new_idx, (old_idx, _)) in indexed.iter().enumerate() {
686 old_to_new[*old_idx] = new_idx as i64;
687 }
688
689 let sorted_vals: Vec<T> = indexed.iter().map(|(_, v)| *v).collect();
690 let sorted_counts: Vec<i64> = indexed
691 .iter()
692 .map(|(old_idx, _)| counts_map[*old_idx])
693 .collect();
694 let updated_inverse: Vec<i64> = inverse.iter().map(|&i| old_to_new[i as usize]).collect();
695
696 (sorted_vals, updated_inverse, sorted_counts)
697 } else {
698 (seen, inverse, counts_map)
699 };
700
701 let n = unique_vals.len();
702
703 UniqueResult {
704 values: Tensor::from_vec(unique_vals, &[n]).unwrap(),
705 inverse_indices: if return_inverse {
706 Some(Tensor::from_vec(final_inverse, x.shape()).unwrap())
707 } else {
708 None
709 },
710 counts: if return_counts {
711 Some(Tensor::from_vec(final_counts, &[n]).unwrap())
712 } else {
713 None
714 },
715 }
716}
717
718pub fn flip<T: Numeric>(x: &Tensor<T>, dims: &[usize]) -> Result<Tensor<T>> {
724 let shape = x.shape();
725 let data = x.to_vec();
726 let ndim = shape.len();
727
728 for &d in dims {
729 if d >= ndim {
730 return Err(axonml_core::error::Error::invalid_operation(format!(
731 "Dimension {} out of range for tensor with {} dimensions",
732 d, ndim
733 )));
734 }
735 }
736
737 if shape.is_empty() {
738 return Ok(x.clone());
739 }
740
741 let mut strides = vec![1usize; ndim];
743 for i in (0..ndim - 1).rev() {
744 strides[i] = strides[i + 1] * shape[i + 1];
745 }
746
747 let mut result = vec![T::zero(); data.len()];
748
749 for src_linear in 0..data.len() {
750 let mut nd_idx = vec![0usize; ndim];
752 let mut remaining = src_linear;
753 for d in 0..ndim {
754 nd_idx[d] = remaining / strides[d];
755 remaining %= strides[d];
756 }
757
758 for &flip_dim in dims {
760 nd_idx[flip_dim] = shape[flip_dim] - 1 - nd_idx[flip_dim];
761 }
762
763 let mut dst_linear = 0;
765 for d in 0..ndim {
766 dst_linear += nd_idx[d] * strides[d];
767 }
768
769 result[dst_linear] = data[src_linear];
770 }
771
772 Tensor::from_vec(result, shape)
773}
774
775pub fn roll<T: Numeric>(x: &Tensor<T>, shifts: &[i64], dims: &[usize]) -> Result<Tensor<T>> {
781 if shifts.len() != dims.len() {
782 return Err(axonml_core::error::Error::invalid_operation(
783 "shifts and dims must have the same length".to_string(),
784 ));
785 }
786
787 let shape = x.shape();
788 let data = x.to_vec();
789 let ndim = shape.len();
790
791 for &d in dims {
792 if d >= ndim {
793 return Err(axonml_core::error::Error::invalid_operation(format!(
794 "Dimension {} out of range",
795 d
796 )));
797 }
798 }
799
800 if shape.is_empty() {
801 return Ok(x.clone());
802 }
803
804 let mut strides = vec![1usize; ndim];
806 for i in (0..ndim - 1).rev() {
807 strides[i] = strides[i + 1] * shape[i + 1];
808 }
809
810 let mut result = vec![T::zero(); data.len()];
811
812 for src_linear in 0..data.len() {
813 let mut nd_idx = vec![0usize; ndim];
815 let mut remaining = src_linear;
816 for d in 0..ndim {
817 nd_idx[d] = remaining / strides[d];
818 remaining %= strides[d];
819 }
820
821 for (shift, &dim) in shifts.iter().zip(dims.iter()) {
823 let dim_size = shape[dim] as i64;
824 let new_idx = ((nd_idx[dim] as i64 + shift) % dim_size + dim_size) % dim_size;
825 nd_idx[dim] = new_idx as usize;
826 }
827
828 let mut dst_linear = 0;
830 for d in 0..ndim {
831 dst_linear += nd_idx[d] * strides[d];
832 }
833
834 result[dst_linear] = data[src_linear];
835 }
836
837 Tensor::from_vec(result, shape)
838}
839
840#[cfg(test)]
845mod tests {
846 use super::*;
847
848 #[test]
849 fn test_softmax() {
850 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
851 let s = softmax(&t, -1).unwrap();
852
853 let sum: f32 = s.to_vec().iter().sum();
854 assert!((sum - 1.0).abs() < 1e-5);
855 }
856
857 #[test]
858 fn test_clamp() {
859 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
860 let c = clamp(&t, 0.0, 1.0);
861 assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
862 }
863
864 #[test]
865 fn test_leaky_relu() {
866 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
867 let r = leaky_relu(&t, 0.01);
868 assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
869 }
870
871 #[test]
872 fn test_comparison() {
873 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
874 let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
875
876 assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
877 assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
878 assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
879 }
880
881 #[test]
882 fn test_topk() {
883 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
884 let result = topk(&t, 3, -1, true, true).unwrap();
885
886 assert_eq!(result.values.shape(), &[3]);
887 assert_eq!(result.values.to_vec(), vec![9.0, 5.0, 4.0]);
888 assert_eq!(result.indices.to_vec(), vec![5, 4, 2]);
889 }
890
891 #[test]
892 fn test_topk_smallest() {
893 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
894 let result = topk(&t, 2, -1, false, true).unwrap();
895
896 assert_eq!(result.values.to_vec(), vec![1.0, 1.0]);
897 }
898
899 #[test]
900 fn test_sort() {
901 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
902 let result = sort(&t, -1, false).unwrap();
903
904 assert_eq!(result.values.to_vec(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
905 }
906
907 #[test]
908 fn test_sort_descending() {
909 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0], &[3]).unwrap();
910 let result = sort(&t, -1, true).unwrap();
911
912 assert_eq!(result.values.to_vec(), vec![4.0, 3.0, 1.0]);
913 }
914
915 #[test]
916 fn test_argsort() {
917 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 2.0], &[3]).unwrap();
918 let indices = argsort(&t, -1, false).unwrap();
919
920 assert_eq!(indices.to_vec(), vec![1, 2, 0]);
921 }
922
923 #[test]
924 fn test_nonzero() {
925 let t = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 2.0, 3.0, 0.0], &[6]).unwrap();
926 let result = nonzero(&t);
927
928 assert_eq!(result.shape(), &[3, 1]);
929 assert_eq!(result.to_vec(), vec![1, 3, 4]);
930 }
931
932 #[test]
933 fn test_nonzero_2d() {
934 let t = Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
935 let result = nonzero(&t);
936
937 assert_eq!(result.shape(), &[2, 2]);
938 assert_eq!(result.to_vec(), vec![0, 0, 1, 1]);
940 }
941
942 #[test]
943 fn test_unique() {
944 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], &[6]).unwrap();
945 let result = unique(&t, true, true, true);
946
947 assert_eq!(result.values.to_vec(), vec![1.0, 2.0, 3.0]);
948 assert_eq!(
949 result.inverse_indices.unwrap().to_vec(),
950 vec![0, 1, 0, 2, 1, 0]
951 );
952 assert_eq!(result.counts.unwrap().to_vec(), vec![3, 2, 1]);
953 }
954
955 #[test]
956 fn test_unique_unsorted() {
957 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 3.0], &[3]).unwrap();
958 let result = unique(&t, false, false, false);
959
960 assert_eq!(result.values.to_vec(), vec![3.0, 1.0]);
962 }
963
964 #[test]
965 fn test_flip() {
966 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
967 let flipped = flip(&t, &[0]).unwrap();
968
969 assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
970 }
971
972 #[test]
973 fn test_flip_2d() {
974 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
975 let flipped = flip(&t, &[0]).unwrap();
976
977 assert_eq!(flipped.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
979 }
980
981 #[test]
982 fn test_roll() {
983 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
984 let rolled = roll(&t, &[1], &[0]).unwrap();
985
986 assert_eq!(rolled.to_vec(), vec![4.0, 1.0, 2.0, 3.0]);
987 }
988
989 #[test]
990 fn test_roll_negative() {
991 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
992 let rolled = roll(&t, &[-1], &[0]).unwrap();
993
994 assert_eq!(rolled.to_vec(), vec![2.0, 3.0, 4.0, 1.0]);
995 }
996
997 #[test]
998 fn test_scatter() {
999 let dst = Tensor::<f32>::zeros(&[3]);
1000 let index = Tensor::from_vec(vec![0_i64, 2], &[2]).unwrap();
1001 let src = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
1002
1003 let result = scatter(&dst, 0, &index, &src).unwrap();
1004 assert_eq!(result.to_vec(), vec![1.0, 0.0, 2.0]);
1005 }
1006}