1use axonml_core::dtype::{Float, Numeric, Scalar};
21use axonml_core::error::Result;
22
23use crate::tensor::Tensor;
24
25pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
31 if a.shape() != b.shape() {
32 return Err(axonml_core::error::Error::shape_mismatch(
33 a.shape(),
34 b.shape(),
35 ));
36 }
37
38 let a_data = a.to_vec();
39 let b_data = b.to_vec();
40
41 Ok(a_data
42 .iter()
43 .zip(b_data.iter())
44 .map(|(x, y)| x == y)
45 .collect())
46}
47
48pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
50 if a.shape() != b.shape() {
51 return Err(axonml_core::error::Error::shape_mismatch(
52 a.shape(),
53 b.shape(),
54 ));
55 }
56
57 let a_data = a.to_vec();
58 let b_data = b.to_vec();
59
60 Ok(a_data
61 .iter()
62 .zip(b_data.iter())
63 .map(|(x, y)| x < y)
64 .collect())
65}
66
67pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
69 if a.shape() != b.shape() {
70 return Err(axonml_core::error::Error::shape_mismatch(
71 a.shape(),
72 b.shape(),
73 ));
74 }
75
76 let a_data = a.to_vec();
77 let b_data = b.to_vec();
78
79 Ok(a_data
80 .iter()
81 .zip(b_data.iter())
82 .map(|(x, y)| x > y)
83 .collect())
84}
85
86pub fn eq_mask<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
88 let bools = eq(a, b)?;
89 let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
90 Tensor::from_vec(data, a.shape())
91}
92
93pub fn lt_mask<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
95 let bools = lt(a, b)?;
96 let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
97 Tensor::from_vec(data, a.shape())
98}
99
100pub fn gt_mask<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
102 let bools = gt(a, b)?;
103 let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
104 Tensor::from_vec(data, a.shape())
105}
106
107pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
113 let data = x.to_vec();
115 let shape = x.shape();
116
117 if shape.is_empty() {
118 return Ok(Tensor::scalar(T::one()));
119 }
120
121 let max_val = data
123 .iter()
124 .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
125
126 let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
128
129 let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
131
132 let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
134
135 Tensor::from_vec(result, shape)
136}
137
138pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
140 let sm = softmax(x, dim)?;
141 Ok(sm.ln())
142}
143
144#[must_use]
146pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
147 let data = x.to_vec();
148 let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
149 let coeff = T::from(0.044715).unwrap();
150
151 let result: Vec<T> = data
152 .iter()
153 .map(|&v| {
154 let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
155 v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
156 })
157 .collect();
158
159 Tensor::from_vec(result, x.shape()).unwrap()
160}
161
162pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
164 let data = x.to_vec();
165 let result: Vec<T> = data
166 .iter()
167 .map(|&v| if v > T::zero() { v } else { negative_slope * v })
168 .collect();
169
170 Tensor::from_vec(result, x.shape()).unwrap()
171}
172
173pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
175 let data = x.to_vec();
176 let result: Vec<T> = data
177 .iter()
178 .map(|&v| {
179 if v > T::zero() {
180 v
181 } else {
182 alpha * (v.exp_value() - T::one())
183 }
184 })
185 .collect();
186
187 Tensor::from_vec(result, x.shape()).unwrap()
188}
189
190#[must_use]
192pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
193 let sig = x.sigmoid();
194 x.mul(&sig).expect("tensor mul failed")
195}
196
197pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
203 let data = x.to_vec();
204 let result: Vec<T> = data
205 .iter()
206 .map(|&v| {
207 if v < min {
208 min
209 } else if v > max {
210 max
211 } else {
212 v
213 }
214 })
215 .collect();
216
217 Tensor::from_vec(result, x.shape()).unwrap()
218}
219
220pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
222 let data = x.to_vec();
223 let result: Vec<T> = data
224 .iter()
225 .map(|&v| if v < min { min } else { v })
226 .collect();
227
228 Tensor::from_vec(result, x.shape()).unwrap()
229}
230
231pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
233 let data = x.to_vec();
234 let result: Vec<T> = data
235 .iter()
236 .map(|&v| if v > max { max } else { v })
237 .collect();
238
239 Tensor::from_vec(result, x.shape()).unwrap()
240}
241
242pub fn where_cond<T: Scalar>(
248 condition: &[bool],
249 x: &Tensor<T>,
250 y: &Tensor<T>,
251) -> Result<Tensor<T>> {
252 if x.shape() != y.shape() {
253 return Err(axonml_core::error::Error::shape_mismatch(
254 x.shape(),
255 y.shape(),
256 ));
257 }
258
259 if condition.len() != x.numel() {
260 return Err(axonml_core::error::Error::shape_mismatch(
261 &[condition.len()],
262 &[x.numel()],
263 ));
264 }
265
266 let x_data = x.to_vec();
267 let y_data = y.to_vec();
268
269 let result: Vec<T> = condition
270 .iter()
271 .zip(x_data.iter().zip(y_data.iter()))
272 .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
273 .collect();
274
275 Tensor::from_vec(result, x.shape())
276}
277
278#[derive(Clone)]
284pub struct TopKResult<T: Scalar> {
285 pub values: Tensor<T>,
287 pub indices: Tensor<i64>,
289}
290
291impl<T: Scalar> std::fmt::Debug for TopKResult<T> {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 f.debug_struct("TopKResult")
294 .field("values_shape", &self.values.shape())
295 .field("indices_shape", &self.indices.shape())
296 .finish()
297 }
298}
299
300pub fn topk<T: Numeric>(
312 x: &Tensor<T>,
313 k: usize,
314 dim: i64,
315 largest: bool,
316 sorted: bool,
317) -> Result<TopKResult<T>> {
318 let shape = x.shape();
319 if shape.is_empty() {
320 return Err(axonml_core::error::Error::invalid_operation(
321 "Cannot apply topk to scalar tensor".to_string(),
322 ));
323 }
324
325 let dim = if dim < 0 {
326 (shape.len() as i64 + dim) as usize
327 } else {
328 dim as usize
329 };
330
331 if dim >= shape.len() {
332 return Err(axonml_core::error::Error::invalid_operation(format!(
333 "Dimension {} out of range for tensor with {} dimensions",
334 dim,
335 shape.len()
336 )));
337 }
338
339 let dim_size = shape[dim];
340 if k > dim_size {
341 return Err(axonml_core::error::Error::invalid_operation(format!(
342 "k ({}) is larger than dimension size ({})",
343 k, dim_size
344 )));
345 }
346
347 let data = x.to_vec();
348
349 if shape.len() == 1 {
351 let mut indexed: Vec<(usize, T)> = data.into_iter().enumerate().collect();
352 if largest {
353 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
354 } else {
355 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
356 }
357
358 if !sorted {
359 indexed[..k].sort_by_key(|x| x.0);
360 }
361
362 let values: Vec<T> = indexed[..k].iter().map(|(_, v)| *v).collect();
363 let indices: Vec<i64> = indexed[..k].iter().map(|(i, _)| *i as i64).collect();
364
365 return Ok(TopKResult {
366 values: Tensor::from_vec(values, &[k])?,
367 indices: Tensor::from_vec(indices, &[k])?,
368 });
369 }
370
371 let outer_size: usize = shape[..dim].iter().product();
373 let inner_size: usize = shape[dim + 1..].iter().product();
374
375 let mut values_data = Vec::with_capacity(outer_size * k * inner_size);
376 let mut indices_data = Vec::with_capacity(outer_size * k * inner_size);
377
378 for outer in 0..outer_size {
379 for inner in 0..inner_size {
380 let mut slice: Vec<(usize, T)> = (0..dim_size)
381 .map(|d| {
382 let idx = outer * dim_size * inner_size + d * inner_size + inner;
383 (d, data[idx])
384 })
385 .collect();
386
387 if largest {
388 slice.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
389 } else {
390 slice.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
391 }
392
393 if !sorted {
394 slice[..k].sort_by_key(|x| x.0);
395 }
396
397 for (orig_idx, val) in slice.into_iter().take(k) {
398 values_data.push(val);
399 indices_data.push(orig_idx as i64);
400 }
401 }
402 }
403
404 let mut output_shape = shape.to_vec();
405 output_shape[dim] = k;
406
407 Ok(TopKResult {
408 values: Tensor::from_vec(values_data, &output_shape)?,
409 indices: Tensor::from_vec(indices_data, &output_shape)?,
410 })
411}
412
413#[derive(Clone)]
415pub struct SortResult<T: Scalar> {
416 pub values: Tensor<T>,
418 pub indices: Tensor<i64>,
420}
421
422impl<T: Scalar> std::fmt::Debug for SortResult<T> {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 f.debug_struct("SortResult")
425 .field("values_shape", &self.values.shape())
426 .field("indices_shape", &self.indices.shape())
427 .finish()
428 }
429}
430
431pub fn sort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<SortResult<T>> {
441 let shape = x.shape();
442 if shape.is_empty() {
443 return Ok(SortResult {
444 values: x.clone(),
445 indices: Tensor::scalar(0i64),
446 });
447 }
448
449 let dim = if dim < 0 {
450 (shape.len() as i64 + dim) as usize
451 } else {
452 dim as usize
453 };
454
455 let dim_size = shape[dim];
456 topk(x, dim_size, dim as i64, descending, true).map(|tk| SortResult {
457 values: tk.values,
458 indices: tk.indices,
459 })
460}
461
462pub fn argsort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<Tensor<i64>> {
469 sort(x, dim, descending).map(|r| r.indices)
470}
471
472pub fn scatter<T: Scalar>(
486 dst: &Tensor<T>,
487 dim: usize,
488 index: &Tensor<i64>,
489 src: &Tensor<T>,
490) -> Result<Tensor<T>> {
491 let dst_shape = dst.shape();
492 let idx_shape = index.shape();
493 let src_shape = src.shape();
494
495 if idx_shape != src_shape {
496 return Err(axonml_core::error::Error::shape_mismatch(
497 idx_shape, src_shape,
498 ));
499 }
500
501 if dim >= dst_shape.len() {
502 return Err(axonml_core::error::Error::invalid_operation(format!(
503 "Dimension {} out of range",
504 dim
505 )));
506 }
507
508 let mut result = dst.to_vec();
509 let idx_data = index.to_vec();
510 let src_data = src.to_vec();
511
512 let mut dst_strides = vec![1usize; dst_shape.len()];
514 for i in (0..dst_shape.len() - 1).rev() {
515 dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
516 }
517
518 let mut idx_strides = vec![1usize; idx_shape.len()];
520 for i in (0..idx_shape.len() - 1).rev() {
521 idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
522 }
523
524 let total = index.numel();
526 for linear_idx in 0..total {
527 let mut nd_idx = vec![0usize; idx_shape.len()];
529 let mut remaining = linear_idx;
530 for d in 0..idx_shape.len() {
531 nd_idx[d] = remaining / idx_strides[d];
532 remaining %= idx_strides[d];
533 }
534
535 let scatter_idx = idx_data[linear_idx] as usize;
537
538 let mut dst_nd_idx = nd_idx.clone();
540 dst_nd_idx[dim] = scatter_idx;
541
542 let mut dst_linear = 0;
544 for d in 0..dst_shape.len() {
545 dst_linear += dst_nd_idx[d] * dst_strides[d];
546 }
547
548 result[dst_linear] = src_data[linear_idx];
549 }
550
551 Tensor::from_vec(result, dst_shape)
552}
553
554pub fn nonzero<T: Numeric>(x: &Tensor<T>) -> Tensor<i64> {
566 let data = x.to_vec();
567 let shape = x.shape();
568 let ndim = shape.len();
569
570 let mut indices: Vec<Vec<i64>> = Vec::new();
572
573 let mut strides = vec![1usize; ndim.max(1)];
575 for i in (0..ndim.saturating_sub(1)).rev() {
576 strides[i] = strides[i + 1] * shape[i + 1];
577 }
578
579 for (linear_idx, &val) in data.iter().enumerate() {
580 if val != T::zero() {
581 let mut nd_idx = vec![0i64; ndim.max(1)];
582 let mut remaining = linear_idx;
583 for d in 0..ndim {
584 nd_idx[d] = (remaining / strides[d]) as i64;
585 remaining %= strides[d];
586 }
587 indices.push(nd_idx);
588 }
589 }
590
591 let num_nonzero = indices.len();
592 if num_nonzero == 0 {
593 return Tensor::from_vec(vec![], &[0, ndim.max(1)]).unwrap();
594 }
595
596 let flat: Vec<i64> = indices.into_iter().flatten().collect();
597 Tensor::from_vec(flat, &[num_nonzero, ndim.max(1)]).unwrap()
598}
599
600#[derive(Clone)]
606pub struct UniqueResult<T: Scalar> {
607 pub values: Tensor<T>,
609 pub inverse_indices: Option<Tensor<i64>>,
611 pub counts: Option<Tensor<i64>>,
613}
614
615impl<T: Scalar> std::fmt::Debug for UniqueResult<T> {
616 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617 f.debug_struct("UniqueResult")
618 .field("values_shape", &self.values.shape())
619 .field("has_inverse", &self.inverse_indices.is_some())
620 .field("has_counts", &self.counts.is_some())
621 .finish()
622 }
623}
624
625pub fn unique<T: Numeric>(
633 x: &Tensor<T>,
634 sorted: bool,
635 return_inverse: bool,
636 return_counts: bool,
637) -> UniqueResult<T> {
638 let data = x.to_vec();
639
640 let mut seen: Vec<T> = Vec::new();
642 let mut counts_map: Vec<i64> = Vec::new();
643 let mut inverse: Vec<i64> = Vec::with_capacity(data.len());
644
645 for &val in &data {
646 if let Some(pos) = seen.iter().position(|&v| v == val) {
647 inverse.push(pos as i64);
648 counts_map[pos] += 1;
649 } else {
650 inverse.push(seen.len() as i64);
651 seen.push(val);
652 counts_map.push(1);
653 }
654 }
655
656 let (unique_vals, final_inverse, final_counts) = if sorted {
657 let mut indexed: Vec<(usize, T)> = seen.into_iter().enumerate().collect();
659 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
660
661 let mut old_to_new = vec![0i64; indexed.len()];
663 for (new_idx, (old_idx, _)) in indexed.iter().enumerate() {
664 old_to_new[*old_idx] = new_idx as i64;
665 }
666
667 let sorted_vals: Vec<T> = indexed.iter().map(|(_, v)| *v).collect();
668 let sorted_counts: Vec<i64> = indexed
669 .iter()
670 .map(|(old_idx, _)| counts_map[*old_idx])
671 .collect();
672 let updated_inverse: Vec<i64> = inverse.iter().map(|&i| old_to_new[i as usize]).collect();
673
674 (sorted_vals, updated_inverse, sorted_counts)
675 } else {
676 (seen, inverse, counts_map)
677 };
678
679 let n = unique_vals.len();
680
681 UniqueResult {
682 values: Tensor::from_vec(unique_vals, &[n]).expect("tensor creation failed"),
683 inverse_indices: if return_inverse {
684 Some(Tensor::from_vec(final_inverse, x.shape()).unwrap())
685 } else {
686 None
687 },
688 counts: if return_counts {
689 Some(Tensor::from_vec(final_counts, &[n]).expect("tensor creation failed"))
690 } else {
691 None
692 },
693 }
694}
695
696pub fn flip<T: Numeric>(x: &Tensor<T>, dims: &[usize]) -> Result<Tensor<T>> {
702 let shape = x.shape();
703 let data = x.to_vec();
704 let ndim = shape.len();
705
706 for &d in dims {
707 if d >= ndim {
708 return Err(axonml_core::error::Error::invalid_operation(format!(
709 "Dimension {} out of range for tensor with {} dimensions",
710 d, ndim
711 )));
712 }
713 }
714
715 if shape.is_empty() {
716 return Ok(x.clone());
717 }
718
719 let mut strides = vec![1usize; ndim];
721 for i in (0..ndim - 1).rev() {
722 strides[i] = strides[i + 1] * shape[i + 1];
723 }
724
725 let mut result = vec![T::zero(); data.len()];
726
727 for src_linear in 0..data.len() {
728 let mut nd_idx = vec![0usize; ndim];
730 let mut remaining = src_linear;
731 for d in 0..ndim {
732 nd_idx[d] = remaining / strides[d];
733 remaining %= strides[d];
734 }
735
736 for &flip_dim in dims {
738 nd_idx[flip_dim] = shape[flip_dim] - 1 - nd_idx[flip_dim];
739 }
740
741 let mut dst_linear = 0;
743 for d in 0..ndim {
744 dst_linear += nd_idx[d] * strides[d];
745 }
746
747 result[dst_linear] = data[src_linear];
748 }
749
750 Tensor::from_vec(result, shape)
751}
752
753pub fn roll<T: Numeric>(x: &Tensor<T>, shifts: &[i64], dims: &[usize]) -> Result<Tensor<T>> {
759 if shifts.len() != dims.len() {
760 return Err(axonml_core::error::Error::invalid_operation(
761 "shifts and dims must have the same length".to_string(),
762 ));
763 }
764
765 let shape = x.shape();
766 let data = x.to_vec();
767 let ndim = shape.len();
768
769 for &d in dims {
770 if d >= ndim {
771 return Err(axonml_core::error::Error::invalid_operation(format!(
772 "Dimension {} out of range",
773 d
774 )));
775 }
776 }
777
778 if shape.is_empty() {
779 return Ok(x.clone());
780 }
781
782 let mut strides = vec![1usize; ndim];
784 for i in (0..ndim - 1).rev() {
785 strides[i] = strides[i + 1] * shape[i + 1];
786 }
787
788 let mut result = vec![T::zero(); data.len()];
789
790 for src_linear in 0..data.len() {
791 let mut nd_idx = vec![0usize; ndim];
793 let mut remaining = src_linear;
794 for d in 0..ndim {
795 nd_idx[d] = remaining / strides[d];
796 remaining %= strides[d];
797 }
798
799 for (shift, &dim) in shifts.iter().zip(dims.iter()) {
801 let dim_size = shape[dim] as i64;
802 let new_idx = ((nd_idx[dim] as i64 + shift) % dim_size + dim_size) % dim_size;
803 nd_idx[dim] = new_idx as usize;
804 }
805
806 let mut dst_linear = 0;
808 for d in 0..ndim {
809 dst_linear += nd_idx[d] * strides[d];
810 }
811
812 result[dst_linear] = data[src_linear];
813 }
814
815 Tensor::from_vec(result, shape)
816}
817
818#[cfg(test)]
823mod tests {
824 use super::*;
825
826 #[test]
827 fn test_softmax() {
828 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
829 let s = softmax(&t, -1).unwrap();
830
831 let sum: f32 = s.to_vec().iter().sum();
832 assert!((sum - 1.0).abs() < 1e-5);
833 }
834
835 #[test]
836 fn test_clamp() {
837 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
838 let c = clamp(&t, 0.0, 1.0);
839 assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
840 }
841
842 #[test]
843 fn test_leaky_relu() {
844 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
845 let r = leaky_relu(&t, 0.01);
846 assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
847 }
848
849 #[test]
850 fn test_comparison() {
851 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
852 let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
853
854 assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
855 assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
856 assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
857 }
858
859 #[test]
860 fn test_topk() {
861 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
862 let result = topk(&t, 3, -1, true, true).unwrap();
863
864 assert_eq!(result.values.shape(), &[3]);
865 assert_eq!(result.values.to_vec(), vec![9.0, 5.0, 4.0]);
866 assert_eq!(result.indices.to_vec(), vec![5, 4, 2]);
867 }
868
869 #[test]
870 fn test_topk_smallest() {
871 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
872 let result = topk(&t, 2, -1, false, true).unwrap();
873
874 assert_eq!(result.values.to_vec(), vec![1.0, 1.0]);
875 }
876
877 #[test]
878 fn test_sort() {
879 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
880 let result = sort(&t, -1, false).unwrap();
881
882 assert_eq!(result.values.to_vec(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
883 }
884
885 #[test]
886 fn test_sort_descending() {
887 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0], &[3]).unwrap();
888 let result = sort(&t, -1, true).unwrap();
889
890 assert_eq!(result.values.to_vec(), vec![4.0, 3.0, 1.0]);
891 }
892
893 #[test]
894 fn test_argsort() {
895 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 2.0], &[3]).unwrap();
896 let indices = argsort(&t, -1, false).unwrap();
897
898 assert_eq!(indices.to_vec(), vec![1, 2, 0]);
899 }
900
901 #[test]
902 fn test_nonzero() {
903 let t = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 2.0, 3.0, 0.0], &[6]).unwrap();
904 let result = nonzero(&t);
905
906 assert_eq!(result.shape(), &[3, 1]);
907 assert_eq!(result.to_vec(), vec![1, 3, 4]);
908 }
909
910 #[test]
911 fn test_nonzero_2d() {
912 let t = Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
913 let result = nonzero(&t);
914
915 assert_eq!(result.shape(), &[2, 2]);
916 assert_eq!(result.to_vec(), vec![0, 0, 1, 1]);
918 }
919
920 #[test]
921 fn test_unique() {
922 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], &[6]).unwrap();
923 let result = unique(&t, true, true, true);
924
925 assert_eq!(result.values.to_vec(), vec![1.0, 2.0, 3.0]);
926 assert_eq!(
927 result.inverse_indices.unwrap().to_vec(),
928 vec![0, 1, 0, 2, 1, 0]
929 );
930 assert_eq!(result.counts.unwrap().to_vec(), vec![3, 2, 1]);
931 }
932
933 #[test]
934 fn test_unique_unsorted() {
935 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 3.0], &[3]).unwrap();
936 let result = unique(&t, false, false, false);
937
938 assert_eq!(result.values.to_vec(), vec![3.0, 1.0]);
940 }
941
942 #[test]
943 fn test_flip() {
944 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
945 let flipped = flip(&t, &[0]).unwrap();
946
947 assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
948 }
949
950 #[test]
951 fn test_flip_2d() {
952 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
953 let flipped = flip(&t, &[0]).unwrap();
954
955 assert_eq!(flipped.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
957 }
958
959 #[test]
960 fn test_roll() {
961 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
962 let rolled = roll(&t, &[1], &[0]).unwrap();
963
964 assert_eq!(rolled.to_vec(), vec![4.0, 1.0, 2.0, 3.0]);
965 }
966
967 #[test]
968 fn test_roll_negative() {
969 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
970 let rolled = roll(&t, &[-1], &[0]).unwrap();
971
972 assert_eq!(rolled.to_vec(), vec![2.0, 3.0, 4.0, 1.0]);
973 }
974
975 #[test]
976 fn test_scatter() {
977 let dst = Tensor::<f32>::zeros(&[3]);
978 let index = Tensor::from_vec(vec![0_i64, 2], &[2]).expect("tensor creation failed");
979 let src = Tensor::from_vec(vec![1.0, 2.0], &[2]).expect("tensor creation failed");
980
981 let result = scatter(&dst, 0, &index, &src).unwrap();
982 assert_eq!(result.to_vec(), vec![1.0, 0.0, 2.0]);
983 }
984}