Skip to main content

axonml_tensor/ops/
mod.rs

1//! Tensor Operations - Mathematical and Structural Operations
2//!
3//! # File
4//! `crates/axonml-tensor/src/ops/mod.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17// Operations are implemented directly on Tensor in tensor.rs
18// This module provides additional standalone functions
19
20use axonml_core::dtype::{Float, Numeric, Scalar};
21use axonml_core::error::Result;
22
23use crate::tensor::Tensor;
24
25// =============================================================================
26// Comparison Operations
27// =============================================================================
28
29/// Element-wise equality comparison.
30pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
31    if a.shape() != b.shape() {
32        return Err(axonml_core::error::Error::shape_mismatch(
33            a.shape(),
34            b.shape(),
35        ));
36    }
37
38    let a_data = a.to_vec();
39    let b_data = b.to_vec();
40
41    Ok(a_data
42        .iter()
43        .zip(b_data.iter())
44        .map(|(x, y)| x == y)
45        .collect())
46}
47
48/// Element-wise less-than comparison.
49pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
50    if a.shape() != b.shape() {
51        return Err(axonml_core::error::Error::shape_mismatch(
52            a.shape(),
53            b.shape(),
54        ));
55    }
56
57    let a_data = a.to_vec();
58    let b_data = b.to_vec();
59
60    Ok(a_data
61        .iter()
62        .zip(b_data.iter())
63        .map(|(x, y)| x < y)
64        .collect())
65}
66
67/// Element-wise greater-than comparison.
68pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
69    if a.shape() != b.shape() {
70        return Err(axonml_core::error::Error::shape_mismatch(
71            a.shape(),
72            b.shape(),
73        ));
74    }
75
76    let a_data = a.to_vec();
77    let b_data = b.to_vec();
78
79    Ok(a_data
80        .iter()
81        .zip(b_data.iter())
82        .map(|(x, y)| x > y)
83        .collect())
84}
85
86// =============================================================================
87// Advanced Activation Functions
88// =============================================================================
89
90/// Applies softmax along the specified dimension.
91pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
92    // For simplicity, this handles the last dimension case
93    let data = x.to_vec();
94    let shape = x.shape();
95
96    if shape.is_empty() {
97        return Ok(Tensor::scalar(T::one()));
98    }
99
100    // Find max for numerical stability
101    let max_val = data
102        .iter()
103        .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
104
105    // Compute exp(x - max)
106    let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
107
108    // Compute sum
109    let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
110
111    // Normalize
112    let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
113
114    Tensor::from_vec(result, shape)
115}
116
117/// Applies log-softmax along the specified dimension.
118pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
119    let sm = softmax(x, dim)?;
120    Ok(sm.ln())
121}
122
123/// Applies GELU (Gaussian Error Linear Unit) activation.
124#[must_use]
125pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
126    let data = x.to_vec();
127    let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
128    let coeff = T::from(0.044715).unwrap();
129
130    let result: Vec<T> = data
131        .iter()
132        .map(|&v| {
133            let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
134            v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
135        })
136        .collect();
137
138    Tensor::from_vec(result, x.shape()).unwrap()
139}
140
141/// Applies Leaky `ReLU` activation.
142pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
143    let data = x.to_vec();
144    let result: Vec<T> = data
145        .iter()
146        .map(|&v| if v > T::zero() { v } else { negative_slope * v })
147        .collect();
148
149    Tensor::from_vec(result, x.shape()).unwrap()
150}
151
152/// Applies ELU (Exponential Linear Unit) activation.
153pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
154    let data = x.to_vec();
155    let result: Vec<T> = data
156        .iter()
157        .map(|&v| {
158            if v > T::zero() {
159                v
160            } else {
161                alpha * (v.exp_value() - T::one())
162            }
163        })
164        .collect();
165
166    Tensor::from_vec(result, x.shape()).unwrap()
167}
168
169/// Applies `SiLU` (Sigmoid Linear Unit) / Swish activation.
170#[must_use]
171pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
172    let sig = x.sigmoid();
173    x.mul(&sig).unwrap()
174}
175
176// =============================================================================
177// Clipping Operations
178// =============================================================================
179
180/// Clamps all elements to the range [min, max].
181pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
182    let data = x.to_vec();
183    let result: Vec<T> = data
184        .iter()
185        .map(|&v| {
186            if v < min {
187                min
188            } else if v > max {
189                max
190            } else {
191                v
192            }
193        })
194        .collect();
195
196    Tensor::from_vec(result, x.shape()).unwrap()
197}
198
199/// Clamps all elements to be at least min.
200pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
201    let data = x.to_vec();
202    let result: Vec<T> = data
203        .iter()
204        .map(|&v| if v < min { min } else { v })
205        .collect();
206
207    Tensor::from_vec(result, x.shape()).unwrap()
208}
209
210/// Clamps all elements to be at most max.
211pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
212    let data = x.to_vec();
213    let result: Vec<T> = data
214        .iter()
215        .map(|&v| if v > max { max } else { v })
216        .collect();
217
218    Tensor::from_vec(result, x.shape()).unwrap()
219}
220
221// =============================================================================
222// Where Operation
223// =============================================================================
224
225/// Selects elements from x or y based on condition.
226pub fn where_cond<T: Scalar>(
227    condition: &[bool],
228    x: &Tensor<T>,
229    y: &Tensor<T>,
230) -> Result<Tensor<T>> {
231    if x.shape() != y.shape() {
232        return Err(axonml_core::error::Error::shape_mismatch(
233            x.shape(),
234            y.shape(),
235        ));
236    }
237
238    if condition.len() != x.numel() {
239        return Err(axonml_core::error::Error::shape_mismatch(
240            &[condition.len()],
241            &[x.numel()],
242        ));
243    }
244
245    let x_data = x.to_vec();
246    let y_data = y.to_vec();
247
248    let result: Vec<T> = condition
249        .iter()
250        .zip(x_data.iter().zip(y_data.iter()))
251        .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
252        .collect();
253
254    Tensor::from_vec(result, x.shape())
255}
256
257// =============================================================================
258// Sorting and Top-K Operations
259// =============================================================================
260
261/// Result of topk operation containing values and indices.
262#[derive(Clone)]
263pub struct TopKResult<T: Scalar> {
264    /// The top-k values.
265    pub values: Tensor<T>,
266    /// The indices of the top-k values in the original tensor.
267    pub indices: Tensor<i64>,
268}
269
270impl<T: Scalar> std::fmt::Debug for TopKResult<T> {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        f.debug_struct("TopKResult")
273            .field("values_shape", &self.values.shape())
274            .field("indices_shape", &self.indices.shape())
275            .finish()
276    }
277}
278
279/// Returns the k largest elements along a dimension.
280///
281/// # Arguments
282/// * `x` - Input tensor
283/// * `k` - Number of top elements to return
284/// * `dim` - Dimension to sort along (default: -1)
285/// * `largest` - If true, return largest elements; if false, return smallest
286/// * `sorted` - If true, return elements in sorted order
287///
288/// # Returns
289/// TopKResult containing values and indices tensors
290pub fn topk<T: Numeric>(
291    x: &Tensor<T>,
292    k: usize,
293    dim: i64,
294    largest: bool,
295    sorted: bool,
296) -> Result<TopKResult<T>> {
297    let shape = x.shape();
298    if shape.is_empty() {
299        return Err(axonml_core::error::Error::invalid_operation(
300            "Cannot apply topk to scalar tensor".to_string(),
301        ));
302    }
303
304    let dim = if dim < 0 {
305        (shape.len() as i64 + dim) as usize
306    } else {
307        dim as usize
308    };
309
310    if dim >= shape.len() {
311        return Err(axonml_core::error::Error::invalid_operation(format!(
312            "Dimension {} out of range for tensor with {} dimensions",
313            dim,
314            shape.len()
315        )));
316    }
317
318    let dim_size = shape[dim];
319    if k > dim_size {
320        return Err(axonml_core::error::Error::invalid_operation(format!(
321            "k ({}) is larger than dimension size ({})",
322            k, dim_size
323        )));
324    }
325
326    let data = x.to_vec();
327
328    // For simplicity, handle the 1D case specially
329    if shape.len() == 1 {
330        let mut indexed: Vec<(usize, T)> = data.into_iter().enumerate().collect();
331        if largest {
332            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
333        } else {
334            indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
335        }
336
337        if !sorted {
338            indexed[..k].sort_by_key(|x| x.0);
339        }
340
341        let values: Vec<T> = indexed[..k].iter().map(|(_, v)| *v).collect();
342        let indices: Vec<i64> = indexed[..k].iter().map(|(i, _)| *i as i64).collect();
343
344        return Ok(TopKResult {
345            values: Tensor::from_vec(values, &[k])?,
346            indices: Tensor::from_vec(indices, &[k])?,
347        });
348    }
349
350    // General n-dimensional case
351    let outer_size: usize = shape[..dim].iter().product();
352    let inner_size: usize = shape[dim + 1..].iter().product();
353
354    let mut values_data = Vec::with_capacity(outer_size * k * inner_size);
355    let mut indices_data = Vec::with_capacity(outer_size * k * inner_size);
356
357    for outer in 0..outer_size {
358        for inner in 0..inner_size {
359            let mut slice: Vec<(usize, T)> = (0..dim_size)
360                .map(|d| {
361                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
362                    (d, data[idx])
363                })
364                .collect();
365
366            if largest {
367                slice.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
368            } else {
369                slice.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
370            }
371
372            if !sorted {
373                slice[..k].sort_by_key(|x| x.0);
374            }
375
376            for (orig_idx, val) in slice.into_iter().take(k) {
377                values_data.push(val);
378                indices_data.push(orig_idx as i64);
379            }
380        }
381    }
382
383    let mut output_shape = shape.to_vec();
384    output_shape[dim] = k;
385
386    Ok(TopKResult {
387        values: Tensor::from_vec(values_data, &output_shape)?,
388        indices: Tensor::from_vec(indices_data, &output_shape)?,
389    })
390}
391
392/// Result of sort operation containing sorted values and indices.
393#[derive(Clone)]
394pub struct SortResult<T: Scalar> {
395    /// Sorted values.
396    pub values: Tensor<T>,
397    /// Indices that would sort the tensor.
398    pub indices: Tensor<i64>,
399}
400
401impl<T: Scalar> std::fmt::Debug for SortResult<T> {
402    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        f.debug_struct("SortResult")
404            .field("values_shape", &self.values.shape())
405            .field("indices_shape", &self.indices.shape())
406            .finish()
407    }
408}
409
410/// Sorts the elements of the tensor along a dimension.
411///
412/// # Arguments
413/// * `x` - Input tensor
414/// * `dim` - Dimension to sort along (default: -1)
415/// * `descending` - If true, sort in descending order
416///
417/// # Returns
418/// SortResult containing sorted values and indices
419pub fn sort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<SortResult<T>> {
420    let shape = x.shape();
421    if shape.is_empty() {
422        return Ok(SortResult {
423            values: x.clone(),
424            indices: Tensor::scalar(0i64),
425        });
426    }
427
428    let dim = if dim < 0 {
429        (shape.len() as i64 + dim) as usize
430    } else {
431        dim as usize
432    };
433
434    let dim_size = shape[dim];
435    topk(x, dim_size, dim as i64, descending, true).map(|tk| SortResult {
436        values: tk.values,
437        indices: tk.indices,
438    })
439}
440
441/// Returns the indices that would sort the tensor along a dimension.
442///
443/// # Arguments
444/// * `x` - Input tensor
445/// * `dim` - Dimension to sort along (default: -1)
446/// * `descending` - If true, sort in descending order
447pub fn argsort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<Tensor<i64>> {
448    sort(x, dim, descending).map(|r| r.indices)
449}
450
451// =============================================================================
452// Scatter Operation
453// =============================================================================
454
455/// Writes values from src into self at locations specified by index.
456///
457/// This is the inverse of gather.
458///
459/// # Arguments
460/// * `dst` - Destination tensor (modified in place conceptually, returns new tensor)
461/// * `dim` - Dimension along which to scatter
462/// * `index` - Indices to scatter to
463/// * `src` - Source values to scatter
464pub fn scatter<T: Scalar>(
465    dst: &Tensor<T>,
466    dim: usize,
467    index: &Tensor<i64>,
468    src: &Tensor<T>,
469) -> Result<Tensor<T>> {
470    let dst_shape = dst.shape();
471    let idx_shape = index.shape();
472    let src_shape = src.shape();
473
474    if idx_shape != src_shape {
475        return Err(axonml_core::error::Error::shape_mismatch(
476            idx_shape, src_shape,
477        ));
478    }
479
480    if dim >= dst_shape.len() {
481        return Err(axonml_core::error::Error::invalid_operation(format!(
482            "Dimension {} out of range",
483            dim
484        )));
485    }
486
487    let mut result = dst.to_vec();
488    let idx_data = index.to_vec();
489    let src_data = src.to_vec();
490
491    // Calculate strides for the destination
492    let mut dst_strides = vec![1usize; dst_shape.len()];
493    for i in (0..dst_shape.len() - 1).rev() {
494        dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
495    }
496
497    // Calculate strides for index/src
498    let mut idx_strides = vec![1usize; idx_shape.len()];
499    for i in (0..idx_shape.len() - 1).rev() {
500        idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
501    }
502
503    // Scatter values
504    let total = index.numel();
505    for linear_idx in 0..total {
506        // Convert linear index to n-dimensional index
507        let mut nd_idx = vec![0usize; idx_shape.len()];
508        let mut remaining = linear_idx;
509        for d in 0..idx_shape.len() {
510            nd_idx[d] = remaining / idx_strides[d];
511            remaining %= idx_strides[d];
512        }
513
514        // Get the scatter index
515        let scatter_idx = idx_data[linear_idx] as usize;
516
517        // Build destination index
518        let mut dst_nd_idx = nd_idx.clone();
519        dst_nd_idx[dim] = scatter_idx;
520
521        // Convert to linear destination index
522        let mut dst_linear = 0;
523        for d in 0..dst_shape.len() {
524            dst_linear += dst_nd_idx[d] * dst_strides[d];
525        }
526
527        result[dst_linear] = src_data[linear_idx];
528    }
529
530    Tensor::from_vec(result, dst_shape)
531}
532
533// =============================================================================
534// Nonzero Operation
535// =============================================================================
536
537/// Returns the indices of non-zero elements.
538///
539/// # Arguments
540/// * `x` - Input tensor
541///
542/// # Returns
543/// Tensor of shape (num_nonzero, ndim) containing indices of non-zero elements
544pub fn nonzero<T: Numeric>(x: &Tensor<T>) -> Tensor<i64> {
545    let data = x.to_vec();
546    let shape = x.shape();
547    let ndim = shape.len();
548
549    // Find all non-zero indices
550    let mut indices: Vec<Vec<i64>> = Vec::new();
551
552    // Calculate strides for index conversion
553    let mut strides = vec![1usize; ndim.max(1)];
554    for i in (0..ndim.saturating_sub(1)).rev() {
555        strides[i] = strides[i + 1] * shape[i + 1];
556    }
557
558    for (linear_idx, &val) in data.iter().enumerate() {
559        if val != T::zero() {
560            let mut nd_idx = vec![0i64; ndim.max(1)];
561            let mut remaining = linear_idx;
562            for d in 0..ndim {
563                nd_idx[d] = (remaining / strides[d]) as i64;
564                remaining %= strides[d];
565            }
566            indices.push(nd_idx);
567        }
568    }
569
570    let num_nonzero = indices.len();
571    if num_nonzero == 0 {
572        return Tensor::from_vec(vec![], &[0, ndim.max(1)]).unwrap();
573    }
574
575    let flat: Vec<i64> = indices.into_iter().flatten().collect();
576    Tensor::from_vec(flat, &[num_nonzero, ndim.max(1)]).unwrap()
577}
578
579// =============================================================================
580// Unique Operation
581// =============================================================================
582
583/// Result of unique operation.
584#[derive(Clone)]
585pub struct UniqueResult<T: Scalar> {
586    /// Unique values.
587    pub values: Tensor<T>,
588    /// Indices of unique values in the original tensor (if return_inverse).
589    pub inverse_indices: Option<Tensor<i64>>,
590    /// Counts of each unique value (if return_counts).
591    pub counts: Option<Tensor<i64>>,
592}
593
594impl<T: Scalar> std::fmt::Debug for UniqueResult<T> {
595    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596        f.debug_struct("UniqueResult")
597            .field("values_shape", &self.values.shape())
598            .field("has_inverse", &self.inverse_indices.is_some())
599            .field("has_counts", &self.counts.is_some())
600            .finish()
601    }
602}
603
604/// Returns the unique elements of the input tensor.
605///
606/// # Arguments
607/// * `x` - Input tensor
608/// * `sorted` - Whether to sort the unique elements
609/// * `return_inverse` - Whether to return inverse indices
610/// * `return_counts` - Whether to return counts
611pub fn unique<T: Numeric>(
612    x: &Tensor<T>,
613    sorted: bool,
614    return_inverse: bool,
615    return_counts: bool,
616) -> UniqueResult<T> {
617    let data = x.to_vec();
618
619    // Use a vec to preserve insertion order (for unsorted case)
620    let mut seen: Vec<T> = Vec::new();
621    let mut counts_map: Vec<i64> = Vec::new();
622    let mut inverse: Vec<i64> = Vec::with_capacity(data.len());
623
624    for &val in &data {
625        if let Some(pos) = seen.iter().position(|&v| v == val) {
626            inverse.push(pos as i64);
627            counts_map[pos] += 1;
628        } else {
629            inverse.push(seen.len() as i64);
630            seen.push(val);
631            counts_map.push(1);
632        }
633    }
634
635    let (unique_vals, final_inverse, final_counts) = if sorted {
636        // Sort unique values and update inverse indices
637        let mut indexed: Vec<(usize, T)> = seen.into_iter().enumerate().collect();
638        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
639
640        // Create mapping from old index to new index
641        let mut old_to_new = vec![0i64; indexed.len()];
642        for (new_idx, (old_idx, _)) in indexed.iter().enumerate() {
643            old_to_new[*old_idx] = new_idx as i64;
644        }
645
646        let sorted_vals: Vec<T> = indexed.iter().map(|(_, v)| *v).collect();
647        let sorted_counts: Vec<i64> = indexed
648            .iter()
649            .map(|(old_idx, _)| counts_map[*old_idx])
650            .collect();
651        let updated_inverse: Vec<i64> = inverse.iter().map(|&i| old_to_new[i as usize]).collect();
652
653        (sorted_vals, updated_inverse, sorted_counts)
654    } else {
655        (seen, inverse, counts_map)
656    };
657
658    let n = unique_vals.len();
659
660    UniqueResult {
661        values: Tensor::from_vec(unique_vals, &[n]).unwrap(),
662        inverse_indices: if return_inverse {
663            Some(Tensor::from_vec(final_inverse, x.shape()).unwrap())
664        } else {
665            None
666        },
667        counts: if return_counts {
668            Some(Tensor::from_vec(final_counts, &[n]).unwrap())
669        } else {
670            None
671        },
672    }
673}
674
675// =============================================================================
676// Flip Operation
677// =============================================================================
678
679/// Reverses the order of elements along specified dimensions.
680pub fn flip<T: Numeric>(x: &Tensor<T>, dims: &[usize]) -> Result<Tensor<T>> {
681    let shape = x.shape();
682    let data = x.to_vec();
683    let ndim = shape.len();
684
685    for &d in dims {
686        if d >= ndim {
687            return Err(axonml_core::error::Error::invalid_operation(format!(
688                "Dimension {} out of range for tensor with {} dimensions",
689                d, ndim
690            )));
691        }
692    }
693
694    if shape.is_empty() {
695        return Ok(x.clone());
696    }
697
698    // Calculate strides
699    let mut strides = vec![1usize; ndim];
700    for i in (0..ndim - 1).rev() {
701        strides[i] = strides[i + 1] * shape[i + 1];
702    }
703
704    let mut result = vec![T::zero(); data.len()];
705
706    for src_linear in 0..data.len() {
707        // Convert to n-dimensional index
708        let mut nd_idx = vec![0usize; ndim];
709        let mut remaining = src_linear;
710        for d in 0..ndim {
711            nd_idx[d] = remaining / strides[d];
712            remaining %= strides[d];
713        }
714
715        // Flip specified dimensions
716        for &flip_dim in dims {
717            nd_idx[flip_dim] = shape[flip_dim] - 1 - nd_idx[flip_dim];
718        }
719
720        // Convert back to linear index
721        let mut dst_linear = 0;
722        for d in 0..ndim {
723            dst_linear += nd_idx[d] * strides[d];
724        }
725
726        result[dst_linear] = data[src_linear];
727    }
728
729    Tensor::from_vec(result, shape)
730}
731
732// =============================================================================
733// Roll Operation
734// =============================================================================
735
736/// Rolls tensor elements along specified dimensions.
737pub fn roll<T: Numeric>(x: &Tensor<T>, shifts: &[i64], dims: &[usize]) -> Result<Tensor<T>> {
738    if shifts.len() != dims.len() {
739        return Err(axonml_core::error::Error::invalid_operation(
740            "shifts and dims must have the same length".to_string(),
741        ));
742    }
743
744    let shape = x.shape();
745    let data = x.to_vec();
746    let ndim = shape.len();
747
748    for &d in dims {
749        if d >= ndim {
750            return Err(axonml_core::error::Error::invalid_operation(format!(
751                "Dimension {} out of range",
752                d
753            )));
754        }
755    }
756
757    if shape.is_empty() {
758        return Ok(x.clone());
759    }
760
761    // Calculate strides
762    let mut strides = vec![1usize; ndim];
763    for i in (0..ndim - 1).rev() {
764        strides[i] = strides[i + 1] * shape[i + 1];
765    }
766
767    let mut result = vec![T::zero(); data.len()];
768
769    for src_linear in 0..data.len() {
770        // Convert to n-dimensional index
771        let mut nd_idx = vec![0usize; ndim];
772        let mut remaining = src_linear;
773        for d in 0..ndim {
774            nd_idx[d] = remaining / strides[d];
775            remaining %= strides[d];
776        }
777
778        // Apply shifts
779        for (shift, &dim) in shifts.iter().zip(dims.iter()) {
780            let dim_size = shape[dim] as i64;
781            let new_idx = ((nd_idx[dim] as i64 + shift) % dim_size + dim_size) % dim_size;
782            nd_idx[dim] = new_idx as usize;
783        }
784
785        // Convert back to linear index
786        let mut dst_linear = 0;
787        for d in 0..ndim {
788            dst_linear += nd_idx[d] * strides[d];
789        }
790
791        result[dst_linear] = data[src_linear];
792    }
793
794    Tensor::from_vec(result, shape)
795}
796
797// =============================================================================
798// Tests
799// =============================================================================
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804
805    #[test]
806    fn test_softmax() {
807        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
808        let s = softmax(&t, -1).unwrap();
809
810        let sum: f32 = s.to_vec().iter().sum();
811        assert!((sum - 1.0).abs() < 1e-5);
812    }
813
814    #[test]
815    fn test_clamp() {
816        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
817        let c = clamp(&t, 0.0, 1.0);
818        assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
819    }
820
821    #[test]
822    fn test_leaky_relu() {
823        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
824        let r = leaky_relu(&t, 0.01);
825        assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
826    }
827
828    #[test]
829    fn test_comparison() {
830        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
831        let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
832
833        assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
834        assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
835        assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
836    }
837
838    #[test]
839    fn test_topk() {
840        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
841        let result = topk(&t, 3, -1, true, true).unwrap();
842
843        assert_eq!(result.values.shape(), &[3]);
844        assert_eq!(result.values.to_vec(), vec![9.0, 5.0, 4.0]);
845        assert_eq!(result.indices.to_vec(), vec![5, 4, 2]);
846    }
847
848    #[test]
849    fn test_topk_smallest() {
850        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
851        let result = topk(&t, 2, -1, false, true).unwrap();
852
853        assert_eq!(result.values.to_vec(), vec![1.0, 1.0]);
854    }
855
856    #[test]
857    fn test_sort() {
858        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
859        let result = sort(&t, -1, false).unwrap();
860
861        assert_eq!(result.values.to_vec(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
862    }
863
864    #[test]
865    fn test_sort_descending() {
866        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0], &[3]).unwrap();
867        let result = sort(&t, -1, true).unwrap();
868
869        assert_eq!(result.values.to_vec(), vec![4.0, 3.0, 1.0]);
870    }
871
872    #[test]
873    fn test_argsort() {
874        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 2.0], &[3]).unwrap();
875        let indices = argsort(&t, -1, false).unwrap();
876
877        assert_eq!(indices.to_vec(), vec![1, 2, 0]);
878    }
879
880    #[test]
881    fn test_nonzero() {
882        let t = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 2.0, 3.0, 0.0], &[6]).unwrap();
883        let result = nonzero(&t);
884
885        assert_eq!(result.shape(), &[3, 1]);
886        assert_eq!(result.to_vec(), vec![1, 3, 4]);
887    }
888
889    #[test]
890    fn test_nonzero_2d() {
891        let t = Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
892        let result = nonzero(&t);
893
894        assert_eq!(result.shape(), &[2, 2]);
895        // (0,0) and (1,1) are non-zero
896        assert_eq!(result.to_vec(), vec![0, 0, 1, 1]);
897    }
898
899    #[test]
900    fn test_unique() {
901        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], &[6]).unwrap();
902        let result = unique(&t, true, true, true);
903
904        assert_eq!(result.values.to_vec(), vec![1.0, 2.0, 3.0]);
905        assert_eq!(
906            result.inverse_indices.unwrap().to_vec(),
907            vec![0, 1, 0, 2, 1, 0]
908        );
909        assert_eq!(result.counts.unwrap().to_vec(), vec![3, 2, 1]);
910    }
911
912    #[test]
913    fn test_unique_unsorted() {
914        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 3.0], &[3]).unwrap();
915        let result = unique(&t, false, false, false);
916
917        // Preserves insertion order
918        assert_eq!(result.values.to_vec(), vec![3.0, 1.0]);
919    }
920
921    #[test]
922    fn test_flip() {
923        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
924        let flipped = flip(&t, &[0]).unwrap();
925
926        assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
927    }
928
929    #[test]
930    fn test_flip_2d() {
931        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
932        let flipped = flip(&t, &[0]).unwrap();
933
934        // Flip along dim 0: [[3,4], [1,2]]
935        assert_eq!(flipped.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
936    }
937
938    #[test]
939    fn test_roll() {
940        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
941        let rolled = roll(&t, &[1], &[0]).unwrap();
942
943        assert_eq!(rolled.to_vec(), vec![4.0, 1.0, 2.0, 3.0]);
944    }
945
946    #[test]
947    fn test_roll_negative() {
948        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
949        let rolled = roll(&t, &[-1], &[0]).unwrap();
950
951        assert_eq!(rolled.to_vec(), vec![2.0, 3.0, 4.0, 1.0]);
952    }
953
954    #[test]
955    fn test_scatter() {
956        let dst = Tensor::<f32>::zeros(&[3]);
957        let index = Tensor::from_vec(vec![0_i64, 2], &[2]).unwrap();
958        let src = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
959
960        let result = scatter(&dst, 0, &index, &src).unwrap();
961        assert_eq!(result.to_vec(), vec![1.0, 0.0, 2.0]);
962    }
963}