Skip to main content

axonml_tensor/ops/
mod.rs

1//! Tensor Operations - Mathematical and Structural Operations
2//!
3//! This module provides standalone tensor operations for convenient access.
4//! Operations are organized by category.
5//!
6//! # Categories
7//!
8//! ## Comparison Operations
9//! - `eq`, `lt`, `gt` - Element-wise comparison returning boolean vectors
10//!
11//! ## Activation Functions
12//! - `softmax`, `log_softmax` - Probability distributions
13//! - `gelu`, `silu`, `elu`, `leaky_relu` - Advanced activations
14//!
15//! ## Clipping Operations
16//! - `clamp`, `clamp_min`, `clamp_max` - Value range limiting
17//!
18//! ## Conditional Operations
19//! - `where_cond` - Select elements based on condition
20//!
21//! ## Sorting and Top-K
22//! - `topk` - Returns k largest/smallest elements with indices
23//! - `sort` - Sorts tensor along dimension with indices
24//! - `argsort` - Returns indices that would sort tensor
25//!
26//! ## Indexing Operations
27//! - `scatter` - Scatter values to specified indices (inverse of gather)
28//! - `nonzero` - Returns indices of non-zero elements
29//! - `unique` - Returns unique elements with optional counts/inverse
30//!
31//! ## Shape Manipulation
32//! - `flip` - Reverses tensor along specified dimensions
33//! - `roll` - Rolls tensor elements along dimensions (circular shift)
34//!
35//! # Example
36//!
37//! ```ignore
38//! use axonml_tensor::ops::{topk, sort, unique};
39//!
40//! let t = Tensor::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
41//!
42//! // Get top 3 largest values
43//! let result = topk(&t, 3, -1, true, true).unwrap();
44//! // result.values = [5.0, 4.0, 3.0]
45//! // result.indices = [4, 2, 0]
46//!
47//! // Sort ascending
48//! let sorted = sort(&t, -1, false).unwrap();
49//! // sorted.values = [1.0, 1.0, 3.0, 4.0, 5.0]
50//!
51//! // Get unique values
52//! let uniq = unique(&t, true, true, true);
53//! // uniq.values = [1.0, 3.0, 4.0, 5.0]
54//! // uniq.counts = [2, 1, 1, 1]
55//! ```
56//!
57//! @version 0.2.6
58//! @author `AutomataNexus` Development Team
59
60// Operations are implemented directly on Tensor in tensor.rs
61// This module provides additional standalone functions
62
63use axonml_core::dtype::{Float, Numeric, Scalar};
64use axonml_core::error::Result;
65
66use crate::tensor::Tensor;
67
68// =============================================================================
69// Comparison Operations
70// =============================================================================
71
72/// Element-wise equality comparison.
73pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
74    if a.shape() != b.shape() {
75        return Err(axonml_core::error::Error::shape_mismatch(
76            a.shape(),
77            b.shape(),
78        ));
79    }
80
81    let a_data = a.to_vec();
82    let b_data = b.to_vec();
83
84    Ok(a_data
85        .iter()
86        .zip(b_data.iter())
87        .map(|(x, y)| x == y)
88        .collect())
89}
90
91/// Element-wise less-than comparison.
92pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
93    if a.shape() != b.shape() {
94        return Err(axonml_core::error::Error::shape_mismatch(
95            a.shape(),
96            b.shape(),
97        ));
98    }
99
100    let a_data = a.to_vec();
101    let b_data = b.to_vec();
102
103    Ok(a_data
104        .iter()
105        .zip(b_data.iter())
106        .map(|(x, y)| x < y)
107        .collect())
108}
109
110/// Element-wise greater-than comparison.
111pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
112    if a.shape() != b.shape() {
113        return Err(axonml_core::error::Error::shape_mismatch(
114            a.shape(),
115            b.shape(),
116        ));
117    }
118
119    let a_data = a.to_vec();
120    let b_data = b.to_vec();
121
122    Ok(a_data
123        .iter()
124        .zip(b_data.iter())
125        .map(|(x, y)| x > y)
126        .collect())
127}
128
129// =============================================================================
130// Advanced Activation Functions
131// =============================================================================
132
133/// Applies softmax along the specified dimension.
134pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
135    // For simplicity, this handles the last dimension case
136    let data = x.to_vec();
137    let shape = x.shape();
138
139    if shape.is_empty() {
140        return Ok(Tensor::scalar(T::one()));
141    }
142
143    // Find max for numerical stability
144    let max_val = data
145        .iter()
146        .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
147
148    // Compute exp(x - max)
149    let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
150
151    // Compute sum
152    let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
153
154    // Normalize
155    let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
156
157    Tensor::from_vec(result, shape)
158}
159
160/// Applies log-softmax along the specified dimension.
161pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
162    let sm = softmax(x, dim)?;
163    Ok(sm.ln())
164}
165
166/// Applies GELU (Gaussian Error Linear Unit) activation.
167#[must_use]
168pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
169    let data = x.to_vec();
170    let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
171    let coeff = T::from(0.044715).unwrap();
172
173    let result: Vec<T> = data
174        .iter()
175        .map(|&v| {
176            let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
177            v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
178        })
179        .collect();
180
181    Tensor::from_vec(result, x.shape()).unwrap()
182}
183
184/// Applies Leaky `ReLU` activation.
185pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
186    let data = x.to_vec();
187    let result: Vec<T> = data
188        .iter()
189        .map(|&v| if v > T::zero() { v } else { negative_slope * v })
190        .collect();
191
192    Tensor::from_vec(result, x.shape()).unwrap()
193}
194
195/// Applies ELU (Exponential Linear Unit) activation.
196pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
197    let data = x.to_vec();
198    let result: Vec<T> = data
199        .iter()
200        .map(|&v| {
201            if v > T::zero() {
202                v
203            } else {
204                alpha * (v.exp_value() - T::one())
205            }
206        })
207        .collect();
208
209    Tensor::from_vec(result, x.shape()).unwrap()
210}
211
212/// Applies `SiLU` (Sigmoid Linear Unit) / Swish activation.
213#[must_use]
214pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
215    let sig = x.sigmoid();
216    x.mul(&sig).unwrap()
217}
218
219// =============================================================================
220// Clipping Operations
221// =============================================================================
222
223/// Clamps all elements to the range [min, max].
224pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
225    let data = x.to_vec();
226    let result: Vec<T> = data
227        .iter()
228        .map(|&v| {
229            if v < min {
230                min
231            } else if v > max {
232                max
233            } else {
234                v
235            }
236        })
237        .collect();
238
239    Tensor::from_vec(result, x.shape()).unwrap()
240}
241
242/// Clamps all elements to be at least min.
243pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
244    let data = x.to_vec();
245    let result: Vec<T> = data
246        .iter()
247        .map(|&v| if v < min { min } else { v })
248        .collect();
249
250    Tensor::from_vec(result, x.shape()).unwrap()
251}
252
253/// Clamps all elements to be at most max.
254pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
255    let data = x.to_vec();
256    let result: Vec<T> = data
257        .iter()
258        .map(|&v| if v > max { max } else { v })
259        .collect();
260
261    Tensor::from_vec(result, x.shape()).unwrap()
262}
263
264// =============================================================================
265// Where Operation
266// =============================================================================
267
268/// Selects elements from x or y based on condition.
269pub fn where_cond<T: Scalar>(
270    condition: &[bool],
271    x: &Tensor<T>,
272    y: &Tensor<T>,
273) -> Result<Tensor<T>> {
274    if x.shape() != y.shape() {
275        return Err(axonml_core::error::Error::shape_mismatch(
276            x.shape(),
277            y.shape(),
278        ));
279    }
280
281    if condition.len() != x.numel() {
282        return Err(axonml_core::error::Error::shape_mismatch(
283            &[condition.len()],
284            &[x.numel()],
285        ));
286    }
287
288    let x_data = x.to_vec();
289    let y_data = y.to_vec();
290
291    let result: Vec<T> = condition
292        .iter()
293        .zip(x_data.iter().zip(y_data.iter()))
294        .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
295        .collect();
296
297    Tensor::from_vec(result, x.shape())
298}
299
300// =============================================================================
301// Sorting and Top-K Operations
302// =============================================================================
303
304/// Result of topk operation containing values and indices.
305#[derive(Clone)]
306pub struct TopKResult<T: Scalar> {
307    /// The top-k values.
308    pub values: Tensor<T>,
309    /// The indices of the top-k values in the original tensor.
310    pub indices: Tensor<i64>,
311}
312
313impl<T: Scalar> std::fmt::Debug for TopKResult<T> {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        f.debug_struct("TopKResult")
316            .field("values_shape", &self.values.shape())
317            .field("indices_shape", &self.indices.shape())
318            .finish()
319    }
320}
321
322/// Returns the k largest elements along a dimension.
323///
324/// # Arguments
325/// * `x` - Input tensor
326/// * `k` - Number of top elements to return
327/// * `dim` - Dimension to sort along (default: -1)
328/// * `largest` - If true, return largest elements; if false, return smallest
329/// * `sorted` - If true, return elements in sorted order
330///
331/// # Returns
332/// TopKResult containing values and indices tensors
333pub fn topk<T: Numeric>(
334    x: &Tensor<T>,
335    k: usize,
336    dim: i64,
337    largest: bool,
338    sorted: bool,
339) -> Result<TopKResult<T>> {
340    let shape = x.shape();
341    if shape.is_empty() {
342        return Err(axonml_core::error::Error::invalid_operation(
343            "Cannot apply topk to scalar tensor".to_string(),
344        ));
345    }
346
347    let dim = if dim < 0 {
348        (shape.len() as i64 + dim) as usize
349    } else {
350        dim as usize
351    };
352
353    if dim >= shape.len() {
354        return Err(axonml_core::error::Error::invalid_operation(format!(
355            "Dimension {} out of range for tensor with {} dimensions",
356            dim,
357            shape.len()
358        )));
359    }
360
361    let dim_size = shape[dim];
362    if k > dim_size {
363        return Err(axonml_core::error::Error::invalid_operation(format!(
364            "k ({}) is larger than dimension size ({})",
365            k, dim_size
366        )));
367    }
368
369    let data = x.to_vec();
370
371    // For simplicity, handle the 1D case specially
372    if shape.len() == 1 {
373        let mut indexed: Vec<(usize, T)> = data.into_iter().enumerate().collect();
374        if largest {
375            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
376        } else {
377            indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
378        }
379
380        if !sorted {
381            indexed[..k].sort_by_key(|x| x.0);
382        }
383
384        let values: Vec<T> = indexed[..k].iter().map(|(_, v)| *v).collect();
385        let indices: Vec<i64> = indexed[..k].iter().map(|(i, _)| *i as i64).collect();
386
387        return Ok(TopKResult {
388            values: Tensor::from_vec(values, &[k])?,
389            indices: Tensor::from_vec(indices, &[k])?,
390        });
391    }
392
393    // General n-dimensional case
394    let outer_size: usize = shape[..dim].iter().product();
395    let inner_size: usize = shape[dim + 1..].iter().product();
396
397    let mut values_data = Vec::with_capacity(outer_size * k * inner_size);
398    let mut indices_data = Vec::with_capacity(outer_size * k * inner_size);
399
400    for outer in 0..outer_size {
401        for inner in 0..inner_size {
402            let mut slice: Vec<(usize, T)> = (0..dim_size)
403                .map(|d| {
404                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
405                    (d, data[idx])
406                })
407                .collect();
408
409            if largest {
410                slice.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
411            } else {
412                slice.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
413            }
414
415            if !sorted {
416                slice[..k].sort_by_key(|x| x.0);
417            }
418
419            for (orig_idx, val) in slice.into_iter().take(k) {
420                values_data.push(val);
421                indices_data.push(orig_idx as i64);
422            }
423        }
424    }
425
426    let mut output_shape = shape.to_vec();
427    output_shape[dim] = k;
428
429    Ok(TopKResult {
430        values: Tensor::from_vec(values_data, &output_shape)?,
431        indices: Tensor::from_vec(indices_data, &output_shape)?,
432    })
433}
434
435/// Result of sort operation containing sorted values and indices.
436#[derive(Clone)]
437pub struct SortResult<T: Scalar> {
438    /// Sorted values.
439    pub values: Tensor<T>,
440    /// Indices that would sort the tensor.
441    pub indices: Tensor<i64>,
442}
443
444impl<T: Scalar> std::fmt::Debug for SortResult<T> {
445    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
446        f.debug_struct("SortResult")
447            .field("values_shape", &self.values.shape())
448            .field("indices_shape", &self.indices.shape())
449            .finish()
450    }
451}
452
453/// Sorts the elements of the tensor along a dimension.
454///
455/// # Arguments
456/// * `x` - Input tensor
457/// * `dim` - Dimension to sort along (default: -1)
458/// * `descending` - If true, sort in descending order
459///
460/// # Returns
461/// SortResult containing sorted values and indices
462pub fn sort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<SortResult<T>> {
463    let shape = x.shape();
464    if shape.is_empty() {
465        return Ok(SortResult {
466            values: x.clone(),
467            indices: Tensor::scalar(0i64),
468        });
469    }
470
471    let dim = if dim < 0 {
472        (shape.len() as i64 + dim) as usize
473    } else {
474        dim as usize
475    };
476
477    let dim_size = shape[dim];
478    topk(x, dim_size, dim as i64, descending, true).map(|tk| SortResult {
479        values: tk.values,
480        indices: tk.indices,
481    })
482}
483
484/// Returns the indices that would sort the tensor along a dimension.
485///
486/// # Arguments
487/// * `x` - Input tensor
488/// * `dim` - Dimension to sort along (default: -1)
489/// * `descending` - If true, sort in descending order
490pub fn argsort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<Tensor<i64>> {
491    sort(x, dim, descending).map(|r| r.indices)
492}
493
494// =============================================================================
495// Scatter Operation
496// =============================================================================
497
498/// Writes values from src into self at locations specified by index.
499///
500/// This is the inverse of gather.
501///
502/// # Arguments
503/// * `dst` - Destination tensor (modified in place conceptually, returns new tensor)
504/// * `dim` - Dimension along which to scatter
505/// * `index` - Indices to scatter to
506/// * `src` - Source values to scatter
507pub fn scatter<T: Scalar>(
508    dst: &Tensor<T>,
509    dim: usize,
510    index: &Tensor<i64>,
511    src: &Tensor<T>,
512) -> Result<Tensor<T>> {
513    let dst_shape = dst.shape();
514    let idx_shape = index.shape();
515    let src_shape = src.shape();
516
517    if idx_shape != src_shape {
518        return Err(axonml_core::error::Error::shape_mismatch(
519            idx_shape, src_shape,
520        ));
521    }
522
523    if dim >= dst_shape.len() {
524        return Err(axonml_core::error::Error::invalid_operation(format!(
525            "Dimension {} out of range",
526            dim
527        )));
528    }
529
530    let mut result = dst.to_vec();
531    let idx_data = index.to_vec();
532    let src_data = src.to_vec();
533
534    // Calculate strides for the destination
535    let mut dst_strides = vec![1usize; dst_shape.len()];
536    for i in (0..dst_shape.len() - 1).rev() {
537        dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
538    }
539
540    // Calculate strides for index/src
541    let mut idx_strides = vec![1usize; idx_shape.len()];
542    for i in (0..idx_shape.len() - 1).rev() {
543        idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
544    }
545
546    // Scatter values
547    let total = index.numel();
548    for linear_idx in 0..total {
549        // Convert linear index to n-dimensional index
550        let mut nd_idx = vec![0usize; idx_shape.len()];
551        let mut remaining = linear_idx;
552        for d in 0..idx_shape.len() {
553            nd_idx[d] = remaining / idx_strides[d];
554            remaining %= idx_strides[d];
555        }
556
557        // Get the scatter index
558        let scatter_idx = idx_data[linear_idx] as usize;
559
560        // Build destination index
561        let mut dst_nd_idx = nd_idx.clone();
562        dst_nd_idx[dim] = scatter_idx;
563
564        // Convert to linear destination index
565        let mut dst_linear = 0;
566        for d in 0..dst_shape.len() {
567            dst_linear += dst_nd_idx[d] * dst_strides[d];
568        }
569
570        result[dst_linear] = src_data[linear_idx];
571    }
572
573    Tensor::from_vec(result, dst_shape)
574}
575
576// =============================================================================
577// Nonzero Operation
578// =============================================================================
579
580/// Returns the indices of non-zero elements.
581///
582/// # Arguments
583/// * `x` - Input tensor
584///
585/// # Returns
586/// Tensor of shape (num_nonzero, ndim) containing indices of non-zero elements
587pub fn nonzero<T: Numeric>(x: &Tensor<T>) -> Tensor<i64> {
588    let data = x.to_vec();
589    let shape = x.shape();
590    let ndim = shape.len();
591
592    // Find all non-zero indices
593    let mut indices: Vec<Vec<i64>> = Vec::new();
594
595    // Calculate strides for index conversion
596    let mut strides = vec![1usize; ndim.max(1)];
597    for i in (0..ndim.saturating_sub(1)).rev() {
598        strides[i] = strides[i + 1] * shape[i + 1];
599    }
600
601    for (linear_idx, &val) in data.iter().enumerate() {
602        if val != T::zero() {
603            let mut nd_idx = vec![0i64; ndim.max(1)];
604            let mut remaining = linear_idx;
605            for d in 0..ndim {
606                nd_idx[d] = (remaining / strides[d]) as i64;
607                remaining %= strides[d];
608            }
609            indices.push(nd_idx);
610        }
611    }
612
613    let num_nonzero = indices.len();
614    if num_nonzero == 0 {
615        return Tensor::from_vec(vec![], &[0, ndim.max(1)]).unwrap();
616    }
617
618    let flat: Vec<i64> = indices.into_iter().flatten().collect();
619    Tensor::from_vec(flat, &[num_nonzero, ndim.max(1)]).unwrap()
620}
621
622// =============================================================================
623// Unique Operation
624// =============================================================================
625
626/// Result of unique operation.
627#[derive(Clone)]
628pub struct UniqueResult<T: Scalar> {
629    /// Unique values.
630    pub values: Tensor<T>,
631    /// Indices of unique values in the original tensor (if return_inverse).
632    pub inverse_indices: Option<Tensor<i64>>,
633    /// Counts of each unique value (if return_counts).
634    pub counts: Option<Tensor<i64>>,
635}
636
637impl<T: Scalar> std::fmt::Debug for UniqueResult<T> {
638    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639        f.debug_struct("UniqueResult")
640            .field("values_shape", &self.values.shape())
641            .field("has_inverse", &self.inverse_indices.is_some())
642            .field("has_counts", &self.counts.is_some())
643            .finish()
644    }
645}
646
647/// Returns the unique elements of the input tensor.
648///
649/// # Arguments
650/// * `x` - Input tensor
651/// * `sorted` - Whether to sort the unique elements
652/// * `return_inverse` - Whether to return inverse indices
653/// * `return_counts` - Whether to return counts
654pub fn unique<T: Numeric>(
655    x: &Tensor<T>,
656    sorted: bool,
657    return_inverse: bool,
658    return_counts: bool,
659) -> UniqueResult<T> {
660    let data = x.to_vec();
661
662    // Use a vec to preserve insertion order (for unsorted case)
663    let mut seen: Vec<T> = Vec::new();
664    let mut counts_map: Vec<i64> = Vec::new();
665    let mut inverse: Vec<i64> = Vec::with_capacity(data.len());
666
667    for &val in &data {
668        if let Some(pos) = seen.iter().position(|&v| v == val) {
669            inverse.push(pos as i64);
670            counts_map[pos] += 1;
671        } else {
672            inverse.push(seen.len() as i64);
673            seen.push(val);
674            counts_map.push(1);
675        }
676    }
677
678    let (unique_vals, final_inverse, final_counts) = if sorted {
679        // Sort unique values and update inverse indices
680        let mut indexed: Vec<(usize, T)> = seen.into_iter().enumerate().collect();
681        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
682
683        // Create mapping from old index to new index
684        let mut old_to_new = vec![0i64; indexed.len()];
685        for (new_idx, (old_idx, _)) in indexed.iter().enumerate() {
686            old_to_new[*old_idx] = new_idx as i64;
687        }
688
689        let sorted_vals: Vec<T> = indexed.iter().map(|(_, v)| *v).collect();
690        let sorted_counts: Vec<i64> = indexed
691            .iter()
692            .map(|(old_idx, _)| counts_map[*old_idx])
693            .collect();
694        let updated_inverse: Vec<i64> = inverse.iter().map(|&i| old_to_new[i as usize]).collect();
695
696        (sorted_vals, updated_inverse, sorted_counts)
697    } else {
698        (seen, inverse, counts_map)
699    };
700
701    let n = unique_vals.len();
702
703    UniqueResult {
704        values: Tensor::from_vec(unique_vals, &[n]).unwrap(),
705        inverse_indices: if return_inverse {
706            Some(Tensor::from_vec(final_inverse, x.shape()).unwrap())
707        } else {
708            None
709        },
710        counts: if return_counts {
711            Some(Tensor::from_vec(final_counts, &[n]).unwrap())
712        } else {
713            None
714        },
715    }
716}
717
718// =============================================================================
719// Flip Operation
720// =============================================================================
721
722/// Reverses the order of elements along specified dimensions.
723pub fn flip<T: Numeric>(x: &Tensor<T>, dims: &[usize]) -> Result<Tensor<T>> {
724    let shape = x.shape();
725    let data = x.to_vec();
726    let ndim = shape.len();
727
728    for &d in dims {
729        if d >= ndim {
730            return Err(axonml_core::error::Error::invalid_operation(format!(
731                "Dimension {} out of range for tensor with {} dimensions",
732                d, ndim
733            )));
734        }
735    }
736
737    if shape.is_empty() {
738        return Ok(x.clone());
739    }
740
741    // Calculate strides
742    let mut strides = vec![1usize; ndim];
743    for i in (0..ndim - 1).rev() {
744        strides[i] = strides[i + 1] * shape[i + 1];
745    }
746
747    let mut result = vec![T::zero(); data.len()];
748
749    for src_linear in 0..data.len() {
750        // Convert to n-dimensional index
751        let mut nd_idx = vec![0usize; ndim];
752        let mut remaining = src_linear;
753        for d in 0..ndim {
754            nd_idx[d] = remaining / strides[d];
755            remaining %= strides[d];
756        }
757
758        // Flip specified dimensions
759        for &flip_dim in dims {
760            nd_idx[flip_dim] = shape[flip_dim] - 1 - nd_idx[flip_dim];
761        }
762
763        // Convert back to linear index
764        let mut dst_linear = 0;
765        for d in 0..ndim {
766            dst_linear += nd_idx[d] * strides[d];
767        }
768
769        result[dst_linear] = data[src_linear];
770    }
771
772    Tensor::from_vec(result, shape)
773}
774
775// =============================================================================
776// Roll Operation
777// =============================================================================
778
779/// Rolls tensor elements along specified dimensions.
780pub fn roll<T: Numeric>(x: &Tensor<T>, shifts: &[i64], dims: &[usize]) -> Result<Tensor<T>> {
781    if shifts.len() != dims.len() {
782        return Err(axonml_core::error::Error::invalid_operation(
783            "shifts and dims must have the same length".to_string(),
784        ));
785    }
786
787    let shape = x.shape();
788    let data = x.to_vec();
789    let ndim = shape.len();
790
791    for &d in dims {
792        if d >= ndim {
793            return Err(axonml_core::error::Error::invalid_operation(format!(
794                "Dimension {} out of range",
795                d
796            )));
797        }
798    }
799
800    if shape.is_empty() {
801        return Ok(x.clone());
802    }
803
804    // Calculate strides
805    let mut strides = vec![1usize; ndim];
806    for i in (0..ndim - 1).rev() {
807        strides[i] = strides[i + 1] * shape[i + 1];
808    }
809
810    let mut result = vec![T::zero(); data.len()];
811
812    for src_linear in 0..data.len() {
813        // Convert to n-dimensional index
814        let mut nd_idx = vec![0usize; ndim];
815        let mut remaining = src_linear;
816        for d in 0..ndim {
817            nd_idx[d] = remaining / strides[d];
818            remaining %= strides[d];
819        }
820
821        // Apply shifts
822        for (shift, &dim) in shifts.iter().zip(dims.iter()) {
823            let dim_size = shape[dim] as i64;
824            let new_idx = ((nd_idx[dim] as i64 + shift) % dim_size + dim_size) % dim_size;
825            nd_idx[dim] = new_idx as usize;
826        }
827
828        // Convert back to linear index
829        let mut dst_linear = 0;
830        for d in 0..ndim {
831            dst_linear += nd_idx[d] * strides[d];
832        }
833
834        result[dst_linear] = data[src_linear];
835    }
836
837    Tensor::from_vec(result, shape)
838}
839
840// =============================================================================
841// Tests
842// =============================================================================
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847
848    #[test]
849    fn test_softmax() {
850        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
851        let s = softmax(&t, -1).unwrap();
852
853        let sum: f32 = s.to_vec().iter().sum();
854        assert!((sum - 1.0).abs() < 1e-5);
855    }
856
857    #[test]
858    fn test_clamp() {
859        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
860        let c = clamp(&t, 0.0, 1.0);
861        assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
862    }
863
864    #[test]
865    fn test_leaky_relu() {
866        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
867        let r = leaky_relu(&t, 0.01);
868        assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
869    }
870
871    #[test]
872    fn test_comparison() {
873        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
874        let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
875
876        assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
877        assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
878        assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
879    }
880
881    #[test]
882    fn test_topk() {
883        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
884        let result = topk(&t, 3, -1, true, true).unwrap();
885
886        assert_eq!(result.values.shape(), &[3]);
887        assert_eq!(result.values.to_vec(), vec![9.0, 5.0, 4.0]);
888        assert_eq!(result.indices.to_vec(), vec![5, 4, 2]);
889    }
890
891    #[test]
892    fn test_topk_smallest() {
893        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
894        let result = topk(&t, 2, -1, false, true).unwrap();
895
896        assert_eq!(result.values.to_vec(), vec![1.0, 1.0]);
897    }
898
899    #[test]
900    fn test_sort() {
901        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
902        let result = sort(&t, -1, false).unwrap();
903
904        assert_eq!(result.values.to_vec(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
905    }
906
907    #[test]
908    fn test_sort_descending() {
909        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0], &[3]).unwrap();
910        let result = sort(&t, -1, true).unwrap();
911
912        assert_eq!(result.values.to_vec(), vec![4.0, 3.0, 1.0]);
913    }
914
915    #[test]
916    fn test_argsort() {
917        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 2.0], &[3]).unwrap();
918        let indices = argsort(&t, -1, false).unwrap();
919
920        assert_eq!(indices.to_vec(), vec![1, 2, 0]);
921    }
922
923    #[test]
924    fn test_nonzero() {
925        let t = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 2.0, 3.0, 0.0], &[6]).unwrap();
926        let result = nonzero(&t);
927
928        assert_eq!(result.shape(), &[3, 1]);
929        assert_eq!(result.to_vec(), vec![1, 3, 4]);
930    }
931
932    #[test]
933    fn test_nonzero_2d() {
934        let t = Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
935        let result = nonzero(&t);
936
937        assert_eq!(result.shape(), &[2, 2]);
938        // (0,0) and (1,1) are non-zero
939        assert_eq!(result.to_vec(), vec![0, 0, 1, 1]);
940    }
941
942    #[test]
943    fn test_unique() {
944        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], &[6]).unwrap();
945        let result = unique(&t, true, true, true);
946
947        assert_eq!(result.values.to_vec(), vec![1.0, 2.0, 3.0]);
948        assert_eq!(
949            result.inverse_indices.unwrap().to_vec(),
950            vec![0, 1, 0, 2, 1, 0]
951        );
952        assert_eq!(result.counts.unwrap().to_vec(), vec![3, 2, 1]);
953    }
954
955    #[test]
956    fn test_unique_unsorted() {
957        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 3.0], &[3]).unwrap();
958        let result = unique(&t, false, false, false);
959
960        // Preserves insertion order
961        assert_eq!(result.values.to_vec(), vec![3.0, 1.0]);
962    }
963
964    #[test]
965    fn test_flip() {
966        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
967        let flipped = flip(&t, &[0]).unwrap();
968
969        assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
970    }
971
972    #[test]
973    fn test_flip_2d() {
974        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
975        let flipped = flip(&t, &[0]).unwrap();
976
977        // Flip along dim 0: [[3,4], [1,2]]
978        assert_eq!(flipped.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
979    }
980
981    #[test]
982    fn test_roll() {
983        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
984        let rolled = roll(&t, &[1], &[0]).unwrap();
985
986        assert_eq!(rolled.to_vec(), vec![4.0, 1.0, 2.0, 3.0]);
987    }
988
989    #[test]
990    fn test_roll_negative() {
991        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
992        let rolled = roll(&t, &[-1], &[0]).unwrap();
993
994        assert_eq!(rolled.to_vec(), vec![2.0, 3.0, 4.0, 1.0]);
995    }
996
997    #[test]
998    fn test_scatter() {
999        let dst = Tensor::<f32>::zeros(&[3]);
1000        let index = Tensor::from_vec(vec![0_i64, 2], &[2]).unwrap();
1001        let src = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
1002
1003        let result = scatter(&dst, 0, &index, &src).unwrap();
1004        assert_eq!(result.to_vec(), vec![1.0, 0.0, 2.0]);
1005    }
1006}