1use axonml_core::dtype::{Float, Numeric, Scalar};
21use axonml_core::error::Result;
22
23use crate::tensor::Tensor;
24
25pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
31 if a.shape() != b.shape() {
32 return Err(axonml_core::error::Error::shape_mismatch(
33 a.shape(),
34 b.shape(),
35 ));
36 }
37
38 let a_data = a.to_vec();
39 let b_data = b.to_vec();
40
41 Ok(a_data
42 .iter()
43 .zip(b_data.iter())
44 .map(|(x, y)| x == y)
45 .collect())
46}
47
48pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
50 if a.shape() != b.shape() {
51 return Err(axonml_core::error::Error::shape_mismatch(
52 a.shape(),
53 b.shape(),
54 ));
55 }
56
57 let a_data = a.to_vec();
58 let b_data = b.to_vec();
59
60 Ok(a_data
61 .iter()
62 .zip(b_data.iter())
63 .map(|(x, y)| x < y)
64 .collect())
65}
66
67pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
69 if a.shape() != b.shape() {
70 return Err(axonml_core::error::Error::shape_mismatch(
71 a.shape(),
72 b.shape(),
73 ));
74 }
75
76 let a_data = a.to_vec();
77 let b_data = b.to_vec();
78
79 Ok(a_data
80 .iter()
81 .zip(b_data.iter())
82 .map(|(x, y)| x > y)
83 .collect())
84}
85
86pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
92 let data = x.to_vec();
94 let shape = x.shape();
95
96 if shape.is_empty() {
97 return Ok(Tensor::scalar(T::one()));
98 }
99
100 let max_val = data
102 .iter()
103 .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
104
105 let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
107
108 let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
110
111 let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
113
114 Tensor::from_vec(result, shape)
115}
116
117pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
119 let sm = softmax(x, dim)?;
120 Ok(sm.ln())
121}
122
123#[must_use]
125pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
126 let data = x.to_vec();
127 let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
128 let coeff = T::from(0.044715).unwrap();
129
130 let result: Vec<T> = data
131 .iter()
132 .map(|&v| {
133 let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
134 v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
135 })
136 .collect();
137
138 Tensor::from_vec(result, x.shape()).unwrap()
139}
140
141pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
143 let data = x.to_vec();
144 let result: Vec<T> = data
145 .iter()
146 .map(|&v| if v > T::zero() { v } else { negative_slope * v })
147 .collect();
148
149 Tensor::from_vec(result, x.shape()).unwrap()
150}
151
152pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
154 let data = x.to_vec();
155 let result: Vec<T> = data
156 .iter()
157 .map(|&v| {
158 if v > T::zero() {
159 v
160 } else {
161 alpha * (v.exp_value() - T::one())
162 }
163 })
164 .collect();
165
166 Tensor::from_vec(result, x.shape()).unwrap()
167}
168
169#[must_use]
171pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
172 let sig = x.sigmoid();
173 x.mul(&sig).unwrap()
174}
175
176pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
182 let data = x.to_vec();
183 let result: Vec<T> = data
184 .iter()
185 .map(|&v| {
186 if v < min {
187 min
188 } else if v > max {
189 max
190 } else {
191 v
192 }
193 })
194 .collect();
195
196 Tensor::from_vec(result, x.shape()).unwrap()
197}
198
199pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
201 let data = x.to_vec();
202 let result: Vec<T> = data
203 .iter()
204 .map(|&v| if v < min { min } else { v })
205 .collect();
206
207 Tensor::from_vec(result, x.shape()).unwrap()
208}
209
210pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
212 let data = x.to_vec();
213 let result: Vec<T> = data
214 .iter()
215 .map(|&v| if v > max { max } else { v })
216 .collect();
217
218 Tensor::from_vec(result, x.shape()).unwrap()
219}
220
221pub fn where_cond<T: Scalar>(
227 condition: &[bool],
228 x: &Tensor<T>,
229 y: &Tensor<T>,
230) -> Result<Tensor<T>> {
231 if x.shape() != y.shape() {
232 return Err(axonml_core::error::Error::shape_mismatch(
233 x.shape(),
234 y.shape(),
235 ));
236 }
237
238 if condition.len() != x.numel() {
239 return Err(axonml_core::error::Error::shape_mismatch(
240 &[condition.len()],
241 &[x.numel()],
242 ));
243 }
244
245 let x_data = x.to_vec();
246 let y_data = y.to_vec();
247
248 let result: Vec<T> = condition
249 .iter()
250 .zip(x_data.iter().zip(y_data.iter()))
251 .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
252 .collect();
253
254 Tensor::from_vec(result, x.shape())
255}
256
257#[derive(Clone)]
263pub struct TopKResult<T: Scalar> {
264 pub values: Tensor<T>,
266 pub indices: Tensor<i64>,
268}
269
270impl<T: Scalar> std::fmt::Debug for TopKResult<T> {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 f.debug_struct("TopKResult")
273 .field("values_shape", &self.values.shape())
274 .field("indices_shape", &self.indices.shape())
275 .finish()
276 }
277}
278
279pub fn topk<T: Numeric>(
291 x: &Tensor<T>,
292 k: usize,
293 dim: i64,
294 largest: bool,
295 sorted: bool,
296) -> Result<TopKResult<T>> {
297 let shape = x.shape();
298 if shape.is_empty() {
299 return Err(axonml_core::error::Error::invalid_operation(
300 "Cannot apply topk to scalar tensor".to_string(),
301 ));
302 }
303
304 let dim = if dim < 0 {
305 (shape.len() as i64 + dim) as usize
306 } else {
307 dim as usize
308 };
309
310 if dim >= shape.len() {
311 return Err(axonml_core::error::Error::invalid_operation(format!(
312 "Dimension {} out of range for tensor with {} dimensions",
313 dim,
314 shape.len()
315 )));
316 }
317
318 let dim_size = shape[dim];
319 if k > dim_size {
320 return Err(axonml_core::error::Error::invalid_operation(format!(
321 "k ({}) is larger than dimension size ({})",
322 k, dim_size
323 )));
324 }
325
326 let data = x.to_vec();
327
328 if shape.len() == 1 {
330 let mut indexed: Vec<(usize, T)> = data.into_iter().enumerate().collect();
331 if largest {
332 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
333 } else {
334 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
335 }
336
337 if !sorted {
338 indexed[..k].sort_by_key(|x| x.0);
339 }
340
341 let values: Vec<T> = indexed[..k].iter().map(|(_, v)| *v).collect();
342 let indices: Vec<i64> = indexed[..k].iter().map(|(i, _)| *i as i64).collect();
343
344 return Ok(TopKResult {
345 values: Tensor::from_vec(values, &[k])?,
346 indices: Tensor::from_vec(indices, &[k])?,
347 });
348 }
349
350 let outer_size: usize = shape[..dim].iter().product();
352 let inner_size: usize = shape[dim + 1..].iter().product();
353
354 let mut values_data = Vec::with_capacity(outer_size * k * inner_size);
355 let mut indices_data = Vec::with_capacity(outer_size * k * inner_size);
356
357 for outer in 0..outer_size {
358 for inner in 0..inner_size {
359 let mut slice: Vec<(usize, T)> = (0..dim_size)
360 .map(|d| {
361 let idx = outer * dim_size * inner_size + d * inner_size + inner;
362 (d, data[idx])
363 })
364 .collect();
365
366 if largest {
367 slice.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
368 } else {
369 slice.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
370 }
371
372 if !sorted {
373 slice[..k].sort_by_key(|x| x.0);
374 }
375
376 for (orig_idx, val) in slice.into_iter().take(k) {
377 values_data.push(val);
378 indices_data.push(orig_idx as i64);
379 }
380 }
381 }
382
383 let mut output_shape = shape.to_vec();
384 output_shape[dim] = k;
385
386 Ok(TopKResult {
387 values: Tensor::from_vec(values_data, &output_shape)?,
388 indices: Tensor::from_vec(indices_data, &output_shape)?,
389 })
390}
391
392#[derive(Clone)]
394pub struct SortResult<T: Scalar> {
395 pub values: Tensor<T>,
397 pub indices: Tensor<i64>,
399}
400
401impl<T: Scalar> std::fmt::Debug for SortResult<T> {
402 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 f.debug_struct("SortResult")
404 .field("values_shape", &self.values.shape())
405 .field("indices_shape", &self.indices.shape())
406 .finish()
407 }
408}
409
410pub fn sort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<SortResult<T>> {
420 let shape = x.shape();
421 if shape.is_empty() {
422 return Ok(SortResult {
423 values: x.clone(),
424 indices: Tensor::scalar(0i64),
425 });
426 }
427
428 let dim = if dim < 0 {
429 (shape.len() as i64 + dim) as usize
430 } else {
431 dim as usize
432 };
433
434 let dim_size = shape[dim];
435 topk(x, dim_size, dim as i64, descending, true).map(|tk| SortResult {
436 values: tk.values,
437 indices: tk.indices,
438 })
439}
440
441pub fn argsort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<Tensor<i64>> {
448 sort(x, dim, descending).map(|r| r.indices)
449}
450
451pub fn scatter<T: Scalar>(
465 dst: &Tensor<T>,
466 dim: usize,
467 index: &Tensor<i64>,
468 src: &Tensor<T>,
469) -> Result<Tensor<T>> {
470 let dst_shape = dst.shape();
471 let idx_shape = index.shape();
472 let src_shape = src.shape();
473
474 if idx_shape != src_shape {
475 return Err(axonml_core::error::Error::shape_mismatch(
476 idx_shape, src_shape,
477 ));
478 }
479
480 if dim >= dst_shape.len() {
481 return Err(axonml_core::error::Error::invalid_operation(format!(
482 "Dimension {} out of range",
483 dim
484 )));
485 }
486
487 let mut result = dst.to_vec();
488 let idx_data = index.to_vec();
489 let src_data = src.to_vec();
490
491 let mut dst_strides = vec![1usize; dst_shape.len()];
493 for i in (0..dst_shape.len() - 1).rev() {
494 dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
495 }
496
497 let mut idx_strides = vec![1usize; idx_shape.len()];
499 for i in (0..idx_shape.len() - 1).rev() {
500 idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
501 }
502
503 let total = index.numel();
505 for linear_idx in 0..total {
506 let mut nd_idx = vec![0usize; idx_shape.len()];
508 let mut remaining = linear_idx;
509 for d in 0..idx_shape.len() {
510 nd_idx[d] = remaining / idx_strides[d];
511 remaining %= idx_strides[d];
512 }
513
514 let scatter_idx = idx_data[linear_idx] as usize;
516
517 let mut dst_nd_idx = nd_idx.clone();
519 dst_nd_idx[dim] = scatter_idx;
520
521 let mut dst_linear = 0;
523 for d in 0..dst_shape.len() {
524 dst_linear += dst_nd_idx[d] * dst_strides[d];
525 }
526
527 result[dst_linear] = src_data[linear_idx];
528 }
529
530 Tensor::from_vec(result, dst_shape)
531}
532
533pub fn nonzero<T: Numeric>(x: &Tensor<T>) -> Tensor<i64> {
545 let data = x.to_vec();
546 let shape = x.shape();
547 let ndim = shape.len();
548
549 let mut indices: Vec<Vec<i64>> = Vec::new();
551
552 let mut strides = vec![1usize; ndim.max(1)];
554 for i in (0..ndim.saturating_sub(1)).rev() {
555 strides[i] = strides[i + 1] * shape[i + 1];
556 }
557
558 for (linear_idx, &val) in data.iter().enumerate() {
559 if val != T::zero() {
560 let mut nd_idx = vec![0i64; ndim.max(1)];
561 let mut remaining = linear_idx;
562 for d in 0..ndim {
563 nd_idx[d] = (remaining / strides[d]) as i64;
564 remaining %= strides[d];
565 }
566 indices.push(nd_idx);
567 }
568 }
569
570 let num_nonzero = indices.len();
571 if num_nonzero == 0 {
572 return Tensor::from_vec(vec![], &[0, ndim.max(1)]).unwrap();
573 }
574
575 let flat: Vec<i64> = indices.into_iter().flatten().collect();
576 Tensor::from_vec(flat, &[num_nonzero, ndim.max(1)]).unwrap()
577}
578
579#[derive(Clone)]
585pub struct UniqueResult<T: Scalar> {
586 pub values: Tensor<T>,
588 pub inverse_indices: Option<Tensor<i64>>,
590 pub counts: Option<Tensor<i64>>,
592}
593
594impl<T: Scalar> std::fmt::Debug for UniqueResult<T> {
595 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596 f.debug_struct("UniqueResult")
597 .field("values_shape", &self.values.shape())
598 .field("has_inverse", &self.inverse_indices.is_some())
599 .field("has_counts", &self.counts.is_some())
600 .finish()
601 }
602}
603
604pub fn unique<T: Numeric>(
612 x: &Tensor<T>,
613 sorted: bool,
614 return_inverse: bool,
615 return_counts: bool,
616) -> UniqueResult<T> {
617 let data = x.to_vec();
618
619 let mut seen: Vec<T> = Vec::new();
621 let mut counts_map: Vec<i64> = Vec::new();
622 let mut inverse: Vec<i64> = Vec::with_capacity(data.len());
623
624 for &val in &data {
625 if let Some(pos) = seen.iter().position(|&v| v == val) {
626 inverse.push(pos as i64);
627 counts_map[pos] += 1;
628 } else {
629 inverse.push(seen.len() as i64);
630 seen.push(val);
631 counts_map.push(1);
632 }
633 }
634
635 let (unique_vals, final_inverse, final_counts) = if sorted {
636 let mut indexed: Vec<(usize, T)> = seen.into_iter().enumerate().collect();
638 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
639
640 let mut old_to_new = vec![0i64; indexed.len()];
642 for (new_idx, (old_idx, _)) in indexed.iter().enumerate() {
643 old_to_new[*old_idx] = new_idx as i64;
644 }
645
646 let sorted_vals: Vec<T> = indexed.iter().map(|(_, v)| *v).collect();
647 let sorted_counts: Vec<i64> = indexed
648 .iter()
649 .map(|(old_idx, _)| counts_map[*old_idx])
650 .collect();
651 let updated_inverse: Vec<i64> = inverse.iter().map(|&i| old_to_new[i as usize]).collect();
652
653 (sorted_vals, updated_inverse, sorted_counts)
654 } else {
655 (seen, inverse, counts_map)
656 };
657
658 let n = unique_vals.len();
659
660 UniqueResult {
661 values: Tensor::from_vec(unique_vals, &[n]).unwrap(),
662 inverse_indices: if return_inverse {
663 Some(Tensor::from_vec(final_inverse, x.shape()).unwrap())
664 } else {
665 None
666 },
667 counts: if return_counts {
668 Some(Tensor::from_vec(final_counts, &[n]).unwrap())
669 } else {
670 None
671 },
672 }
673}
674
675pub fn flip<T: Numeric>(x: &Tensor<T>, dims: &[usize]) -> Result<Tensor<T>> {
681 let shape = x.shape();
682 let data = x.to_vec();
683 let ndim = shape.len();
684
685 for &d in dims {
686 if d >= ndim {
687 return Err(axonml_core::error::Error::invalid_operation(format!(
688 "Dimension {} out of range for tensor with {} dimensions",
689 d, ndim
690 )));
691 }
692 }
693
694 if shape.is_empty() {
695 return Ok(x.clone());
696 }
697
698 let mut strides = vec![1usize; ndim];
700 for i in (0..ndim - 1).rev() {
701 strides[i] = strides[i + 1] * shape[i + 1];
702 }
703
704 let mut result = vec![T::zero(); data.len()];
705
706 for src_linear in 0..data.len() {
707 let mut nd_idx = vec![0usize; ndim];
709 let mut remaining = src_linear;
710 for d in 0..ndim {
711 nd_idx[d] = remaining / strides[d];
712 remaining %= strides[d];
713 }
714
715 for &flip_dim in dims {
717 nd_idx[flip_dim] = shape[flip_dim] - 1 - nd_idx[flip_dim];
718 }
719
720 let mut dst_linear = 0;
722 for d in 0..ndim {
723 dst_linear += nd_idx[d] * strides[d];
724 }
725
726 result[dst_linear] = data[src_linear];
727 }
728
729 Tensor::from_vec(result, shape)
730}
731
732pub fn roll<T: Numeric>(x: &Tensor<T>, shifts: &[i64], dims: &[usize]) -> Result<Tensor<T>> {
738 if shifts.len() != dims.len() {
739 return Err(axonml_core::error::Error::invalid_operation(
740 "shifts and dims must have the same length".to_string(),
741 ));
742 }
743
744 let shape = x.shape();
745 let data = x.to_vec();
746 let ndim = shape.len();
747
748 for &d in dims {
749 if d >= ndim {
750 return Err(axonml_core::error::Error::invalid_operation(format!(
751 "Dimension {} out of range",
752 d
753 )));
754 }
755 }
756
757 if shape.is_empty() {
758 return Ok(x.clone());
759 }
760
761 let mut strides = vec![1usize; ndim];
763 for i in (0..ndim - 1).rev() {
764 strides[i] = strides[i + 1] * shape[i + 1];
765 }
766
767 let mut result = vec![T::zero(); data.len()];
768
769 for src_linear in 0..data.len() {
770 let mut nd_idx = vec![0usize; ndim];
772 let mut remaining = src_linear;
773 for d in 0..ndim {
774 nd_idx[d] = remaining / strides[d];
775 remaining %= strides[d];
776 }
777
778 for (shift, &dim) in shifts.iter().zip(dims.iter()) {
780 let dim_size = shape[dim] as i64;
781 let new_idx = ((nd_idx[dim] as i64 + shift) % dim_size + dim_size) % dim_size;
782 nd_idx[dim] = new_idx as usize;
783 }
784
785 let mut dst_linear = 0;
787 for d in 0..ndim {
788 dst_linear += nd_idx[d] * strides[d];
789 }
790
791 result[dst_linear] = data[src_linear];
792 }
793
794 Tensor::from_vec(result, shape)
795}
796
797#[cfg(test)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn test_softmax() {
807 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
808 let s = softmax(&t, -1).unwrap();
809
810 let sum: f32 = s.to_vec().iter().sum();
811 assert!((sum - 1.0).abs() < 1e-5);
812 }
813
814 #[test]
815 fn test_clamp() {
816 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
817 let c = clamp(&t, 0.0, 1.0);
818 assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
819 }
820
821 #[test]
822 fn test_leaky_relu() {
823 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
824 let r = leaky_relu(&t, 0.01);
825 assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
826 }
827
828 #[test]
829 fn test_comparison() {
830 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
831 let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
832
833 assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
834 assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
835 assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
836 }
837
838 #[test]
839 fn test_topk() {
840 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
841 let result = topk(&t, 3, -1, true, true).unwrap();
842
843 assert_eq!(result.values.shape(), &[3]);
844 assert_eq!(result.values.to_vec(), vec![9.0, 5.0, 4.0]);
845 assert_eq!(result.indices.to_vec(), vec![5, 4, 2]);
846 }
847
848 #[test]
849 fn test_topk_smallest() {
850 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
851 let result = topk(&t, 2, -1, false, true).unwrap();
852
853 assert_eq!(result.values.to_vec(), vec![1.0, 1.0]);
854 }
855
856 #[test]
857 fn test_sort() {
858 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
859 let result = sort(&t, -1, false).unwrap();
860
861 assert_eq!(result.values.to_vec(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
862 }
863
864 #[test]
865 fn test_sort_descending() {
866 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0], &[3]).unwrap();
867 let result = sort(&t, -1, true).unwrap();
868
869 assert_eq!(result.values.to_vec(), vec![4.0, 3.0, 1.0]);
870 }
871
872 #[test]
873 fn test_argsort() {
874 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 2.0], &[3]).unwrap();
875 let indices = argsort(&t, -1, false).unwrap();
876
877 assert_eq!(indices.to_vec(), vec![1, 2, 0]);
878 }
879
880 #[test]
881 fn test_nonzero() {
882 let t = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 2.0, 3.0, 0.0], &[6]).unwrap();
883 let result = nonzero(&t);
884
885 assert_eq!(result.shape(), &[3, 1]);
886 assert_eq!(result.to_vec(), vec![1, 3, 4]);
887 }
888
889 #[test]
890 fn test_nonzero_2d() {
891 let t = Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
892 let result = nonzero(&t);
893
894 assert_eq!(result.shape(), &[2, 2]);
895 assert_eq!(result.to_vec(), vec![0, 0, 1, 1]);
897 }
898
899 #[test]
900 fn test_unique() {
901 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], &[6]).unwrap();
902 let result = unique(&t, true, true, true);
903
904 assert_eq!(result.values.to_vec(), vec![1.0, 2.0, 3.0]);
905 assert_eq!(
906 result.inverse_indices.unwrap().to_vec(),
907 vec![0, 1, 0, 2, 1, 0]
908 );
909 assert_eq!(result.counts.unwrap().to_vec(), vec![3, 2, 1]);
910 }
911
912 #[test]
913 fn test_unique_unsorted() {
914 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 3.0], &[3]).unwrap();
915 let result = unique(&t, false, false, false);
916
917 assert_eq!(result.values.to_vec(), vec![3.0, 1.0]);
919 }
920
921 #[test]
922 fn test_flip() {
923 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
924 let flipped = flip(&t, &[0]).unwrap();
925
926 assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
927 }
928
929 #[test]
930 fn test_flip_2d() {
931 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
932 let flipped = flip(&t, &[0]).unwrap();
933
934 assert_eq!(flipped.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
936 }
937
938 #[test]
939 fn test_roll() {
940 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
941 let rolled = roll(&t, &[1], &[0]).unwrap();
942
943 assert_eq!(rolled.to_vec(), vec![4.0, 1.0, 2.0, 3.0]);
944 }
945
946 #[test]
947 fn test_roll_negative() {
948 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
949 let rolled = roll(&t, &[-1], &[0]).unwrap();
950
951 assert_eq!(rolled.to_vec(), vec![2.0, 3.0, 4.0, 1.0]);
952 }
953
954 #[test]
955 fn test_scatter() {
956 let dst = Tensor::<f32>::zeros(&[3]);
957 let index = Tensor::from_vec(vec![0_i64, 2], &[2]).unwrap();
958 let src = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
959
960 let result = scatter(&dst, 0, &index, &src).unwrap();
961 assert_eq!(result.to_vec(), vec![1.0, 0.0, 2.0]);
962 }
963}