1use axonml_core::dtype::{Float, Numeric, Scalar};
64use axonml_core::error::Result;
65
66use crate::tensor::Tensor;
67
68pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
74 if a.shape() != b.shape() {
75 return Err(axonml_core::error::Error::shape_mismatch(
76 a.shape(),
77 b.shape(),
78 ));
79 }
80
81 let a_data = a.to_vec();
82 let b_data = b.to_vec();
83
84 Ok(a_data
85 .iter()
86 .zip(b_data.iter())
87 .map(|(x, y)| x == y)
88 .collect())
89}
90
91pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
93 if a.shape() != b.shape() {
94 return Err(axonml_core::error::Error::shape_mismatch(
95 a.shape(),
96 b.shape(),
97 ));
98 }
99
100 let a_data = a.to_vec();
101 let b_data = b.to_vec();
102
103 Ok(a_data
104 .iter()
105 .zip(b_data.iter())
106 .map(|(x, y)| x < y)
107 .collect())
108}
109
110pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
112 if a.shape() != b.shape() {
113 return Err(axonml_core::error::Error::shape_mismatch(
114 a.shape(),
115 b.shape(),
116 ));
117 }
118
119 let a_data = a.to_vec();
120 let b_data = b.to_vec();
121
122 Ok(a_data
123 .iter()
124 .zip(b_data.iter())
125 .map(|(x, y)| x > y)
126 .collect())
127}
128
129pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
135 let data = x.to_vec();
137 let shape = x.shape();
138
139 if shape.is_empty() {
140 return Ok(Tensor::scalar(T::one()));
141 }
142
143 let max_val = data
145 .iter()
146 .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
147
148 let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
150
151 let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
153
154 let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
156
157 Tensor::from_vec(result, shape)
158}
159
160pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
162 let sm = softmax(x, dim)?;
163 Ok(sm.ln())
164}
165
166#[must_use] pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
168 let data = x.to_vec();
169 let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
170 let coeff = T::from(0.044715).unwrap();
171
172 let result: Vec<T> = data
173 .iter()
174 .map(|&v| {
175 let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
176 v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
177 })
178 .collect();
179
180 Tensor::from_vec(result, x.shape()).unwrap()
181}
182
183pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
185 let data = x.to_vec();
186 let result: Vec<T> = data
187 .iter()
188 .map(|&v| if v > T::zero() { v } else { negative_slope * v })
189 .collect();
190
191 Tensor::from_vec(result, x.shape()).unwrap()
192}
193
194pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
196 let data = x.to_vec();
197 let result: Vec<T> = data
198 .iter()
199 .map(|&v| {
200 if v > T::zero() {
201 v
202 } else {
203 alpha * (v.exp_value() - T::one())
204 }
205 })
206 .collect();
207
208 Tensor::from_vec(result, x.shape()).unwrap()
209}
210
211#[must_use] pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
213 let sig = x.sigmoid();
214 x.mul(&sig).unwrap()
215}
216
217pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
223 let data = x.to_vec();
224 let result: Vec<T> = data
225 .iter()
226 .map(|&v| {
227 if v < min {
228 min
229 } else if v > max {
230 max
231 } else {
232 v
233 }
234 })
235 .collect();
236
237 Tensor::from_vec(result, x.shape()).unwrap()
238}
239
240pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
242 let data = x.to_vec();
243 let result: Vec<T> = data
244 .iter()
245 .map(|&v| if v < min { min } else { v })
246 .collect();
247
248 Tensor::from_vec(result, x.shape()).unwrap()
249}
250
251pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
253 let data = x.to_vec();
254 let result: Vec<T> = data
255 .iter()
256 .map(|&v| if v > max { max } else { v })
257 .collect();
258
259 Tensor::from_vec(result, x.shape()).unwrap()
260}
261
262pub fn where_cond<T: Scalar>(
268 condition: &[bool],
269 x: &Tensor<T>,
270 y: &Tensor<T>,
271) -> Result<Tensor<T>> {
272 if x.shape() != y.shape() {
273 return Err(axonml_core::error::Error::shape_mismatch(
274 x.shape(),
275 y.shape(),
276 ));
277 }
278
279 if condition.len() != x.numel() {
280 return Err(axonml_core::error::Error::shape_mismatch(
281 &[condition.len()],
282 &[x.numel()],
283 ));
284 }
285
286 let x_data = x.to_vec();
287 let y_data = y.to_vec();
288
289 let result: Vec<T> = condition
290 .iter()
291 .zip(x_data.iter().zip(y_data.iter()))
292 .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
293 .collect();
294
295 Tensor::from_vec(result, x.shape())
296}
297
298#[derive(Clone)]
304pub struct TopKResult<T: Scalar> {
305 pub values: Tensor<T>,
307 pub indices: Tensor<i64>,
309}
310
311impl<T: Scalar> std::fmt::Debug for TopKResult<T> {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 f.debug_struct("TopKResult")
314 .field("values_shape", &self.values.shape())
315 .field("indices_shape", &self.indices.shape())
316 .finish()
317 }
318}
319
320pub fn topk<T: Numeric>(
332 x: &Tensor<T>,
333 k: usize,
334 dim: i64,
335 largest: bool,
336 sorted: bool,
337) -> Result<TopKResult<T>> {
338 let shape = x.shape();
339 if shape.is_empty() {
340 return Err(axonml_core::error::Error::invalid_operation(
341 "Cannot apply topk to scalar tensor".to_string(),
342 ));
343 }
344
345 let dim = if dim < 0 {
346 (shape.len() as i64 + dim) as usize
347 } else {
348 dim as usize
349 };
350
351 if dim >= shape.len() {
352 return Err(axonml_core::error::Error::invalid_operation(
353 format!("Dimension {} out of range for tensor with {} dimensions", dim, shape.len()),
354 ));
355 }
356
357 let dim_size = shape[dim];
358 if k > dim_size {
359 return Err(axonml_core::error::Error::invalid_operation(
360 format!("k ({}) is larger than dimension size ({})", k, dim_size),
361 ));
362 }
363
364 let data = x.to_vec();
365
366 if shape.len() == 1 {
368 let mut indexed: Vec<(usize, T)> = data.into_iter().enumerate().collect();
369 if largest {
370 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
371 } else {
372 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
373 }
374
375 if !sorted {
376 indexed[..k].sort_by_key(|x| x.0);
377 }
378
379 let values: Vec<T> = indexed[..k].iter().map(|(_, v)| *v).collect();
380 let indices: Vec<i64> = indexed[..k].iter().map(|(i, _)| *i as i64).collect();
381
382 return Ok(TopKResult {
383 values: Tensor::from_vec(values, &[k])?,
384 indices: Tensor::from_vec(indices, &[k])?,
385 });
386 }
387
388 let outer_size: usize = shape[..dim].iter().product();
390 let inner_size: usize = shape[dim + 1..].iter().product();
391
392 let mut values_data = Vec::with_capacity(outer_size * k * inner_size);
393 let mut indices_data = Vec::with_capacity(outer_size * k * inner_size);
394
395 for outer in 0..outer_size {
396 for inner in 0..inner_size {
397 let mut slice: Vec<(usize, T)> = (0..dim_size)
398 .map(|d| {
399 let idx = outer * dim_size * inner_size + d * inner_size + inner;
400 (d, data[idx])
401 })
402 .collect();
403
404 if largest {
405 slice.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
406 } else {
407 slice.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
408 }
409
410 if !sorted {
411 slice[..k].sort_by_key(|x| x.0);
412 }
413
414 for (orig_idx, val) in slice.into_iter().take(k) {
415 values_data.push(val);
416 indices_data.push(orig_idx as i64);
417 }
418 }
419 }
420
421 let mut output_shape = shape.to_vec();
422 output_shape[dim] = k;
423
424 Ok(TopKResult {
425 values: Tensor::from_vec(values_data, &output_shape)?,
426 indices: Tensor::from_vec(indices_data, &output_shape)?,
427 })
428}
429
430#[derive(Clone)]
432pub struct SortResult<T: Scalar> {
433 pub values: Tensor<T>,
435 pub indices: Tensor<i64>,
437}
438
439impl<T: Scalar> std::fmt::Debug for SortResult<T> {
440 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441 f.debug_struct("SortResult")
442 .field("values_shape", &self.values.shape())
443 .field("indices_shape", &self.indices.shape())
444 .finish()
445 }
446}
447
448pub fn sort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<SortResult<T>> {
458 let shape = x.shape();
459 if shape.is_empty() {
460 return Ok(SortResult {
461 values: x.clone(),
462 indices: Tensor::scalar(0i64),
463 });
464 }
465
466 let dim = if dim < 0 {
467 (shape.len() as i64 + dim) as usize
468 } else {
469 dim as usize
470 };
471
472 let dim_size = shape[dim];
473 topk(x, dim_size, dim as i64, descending, true).map(|tk| SortResult {
474 values: tk.values,
475 indices: tk.indices,
476 })
477}
478
479pub fn argsort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<Tensor<i64>> {
486 sort(x, dim, descending).map(|r| r.indices)
487}
488
489pub fn scatter<T: Scalar>(
503 dst: &Tensor<T>,
504 dim: usize,
505 index: &Tensor<i64>,
506 src: &Tensor<T>,
507) -> Result<Tensor<T>> {
508 let dst_shape = dst.shape();
509 let idx_shape = index.shape();
510 let src_shape = src.shape();
511
512 if idx_shape != src_shape {
513 return Err(axonml_core::error::Error::shape_mismatch(idx_shape, src_shape));
514 }
515
516 if dim >= dst_shape.len() {
517 return Err(axonml_core::error::Error::invalid_operation(
518 format!("Dimension {} out of range", dim),
519 ));
520 }
521
522 let mut result = dst.to_vec();
523 let idx_data = index.to_vec();
524 let src_data = src.to_vec();
525
526 let mut dst_strides = vec![1usize; dst_shape.len()];
528 for i in (0..dst_shape.len() - 1).rev() {
529 dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
530 }
531
532 let mut idx_strides = vec![1usize; idx_shape.len()];
534 for i in (0..idx_shape.len() - 1).rev() {
535 idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
536 }
537
538 let total = index.numel();
540 for linear_idx in 0..total {
541 let mut nd_idx = vec![0usize; idx_shape.len()];
543 let mut remaining = linear_idx;
544 for d in 0..idx_shape.len() {
545 nd_idx[d] = remaining / idx_strides[d];
546 remaining %= idx_strides[d];
547 }
548
549 let scatter_idx = idx_data[linear_idx] as usize;
551
552 let mut dst_nd_idx = nd_idx.clone();
554 dst_nd_idx[dim] = scatter_idx;
555
556 let mut dst_linear = 0;
558 for d in 0..dst_shape.len() {
559 dst_linear += dst_nd_idx[d] * dst_strides[d];
560 }
561
562 result[dst_linear] = src_data[linear_idx];
563 }
564
565 Tensor::from_vec(result, dst_shape)
566}
567
568pub fn nonzero<T: Numeric>(x: &Tensor<T>) -> Tensor<i64> {
580 let data = x.to_vec();
581 let shape = x.shape();
582 let ndim = shape.len();
583
584 let mut indices: Vec<Vec<i64>> = Vec::new();
586
587 let mut strides = vec![1usize; ndim.max(1)];
589 for i in (0..ndim.saturating_sub(1)).rev() {
590 strides[i] = strides[i + 1] * shape[i + 1];
591 }
592
593 for (linear_idx, &val) in data.iter().enumerate() {
594 if val != T::zero() {
595 let mut nd_idx = vec![0i64; ndim.max(1)];
596 let mut remaining = linear_idx;
597 for d in 0..ndim {
598 nd_idx[d] = (remaining / strides[d]) as i64;
599 remaining %= strides[d];
600 }
601 indices.push(nd_idx);
602 }
603 }
604
605 let num_nonzero = indices.len();
606 if num_nonzero == 0 {
607 return Tensor::from_vec(vec![], &[0, ndim.max(1)]).unwrap();
608 }
609
610 let flat: Vec<i64> = indices.into_iter().flatten().collect();
611 Tensor::from_vec(flat, &[num_nonzero, ndim.max(1)]).unwrap()
612}
613
614#[derive(Clone)]
620pub struct UniqueResult<T: Scalar> {
621 pub values: Tensor<T>,
623 pub inverse_indices: Option<Tensor<i64>>,
625 pub counts: Option<Tensor<i64>>,
627}
628
629impl<T: Scalar> std::fmt::Debug for UniqueResult<T> {
630 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631 f.debug_struct("UniqueResult")
632 .field("values_shape", &self.values.shape())
633 .field("has_inverse", &self.inverse_indices.is_some())
634 .field("has_counts", &self.counts.is_some())
635 .finish()
636 }
637}
638
639pub fn unique<T: Numeric>(
647 x: &Tensor<T>,
648 sorted: bool,
649 return_inverse: bool,
650 return_counts: bool,
651) -> UniqueResult<T> {
652 let data = x.to_vec();
653
654 let mut seen: Vec<T> = Vec::new();
656 let mut counts_map: Vec<i64> = Vec::new();
657 let mut inverse: Vec<i64> = Vec::with_capacity(data.len());
658
659 for &val in &data {
660 if let Some(pos) = seen.iter().position(|&v| v == val) {
661 inverse.push(pos as i64);
662 counts_map[pos] += 1;
663 } else {
664 inverse.push(seen.len() as i64);
665 seen.push(val);
666 counts_map.push(1);
667 }
668 }
669
670 let (unique_vals, final_inverse, final_counts) = if sorted {
671 let mut indexed: Vec<(usize, T)> = seen.into_iter().enumerate().collect();
673 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
674
675 let mut old_to_new = vec![0i64; indexed.len()];
677 for (new_idx, (old_idx, _)) in indexed.iter().enumerate() {
678 old_to_new[*old_idx] = new_idx as i64;
679 }
680
681 let sorted_vals: Vec<T> = indexed.iter().map(|(_, v)| *v).collect();
682 let sorted_counts: Vec<i64> = indexed.iter().map(|(old_idx, _)| counts_map[*old_idx]).collect();
683 let updated_inverse: Vec<i64> = inverse.iter().map(|&i| old_to_new[i as usize]).collect();
684
685 (sorted_vals, updated_inverse, sorted_counts)
686 } else {
687 (seen, inverse, counts_map)
688 };
689
690 let n = unique_vals.len();
691
692 UniqueResult {
693 values: Tensor::from_vec(unique_vals, &[n]).unwrap(),
694 inverse_indices: if return_inverse {
695 Some(Tensor::from_vec(final_inverse, x.shape()).unwrap())
696 } else {
697 None
698 },
699 counts: if return_counts {
700 Some(Tensor::from_vec(final_counts, &[n]).unwrap())
701 } else {
702 None
703 },
704 }
705}
706
707pub fn flip<T: Numeric>(x: &Tensor<T>, dims: &[usize]) -> Result<Tensor<T>> {
713 let shape = x.shape();
714 let data = x.to_vec();
715 let ndim = shape.len();
716
717 for &d in dims {
718 if d >= ndim {
719 return Err(axonml_core::error::Error::invalid_operation(
720 format!("Dimension {} out of range for tensor with {} dimensions", d, ndim),
721 ));
722 }
723 }
724
725 if shape.is_empty() {
726 return Ok(x.clone());
727 }
728
729 let mut strides = vec![1usize; ndim];
731 for i in (0..ndim - 1).rev() {
732 strides[i] = strides[i + 1] * shape[i + 1];
733 }
734
735 let mut result = vec![T::zero(); data.len()];
736
737 for src_linear in 0..data.len() {
738 let mut nd_idx = vec![0usize; ndim];
740 let mut remaining = src_linear;
741 for d in 0..ndim {
742 nd_idx[d] = remaining / strides[d];
743 remaining %= strides[d];
744 }
745
746 for &flip_dim in dims {
748 nd_idx[flip_dim] = shape[flip_dim] - 1 - nd_idx[flip_dim];
749 }
750
751 let mut dst_linear = 0;
753 for d in 0..ndim {
754 dst_linear += nd_idx[d] * strides[d];
755 }
756
757 result[dst_linear] = data[src_linear];
758 }
759
760 Tensor::from_vec(result, shape)
761}
762
763pub fn roll<T: Numeric>(x: &Tensor<T>, shifts: &[i64], dims: &[usize]) -> Result<Tensor<T>> {
769 if shifts.len() != dims.len() {
770 return Err(axonml_core::error::Error::invalid_operation(
771 "shifts and dims must have the same length".to_string(),
772 ));
773 }
774
775 let shape = x.shape();
776 let data = x.to_vec();
777 let ndim = shape.len();
778
779 for &d in dims {
780 if d >= ndim {
781 return Err(axonml_core::error::Error::invalid_operation(
782 format!("Dimension {} out of range", d),
783 ));
784 }
785 }
786
787 if shape.is_empty() {
788 return Ok(x.clone());
789 }
790
791 let mut strides = vec![1usize; ndim];
793 for i in (0..ndim - 1).rev() {
794 strides[i] = strides[i + 1] * shape[i + 1];
795 }
796
797 let mut result = vec![T::zero(); data.len()];
798
799 for src_linear in 0..data.len() {
800 let mut nd_idx = vec![0usize; ndim];
802 let mut remaining = src_linear;
803 for d in 0..ndim {
804 nd_idx[d] = remaining / strides[d];
805 remaining %= strides[d];
806 }
807
808 for (shift, &dim) in shifts.iter().zip(dims.iter()) {
810 let dim_size = shape[dim] as i64;
811 let new_idx = ((nd_idx[dim] as i64 + shift) % dim_size + dim_size) % dim_size;
812 nd_idx[dim] = new_idx as usize;
813 }
814
815 let mut dst_linear = 0;
817 for d in 0..ndim {
818 dst_linear += nd_idx[d] * strides[d];
819 }
820
821 result[dst_linear] = data[src_linear];
822 }
823
824 Tensor::from_vec(result, shape)
825}
826
827#[cfg(test)]
832mod tests {
833 use super::*;
834
835 #[test]
836 fn test_softmax() {
837 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
838 let s = softmax(&t, -1).unwrap();
839
840 let sum: f32 = s.to_vec().iter().sum();
841 assert!((sum - 1.0).abs() < 1e-5);
842 }
843
844 #[test]
845 fn test_clamp() {
846 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
847 let c = clamp(&t, 0.0, 1.0);
848 assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
849 }
850
851 #[test]
852 fn test_leaky_relu() {
853 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
854 let r = leaky_relu(&t, 0.01);
855 assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
856 }
857
858 #[test]
859 fn test_comparison() {
860 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
861 let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
862
863 assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
864 assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
865 assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
866 }
867
868 #[test]
869 fn test_topk() {
870 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
871 let result = topk(&t, 3, -1, true, true).unwrap();
872
873 assert_eq!(result.values.shape(), &[3]);
874 assert_eq!(result.values.to_vec(), vec![9.0, 5.0, 4.0]);
875 assert_eq!(result.indices.to_vec(), vec![5, 4, 2]);
876 }
877
878 #[test]
879 fn test_topk_smallest() {
880 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
881 let result = topk(&t, 2, -1, false, true).unwrap();
882
883 assert_eq!(result.values.to_vec(), vec![1.0, 1.0]);
884 }
885
886 #[test]
887 fn test_sort() {
888 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
889 let result = sort(&t, -1, false).unwrap();
890
891 assert_eq!(result.values.to_vec(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
892 }
893
894 #[test]
895 fn test_sort_descending() {
896 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0], &[3]).unwrap();
897 let result = sort(&t, -1, true).unwrap();
898
899 assert_eq!(result.values.to_vec(), vec![4.0, 3.0, 1.0]);
900 }
901
902 #[test]
903 fn test_argsort() {
904 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 2.0], &[3]).unwrap();
905 let indices = argsort(&t, -1, false).unwrap();
906
907 assert_eq!(indices.to_vec(), vec![1, 2, 0]);
908 }
909
910 #[test]
911 fn test_nonzero() {
912 let t = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 2.0, 3.0, 0.0], &[6]).unwrap();
913 let result = nonzero(&t);
914
915 assert_eq!(result.shape(), &[3, 1]);
916 assert_eq!(result.to_vec(), vec![1, 3, 4]);
917 }
918
919 #[test]
920 fn test_nonzero_2d() {
921 let t = Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
922 let result = nonzero(&t);
923
924 assert_eq!(result.shape(), &[2, 2]);
925 assert_eq!(result.to_vec(), vec![0, 0, 1, 1]);
927 }
928
929 #[test]
930 fn test_unique() {
931 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], &[6]).unwrap();
932 let result = unique(&t, true, true, true);
933
934 assert_eq!(result.values.to_vec(), vec![1.0, 2.0, 3.0]);
935 assert_eq!(result.inverse_indices.unwrap().to_vec(), vec![0, 1, 0, 2, 1, 0]);
936 assert_eq!(result.counts.unwrap().to_vec(), vec![3, 2, 1]);
937 }
938
939 #[test]
940 fn test_unique_unsorted() {
941 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 3.0], &[3]).unwrap();
942 let result = unique(&t, false, false, false);
943
944 assert_eq!(result.values.to_vec(), vec![3.0, 1.0]);
946 }
947
948 #[test]
949 fn test_flip() {
950 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
951 let flipped = flip(&t, &[0]).unwrap();
952
953 assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
954 }
955
956 #[test]
957 fn test_flip_2d() {
958 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
959 let flipped = flip(&t, &[0]).unwrap();
960
961 assert_eq!(flipped.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
963 }
964
965 #[test]
966 fn test_roll() {
967 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
968 let rolled = roll(&t, &[1], &[0]).unwrap();
969
970 assert_eq!(rolled.to_vec(), vec![4.0, 1.0, 2.0, 3.0]);
971 }
972
973 #[test]
974 fn test_roll_negative() {
975 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
976 let rolled = roll(&t, &[-1], &[0]).unwrap();
977
978 assert_eq!(rolled.to_vec(), vec![2.0, 3.0, 4.0, 1.0]);
979 }
980
981 #[test]
982 fn test_scatter() {
983 let dst = Tensor::<f32>::zeros(&[3]);
984 let index = Tensor::from_vec(vec![0_i64, 2], &[2]).unwrap();
985 let src = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
986
987 let result = scatter(&dst, 0, &index, &src).unwrap();
988 assert_eq!(result.to_vec(), vec![1.0, 0.0, 2.0]);
989 }
990}