Skip to main content

axonml_tensor/ops/
mod.rs

1//! Tensor Operations - Mathematical and Structural Operations
2//!
3//! This module provides standalone tensor operations for convenient access.
4//! Operations are organized by category.
5//!
6//! # Categories
7//!
8//! ## Comparison Operations
9//! - `eq`, `lt`, `gt` - Element-wise comparison returning boolean vectors
10//!
11//! ## Activation Functions
12//! - `softmax`, `log_softmax` - Probability distributions
13//! - `gelu`, `silu`, `elu`, `leaky_relu` - Advanced activations
14//!
15//! ## Clipping Operations
16//! - `clamp`, `clamp_min`, `clamp_max` - Value range limiting
17//!
18//! ## Conditional Operations
19//! - `where_cond` - Select elements based on condition
20//!
21//! ## Sorting and Top-K
22//! - `topk` - Returns k largest/smallest elements with indices
23//! - `sort` - Sorts tensor along dimension with indices
24//! - `argsort` - Returns indices that would sort tensor
25//!
26//! ## Indexing Operations
27//! - `scatter` - Scatter values to specified indices (inverse of gather)
28//! - `nonzero` - Returns indices of non-zero elements
29//! - `unique` - Returns unique elements with optional counts/inverse
30//!
31//! ## Shape Manipulation
32//! - `flip` - Reverses tensor along specified dimensions
33//! - `roll` - Rolls tensor elements along dimensions (circular shift)
34//!
35//! # Example
36//!
37//! ```ignore
38//! use axonml_tensor::ops::{topk, sort, unique};
39//!
40//! let t = Tensor::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
41//!
42//! // Get top 3 largest values
43//! let result = topk(&t, 3, -1, true, true).unwrap();
44//! // result.values = [5.0, 4.0, 3.0]
45//! // result.indices = [4, 2, 0]
46//!
47//! // Sort ascending
48//! let sorted = sort(&t, -1, false).unwrap();
49//! // sorted.values = [1.0, 1.0, 3.0, 4.0, 5.0]
50//!
51//! // Get unique values
52//! let uniq = unique(&t, true, true, true);
53//! // uniq.values = [1.0, 3.0, 4.0, 5.0]
54//! // uniq.counts = [2, 1, 1, 1]
55//! ```
56//!
57//! @version 0.2.6
58//! @author `AutomataNexus` Development Team
59
60// Operations are implemented directly on Tensor in tensor.rs
61// This module provides additional standalone functions
62
63use axonml_core::dtype::{Float, Numeric, Scalar};
64use axonml_core::error::Result;
65
66use crate::tensor::Tensor;
67
68// =============================================================================
69// Comparison Operations
70// =============================================================================
71
72/// Element-wise equality comparison.
73pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
74    if a.shape() != b.shape() {
75        return Err(axonml_core::error::Error::shape_mismatch(
76            a.shape(),
77            b.shape(),
78        ));
79    }
80
81    let a_data = a.to_vec();
82    let b_data = b.to_vec();
83
84    Ok(a_data
85        .iter()
86        .zip(b_data.iter())
87        .map(|(x, y)| x == y)
88        .collect())
89}
90
91/// Element-wise less-than comparison.
92pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
93    if a.shape() != b.shape() {
94        return Err(axonml_core::error::Error::shape_mismatch(
95            a.shape(),
96            b.shape(),
97        ));
98    }
99
100    let a_data = a.to_vec();
101    let b_data = b.to_vec();
102
103    Ok(a_data
104        .iter()
105        .zip(b_data.iter())
106        .map(|(x, y)| x < y)
107        .collect())
108}
109
110/// Element-wise greater-than comparison.
111pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
112    if a.shape() != b.shape() {
113        return Err(axonml_core::error::Error::shape_mismatch(
114            a.shape(),
115            b.shape(),
116        ));
117    }
118
119    let a_data = a.to_vec();
120    let b_data = b.to_vec();
121
122    Ok(a_data
123        .iter()
124        .zip(b_data.iter())
125        .map(|(x, y)| x > y)
126        .collect())
127}
128
129// =============================================================================
130// Advanced Activation Functions
131// =============================================================================
132
133/// Applies softmax along the specified dimension.
134pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
135    // For simplicity, this handles the last dimension case
136    let data = x.to_vec();
137    let shape = x.shape();
138
139    if shape.is_empty() {
140        return Ok(Tensor::scalar(T::one()));
141    }
142
143    // Find max for numerical stability
144    let max_val = data
145        .iter()
146        .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
147
148    // Compute exp(x - max)
149    let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
150
151    // Compute sum
152    let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
153
154    // Normalize
155    let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
156
157    Tensor::from_vec(result, shape)
158}
159
160/// Applies log-softmax along the specified dimension.
161pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
162    let sm = softmax(x, dim)?;
163    Ok(sm.ln())
164}
165
166/// Applies GELU (Gaussian Error Linear Unit) activation.
167#[must_use] pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
168    let data = x.to_vec();
169    let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
170    let coeff = T::from(0.044715).unwrap();
171
172    let result: Vec<T> = data
173        .iter()
174        .map(|&v| {
175            let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
176            v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
177        })
178        .collect();
179
180    Tensor::from_vec(result, x.shape()).unwrap()
181}
182
183/// Applies Leaky `ReLU` activation.
184pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
185    let data = x.to_vec();
186    let result: Vec<T> = data
187        .iter()
188        .map(|&v| if v > T::zero() { v } else { negative_slope * v })
189        .collect();
190
191    Tensor::from_vec(result, x.shape()).unwrap()
192}
193
194/// Applies ELU (Exponential Linear Unit) activation.
195pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
196    let data = x.to_vec();
197    let result: Vec<T> = data
198        .iter()
199        .map(|&v| {
200            if v > T::zero() {
201                v
202            } else {
203                alpha * (v.exp_value() - T::one())
204            }
205        })
206        .collect();
207
208    Tensor::from_vec(result, x.shape()).unwrap()
209}
210
211/// Applies `SiLU` (Sigmoid Linear Unit) / Swish activation.
212#[must_use] pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
213    let sig = x.sigmoid();
214    x.mul(&sig).unwrap()
215}
216
217// =============================================================================
218// Clipping Operations
219// =============================================================================
220
221/// Clamps all elements to the range [min, max].
222pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
223    let data = x.to_vec();
224    let result: Vec<T> = data
225        .iter()
226        .map(|&v| {
227            if v < min {
228                min
229            } else if v > max {
230                max
231            } else {
232                v
233            }
234        })
235        .collect();
236
237    Tensor::from_vec(result, x.shape()).unwrap()
238}
239
240/// Clamps all elements to be at least min.
241pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
242    let data = x.to_vec();
243    let result: Vec<T> = data
244        .iter()
245        .map(|&v| if v < min { min } else { v })
246        .collect();
247
248    Tensor::from_vec(result, x.shape()).unwrap()
249}
250
251/// Clamps all elements to be at most max.
252pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
253    let data = x.to_vec();
254    let result: Vec<T> = data
255        .iter()
256        .map(|&v| if v > max { max } else { v })
257        .collect();
258
259    Tensor::from_vec(result, x.shape()).unwrap()
260}
261
262// =============================================================================
263// Where Operation
264// =============================================================================
265
266/// Selects elements from x or y based on condition.
267pub fn where_cond<T: Scalar>(
268    condition: &[bool],
269    x: &Tensor<T>,
270    y: &Tensor<T>,
271) -> Result<Tensor<T>> {
272    if x.shape() != y.shape() {
273        return Err(axonml_core::error::Error::shape_mismatch(
274            x.shape(),
275            y.shape(),
276        ));
277    }
278
279    if condition.len() != x.numel() {
280        return Err(axonml_core::error::Error::shape_mismatch(
281            &[condition.len()],
282            &[x.numel()],
283        ));
284    }
285
286    let x_data = x.to_vec();
287    let y_data = y.to_vec();
288
289    let result: Vec<T> = condition
290        .iter()
291        .zip(x_data.iter().zip(y_data.iter()))
292        .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
293        .collect();
294
295    Tensor::from_vec(result, x.shape())
296}
297
298// =============================================================================
299// Sorting and Top-K Operations
300// =============================================================================
301
302/// Result of topk operation containing values and indices.
303#[derive(Clone)]
304pub struct TopKResult<T: Scalar> {
305    /// The top-k values.
306    pub values: Tensor<T>,
307    /// The indices of the top-k values in the original tensor.
308    pub indices: Tensor<i64>,
309}
310
311impl<T: Scalar> std::fmt::Debug for TopKResult<T> {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        f.debug_struct("TopKResult")
314            .field("values_shape", &self.values.shape())
315            .field("indices_shape", &self.indices.shape())
316            .finish()
317    }
318}
319
320/// Returns the k largest elements along a dimension.
321///
322/// # Arguments
323/// * `x` - Input tensor
324/// * `k` - Number of top elements to return
325/// * `dim` - Dimension to sort along (default: -1)
326/// * `largest` - If true, return largest elements; if false, return smallest
327/// * `sorted` - If true, return elements in sorted order
328///
329/// # Returns
330/// TopKResult containing values and indices tensors
331pub fn topk<T: Numeric>(
332    x: &Tensor<T>,
333    k: usize,
334    dim: i64,
335    largest: bool,
336    sorted: bool,
337) -> Result<TopKResult<T>> {
338    let shape = x.shape();
339    if shape.is_empty() {
340        return Err(axonml_core::error::Error::invalid_operation(
341            "Cannot apply topk to scalar tensor".to_string(),
342        ));
343    }
344
345    let dim = if dim < 0 {
346        (shape.len() as i64 + dim) as usize
347    } else {
348        dim as usize
349    };
350
351    if dim >= shape.len() {
352        return Err(axonml_core::error::Error::invalid_operation(
353            format!("Dimension {} out of range for tensor with {} dimensions", dim, shape.len()),
354        ));
355    }
356
357    let dim_size = shape[dim];
358    if k > dim_size {
359        return Err(axonml_core::error::Error::invalid_operation(
360            format!("k ({}) is larger than dimension size ({})", k, dim_size),
361        ));
362    }
363
364    let data = x.to_vec();
365
366    // For simplicity, handle the 1D case specially
367    if shape.len() == 1 {
368        let mut indexed: Vec<(usize, T)> = data.into_iter().enumerate().collect();
369        if largest {
370            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
371        } else {
372            indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
373        }
374
375        if !sorted {
376            indexed[..k].sort_by_key(|x| x.0);
377        }
378
379        let values: Vec<T> = indexed[..k].iter().map(|(_, v)| *v).collect();
380        let indices: Vec<i64> = indexed[..k].iter().map(|(i, _)| *i as i64).collect();
381
382        return Ok(TopKResult {
383            values: Tensor::from_vec(values, &[k])?,
384            indices: Tensor::from_vec(indices, &[k])?,
385        });
386    }
387
388    // General n-dimensional case
389    let outer_size: usize = shape[..dim].iter().product();
390    let inner_size: usize = shape[dim + 1..].iter().product();
391
392    let mut values_data = Vec::with_capacity(outer_size * k * inner_size);
393    let mut indices_data = Vec::with_capacity(outer_size * k * inner_size);
394
395    for outer in 0..outer_size {
396        for inner in 0..inner_size {
397            let mut slice: Vec<(usize, T)> = (0..dim_size)
398                .map(|d| {
399                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
400                    (d, data[idx])
401                })
402                .collect();
403
404            if largest {
405                slice.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
406            } else {
407                slice.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
408            }
409
410            if !sorted {
411                slice[..k].sort_by_key(|x| x.0);
412            }
413
414            for (orig_idx, val) in slice.into_iter().take(k) {
415                values_data.push(val);
416                indices_data.push(orig_idx as i64);
417            }
418        }
419    }
420
421    let mut output_shape = shape.to_vec();
422    output_shape[dim] = k;
423
424    Ok(TopKResult {
425        values: Tensor::from_vec(values_data, &output_shape)?,
426        indices: Tensor::from_vec(indices_data, &output_shape)?,
427    })
428}
429
430/// Result of sort operation containing sorted values and indices.
431#[derive(Clone)]
432pub struct SortResult<T: Scalar> {
433    /// Sorted values.
434    pub values: Tensor<T>,
435    /// Indices that would sort the tensor.
436    pub indices: Tensor<i64>,
437}
438
439impl<T: Scalar> std::fmt::Debug for SortResult<T> {
440    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441        f.debug_struct("SortResult")
442            .field("values_shape", &self.values.shape())
443            .field("indices_shape", &self.indices.shape())
444            .finish()
445    }
446}
447
448/// Sorts the elements of the tensor along a dimension.
449///
450/// # Arguments
451/// * `x` - Input tensor
452/// * `dim` - Dimension to sort along (default: -1)
453/// * `descending` - If true, sort in descending order
454///
455/// # Returns
456/// SortResult containing sorted values and indices
457pub fn sort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<SortResult<T>> {
458    let shape = x.shape();
459    if shape.is_empty() {
460        return Ok(SortResult {
461            values: x.clone(),
462            indices: Tensor::scalar(0i64),
463        });
464    }
465
466    let dim = if dim < 0 {
467        (shape.len() as i64 + dim) as usize
468    } else {
469        dim as usize
470    };
471
472    let dim_size = shape[dim];
473    topk(x, dim_size, dim as i64, descending, true).map(|tk| SortResult {
474        values: tk.values,
475        indices: tk.indices,
476    })
477}
478
479/// Returns the indices that would sort the tensor along a dimension.
480///
481/// # Arguments
482/// * `x` - Input tensor
483/// * `dim` - Dimension to sort along (default: -1)
484/// * `descending` - If true, sort in descending order
485pub fn argsort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<Tensor<i64>> {
486    sort(x, dim, descending).map(|r| r.indices)
487}
488
489// =============================================================================
490// Scatter Operation
491// =============================================================================
492
493/// Writes values from src into self at locations specified by index.
494///
495/// This is the inverse of gather.
496///
497/// # Arguments
498/// * `dst` - Destination tensor (modified in place conceptually, returns new tensor)
499/// * `dim` - Dimension along which to scatter
500/// * `index` - Indices to scatter to
501/// * `src` - Source values to scatter
502pub fn scatter<T: Scalar>(
503    dst: &Tensor<T>,
504    dim: usize,
505    index: &Tensor<i64>,
506    src: &Tensor<T>,
507) -> Result<Tensor<T>> {
508    let dst_shape = dst.shape();
509    let idx_shape = index.shape();
510    let src_shape = src.shape();
511
512    if idx_shape != src_shape {
513        return Err(axonml_core::error::Error::shape_mismatch(idx_shape, src_shape));
514    }
515
516    if dim >= dst_shape.len() {
517        return Err(axonml_core::error::Error::invalid_operation(
518            format!("Dimension {} out of range", dim),
519        ));
520    }
521
522    let mut result = dst.to_vec();
523    let idx_data = index.to_vec();
524    let src_data = src.to_vec();
525
526    // Calculate strides for the destination
527    let mut dst_strides = vec![1usize; dst_shape.len()];
528    for i in (0..dst_shape.len() - 1).rev() {
529        dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
530    }
531
532    // Calculate strides for index/src
533    let mut idx_strides = vec![1usize; idx_shape.len()];
534    for i in (0..idx_shape.len() - 1).rev() {
535        idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
536    }
537
538    // Scatter values
539    let total = index.numel();
540    for linear_idx in 0..total {
541        // Convert linear index to n-dimensional index
542        let mut nd_idx = vec![0usize; idx_shape.len()];
543        let mut remaining = linear_idx;
544        for d in 0..idx_shape.len() {
545            nd_idx[d] = remaining / idx_strides[d];
546            remaining %= idx_strides[d];
547        }
548
549        // Get the scatter index
550        let scatter_idx = idx_data[linear_idx] as usize;
551
552        // Build destination index
553        let mut dst_nd_idx = nd_idx.clone();
554        dst_nd_idx[dim] = scatter_idx;
555
556        // Convert to linear destination index
557        let mut dst_linear = 0;
558        for d in 0..dst_shape.len() {
559            dst_linear += dst_nd_idx[d] * dst_strides[d];
560        }
561
562        result[dst_linear] = src_data[linear_idx];
563    }
564
565    Tensor::from_vec(result, dst_shape)
566}
567
568// =============================================================================
569// Nonzero Operation
570// =============================================================================
571
572/// Returns the indices of non-zero elements.
573///
574/// # Arguments
575/// * `x` - Input tensor
576///
577/// # Returns
578/// Tensor of shape (num_nonzero, ndim) containing indices of non-zero elements
579pub fn nonzero<T: Numeric>(x: &Tensor<T>) -> Tensor<i64> {
580    let data = x.to_vec();
581    let shape = x.shape();
582    let ndim = shape.len();
583
584    // Find all non-zero indices
585    let mut indices: Vec<Vec<i64>> = Vec::new();
586
587    // Calculate strides for index conversion
588    let mut strides = vec![1usize; ndim.max(1)];
589    for i in (0..ndim.saturating_sub(1)).rev() {
590        strides[i] = strides[i + 1] * shape[i + 1];
591    }
592
593    for (linear_idx, &val) in data.iter().enumerate() {
594        if val != T::zero() {
595            let mut nd_idx = vec![0i64; ndim.max(1)];
596            let mut remaining = linear_idx;
597            for d in 0..ndim {
598                nd_idx[d] = (remaining / strides[d]) as i64;
599                remaining %= strides[d];
600            }
601            indices.push(nd_idx);
602        }
603    }
604
605    let num_nonzero = indices.len();
606    if num_nonzero == 0 {
607        return Tensor::from_vec(vec![], &[0, ndim.max(1)]).unwrap();
608    }
609
610    let flat: Vec<i64> = indices.into_iter().flatten().collect();
611    Tensor::from_vec(flat, &[num_nonzero, ndim.max(1)]).unwrap()
612}
613
614// =============================================================================
615// Unique Operation
616// =============================================================================
617
618/// Result of unique operation.
619#[derive(Clone)]
620pub struct UniqueResult<T: Scalar> {
621    /// Unique values.
622    pub values: Tensor<T>,
623    /// Indices of unique values in the original tensor (if return_inverse).
624    pub inverse_indices: Option<Tensor<i64>>,
625    /// Counts of each unique value (if return_counts).
626    pub counts: Option<Tensor<i64>>,
627}
628
629impl<T: Scalar> std::fmt::Debug for UniqueResult<T> {
630    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631        f.debug_struct("UniqueResult")
632            .field("values_shape", &self.values.shape())
633            .field("has_inverse", &self.inverse_indices.is_some())
634            .field("has_counts", &self.counts.is_some())
635            .finish()
636    }
637}
638
639/// Returns the unique elements of the input tensor.
640///
641/// # Arguments
642/// * `x` - Input tensor
643/// * `sorted` - Whether to sort the unique elements
644/// * `return_inverse` - Whether to return inverse indices
645/// * `return_counts` - Whether to return counts
646pub fn unique<T: Numeric>(
647    x: &Tensor<T>,
648    sorted: bool,
649    return_inverse: bool,
650    return_counts: bool,
651) -> UniqueResult<T> {
652    let data = x.to_vec();
653
654    // Use a vec to preserve insertion order (for unsorted case)
655    let mut seen: Vec<T> = Vec::new();
656    let mut counts_map: Vec<i64> = Vec::new();
657    let mut inverse: Vec<i64> = Vec::with_capacity(data.len());
658
659    for &val in &data {
660        if let Some(pos) = seen.iter().position(|&v| v == val) {
661            inverse.push(pos as i64);
662            counts_map[pos] += 1;
663        } else {
664            inverse.push(seen.len() as i64);
665            seen.push(val);
666            counts_map.push(1);
667        }
668    }
669
670    let (unique_vals, final_inverse, final_counts) = if sorted {
671        // Sort unique values and update inverse indices
672        let mut indexed: Vec<(usize, T)> = seen.into_iter().enumerate().collect();
673        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
674
675        // Create mapping from old index to new index
676        let mut old_to_new = vec![0i64; indexed.len()];
677        for (new_idx, (old_idx, _)) in indexed.iter().enumerate() {
678            old_to_new[*old_idx] = new_idx as i64;
679        }
680
681        let sorted_vals: Vec<T> = indexed.iter().map(|(_, v)| *v).collect();
682        let sorted_counts: Vec<i64> = indexed.iter().map(|(old_idx, _)| counts_map[*old_idx]).collect();
683        let updated_inverse: Vec<i64> = inverse.iter().map(|&i| old_to_new[i as usize]).collect();
684
685        (sorted_vals, updated_inverse, sorted_counts)
686    } else {
687        (seen, inverse, counts_map)
688    };
689
690    let n = unique_vals.len();
691
692    UniqueResult {
693        values: Tensor::from_vec(unique_vals, &[n]).unwrap(),
694        inverse_indices: if return_inverse {
695            Some(Tensor::from_vec(final_inverse, x.shape()).unwrap())
696        } else {
697            None
698        },
699        counts: if return_counts {
700            Some(Tensor::from_vec(final_counts, &[n]).unwrap())
701        } else {
702            None
703        },
704    }
705}
706
707// =============================================================================
708// Flip Operation
709// =============================================================================
710
711/// Reverses the order of elements along specified dimensions.
712pub fn flip<T: Numeric>(x: &Tensor<T>, dims: &[usize]) -> Result<Tensor<T>> {
713    let shape = x.shape();
714    let data = x.to_vec();
715    let ndim = shape.len();
716
717    for &d in dims {
718        if d >= ndim {
719            return Err(axonml_core::error::Error::invalid_operation(
720                format!("Dimension {} out of range for tensor with {} dimensions", d, ndim),
721            ));
722        }
723    }
724
725    if shape.is_empty() {
726        return Ok(x.clone());
727    }
728
729    // Calculate strides
730    let mut strides = vec![1usize; ndim];
731    for i in (0..ndim - 1).rev() {
732        strides[i] = strides[i + 1] * shape[i + 1];
733    }
734
735    let mut result = vec![T::zero(); data.len()];
736
737    for src_linear in 0..data.len() {
738        // Convert to n-dimensional index
739        let mut nd_idx = vec![0usize; ndim];
740        let mut remaining = src_linear;
741        for d in 0..ndim {
742            nd_idx[d] = remaining / strides[d];
743            remaining %= strides[d];
744        }
745
746        // Flip specified dimensions
747        for &flip_dim in dims {
748            nd_idx[flip_dim] = shape[flip_dim] - 1 - nd_idx[flip_dim];
749        }
750
751        // Convert back to linear index
752        let mut dst_linear = 0;
753        for d in 0..ndim {
754            dst_linear += nd_idx[d] * strides[d];
755        }
756
757        result[dst_linear] = data[src_linear];
758    }
759
760    Tensor::from_vec(result, shape)
761}
762
763// =============================================================================
764// Roll Operation
765// =============================================================================
766
767/// Rolls tensor elements along specified dimensions.
768pub fn roll<T: Numeric>(x: &Tensor<T>, shifts: &[i64], dims: &[usize]) -> Result<Tensor<T>> {
769    if shifts.len() != dims.len() {
770        return Err(axonml_core::error::Error::invalid_operation(
771            "shifts and dims must have the same length".to_string(),
772        ));
773    }
774
775    let shape = x.shape();
776    let data = x.to_vec();
777    let ndim = shape.len();
778
779    for &d in dims {
780        if d >= ndim {
781            return Err(axonml_core::error::Error::invalid_operation(
782                format!("Dimension {} out of range", d),
783            ));
784        }
785    }
786
787    if shape.is_empty() {
788        return Ok(x.clone());
789    }
790
791    // Calculate strides
792    let mut strides = vec![1usize; ndim];
793    for i in (0..ndim - 1).rev() {
794        strides[i] = strides[i + 1] * shape[i + 1];
795    }
796
797    let mut result = vec![T::zero(); data.len()];
798
799    for src_linear in 0..data.len() {
800        // Convert to n-dimensional index
801        let mut nd_idx = vec![0usize; ndim];
802        let mut remaining = src_linear;
803        for d in 0..ndim {
804            nd_idx[d] = remaining / strides[d];
805            remaining %= strides[d];
806        }
807
808        // Apply shifts
809        for (shift, &dim) in shifts.iter().zip(dims.iter()) {
810            let dim_size = shape[dim] as i64;
811            let new_idx = ((nd_idx[dim] as i64 + shift) % dim_size + dim_size) % dim_size;
812            nd_idx[dim] = new_idx as usize;
813        }
814
815        // Convert back to linear index
816        let mut dst_linear = 0;
817        for d in 0..ndim {
818            dst_linear += nd_idx[d] * strides[d];
819        }
820
821        result[dst_linear] = data[src_linear];
822    }
823
824    Tensor::from_vec(result, shape)
825}
826
827// =============================================================================
828// Tests
829// =============================================================================
830
831#[cfg(test)]
832mod tests {
833    use super::*;
834
835    #[test]
836    fn test_softmax() {
837        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
838        let s = softmax(&t, -1).unwrap();
839
840        let sum: f32 = s.to_vec().iter().sum();
841        assert!((sum - 1.0).abs() < 1e-5);
842    }
843
844    #[test]
845    fn test_clamp() {
846        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
847        let c = clamp(&t, 0.0, 1.0);
848        assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
849    }
850
851    #[test]
852    fn test_leaky_relu() {
853        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
854        let r = leaky_relu(&t, 0.01);
855        assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
856    }
857
858    #[test]
859    fn test_comparison() {
860        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
861        let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
862
863        assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
864        assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
865        assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
866    }
867
868    #[test]
869    fn test_topk() {
870        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
871        let result = topk(&t, 3, -1, true, true).unwrap();
872
873        assert_eq!(result.values.shape(), &[3]);
874        assert_eq!(result.values.to_vec(), vec![9.0, 5.0, 4.0]);
875        assert_eq!(result.indices.to_vec(), vec![5, 4, 2]);
876    }
877
878    #[test]
879    fn test_topk_smallest() {
880        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
881        let result = topk(&t, 2, -1, false, true).unwrap();
882
883        assert_eq!(result.values.to_vec(), vec![1.0, 1.0]);
884    }
885
886    #[test]
887    fn test_sort() {
888        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
889        let result = sort(&t, -1, false).unwrap();
890
891        assert_eq!(result.values.to_vec(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
892    }
893
894    #[test]
895    fn test_sort_descending() {
896        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0], &[3]).unwrap();
897        let result = sort(&t, -1, true).unwrap();
898
899        assert_eq!(result.values.to_vec(), vec![4.0, 3.0, 1.0]);
900    }
901
902    #[test]
903    fn test_argsort() {
904        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 2.0], &[3]).unwrap();
905        let indices = argsort(&t, -1, false).unwrap();
906
907        assert_eq!(indices.to_vec(), vec![1, 2, 0]);
908    }
909
910    #[test]
911    fn test_nonzero() {
912        let t = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 2.0, 3.0, 0.0], &[6]).unwrap();
913        let result = nonzero(&t);
914
915        assert_eq!(result.shape(), &[3, 1]);
916        assert_eq!(result.to_vec(), vec![1, 3, 4]);
917    }
918
919    #[test]
920    fn test_nonzero_2d() {
921        let t = Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
922        let result = nonzero(&t);
923
924        assert_eq!(result.shape(), &[2, 2]);
925        // (0,0) and (1,1) are non-zero
926        assert_eq!(result.to_vec(), vec![0, 0, 1, 1]);
927    }
928
929    #[test]
930    fn test_unique() {
931        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], &[6]).unwrap();
932        let result = unique(&t, true, true, true);
933
934        assert_eq!(result.values.to_vec(), vec![1.0, 2.0, 3.0]);
935        assert_eq!(result.inverse_indices.unwrap().to_vec(), vec![0, 1, 0, 2, 1, 0]);
936        assert_eq!(result.counts.unwrap().to_vec(), vec![3, 2, 1]);
937    }
938
939    #[test]
940    fn test_unique_unsorted() {
941        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 3.0], &[3]).unwrap();
942        let result = unique(&t, false, false, false);
943
944        // Preserves insertion order
945        assert_eq!(result.values.to_vec(), vec![3.0, 1.0]);
946    }
947
948    #[test]
949    fn test_flip() {
950        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
951        let flipped = flip(&t, &[0]).unwrap();
952
953        assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
954    }
955
956    #[test]
957    fn test_flip_2d() {
958        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
959        let flipped = flip(&t, &[0]).unwrap();
960
961        // Flip along dim 0: [[3,4], [1,2]]
962        assert_eq!(flipped.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
963    }
964
965    #[test]
966    fn test_roll() {
967        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
968        let rolled = roll(&t, &[1], &[0]).unwrap();
969
970        assert_eq!(rolled.to_vec(), vec![4.0, 1.0, 2.0, 3.0]);
971    }
972
973    #[test]
974    fn test_roll_negative() {
975        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
976        let rolled = roll(&t, &[-1], &[0]).unwrap();
977
978        assert_eq!(rolled.to_vec(), vec![2.0, 3.0, 4.0, 1.0]);
979    }
980
981    #[test]
982    fn test_scatter() {
983        let dst = Tensor::<f32>::zeros(&[3]);
984        let index = Tensor::from_vec(vec![0_i64, 2], &[2]).unwrap();
985        let src = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
986
987        let result = scatter(&dst, 0, &index, &src).unwrap();
988        assert_eq!(result.to_vec(), vec![1.0, 0.0, 2.0]);
989    }
990}