1use axonml_core::dtype::{Float, Numeric, Scalar};
30use axonml_core::error::Result;
31
32use crate::tensor::Tensor;
33
34pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
40 if a.shape() != b.shape() {
41 return Err(axonml_core::error::Error::shape_mismatch(
42 a.shape(),
43 b.shape(),
44 ));
45 }
46
47 let a_data = a.to_vec();
48 let b_data = b.to_vec();
49
50 Ok(a_data
51 .iter()
52 .zip(b_data.iter())
53 .map(|(x, y)| x == y)
54 .collect())
55}
56
57pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
59 if a.shape() != b.shape() {
60 return Err(axonml_core::error::Error::shape_mismatch(
61 a.shape(),
62 b.shape(),
63 ));
64 }
65
66 let a_data = a.to_vec();
67 let b_data = b.to_vec();
68
69 Ok(a_data
70 .iter()
71 .zip(b_data.iter())
72 .map(|(x, y)| x < y)
73 .collect())
74}
75
76pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
78 if a.shape() != b.shape() {
79 return Err(axonml_core::error::Error::shape_mismatch(
80 a.shape(),
81 b.shape(),
82 ));
83 }
84
85 let a_data = a.to_vec();
86 let b_data = b.to_vec();
87
88 Ok(a_data
89 .iter()
90 .zip(b_data.iter())
91 .map(|(x, y)| x > y)
92 .collect())
93}
94
95pub fn eq_mask<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
97 let bools = eq(a, b)?;
98 let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
99 Tensor::from_vec(data, a.shape())
100}
101
102pub fn lt_mask<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
104 let bools = lt(a, b)?;
105 let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
106 Tensor::from_vec(data, a.shape())
107}
108
109pub fn gt_mask<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
111 let bools = gt(a, b)?;
112 let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
113 Tensor::from_vec(data, a.shape())
114}
115
116pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
122 let data = x.to_vec();
124 let shape = x.shape();
125
126 if shape.is_empty() {
127 return Ok(Tensor::scalar(T::one()));
128 }
129
130 let max_val = data
132 .iter()
133 .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
134
135 let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
137
138 let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
140
141 let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
143
144 Tensor::from_vec(result, shape)
145}
146
147pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
149 let sm = softmax(x, dim)?;
150 Ok(sm.ln())
151}
152
153#[must_use]
155pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
156 let data = x.to_vec();
157 let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
158 let coeff = T::from(0.044715).unwrap();
159
160 let result: Vec<T> = data
161 .iter()
162 .map(|&v| {
163 let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
164 v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
165 })
166 .collect();
167
168 Tensor::from_vec(result, x.shape()).unwrap()
169}
170
171pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
173 let data = x.to_vec();
174 let result: Vec<T> = data
175 .iter()
176 .map(|&v| if v > T::zero() { v } else { negative_slope * v })
177 .collect();
178
179 Tensor::from_vec(result, x.shape()).unwrap()
180}
181
182pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
184 let data = x.to_vec();
185 let result: Vec<T> = data
186 .iter()
187 .map(|&v| {
188 if v > T::zero() {
189 v
190 } else {
191 alpha * (v.exp_value() - T::one())
192 }
193 })
194 .collect();
195
196 Tensor::from_vec(result, x.shape()).unwrap()
197}
198
199#[must_use]
201pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
202 let sig = x.sigmoid();
203 x.mul(&sig).expect("tensor mul failed")
204}
205
206pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
212 let data = x.to_vec();
213 let result: Vec<T> = data
214 .iter()
215 .map(|&v| {
216 if v < min {
217 min
218 } else if v > max {
219 max
220 } else {
221 v
222 }
223 })
224 .collect();
225
226 Tensor::from_vec(result, x.shape()).unwrap()
227}
228
229pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
231 let data = x.to_vec();
232 let result: Vec<T> = data
233 .iter()
234 .map(|&v| if v < min { min } else { v })
235 .collect();
236
237 Tensor::from_vec(result, x.shape()).unwrap()
238}
239
240pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
242 let data = x.to_vec();
243 let result: Vec<T> = data
244 .iter()
245 .map(|&v| if v > max { max } else { v })
246 .collect();
247
248 Tensor::from_vec(result, x.shape()).unwrap()
249}
250
251pub fn where_cond<T: Scalar>(
257 condition: &[bool],
258 x: &Tensor<T>,
259 y: &Tensor<T>,
260) -> Result<Tensor<T>> {
261 if x.shape() != y.shape() {
262 return Err(axonml_core::error::Error::shape_mismatch(
263 x.shape(),
264 y.shape(),
265 ));
266 }
267
268 if condition.len() != x.numel() {
269 return Err(axonml_core::error::Error::shape_mismatch(
270 &[condition.len()],
271 &[x.numel()],
272 ));
273 }
274
275 let x_data = x.to_vec();
276 let y_data = y.to_vec();
277
278 let result: Vec<T> = condition
279 .iter()
280 .zip(x_data.iter().zip(y_data.iter()))
281 .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
282 .collect();
283
284 Tensor::from_vec(result, x.shape())
285}
286
287#[derive(Clone)]
293pub struct TopKResult<T: Scalar> {
294 pub values: Tensor<T>,
296 pub indices: Tensor<i64>,
298}
299
300impl<T: Scalar> std::fmt::Debug for TopKResult<T> {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.debug_struct("TopKResult")
303 .field("values_shape", &self.values.shape())
304 .field("indices_shape", &self.indices.shape())
305 .finish()
306 }
307}
308
309pub fn topk<T: Numeric>(
321 x: &Tensor<T>,
322 k: usize,
323 dim: i64,
324 largest: bool,
325 sorted: bool,
326) -> Result<TopKResult<T>> {
327 let shape = x.shape();
328 if shape.is_empty() {
329 return Err(axonml_core::error::Error::invalid_operation(
330 "Cannot apply topk to scalar tensor".to_string(),
331 ));
332 }
333
334 let dim = if dim < 0 {
335 (shape.len() as i64 + dim) as usize
336 } else {
337 dim as usize
338 };
339
340 if dim >= shape.len() {
341 return Err(axonml_core::error::Error::invalid_operation(format!(
342 "Dimension {} out of range for tensor with {} dimensions",
343 dim,
344 shape.len()
345 )));
346 }
347
348 let dim_size = shape[dim];
349 if k > dim_size {
350 return Err(axonml_core::error::Error::invalid_operation(format!(
351 "k ({}) is larger than dimension size ({})",
352 k, dim_size
353 )));
354 }
355
356 let data = x.to_vec();
357
358 if shape.len() == 1 {
360 let mut indexed: Vec<(usize, T)> = data.into_iter().enumerate().collect();
361 if largest {
362 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
363 } else {
364 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
365 }
366
367 if !sorted {
368 indexed[..k].sort_by_key(|x| x.0);
369 }
370
371 let values: Vec<T> = indexed[..k].iter().map(|(_, v)| *v).collect();
372 let indices: Vec<i64> = indexed[..k].iter().map(|(i, _)| *i as i64).collect();
373
374 return Ok(TopKResult {
375 values: Tensor::from_vec(values, &[k])?,
376 indices: Tensor::from_vec(indices, &[k])?,
377 });
378 }
379
380 let outer_size: usize = shape[..dim].iter().product();
382 let inner_size: usize = shape[dim + 1..].iter().product();
383
384 let mut values_data = Vec::with_capacity(outer_size * k * inner_size);
385 let mut indices_data = Vec::with_capacity(outer_size * k * inner_size);
386
387 for outer in 0..outer_size {
388 for inner in 0..inner_size {
389 let mut slice: Vec<(usize, T)> = (0..dim_size)
390 .map(|d| {
391 let idx = outer * dim_size * inner_size + d * inner_size + inner;
392 (d, data[idx])
393 })
394 .collect();
395
396 if largest {
397 slice.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
398 } else {
399 slice.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
400 }
401
402 if !sorted {
403 slice[..k].sort_by_key(|x| x.0);
404 }
405
406 for (orig_idx, val) in slice.into_iter().take(k) {
407 values_data.push(val);
408 indices_data.push(orig_idx as i64);
409 }
410 }
411 }
412
413 let mut output_shape = shape.to_vec();
414 output_shape[dim] = k;
415
416 Ok(TopKResult {
417 values: Tensor::from_vec(values_data, &output_shape)?,
418 indices: Tensor::from_vec(indices_data, &output_shape)?,
419 })
420}
421
422#[derive(Clone)]
424pub struct SortResult<T: Scalar> {
425 pub values: Tensor<T>,
427 pub indices: Tensor<i64>,
429}
430
431impl<T: Scalar> std::fmt::Debug for SortResult<T> {
432 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433 f.debug_struct("SortResult")
434 .field("values_shape", &self.values.shape())
435 .field("indices_shape", &self.indices.shape())
436 .finish()
437 }
438}
439
440pub fn sort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<SortResult<T>> {
450 let shape = x.shape();
451 if shape.is_empty() {
452 return Ok(SortResult {
453 values: x.clone(),
454 indices: Tensor::scalar(0i64),
455 });
456 }
457
458 let dim = if dim < 0 {
459 (shape.len() as i64 + dim) as usize
460 } else {
461 dim as usize
462 };
463
464 let dim_size = shape[dim];
465 topk(x, dim_size, dim as i64, descending, true).map(|tk| SortResult {
466 values: tk.values,
467 indices: tk.indices,
468 })
469}
470
471pub fn argsort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<Tensor<i64>> {
478 sort(x, dim, descending).map(|r| r.indices)
479}
480
481pub fn scatter<T: Scalar>(
495 dst: &Tensor<T>,
496 dim: usize,
497 index: &Tensor<i64>,
498 src: &Tensor<T>,
499) -> Result<Tensor<T>> {
500 let dst_shape = dst.shape();
501 let idx_shape = index.shape();
502 let src_shape = src.shape();
503
504 if idx_shape != src_shape {
505 return Err(axonml_core::error::Error::shape_mismatch(
506 idx_shape, src_shape,
507 ));
508 }
509
510 if dim >= dst_shape.len() {
511 return Err(axonml_core::error::Error::invalid_operation(format!(
512 "Dimension {} out of range",
513 dim
514 )));
515 }
516
517 let mut result = dst.to_vec();
518 let idx_data = index.to_vec();
519 let src_data = src.to_vec();
520
521 let mut dst_strides = vec![1usize; dst_shape.len()];
523 for i in (0..dst_shape.len() - 1).rev() {
524 dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
525 }
526
527 let mut idx_strides = vec![1usize; idx_shape.len()];
529 for i in (0..idx_shape.len() - 1).rev() {
530 idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
531 }
532
533 let total = index.numel();
535 for linear_idx in 0..total {
536 let mut nd_idx = vec![0usize; idx_shape.len()];
538 let mut remaining = linear_idx;
539 for d in 0..idx_shape.len() {
540 nd_idx[d] = remaining / idx_strides[d];
541 remaining %= idx_strides[d];
542 }
543
544 let scatter_idx = idx_data[linear_idx] as usize;
546
547 let mut dst_nd_idx = nd_idx.clone();
549 dst_nd_idx[dim] = scatter_idx;
550
551 let mut dst_linear = 0;
553 for d in 0..dst_shape.len() {
554 dst_linear += dst_nd_idx[d] * dst_strides[d];
555 }
556
557 result[dst_linear] = src_data[linear_idx];
558 }
559
560 Tensor::from_vec(result, dst_shape)
561}
562
563pub fn nonzero<T: Numeric>(x: &Tensor<T>) -> Tensor<i64> {
575 let data = x.to_vec();
576 let shape = x.shape();
577 let ndim = shape.len();
578
579 let mut indices: Vec<Vec<i64>> = Vec::new();
581
582 let mut strides = vec![1usize; ndim.max(1)];
584 for i in (0..ndim.saturating_sub(1)).rev() {
585 strides[i] = strides[i + 1] * shape[i + 1];
586 }
587
588 for (linear_idx, &val) in data.iter().enumerate() {
589 if val != T::zero() {
590 let mut nd_idx = vec![0i64; ndim.max(1)];
591 let mut remaining = linear_idx;
592 for d in 0..ndim {
593 nd_idx[d] = (remaining / strides[d]) as i64;
594 remaining %= strides[d];
595 }
596 indices.push(nd_idx);
597 }
598 }
599
600 let num_nonzero = indices.len();
601 if num_nonzero == 0 {
602 return Tensor::from_vec(vec![], &[0, ndim.max(1)]).unwrap();
603 }
604
605 let flat: Vec<i64> = indices.into_iter().flatten().collect();
606 Tensor::from_vec(flat, &[num_nonzero, ndim.max(1)]).unwrap()
607}
608
609#[derive(Clone)]
615pub struct UniqueResult<T: Scalar> {
616 pub values: Tensor<T>,
618 pub inverse_indices: Option<Tensor<i64>>,
620 pub counts: Option<Tensor<i64>>,
622}
623
624impl<T: Scalar> std::fmt::Debug for UniqueResult<T> {
625 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626 f.debug_struct("UniqueResult")
627 .field("values_shape", &self.values.shape())
628 .field("has_inverse", &self.inverse_indices.is_some())
629 .field("has_counts", &self.counts.is_some())
630 .finish()
631 }
632}
633
634pub fn unique<T: Numeric>(
642 x: &Tensor<T>,
643 sorted: bool,
644 return_inverse: bool,
645 return_counts: bool,
646) -> UniqueResult<T> {
647 let data = x.to_vec();
648
649 let mut seen: Vec<T> = Vec::new();
651 let mut counts_map: Vec<i64> = Vec::new();
652 let mut inverse: Vec<i64> = Vec::with_capacity(data.len());
653
654 for &val in &data {
655 if let Some(pos) = seen.iter().position(|&v| v == val) {
656 inverse.push(pos as i64);
657 counts_map[pos] += 1;
658 } else {
659 inverse.push(seen.len() as i64);
660 seen.push(val);
661 counts_map.push(1);
662 }
663 }
664
665 let (unique_vals, final_inverse, final_counts) = if sorted {
666 let mut indexed: Vec<(usize, T)> = seen.into_iter().enumerate().collect();
668 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
669
670 let mut old_to_new = vec![0i64; indexed.len()];
672 for (new_idx, (old_idx, _)) in indexed.iter().enumerate() {
673 old_to_new[*old_idx] = new_idx as i64;
674 }
675
676 let sorted_vals: Vec<T> = indexed.iter().map(|(_, v)| *v).collect();
677 let sorted_counts: Vec<i64> = indexed
678 .iter()
679 .map(|(old_idx, _)| counts_map[*old_idx])
680 .collect();
681 let updated_inverse: Vec<i64> = inverse.iter().map(|&i| old_to_new[i as usize]).collect();
682
683 (sorted_vals, updated_inverse, sorted_counts)
684 } else {
685 (seen, inverse, counts_map)
686 };
687
688 let n = unique_vals.len();
689
690 UniqueResult {
691 values: Tensor::from_vec(unique_vals, &[n]).expect("tensor creation failed"),
692 inverse_indices: if return_inverse {
693 Some(Tensor::from_vec(final_inverse, x.shape()).unwrap())
694 } else {
695 None
696 },
697 counts: if return_counts {
698 Some(Tensor::from_vec(final_counts, &[n]).expect("tensor creation failed"))
699 } else {
700 None
701 },
702 }
703}
704
705pub fn flip<T: Numeric>(x: &Tensor<T>, dims: &[usize]) -> Result<Tensor<T>> {
711 let shape = x.shape();
712 let data = x.to_vec();
713 let ndim = shape.len();
714
715 for &d in dims {
716 if d >= ndim {
717 return Err(axonml_core::error::Error::invalid_operation(format!(
718 "Dimension {} out of range for tensor with {} dimensions",
719 d, ndim
720 )));
721 }
722 }
723
724 if shape.is_empty() {
725 return Ok(x.clone());
726 }
727
728 let mut strides = vec![1usize; ndim];
730 for i in (0..ndim - 1).rev() {
731 strides[i] = strides[i + 1] * shape[i + 1];
732 }
733
734 let mut result = vec![T::zero(); data.len()];
735
736 for src_linear in 0..data.len() {
737 let mut nd_idx = vec![0usize; ndim];
739 let mut remaining = src_linear;
740 for d in 0..ndim {
741 nd_idx[d] = remaining / strides[d];
742 remaining %= strides[d];
743 }
744
745 for &flip_dim in dims {
747 nd_idx[flip_dim] = shape[flip_dim] - 1 - nd_idx[flip_dim];
748 }
749
750 let mut dst_linear = 0;
752 for d in 0..ndim {
753 dst_linear += nd_idx[d] * strides[d];
754 }
755
756 result[dst_linear] = data[src_linear];
757 }
758
759 Tensor::from_vec(result, shape)
760}
761
762pub fn roll<T: Numeric>(x: &Tensor<T>, shifts: &[i64], dims: &[usize]) -> Result<Tensor<T>> {
768 if shifts.len() != dims.len() {
769 return Err(axonml_core::error::Error::invalid_operation(
770 "shifts and dims must have the same length".to_string(),
771 ));
772 }
773
774 let shape = x.shape();
775 let data = x.to_vec();
776 let ndim = shape.len();
777
778 for &d in dims {
779 if d >= ndim {
780 return Err(axonml_core::error::Error::invalid_operation(format!(
781 "Dimension {} out of range",
782 d
783 )));
784 }
785 }
786
787 if shape.is_empty() {
788 return Ok(x.clone());
789 }
790
791 let mut strides = vec![1usize; ndim];
793 for i in (0..ndim - 1).rev() {
794 strides[i] = strides[i + 1] * shape[i + 1];
795 }
796
797 let mut result = vec![T::zero(); data.len()];
798
799 for src_linear in 0..data.len() {
800 let mut nd_idx = vec![0usize; ndim];
802 let mut remaining = src_linear;
803 for d in 0..ndim {
804 nd_idx[d] = remaining / strides[d];
805 remaining %= strides[d];
806 }
807
808 for (shift, &dim) in shifts.iter().zip(dims.iter()) {
810 let dim_size = shape[dim] as i64;
811 let new_idx = ((nd_idx[dim] as i64 + shift) % dim_size + dim_size) % dim_size;
812 nd_idx[dim] = new_idx as usize;
813 }
814
815 let mut dst_linear = 0;
817 for d in 0..ndim {
818 dst_linear += nd_idx[d] * strides[d];
819 }
820
821 result[dst_linear] = data[src_linear];
822 }
823
824 Tensor::from_vec(result, shape)
825}
826
827#[cfg(test)]
832mod tests {
833 use super::*;
834
835 #[test]
836 fn test_softmax() {
837 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
838 let s = softmax(&t, -1).unwrap();
839
840 let sum: f32 = s.to_vec().iter().sum();
841 assert!((sum - 1.0).abs() < 1e-5);
842 }
843
844 #[test]
845 fn test_clamp() {
846 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
847 let c = clamp(&t, 0.0, 1.0);
848 assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
849 }
850
851 #[test]
852 fn test_leaky_relu() {
853 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
854 let r = leaky_relu(&t, 0.01);
855 assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
856 }
857
858 #[test]
859 fn test_comparison() {
860 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
861 let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
862
863 assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
864 assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
865 assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
866 }
867
868 #[test]
869 fn test_topk() {
870 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
871 let result = topk(&t, 3, -1, true, true).unwrap();
872
873 assert_eq!(result.values.shape(), &[3]);
874 assert_eq!(result.values.to_vec(), vec![9.0, 5.0, 4.0]);
875 assert_eq!(result.indices.to_vec(), vec![5, 4, 2]);
876 }
877
878 #[test]
879 fn test_topk_smallest() {
880 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
881 let result = topk(&t, 2, -1, false, true).unwrap();
882
883 assert_eq!(result.values.to_vec(), vec![1.0, 1.0]);
884 }
885
886 #[test]
887 fn test_sort() {
888 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
889 let result = sort(&t, -1, false).unwrap();
890
891 assert_eq!(result.values.to_vec(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
892 }
893
894 #[test]
895 fn test_sort_descending() {
896 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0], &[3]).unwrap();
897 let result = sort(&t, -1, true).unwrap();
898
899 assert_eq!(result.values.to_vec(), vec![4.0, 3.0, 1.0]);
900 }
901
902 #[test]
903 fn test_argsort() {
904 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 2.0], &[3]).unwrap();
905 let indices = argsort(&t, -1, false).unwrap();
906
907 assert_eq!(indices.to_vec(), vec![1, 2, 0]);
908 }
909
910 #[test]
911 fn test_nonzero() {
912 let t = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 2.0, 3.0, 0.0], &[6]).unwrap();
913 let result = nonzero(&t);
914
915 assert_eq!(result.shape(), &[3, 1]);
916 assert_eq!(result.to_vec(), vec![1, 3, 4]);
917 }
918
919 #[test]
920 fn test_nonzero_2d() {
921 let t = Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
922 let result = nonzero(&t);
923
924 assert_eq!(result.shape(), &[2, 2]);
925 assert_eq!(result.to_vec(), vec![0, 0, 1, 1]);
927 }
928
929 #[test]
930 fn test_unique() {
931 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], &[6]).unwrap();
932 let result = unique(&t, true, true, true);
933
934 assert_eq!(result.values.to_vec(), vec![1.0, 2.0, 3.0]);
935 assert_eq!(
936 result.inverse_indices.unwrap().to_vec(),
937 vec![0, 1, 0, 2, 1, 0]
938 );
939 assert_eq!(result.counts.unwrap().to_vec(), vec![3, 2, 1]);
940 }
941
942 #[test]
943 fn test_unique_unsorted() {
944 let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 3.0], &[3]).unwrap();
945 let result = unique(&t, false, false, false);
946
947 assert_eq!(result.values.to_vec(), vec![3.0, 1.0]);
949 }
950
951 #[test]
952 fn test_flip() {
953 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
954 let flipped = flip(&t, &[0]).unwrap();
955
956 assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
957 }
958
959 #[test]
960 fn test_flip_2d() {
961 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
962 let flipped = flip(&t, &[0]).unwrap();
963
964 assert_eq!(flipped.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
966 }
967
968 #[test]
969 fn test_roll() {
970 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
971 let rolled = roll(&t, &[1], &[0]).unwrap();
972
973 assert_eq!(rolled.to_vec(), vec![4.0, 1.0, 2.0, 3.0]);
974 }
975
976 #[test]
977 fn test_roll_negative() {
978 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
979 let rolled = roll(&t, &[-1], &[0]).unwrap();
980
981 assert_eq!(rolled.to_vec(), vec![2.0, 3.0, 4.0, 1.0]);
982 }
983
984 #[test]
985 fn test_scatter() {
986 let dst = Tensor::<f32>::zeros(&[3]);
987 let index = Tensor::from_vec(vec![0_i64, 2], &[2]).expect("tensor creation failed");
988 let src = Tensor::from_vec(vec![1.0, 2.0], &[2]).expect("tensor creation failed");
989
990 let result = scatter(&dst, 0, &index, &src).unwrap();
991 assert_eq!(result.to_vec(), vec![1.0, 0.0, 2.0]);
992 }
993
994 #[test]
999 fn test_where_cond_basic() {
1000 let cond = vec![true, false, true, false];
1001 let a = Tensor::from_vec(vec![10.0f32, 20.0, 30.0, 40.0], &[4]).unwrap();
1002 let b = Tensor::from_vec(vec![-1.0f32, -2.0, -3.0, -4.0], &[4]).unwrap();
1003 let result = where_cond(&cond, &a, &b).unwrap();
1004 assert_eq!(result.to_vec(), vec![10.0, -2.0, 30.0, -4.0]);
1005 }
1006
1007 #[test]
1012 fn test_scatter_overwrites() {
1013 let dst = Tensor::from_vec(vec![0.0f32; 4], &[4]).unwrap();
1014 let index = Tensor::from_vec(vec![1_i64, 1], &[2]).unwrap();
1015 let src = Tensor::from_vec(vec![5.0f32, 10.0], &[2]).unwrap();
1016
1017 let result = scatter(&dst, 0, &index, &src).unwrap();
1018 assert!(result.to_vec()[1] > 0.0, "Scatter should write to index 1");
1019 }
1020
1021 #[test]
1026 fn test_unique_all_same() {
1027 let t = Tensor::from_vec(vec![5.0f32, 5.0, 5.0, 5.0], &[4]).unwrap();
1028 let result = unique(&t, true, false, false);
1029 assert_eq!(result.values.to_vec().len(), 1);
1030 assert!((result.values.to_vec()[0] - 5.0).abs() < 1e-5);
1031 }
1032
1033 #[test]
1034 fn test_unique_already_unique() {
1035 let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], &[4]).unwrap();
1036 let result = unique(&t, true, false, false);
1037 assert_eq!(result.values.to_vec().len(), 4);
1038 }
1039
1040 #[test]
1045 fn test_flip_both_dims() {
1046 let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1047 let flipped = flip(&t, &[0, 1]).unwrap();
1048 assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
1049 }
1050
1051 #[test]
1052 fn test_flip_col_only() {
1053 let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1054 let flipped = flip(&t, &[1]).unwrap();
1055 assert_eq!(flipped.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
1056 }
1057
1058 #[test]
1063 fn test_roll_2d() {
1064 let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1065 let rolled = roll(&t, &[1], &[1]).unwrap();
1066 assert_eq!(rolled.to_vec(), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0]);
1067 }
1068
1069 #[test]
1070 fn test_roll_full_cycle() {
1071 let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], &[3]).unwrap();
1072 let rolled = roll(&t, &[3], &[0]).unwrap();
1073 assert_eq!(rolled.to_vec(), vec![1.0, 2.0, 3.0]);
1074 }
1075
1076 #[test]
1081 fn test_nonzero_all_zeros() {
1082 let t = Tensor::from_vec(vec![0.0f32, 0.0, 0.0], &[3]).unwrap();
1083 let result = nonzero(&t);
1084 assert_eq!(result.to_vec().len(), 0);
1085 }
1086
1087 #[test]
1088 fn test_nonzero_all_nonzero() {
1089 let t = Tensor::from_vec(vec![1.0f32, -2.0, 0.5], &[3]).unwrap();
1090 let result = nonzero(&t);
1091 assert_eq!(result.to_vec().len(), 3);
1092 }
1093
1094 #[test]
1099 fn test_softmax_large_values() {
1100 let t = Tensor::from_vec(vec![1000.0f32, 1001.0, 999.0], &[3]).unwrap();
1101 let result = softmax(&t, 0).unwrap();
1102 let rv = result.to_vec();
1103 assert!(
1104 rv.iter().all(|v: &f32| v.is_finite()),
1105 "Softmax should handle large values"
1106 );
1107 let sum: f32 = rv.iter().sum();
1108 assert!(
1109 (sum - 1.0).abs() < 1e-5,
1110 "Softmax should sum to 1.0, got {}",
1111 sum
1112 );
1113 }
1114
1115 #[test]
1116 fn test_softmax_negative_values() {
1117 let t = Tensor::from_vec(vec![-100.0f32, -200.0, -150.0], &[3]).unwrap();
1118 let result = softmax(&t, 0).unwrap();
1119 let rv = result.to_vec();
1120 assert!(rv.iter().all(|v: &f32| v.is_finite() && *v >= 0.0));
1121 let sum: f32 = rv.iter().sum();
1122 assert!((sum - 1.0).abs() < 1e-5);
1123 }
1124
1125 #[test]
1130 fn test_clamp_no_op() {
1131 let t = Tensor::from_vec(vec![0.5f32, 0.3, 0.8], &[3]).unwrap();
1132 let clamped = clamp_min(&t, 0.0);
1133 assert_eq!(clamped.to_vec(), vec![0.5, 0.3, 0.8]);
1134 }
1135
1136 #[test]
1137 fn test_clamp_min_all_negative() {
1138 let t = Tensor::from_vec(vec![-5.0f32, -3.0, -1.0], &[3]).unwrap();
1139 let clamped = clamp_min(&t, 0.0);
1140 assert_eq!(clamped.to_vec(), vec![0.0, 0.0, 0.0]);
1141 }
1142}