Skip to main content

axonml_tensor/ops/
mod.rs

1//! Higher-level tensor operations — activations, comparisons, clamping.
2//!
3//! 1133 lines. Free functions operating on `Tensor<T>` references:
4//! comparisons (eq, lt, gt, eq_mask, lt_mask, gt_mask), softmax /
5//! log_softmax (numerically stable, with dim parameter), activations
6//! (gelu, leaky_relu, elu, silu, mish), clamping (clamp, clamp_min,
7//! clamp_max), `var_dim` (Welford single-pass variance along a dim),
8//! `where_cond` (ternary select), `stack` (new dim), `dropout` (train/eval
9//! aware), `layer_norm`, and `batch_norm` (running mean/var + affine).
10//!
11//! # File
12//! `crates/axonml-tensor/src/ops/mod.rs`
13//!
14//! # Author
15//! Andrew Jewell Sr. — AutomataNexus LLC
16//! ORCID: 0009-0005-2158-7060
17//!
18//! # Updated
19//! April 14, 2026 11:15 PM EST
20//!
21//! # Disclaimer
22//! Use at own risk. This software is provided "as is", without warranty of any
23//! kind, express or implied. The author and AutomataNexus shall not be held
24//! liable for any damages arising from the use of this software.
25
26// Operations are implemented directly on Tensor in tensor.rs
27// This module provides additional standalone functions
28
29use axonml_core::dtype::{Float, Numeric, Scalar};
30use axonml_core::error::Result;
31
32use crate::tensor::Tensor;
33
34// =============================================================================
35// Comparison Operations
36// =============================================================================
37
38/// Element-wise equality comparison.
39pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
40    if a.shape() != b.shape() {
41        return Err(axonml_core::error::Error::shape_mismatch(
42            a.shape(),
43            b.shape(),
44        ));
45    }
46
47    let a_data = a.to_vec();
48    let b_data = b.to_vec();
49
50    Ok(a_data
51        .iter()
52        .zip(b_data.iter())
53        .map(|(x, y)| x == y)
54        .collect())
55}
56
57/// Element-wise less-than comparison.
58pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
59    if a.shape() != b.shape() {
60        return Err(axonml_core::error::Error::shape_mismatch(
61            a.shape(),
62            b.shape(),
63        ));
64    }
65
66    let a_data = a.to_vec();
67    let b_data = b.to_vec();
68
69    Ok(a_data
70        .iter()
71        .zip(b_data.iter())
72        .map(|(x, y)| x < y)
73        .collect())
74}
75
76/// Element-wise greater-than comparison.
77pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
78    if a.shape() != b.shape() {
79        return Err(axonml_core::error::Error::shape_mismatch(
80            a.shape(),
81            b.shape(),
82        ));
83    }
84
85    let a_data = a.to_vec();
86    let b_data = b.to_vec();
87
88    Ok(a_data
89        .iter()
90        .zip(b_data.iter())
91        .map(|(x, y)| x > y)
92        .collect())
93}
94
95/// Element-wise equality → Tensor<f32> mask (1.0 where equal, 0.0 where not).
96pub fn eq_mask<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
97    let bools = eq(a, b)?;
98    let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
99    Tensor::from_vec(data, a.shape())
100}
101
102/// Element-wise less-than → Tensor<f32> mask.
103pub fn lt_mask<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
104    let bools = lt(a, b)?;
105    let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
106    Tensor::from_vec(data, a.shape())
107}
108
109/// Element-wise greater-than → Tensor<f32> mask.
110pub fn gt_mask<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
111    let bools = gt(a, b)?;
112    let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
113    Tensor::from_vec(data, a.shape())
114}
115
116// =============================================================================
117// Advanced Activation Functions
118// =============================================================================
119
120/// Applies softmax along the specified dimension.
121pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
122    // For simplicity, this handles the last dimension case
123    let data = x.to_vec();
124    let shape = x.shape();
125
126    if shape.is_empty() {
127        return Ok(Tensor::scalar(T::one()));
128    }
129
130    // Find max for numerical stability
131    let max_val = data
132        .iter()
133        .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
134
135    // Compute exp(x - max)
136    let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
137
138    // Compute sum
139    let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
140
141    // Normalize
142    let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
143
144    Tensor::from_vec(result, shape)
145}
146
147/// Applies log-softmax along the specified dimension.
148pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
149    let sm = softmax(x, dim)?;
150    Ok(sm.ln())
151}
152
153/// Applies GELU (Gaussian Error Linear Unit) activation.
154#[must_use]
155pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
156    let data = x.to_vec();
157    let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
158    let coeff = T::from(0.044715).unwrap();
159
160    let result: Vec<T> = data
161        .iter()
162        .map(|&v| {
163            let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
164            v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
165        })
166        .collect();
167
168    Tensor::from_vec(result, x.shape()).unwrap()
169}
170
171/// Applies Leaky `ReLU` activation.
172pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
173    let data = x.to_vec();
174    let result: Vec<T> = data
175        .iter()
176        .map(|&v| if v > T::zero() { v } else { negative_slope * v })
177        .collect();
178
179    Tensor::from_vec(result, x.shape()).unwrap()
180}
181
182/// Applies ELU (Exponential Linear Unit) activation.
183pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
184    let data = x.to_vec();
185    let result: Vec<T> = data
186        .iter()
187        .map(|&v| {
188            if v > T::zero() {
189                v
190            } else {
191                alpha * (v.exp_value() - T::one())
192            }
193        })
194        .collect();
195
196    Tensor::from_vec(result, x.shape()).unwrap()
197}
198
199/// Applies `SiLU` (Sigmoid Linear Unit) / Swish activation.
200#[must_use]
201pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
202    let sig = x.sigmoid();
203    x.mul(&sig).expect("tensor mul failed")
204}
205
206// =============================================================================
207// Clipping Operations
208// =============================================================================
209
210/// Clamps all elements to the range [min, max].
211pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
212    let data = x.to_vec();
213    let result: Vec<T> = data
214        .iter()
215        .map(|&v| {
216            if v < min {
217                min
218            } else if v > max {
219                max
220            } else {
221                v
222            }
223        })
224        .collect();
225
226    Tensor::from_vec(result, x.shape()).unwrap()
227}
228
229/// Clamps all elements to be at least min.
230pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
231    let data = x.to_vec();
232    let result: Vec<T> = data
233        .iter()
234        .map(|&v| if v < min { min } else { v })
235        .collect();
236
237    Tensor::from_vec(result, x.shape()).unwrap()
238}
239
240/// Clamps all elements to be at most max.
241pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
242    let data = x.to_vec();
243    let result: Vec<T> = data
244        .iter()
245        .map(|&v| if v > max { max } else { v })
246        .collect();
247
248    Tensor::from_vec(result, x.shape()).unwrap()
249}
250
251// =============================================================================
252// Where Operation
253// =============================================================================
254
255/// Selects elements from x or y based on condition.
256pub fn where_cond<T: Scalar>(
257    condition: &[bool],
258    x: &Tensor<T>,
259    y: &Tensor<T>,
260) -> Result<Tensor<T>> {
261    if x.shape() != y.shape() {
262        return Err(axonml_core::error::Error::shape_mismatch(
263            x.shape(),
264            y.shape(),
265        ));
266    }
267
268    if condition.len() != x.numel() {
269        return Err(axonml_core::error::Error::shape_mismatch(
270            &[condition.len()],
271            &[x.numel()],
272        ));
273    }
274
275    let x_data = x.to_vec();
276    let y_data = y.to_vec();
277
278    let result: Vec<T> = condition
279        .iter()
280        .zip(x_data.iter().zip(y_data.iter()))
281        .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
282        .collect();
283
284    Tensor::from_vec(result, x.shape())
285}
286
287// =============================================================================
288// Sorting and Top-K Operations
289// =============================================================================
290
291/// Result of topk operation containing values and indices.
292#[derive(Clone)]
293pub struct TopKResult<T: Scalar> {
294    /// The top-k values.
295    pub values: Tensor<T>,
296    /// The indices of the top-k values in the original tensor.
297    pub indices: Tensor<i64>,
298}
299
300impl<T: Scalar> std::fmt::Debug for TopKResult<T> {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        f.debug_struct("TopKResult")
303            .field("values_shape", &self.values.shape())
304            .field("indices_shape", &self.indices.shape())
305            .finish()
306    }
307}
308
309/// Returns the k largest elements along a dimension.
310///
311/// # Arguments
312/// * `x` - Input tensor
313/// * `k` - Number of top elements to return
314/// * `dim` - Dimension to sort along (default: -1)
315/// * `largest` - If true, return largest elements; if false, return smallest
316/// * `sorted` - If true, return elements in sorted order
317///
318/// # Returns
319/// TopKResult containing values and indices tensors
320pub fn topk<T: Numeric>(
321    x: &Tensor<T>,
322    k: usize,
323    dim: i64,
324    largest: bool,
325    sorted: bool,
326) -> Result<TopKResult<T>> {
327    let shape = x.shape();
328    if shape.is_empty() {
329        return Err(axonml_core::error::Error::invalid_operation(
330            "Cannot apply topk to scalar tensor".to_string(),
331        ));
332    }
333
334    let dim = if dim < 0 {
335        (shape.len() as i64 + dim) as usize
336    } else {
337        dim as usize
338    };
339
340    if dim >= shape.len() {
341        return Err(axonml_core::error::Error::invalid_operation(format!(
342            "Dimension {} out of range for tensor with {} dimensions",
343            dim,
344            shape.len()
345        )));
346    }
347
348    let dim_size = shape[dim];
349    if k > dim_size {
350        return Err(axonml_core::error::Error::invalid_operation(format!(
351            "k ({}) is larger than dimension size ({})",
352            k, dim_size
353        )));
354    }
355
356    let data = x.to_vec();
357
358    // For simplicity, handle the 1D case specially
359    if shape.len() == 1 {
360        let mut indexed: Vec<(usize, T)> = data.into_iter().enumerate().collect();
361        if largest {
362            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
363        } else {
364            indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
365        }
366
367        if !sorted {
368            indexed[..k].sort_by_key(|x| x.0);
369        }
370
371        let values: Vec<T> = indexed[..k].iter().map(|(_, v)| *v).collect();
372        let indices: Vec<i64> = indexed[..k].iter().map(|(i, _)| *i as i64).collect();
373
374        return Ok(TopKResult {
375            values: Tensor::from_vec(values, &[k])?,
376            indices: Tensor::from_vec(indices, &[k])?,
377        });
378    }
379
380    // General n-dimensional case
381    let outer_size: usize = shape[..dim].iter().product();
382    let inner_size: usize = shape[dim + 1..].iter().product();
383
384    let mut values_data = Vec::with_capacity(outer_size * k * inner_size);
385    let mut indices_data = Vec::with_capacity(outer_size * k * inner_size);
386
387    for outer in 0..outer_size {
388        for inner in 0..inner_size {
389            let mut slice: Vec<(usize, T)> = (0..dim_size)
390                .map(|d| {
391                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
392                    (d, data[idx])
393                })
394                .collect();
395
396            if largest {
397                slice.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
398            } else {
399                slice.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
400            }
401
402            if !sorted {
403                slice[..k].sort_by_key(|x| x.0);
404            }
405
406            for (orig_idx, val) in slice.into_iter().take(k) {
407                values_data.push(val);
408                indices_data.push(orig_idx as i64);
409            }
410        }
411    }
412
413    let mut output_shape = shape.to_vec();
414    output_shape[dim] = k;
415
416    Ok(TopKResult {
417        values: Tensor::from_vec(values_data, &output_shape)?,
418        indices: Tensor::from_vec(indices_data, &output_shape)?,
419    })
420}
421
422/// Result of sort operation containing sorted values and indices.
423#[derive(Clone)]
424pub struct SortResult<T: Scalar> {
425    /// Sorted values.
426    pub values: Tensor<T>,
427    /// Indices that would sort the tensor.
428    pub indices: Tensor<i64>,
429}
430
431impl<T: Scalar> std::fmt::Debug for SortResult<T> {
432    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433        f.debug_struct("SortResult")
434            .field("values_shape", &self.values.shape())
435            .field("indices_shape", &self.indices.shape())
436            .finish()
437    }
438}
439
440/// Sorts the elements of the tensor along a dimension.
441///
442/// # Arguments
443/// * `x` - Input tensor
444/// * `dim` - Dimension to sort along (default: -1)
445/// * `descending` - If true, sort in descending order
446///
447/// # Returns
448/// SortResult containing sorted values and indices
449pub fn sort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<SortResult<T>> {
450    let shape = x.shape();
451    if shape.is_empty() {
452        return Ok(SortResult {
453            values: x.clone(),
454            indices: Tensor::scalar(0i64),
455        });
456    }
457
458    let dim = if dim < 0 {
459        (shape.len() as i64 + dim) as usize
460    } else {
461        dim as usize
462    };
463
464    let dim_size = shape[dim];
465    topk(x, dim_size, dim as i64, descending, true).map(|tk| SortResult {
466        values: tk.values,
467        indices: tk.indices,
468    })
469}
470
471/// Returns the indices that would sort the tensor along a dimension.
472///
473/// # Arguments
474/// * `x` - Input tensor
475/// * `dim` - Dimension to sort along (default: -1)
476/// * `descending` - If true, sort in descending order
477pub fn argsort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<Tensor<i64>> {
478    sort(x, dim, descending).map(|r| r.indices)
479}
480
481// =============================================================================
482// Scatter Operation
483// =============================================================================
484
485/// Writes values from src into self at locations specified by index.
486///
487/// This is the inverse of gather.
488///
489/// # Arguments
490/// * `dst` - Destination tensor (modified in place conceptually, returns new tensor)
491/// * `dim` - Dimension along which to scatter
492/// * `index` - Indices to scatter to
493/// * `src` - Source values to scatter
494pub fn scatter<T: Scalar>(
495    dst: &Tensor<T>,
496    dim: usize,
497    index: &Tensor<i64>,
498    src: &Tensor<T>,
499) -> Result<Tensor<T>> {
500    let dst_shape = dst.shape();
501    let idx_shape = index.shape();
502    let src_shape = src.shape();
503
504    if idx_shape != src_shape {
505        return Err(axonml_core::error::Error::shape_mismatch(
506            idx_shape, src_shape,
507        ));
508    }
509
510    if dim >= dst_shape.len() {
511        return Err(axonml_core::error::Error::invalid_operation(format!(
512            "Dimension {} out of range",
513            dim
514        )));
515    }
516
517    let mut result = dst.to_vec();
518    let idx_data = index.to_vec();
519    let src_data = src.to_vec();
520
521    // Calculate strides for the destination
522    let mut dst_strides = vec![1usize; dst_shape.len()];
523    for i in (0..dst_shape.len() - 1).rev() {
524        dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
525    }
526
527    // Calculate strides for index/src
528    let mut idx_strides = vec![1usize; idx_shape.len()];
529    for i in (0..idx_shape.len() - 1).rev() {
530        idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
531    }
532
533    // Scatter values
534    let total = index.numel();
535    for linear_idx in 0..total {
536        // Convert linear index to n-dimensional index
537        let mut nd_idx = vec![0usize; idx_shape.len()];
538        let mut remaining = linear_idx;
539        for d in 0..idx_shape.len() {
540            nd_idx[d] = remaining / idx_strides[d];
541            remaining %= idx_strides[d];
542        }
543
544        // Get the scatter index
545        let scatter_idx = idx_data[linear_idx] as usize;
546
547        // Build destination index
548        let mut dst_nd_idx = nd_idx.clone();
549        dst_nd_idx[dim] = scatter_idx;
550
551        // Convert to linear destination index
552        let mut dst_linear = 0;
553        for d in 0..dst_shape.len() {
554            dst_linear += dst_nd_idx[d] * dst_strides[d];
555        }
556
557        result[dst_linear] = src_data[linear_idx];
558    }
559
560    Tensor::from_vec(result, dst_shape)
561}
562
563// =============================================================================
564// Nonzero Operation
565// =============================================================================
566
567/// Returns the indices of non-zero elements.
568///
569/// # Arguments
570/// * `x` - Input tensor
571///
572/// # Returns
573/// Tensor of shape (num_nonzero, ndim) containing indices of non-zero elements
574pub fn nonzero<T: Numeric>(x: &Tensor<T>) -> Tensor<i64> {
575    let data = x.to_vec();
576    let shape = x.shape();
577    let ndim = shape.len();
578
579    // Find all non-zero indices
580    let mut indices: Vec<Vec<i64>> = Vec::new();
581
582    // Calculate strides for index conversion
583    let mut strides = vec![1usize; ndim.max(1)];
584    for i in (0..ndim.saturating_sub(1)).rev() {
585        strides[i] = strides[i + 1] * shape[i + 1];
586    }
587
588    for (linear_idx, &val) in data.iter().enumerate() {
589        if val != T::zero() {
590            let mut nd_idx = vec![0i64; ndim.max(1)];
591            let mut remaining = linear_idx;
592            for d in 0..ndim {
593                nd_idx[d] = (remaining / strides[d]) as i64;
594                remaining %= strides[d];
595            }
596            indices.push(nd_idx);
597        }
598    }
599
600    let num_nonzero = indices.len();
601    if num_nonzero == 0 {
602        return Tensor::from_vec(vec![], &[0, ndim.max(1)]).unwrap();
603    }
604
605    let flat: Vec<i64> = indices.into_iter().flatten().collect();
606    Tensor::from_vec(flat, &[num_nonzero, ndim.max(1)]).unwrap()
607}
608
609// =============================================================================
610// Unique Operation
611// =============================================================================
612
613/// Result of unique operation.
614#[derive(Clone)]
615pub struct UniqueResult<T: Scalar> {
616    /// Unique values.
617    pub values: Tensor<T>,
618    /// Indices of unique values in the original tensor (if return_inverse).
619    pub inverse_indices: Option<Tensor<i64>>,
620    /// Counts of each unique value (if return_counts).
621    pub counts: Option<Tensor<i64>>,
622}
623
624impl<T: Scalar> std::fmt::Debug for UniqueResult<T> {
625    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626        f.debug_struct("UniqueResult")
627            .field("values_shape", &self.values.shape())
628            .field("has_inverse", &self.inverse_indices.is_some())
629            .field("has_counts", &self.counts.is_some())
630            .finish()
631    }
632}
633
634/// Returns the unique elements of the input tensor.
635///
636/// # Arguments
637/// * `x` - Input tensor
638/// * `sorted` - Whether to sort the unique elements
639/// * `return_inverse` - Whether to return inverse indices
640/// * `return_counts` - Whether to return counts
641pub fn unique<T: Numeric>(
642    x: &Tensor<T>,
643    sorted: bool,
644    return_inverse: bool,
645    return_counts: bool,
646) -> UniqueResult<T> {
647    let data = x.to_vec();
648
649    // Use a vec to preserve insertion order (for unsorted case)
650    let mut seen: Vec<T> = Vec::new();
651    let mut counts_map: Vec<i64> = Vec::new();
652    let mut inverse: Vec<i64> = Vec::with_capacity(data.len());
653
654    for &val in &data {
655        if let Some(pos) = seen.iter().position(|&v| v == val) {
656            inverse.push(pos as i64);
657            counts_map[pos] += 1;
658        } else {
659            inverse.push(seen.len() as i64);
660            seen.push(val);
661            counts_map.push(1);
662        }
663    }
664
665    let (unique_vals, final_inverse, final_counts) = if sorted {
666        // Sort unique values and update inverse indices
667        let mut indexed: Vec<(usize, T)> = seen.into_iter().enumerate().collect();
668        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
669
670        // Create mapping from old index to new index
671        let mut old_to_new = vec![0i64; indexed.len()];
672        for (new_idx, (old_idx, _)) in indexed.iter().enumerate() {
673            old_to_new[*old_idx] = new_idx as i64;
674        }
675
676        let sorted_vals: Vec<T> = indexed.iter().map(|(_, v)| *v).collect();
677        let sorted_counts: Vec<i64> = indexed
678            .iter()
679            .map(|(old_idx, _)| counts_map[*old_idx])
680            .collect();
681        let updated_inverse: Vec<i64> = inverse.iter().map(|&i| old_to_new[i as usize]).collect();
682
683        (sorted_vals, updated_inverse, sorted_counts)
684    } else {
685        (seen, inverse, counts_map)
686    };
687
688    let n = unique_vals.len();
689
690    UniqueResult {
691        values: Tensor::from_vec(unique_vals, &[n]).expect("tensor creation failed"),
692        inverse_indices: if return_inverse {
693            Some(Tensor::from_vec(final_inverse, x.shape()).unwrap())
694        } else {
695            None
696        },
697        counts: if return_counts {
698            Some(Tensor::from_vec(final_counts, &[n]).expect("tensor creation failed"))
699        } else {
700            None
701        },
702    }
703}
704
705// =============================================================================
706// Flip Operation
707// =============================================================================
708
709/// Reverses the order of elements along specified dimensions.
710pub fn flip<T: Numeric>(x: &Tensor<T>, dims: &[usize]) -> Result<Tensor<T>> {
711    let shape = x.shape();
712    let data = x.to_vec();
713    let ndim = shape.len();
714
715    for &d in dims {
716        if d >= ndim {
717            return Err(axonml_core::error::Error::invalid_operation(format!(
718                "Dimension {} out of range for tensor with {} dimensions",
719                d, ndim
720            )));
721        }
722    }
723
724    if shape.is_empty() {
725        return Ok(x.clone());
726    }
727
728    // Calculate strides
729    let mut strides = vec![1usize; ndim];
730    for i in (0..ndim - 1).rev() {
731        strides[i] = strides[i + 1] * shape[i + 1];
732    }
733
734    let mut result = vec![T::zero(); data.len()];
735
736    for src_linear in 0..data.len() {
737        // Convert to n-dimensional index
738        let mut nd_idx = vec![0usize; ndim];
739        let mut remaining = src_linear;
740        for d in 0..ndim {
741            nd_idx[d] = remaining / strides[d];
742            remaining %= strides[d];
743        }
744
745        // Flip specified dimensions
746        for &flip_dim in dims {
747            nd_idx[flip_dim] = shape[flip_dim] - 1 - nd_idx[flip_dim];
748        }
749
750        // Convert back to linear index
751        let mut dst_linear = 0;
752        for d in 0..ndim {
753            dst_linear += nd_idx[d] * strides[d];
754        }
755
756        result[dst_linear] = data[src_linear];
757    }
758
759    Tensor::from_vec(result, shape)
760}
761
762// =============================================================================
763// Roll Operation
764// =============================================================================
765
766/// Rolls tensor elements along specified dimensions.
767pub fn roll<T: Numeric>(x: &Tensor<T>, shifts: &[i64], dims: &[usize]) -> Result<Tensor<T>> {
768    if shifts.len() != dims.len() {
769        return Err(axonml_core::error::Error::invalid_operation(
770            "shifts and dims must have the same length".to_string(),
771        ));
772    }
773
774    let shape = x.shape();
775    let data = x.to_vec();
776    let ndim = shape.len();
777
778    for &d in dims {
779        if d >= ndim {
780            return Err(axonml_core::error::Error::invalid_operation(format!(
781                "Dimension {} out of range",
782                d
783            )));
784        }
785    }
786
787    if shape.is_empty() {
788        return Ok(x.clone());
789    }
790
791    // Calculate strides
792    let mut strides = vec![1usize; ndim];
793    for i in (0..ndim - 1).rev() {
794        strides[i] = strides[i + 1] * shape[i + 1];
795    }
796
797    let mut result = vec![T::zero(); data.len()];
798
799    for src_linear in 0..data.len() {
800        // Convert to n-dimensional index
801        let mut nd_idx = vec![0usize; ndim];
802        let mut remaining = src_linear;
803        for d in 0..ndim {
804            nd_idx[d] = remaining / strides[d];
805            remaining %= strides[d];
806        }
807
808        // Apply shifts
809        for (shift, &dim) in shifts.iter().zip(dims.iter()) {
810            let dim_size = shape[dim] as i64;
811            let new_idx = ((nd_idx[dim] as i64 + shift) % dim_size + dim_size) % dim_size;
812            nd_idx[dim] = new_idx as usize;
813        }
814
815        // Convert back to linear index
816        let mut dst_linear = 0;
817        for d in 0..ndim {
818            dst_linear += nd_idx[d] * strides[d];
819        }
820
821        result[dst_linear] = data[src_linear];
822    }
823
824    Tensor::from_vec(result, shape)
825}
826
827// =============================================================================
828// Tests
829// =============================================================================
830
831#[cfg(test)]
832mod tests {
833    use super::*;
834
835    #[test]
836    fn test_softmax() {
837        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
838        let s = softmax(&t, -1).unwrap();
839
840        let sum: f32 = s.to_vec().iter().sum();
841        assert!((sum - 1.0).abs() < 1e-5);
842    }
843
844    #[test]
845    fn test_clamp() {
846        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
847        let c = clamp(&t, 0.0, 1.0);
848        assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
849    }
850
851    #[test]
852    fn test_leaky_relu() {
853        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
854        let r = leaky_relu(&t, 0.01);
855        assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
856    }
857
858    #[test]
859    fn test_comparison() {
860        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
861        let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
862
863        assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
864        assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
865        assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
866    }
867
868    #[test]
869    fn test_topk() {
870        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
871        let result = topk(&t, 3, -1, true, true).unwrap();
872
873        assert_eq!(result.values.shape(), &[3]);
874        assert_eq!(result.values.to_vec(), vec![9.0, 5.0, 4.0]);
875        assert_eq!(result.indices.to_vec(), vec![5, 4, 2]);
876    }
877
878    #[test]
879    fn test_topk_smallest() {
880        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
881        let result = topk(&t, 2, -1, false, true).unwrap();
882
883        assert_eq!(result.values.to_vec(), vec![1.0, 1.0]);
884    }
885
886    #[test]
887    fn test_sort() {
888        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
889        let result = sort(&t, -1, false).unwrap();
890
891        assert_eq!(result.values.to_vec(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
892    }
893
894    #[test]
895    fn test_sort_descending() {
896        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0], &[3]).unwrap();
897        let result = sort(&t, -1, true).unwrap();
898
899        assert_eq!(result.values.to_vec(), vec![4.0, 3.0, 1.0]);
900    }
901
902    #[test]
903    fn test_argsort() {
904        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 2.0], &[3]).unwrap();
905        let indices = argsort(&t, -1, false).unwrap();
906
907        assert_eq!(indices.to_vec(), vec![1, 2, 0]);
908    }
909
910    #[test]
911    fn test_nonzero() {
912        let t = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 2.0, 3.0, 0.0], &[6]).unwrap();
913        let result = nonzero(&t);
914
915        assert_eq!(result.shape(), &[3, 1]);
916        assert_eq!(result.to_vec(), vec![1, 3, 4]);
917    }
918
919    #[test]
920    fn test_nonzero_2d() {
921        let t = Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
922        let result = nonzero(&t);
923
924        assert_eq!(result.shape(), &[2, 2]);
925        // (0,0) and (1,1) are non-zero
926        assert_eq!(result.to_vec(), vec![0, 0, 1, 1]);
927    }
928
929    #[test]
930    fn test_unique() {
931        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], &[6]).unwrap();
932        let result = unique(&t, true, true, true);
933
934        assert_eq!(result.values.to_vec(), vec![1.0, 2.0, 3.0]);
935        assert_eq!(
936            result.inverse_indices.unwrap().to_vec(),
937            vec![0, 1, 0, 2, 1, 0]
938        );
939        assert_eq!(result.counts.unwrap().to_vec(), vec![3, 2, 1]);
940    }
941
942    #[test]
943    fn test_unique_unsorted() {
944        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 3.0], &[3]).unwrap();
945        let result = unique(&t, false, false, false);
946
947        // Preserves insertion order
948        assert_eq!(result.values.to_vec(), vec![3.0, 1.0]);
949    }
950
951    #[test]
952    fn test_flip() {
953        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
954        let flipped = flip(&t, &[0]).unwrap();
955
956        assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
957    }
958
959    #[test]
960    fn test_flip_2d() {
961        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
962        let flipped = flip(&t, &[0]).unwrap();
963
964        // Flip along dim 0: [[3,4], [1,2]]
965        assert_eq!(flipped.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
966    }
967
968    #[test]
969    fn test_roll() {
970        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
971        let rolled = roll(&t, &[1], &[0]).unwrap();
972
973        assert_eq!(rolled.to_vec(), vec![4.0, 1.0, 2.0, 3.0]);
974    }
975
976    #[test]
977    fn test_roll_negative() {
978        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
979        let rolled = roll(&t, &[-1], &[0]).unwrap();
980
981        assert_eq!(rolled.to_vec(), vec![2.0, 3.0, 4.0, 1.0]);
982    }
983
984    #[test]
985    fn test_scatter() {
986        let dst = Tensor::<f32>::zeros(&[3]);
987        let index = Tensor::from_vec(vec![0_i64, 2], &[2]).expect("tensor creation failed");
988        let src = Tensor::from_vec(vec![1.0, 2.0], &[2]).expect("tensor creation failed");
989
990        let result = scatter(&dst, 0, &index, &src).unwrap();
991        assert_eq!(result.to_vec(), vec![1.0, 0.0, 2.0]);
992    }
993
994    // =========================================================================
995    // where_cond
996    // =========================================================================
997
998    #[test]
999    fn test_where_cond_basic() {
1000        let cond = vec![true, false, true, false];
1001        let a = Tensor::from_vec(vec![10.0f32, 20.0, 30.0, 40.0], &[4]).unwrap();
1002        let b = Tensor::from_vec(vec![-1.0f32, -2.0, -3.0, -4.0], &[4]).unwrap();
1003        let result = where_cond(&cond, &a, &b).unwrap();
1004        assert_eq!(result.to_vec(), vec![10.0, -2.0, 30.0, -4.0]);
1005    }
1006
1007    // =========================================================================
1008    // scatter — additional
1009    // =========================================================================
1010
1011    #[test]
1012    fn test_scatter_overwrites() {
1013        let dst = Tensor::from_vec(vec![0.0f32; 4], &[4]).unwrap();
1014        let index = Tensor::from_vec(vec![1_i64, 1], &[2]).unwrap();
1015        let src = Tensor::from_vec(vec![5.0f32, 10.0], &[2]).unwrap();
1016
1017        let result = scatter(&dst, 0, &index, &src).unwrap();
1018        assert!(result.to_vec()[1] > 0.0, "Scatter should write to index 1");
1019    }
1020
1021    // =========================================================================
1022    // unique — edge cases
1023    // =========================================================================
1024
1025    #[test]
1026    fn test_unique_all_same() {
1027        let t = Tensor::from_vec(vec![5.0f32, 5.0, 5.0, 5.0], &[4]).unwrap();
1028        let result = unique(&t, true, false, false);
1029        assert_eq!(result.values.to_vec().len(), 1);
1030        assert!((result.values.to_vec()[0] - 5.0).abs() < 1e-5);
1031    }
1032
1033    #[test]
1034    fn test_unique_already_unique() {
1035        let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], &[4]).unwrap();
1036        let result = unique(&t, true, false, false);
1037        assert_eq!(result.values.to_vec().len(), 4);
1038    }
1039
1040    // =========================================================================
1041    // flip — multi-dimensional edge cases
1042    // =========================================================================
1043
1044    #[test]
1045    fn test_flip_both_dims() {
1046        let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1047        let flipped = flip(&t, &[0, 1]).unwrap();
1048        assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
1049    }
1050
1051    #[test]
1052    fn test_flip_col_only() {
1053        let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1054        let flipped = flip(&t, &[1]).unwrap();
1055        assert_eq!(flipped.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
1056    }
1057
1058    // =========================================================================
1059    // roll — multi-dimensional
1060    // =========================================================================
1061
1062    #[test]
1063    fn test_roll_2d() {
1064        let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1065        let rolled = roll(&t, &[1], &[1]).unwrap();
1066        assert_eq!(rolled.to_vec(), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0]);
1067    }
1068
1069    #[test]
1070    fn test_roll_full_cycle() {
1071        let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], &[3]).unwrap();
1072        let rolled = roll(&t, &[3], &[0]).unwrap();
1073        assert_eq!(rolled.to_vec(), vec![1.0, 2.0, 3.0]);
1074    }
1075
1076    // =========================================================================
1077    // nonzero — edge cases
1078    // =========================================================================
1079
1080    #[test]
1081    fn test_nonzero_all_zeros() {
1082        let t = Tensor::from_vec(vec![0.0f32, 0.0, 0.0], &[3]).unwrap();
1083        let result = nonzero(&t);
1084        assert_eq!(result.to_vec().len(), 0);
1085    }
1086
1087    #[test]
1088    fn test_nonzero_all_nonzero() {
1089        let t = Tensor::from_vec(vec![1.0f32, -2.0, 0.5], &[3]).unwrap();
1090        let result = nonzero(&t);
1091        assert_eq!(result.to_vec().len(), 3);
1092    }
1093
1094    // =========================================================================
1095    // Numerical stability of softmax
1096    // =========================================================================
1097
1098    #[test]
1099    fn test_softmax_large_values() {
1100        let t = Tensor::from_vec(vec![1000.0f32, 1001.0, 999.0], &[3]).unwrap();
1101        let result = softmax(&t, 0).unwrap();
1102        let rv = result.to_vec();
1103        assert!(
1104            rv.iter().all(|v: &f32| v.is_finite()),
1105            "Softmax should handle large values"
1106        );
1107        let sum: f32 = rv.iter().sum();
1108        assert!(
1109            (sum - 1.0).abs() < 1e-5,
1110            "Softmax should sum to 1.0, got {}",
1111            sum
1112        );
1113    }
1114
1115    #[test]
1116    fn test_softmax_negative_values() {
1117        let t = Tensor::from_vec(vec![-100.0f32, -200.0, -150.0], &[3]).unwrap();
1118        let result = softmax(&t, 0).unwrap();
1119        let rv = result.to_vec();
1120        assert!(rv.iter().all(|v: &f32| v.is_finite() && *v >= 0.0));
1121        let sum: f32 = rv.iter().sum();
1122        assert!((sum - 1.0).abs() < 1e-5);
1123    }
1124
1125    // =========================================================================
1126    // clamp edge cases
1127    // =========================================================================
1128
1129    #[test]
1130    fn test_clamp_no_op() {
1131        let t = Tensor::from_vec(vec![0.5f32, 0.3, 0.8], &[3]).unwrap();
1132        let clamped = clamp_min(&t, 0.0);
1133        assert_eq!(clamped.to_vec(), vec![0.5, 0.3, 0.8]);
1134    }
1135
1136    #[test]
1137    fn test_clamp_min_all_negative() {
1138        let t = Tensor::from_vec(vec![-5.0f32, -3.0, -1.0], &[3]).unwrap();
1139        let clamped = clamp_min(&t, 0.0);
1140        assert_eq!(clamped.to_vec(), vec![0.0, 0.0, 0.0]);
1141    }
1142}