Skip to main content

axonml_tensor/ops/
mod.rs

1//! Tensor Operations - Mathematical and Structural Operations
2//!
3//! # File
4//! `crates/axonml-tensor/src/ops/mod.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17// Operations are implemented directly on Tensor in tensor.rs
18// This module provides additional standalone functions
19
20use axonml_core::dtype::{Float, Numeric, Scalar};
21use axonml_core::error::Result;
22
23use crate::tensor::Tensor;
24
25// =============================================================================
26// Comparison Operations
27// =============================================================================
28
29/// Element-wise equality comparison.
30pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
31    if a.shape() != b.shape() {
32        return Err(axonml_core::error::Error::shape_mismatch(
33            a.shape(),
34            b.shape(),
35        ));
36    }
37
38    let a_data = a.to_vec();
39    let b_data = b.to_vec();
40
41    Ok(a_data
42        .iter()
43        .zip(b_data.iter())
44        .map(|(x, y)| x == y)
45        .collect())
46}
47
48/// Element-wise less-than comparison.
49pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
50    if a.shape() != b.shape() {
51        return Err(axonml_core::error::Error::shape_mismatch(
52            a.shape(),
53            b.shape(),
54        ));
55    }
56
57    let a_data = a.to_vec();
58    let b_data = b.to_vec();
59
60    Ok(a_data
61        .iter()
62        .zip(b_data.iter())
63        .map(|(x, y)| x < y)
64        .collect())
65}
66
67/// Element-wise greater-than comparison.
68pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
69    if a.shape() != b.shape() {
70        return Err(axonml_core::error::Error::shape_mismatch(
71            a.shape(),
72            b.shape(),
73        ));
74    }
75
76    let a_data = a.to_vec();
77    let b_data = b.to_vec();
78
79    Ok(a_data
80        .iter()
81        .zip(b_data.iter())
82        .map(|(x, y)| x > y)
83        .collect())
84}
85
86/// Element-wise equality → Tensor<f32> mask (1.0 where equal, 0.0 where not).
87pub fn eq_mask<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
88    let bools = eq(a, b)?;
89    let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
90    Tensor::from_vec(data, a.shape())
91}
92
93/// Element-wise less-than → Tensor<f32> mask.
94pub fn lt_mask<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
95    let bools = lt(a, b)?;
96    let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
97    Tensor::from_vec(data, a.shape())
98}
99
100/// Element-wise greater-than → Tensor<f32> mask.
101pub fn gt_mask<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<f32>> {
102    let bools = gt(a, b)?;
103    let data: Vec<f32> = bools.iter().map(|&v| if v { 1.0 } else { 0.0 }).collect();
104    Tensor::from_vec(data, a.shape())
105}
106
107// =============================================================================
108// Advanced Activation Functions
109// =============================================================================
110
111/// Applies softmax along the specified dimension.
112pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
113    // For simplicity, this handles the last dimension case
114    let data = x.to_vec();
115    let shape = x.shape();
116
117    if shape.is_empty() {
118        return Ok(Tensor::scalar(T::one()));
119    }
120
121    // Find max for numerical stability
122    let max_val = data
123        .iter()
124        .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
125
126    // Compute exp(x - max)
127    let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
128
129    // Compute sum
130    let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
131
132    // Normalize
133    let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
134
135    Tensor::from_vec(result, shape)
136}
137
138/// Applies log-softmax along the specified dimension.
139pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
140    let sm = softmax(x, dim)?;
141    Ok(sm.ln())
142}
143
144/// Applies GELU (Gaussian Error Linear Unit) activation.
145#[must_use]
146pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
147    let data = x.to_vec();
148    let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
149    let coeff = T::from(0.044715).unwrap();
150
151    let result: Vec<T> = data
152        .iter()
153        .map(|&v| {
154            let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
155            v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
156        })
157        .collect();
158
159    Tensor::from_vec(result, x.shape()).unwrap()
160}
161
162/// Applies Leaky `ReLU` activation.
163pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
164    let data = x.to_vec();
165    let result: Vec<T> = data
166        .iter()
167        .map(|&v| if v > T::zero() { v } else { negative_slope * v })
168        .collect();
169
170    Tensor::from_vec(result, x.shape()).unwrap()
171}
172
173/// Applies ELU (Exponential Linear Unit) activation.
174pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
175    let data = x.to_vec();
176    let result: Vec<T> = data
177        .iter()
178        .map(|&v| {
179            if v > T::zero() {
180                v
181            } else {
182                alpha * (v.exp_value() - T::one())
183            }
184        })
185        .collect();
186
187    Tensor::from_vec(result, x.shape()).unwrap()
188}
189
190/// Applies `SiLU` (Sigmoid Linear Unit) / Swish activation.
191#[must_use]
192pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
193    let sig = x.sigmoid();
194    x.mul(&sig).expect("tensor mul failed")
195}
196
197// =============================================================================
198// Clipping Operations
199// =============================================================================
200
201/// Clamps all elements to the range [min, max].
202pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
203    let data = x.to_vec();
204    let result: Vec<T> = data
205        .iter()
206        .map(|&v| {
207            if v < min {
208                min
209            } else if v > max {
210                max
211            } else {
212                v
213            }
214        })
215        .collect();
216
217    Tensor::from_vec(result, x.shape()).unwrap()
218}
219
220/// Clamps all elements to be at least min.
221pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
222    let data = x.to_vec();
223    let result: Vec<T> = data
224        .iter()
225        .map(|&v| if v < min { min } else { v })
226        .collect();
227
228    Tensor::from_vec(result, x.shape()).unwrap()
229}
230
231/// Clamps all elements to be at most max.
232pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
233    let data = x.to_vec();
234    let result: Vec<T> = data
235        .iter()
236        .map(|&v| if v > max { max } else { v })
237        .collect();
238
239    Tensor::from_vec(result, x.shape()).unwrap()
240}
241
242// =============================================================================
243// Where Operation
244// =============================================================================
245
246/// Selects elements from x or y based on condition.
247pub fn where_cond<T: Scalar>(
248    condition: &[bool],
249    x: &Tensor<T>,
250    y: &Tensor<T>,
251) -> Result<Tensor<T>> {
252    if x.shape() != y.shape() {
253        return Err(axonml_core::error::Error::shape_mismatch(
254            x.shape(),
255            y.shape(),
256        ));
257    }
258
259    if condition.len() != x.numel() {
260        return Err(axonml_core::error::Error::shape_mismatch(
261            &[condition.len()],
262            &[x.numel()],
263        ));
264    }
265
266    let x_data = x.to_vec();
267    let y_data = y.to_vec();
268
269    let result: Vec<T> = condition
270        .iter()
271        .zip(x_data.iter().zip(y_data.iter()))
272        .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
273        .collect();
274
275    Tensor::from_vec(result, x.shape())
276}
277
278// =============================================================================
279// Sorting and Top-K Operations
280// =============================================================================
281
282/// Result of topk operation containing values and indices.
283#[derive(Clone)]
284pub struct TopKResult<T: Scalar> {
285    /// The top-k values.
286    pub values: Tensor<T>,
287    /// The indices of the top-k values in the original tensor.
288    pub indices: Tensor<i64>,
289}
290
291impl<T: Scalar> std::fmt::Debug for TopKResult<T> {
292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        f.debug_struct("TopKResult")
294            .field("values_shape", &self.values.shape())
295            .field("indices_shape", &self.indices.shape())
296            .finish()
297    }
298}
299
300/// Returns the k largest elements along a dimension.
301///
302/// # Arguments
303/// * `x` - Input tensor
304/// * `k` - Number of top elements to return
305/// * `dim` - Dimension to sort along (default: -1)
306/// * `largest` - If true, return largest elements; if false, return smallest
307/// * `sorted` - If true, return elements in sorted order
308///
309/// # Returns
310/// TopKResult containing values and indices tensors
311pub fn topk<T: Numeric>(
312    x: &Tensor<T>,
313    k: usize,
314    dim: i64,
315    largest: bool,
316    sorted: bool,
317) -> Result<TopKResult<T>> {
318    let shape = x.shape();
319    if shape.is_empty() {
320        return Err(axonml_core::error::Error::invalid_operation(
321            "Cannot apply topk to scalar tensor".to_string(),
322        ));
323    }
324
325    let dim = if dim < 0 {
326        (shape.len() as i64 + dim) as usize
327    } else {
328        dim as usize
329    };
330
331    if dim >= shape.len() {
332        return Err(axonml_core::error::Error::invalid_operation(format!(
333            "Dimension {} out of range for tensor with {} dimensions",
334            dim,
335            shape.len()
336        )));
337    }
338
339    let dim_size = shape[dim];
340    if k > dim_size {
341        return Err(axonml_core::error::Error::invalid_operation(format!(
342            "k ({}) is larger than dimension size ({})",
343            k, dim_size
344        )));
345    }
346
347    let data = x.to_vec();
348
349    // For simplicity, handle the 1D case specially
350    if shape.len() == 1 {
351        let mut indexed: Vec<(usize, T)> = data.into_iter().enumerate().collect();
352        if largest {
353            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
354        } else {
355            indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
356        }
357
358        if !sorted {
359            indexed[..k].sort_by_key(|x| x.0);
360        }
361
362        let values: Vec<T> = indexed[..k].iter().map(|(_, v)| *v).collect();
363        let indices: Vec<i64> = indexed[..k].iter().map(|(i, _)| *i as i64).collect();
364
365        return Ok(TopKResult {
366            values: Tensor::from_vec(values, &[k])?,
367            indices: Tensor::from_vec(indices, &[k])?,
368        });
369    }
370
371    // General n-dimensional case
372    let outer_size: usize = shape[..dim].iter().product();
373    let inner_size: usize = shape[dim + 1..].iter().product();
374
375    let mut values_data = Vec::with_capacity(outer_size * k * inner_size);
376    let mut indices_data = Vec::with_capacity(outer_size * k * inner_size);
377
378    for outer in 0..outer_size {
379        for inner in 0..inner_size {
380            let mut slice: Vec<(usize, T)> = (0..dim_size)
381                .map(|d| {
382                    let idx = outer * dim_size * inner_size + d * inner_size + inner;
383                    (d, data[idx])
384                })
385                .collect();
386
387            if largest {
388                slice.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
389            } else {
390                slice.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
391            }
392
393            if !sorted {
394                slice[..k].sort_by_key(|x| x.0);
395            }
396
397            for (orig_idx, val) in slice.into_iter().take(k) {
398                values_data.push(val);
399                indices_data.push(orig_idx as i64);
400            }
401        }
402    }
403
404    let mut output_shape = shape.to_vec();
405    output_shape[dim] = k;
406
407    Ok(TopKResult {
408        values: Tensor::from_vec(values_data, &output_shape)?,
409        indices: Tensor::from_vec(indices_data, &output_shape)?,
410    })
411}
412
413/// Result of sort operation containing sorted values and indices.
414#[derive(Clone)]
415pub struct SortResult<T: Scalar> {
416    /// Sorted values.
417    pub values: Tensor<T>,
418    /// Indices that would sort the tensor.
419    pub indices: Tensor<i64>,
420}
421
422impl<T: Scalar> std::fmt::Debug for SortResult<T> {
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        f.debug_struct("SortResult")
425            .field("values_shape", &self.values.shape())
426            .field("indices_shape", &self.indices.shape())
427            .finish()
428    }
429}
430
431/// Sorts the elements of the tensor along a dimension.
432///
433/// # Arguments
434/// * `x` - Input tensor
435/// * `dim` - Dimension to sort along (default: -1)
436/// * `descending` - If true, sort in descending order
437///
438/// # Returns
439/// SortResult containing sorted values and indices
440pub fn sort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<SortResult<T>> {
441    let shape = x.shape();
442    if shape.is_empty() {
443        return Ok(SortResult {
444            values: x.clone(),
445            indices: Tensor::scalar(0i64),
446        });
447    }
448
449    let dim = if dim < 0 {
450        (shape.len() as i64 + dim) as usize
451    } else {
452        dim as usize
453    };
454
455    let dim_size = shape[dim];
456    topk(x, dim_size, dim as i64, descending, true).map(|tk| SortResult {
457        values: tk.values,
458        indices: tk.indices,
459    })
460}
461
462/// Returns the indices that would sort the tensor along a dimension.
463///
464/// # Arguments
465/// * `x` - Input tensor
466/// * `dim` - Dimension to sort along (default: -1)
467/// * `descending` - If true, sort in descending order
468pub fn argsort<T: Numeric>(x: &Tensor<T>, dim: i64, descending: bool) -> Result<Tensor<i64>> {
469    sort(x, dim, descending).map(|r| r.indices)
470}
471
472// =============================================================================
473// Scatter Operation
474// =============================================================================
475
476/// Writes values from src into self at locations specified by index.
477///
478/// This is the inverse of gather.
479///
480/// # Arguments
481/// * `dst` - Destination tensor (modified in place conceptually, returns new tensor)
482/// * `dim` - Dimension along which to scatter
483/// * `index` - Indices to scatter to
484/// * `src` - Source values to scatter
485pub fn scatter<T: Scalar>(
486    dst: &Tensor<T>,
487    dim: usize,
488    index: &Tensor<i64>,
489    src: &Tensor<T>,
490) -> Result<Tensor<T>> {
491    let dst_shape = dst.shape();
492    let idx_shape = index.shape();
493    let src_shape = src.shape();
494
495    if idx_shape != src_shape {
496        return Err(axonml_core::error::Error::shape_mismatch(
497            idx_shape, src_shape,
498        ));
499    }
500
501    if dim >= dst_shape.len() {
502        return Err(axonml_core::error::Error::invalid_operation(format!(
503            "Dimension {} out of range",
504            dim
505        )));
506    }
507
508    let mut result = dst.to_vec();
509    let idx_data = index.to_vec();
510    let src_data = src.to_vec();
511
512    // Calculate strides for the destination
513    let mut dst_strides = vec![1usize; dst_shape.len()];
514    for i in (0..dst_shape.len() - 1).rev() {
515        dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
516    }
517
518    // Calculate strides for index/src
519    let mut idx_strides = vec![1usize; idx_shape.len()];
520    for i in (0..idx_shape.len() - 1).rev() {
521        idx_strides[i] = idx_strides[i + 1] * idx_shape[i + 1];
522    }
523
524    // Scatter values
525    let total = index.numel();
526    for linear_idx in 0..total {
527        // Convert linear index to n-dimensional index
528        let mut nd_idx = vec![0usize; idx_shape.len()];
529        let mut remaining = linear_idx;
530        for d in 0..idx_shape.len() {
531            nd_idx[d] = remaining / idx_strides[d];
532            remaining %= idx_strides[d];
533        }
534
535        // Get the scatter index
536        let scatter_idx = idx_data[linear_idx] as usize;
537
538        // Build destination index
539        let mut dst_nd_idx = nd_idx.clone();
540        dst_nd_idx[dim] = scatter_idx;
541
542        // Convert to linear destination index
543        let mut dst_linear = 0;
544        for d in 0..dst_shape.len() {
545            dst_linear += dst_nd_idx[d] * dst_strides[d];
546        }
547
548        result[dst_linear] = src_data[linear_idx];
549    }
550
551    Tensor::from_vec(result, dst_shape)
552}
553
554// =============================================================================
555// Nonzero Operation
556// =============================================================================
557
558/// Returns the indices of non-zero elements.
559///
560/// # Arguments
561/// * `x` - Input tensor
562///
563/// # Returns
564/// Tensor of shape (num_nonzero, ndim) containing indices of non-zero elements
565pub fn nonzero<T: Numeric>(x: &Tensor<T>) -> Tensor<i64> {
566    let data = x.to_vec();
567    let shape = x.shape();
568    let ndim = shape.len();
569
570    // Find all non-zero indices
571    let mut indices: Vec<Vec<i64>> = Vec::new();
572
573    // Calculate strides for index conversion
574    let mut strides = vec![1usize; ndim.max(1)];
575    for i in (0..ndim.saturating_sub(1)).rev() {
576        strides[i] = strides[i + 1] * shape[i + 1];
577    }
578
579    for (linear_idx, &val) in data.iter().enumerate() {
580        if val != T::zero() {
581            let mut nd_idx = vec![0i64; ndim.max(1)];
582            let mut remaining = linear_idx;
583            for d in 0..ndim {
584                nd_idx[d] = (remaining / strides[d]) as i64;
585                remaining %= strides[d];
586            }
587            indices.push(nd_idx);
588        }
589    }
590
591    let num_nonzero = indices.len();
592    if num_nonzero == 0 {
593        return Tensor::from_vec(vec![], &[0, ndim.max(1)]).unwrap();
594    }
595
596    let flat: Vec<i64> = indices.into_iter().flatten().collect();
597    Tensor::from_vec(flat, &[num_nonzero, ndim.max(1)]).unwrap()
598}
599
600// =============================================================================
601// Unique Operation
602// =============================================================================
603
604/// Result of unique operation.
605#[derive(Clone)]
606pub struct UniqueResult<T: Scalar> {
607    /// Unique values.
608    pub values: Tensor<T>,
609    /// Indices of unique values in the original tensor (if return_inverse).
610    pub inverse_indices: Option<Tensor<i64>>,
611    /// Counts of each unique value (if return_counts).
612    pub counts: Option<Tensor<i64>>,
613}
614
615impl<T: Scalar> std::fmt::Debug for UniqueResult<T> {
616    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617        f.debug_struct("UniqueResult")
618            .field("values_shape", &self.values.shape())
619            .field("has_inverse", &self.inverse_indices.is_some())
620            .field("has_counts", &self.counts.is_some())
621            .finish()
622    }
623}
624
625/// Returns the unique elements of the input tensor.
626///
627/// # Arguments
628/// * `x` - Input tensor
629/// * `sorted` - Whether to sort the unique elements
630/// * `return_inverse` - Whether to return inverse indices
631/// * `return_counts` - Whether to return counts
632pub fn unique<T: Numeric>(
633    x: &Tensor<T>,
634    sorted: bool,
635    return_inverse: bool,
636    return_counts: bool,
637) -> UniqueResult<T> {
638    let data = x.to_vec();
639
640    // Use a vec to preserve insertion order (for unsorted case)
641    let mut seen: Vec<T> = Vec::new();
642    let mut counts_map: Vec<i64> = Vec::new();
643    let mut inverse: Vec<i64> = Vec::with_capacity(data.len());
644
645    for &val in &data {
646        if let Some(pos) = seen.iter().position(|&v| v == val) {
647            inverse.push(pos as i64);
648            counts_map[pos] += 1;
649        } else {
650            inverse.push(seen.len() as i64);
651            seen.push(val);
652            counts_map.push(1);
653        }
654    }
655
656    let (unique_vals, final_inverse, final_counts) = if sorted {
657        // Sort unique values and update inverse indices
658        let mut indexed: Vec<(usize, T)> = seen.into_iter().enumerate().collect();
659        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
660
661        // Create mapping from old index to new index
662        let mut old_to_new = vec![0i64; indexed.len()];
663        for (new_idx, (old_idx, _)) in indexed.iter().enumerate() {
664            old_to_new[*old_idx] = new_idx as i64;
665        }
666
667        let sorted_vals: Vec<T> = indexed.iter().map(|(_, v)| *v).collect();
668        let sorted_counts: Vec<i64> = indexed
669            .iter()
670            .map(|(old_idx, _)| counts_map[*old_idx])
671            .collect();
672        let updated_inverse: Vec<i64> = inverse.iter().map(|&i| old_to_new[i as usize]).collect();
673
674        (sorted_vals, updated_inverse, sorted_counts)
675    } else {
676        (seen, inverse, counts_map)
677    };
678
679    let n = unique_vals.len();
680
681    UniqueResult {
682        values: Tensor::from_vec(unique_vals, &[n]).expect("tensor creation failed"),
683        inverse_indices: if return_inverse {
684            Some(Tensor::from_vec(final_inverse, x.shape()).unwrap())
685        } else {
686            None
687        },
688        counts: if return_counts {
689            Some(Tensor::from_vec(final_counts, &[n]).expect("tensor creation failed"))
690        } else {
691            None
692        },
693    }
694}
695
696// =============================================================================
697// Flip Operation
698// =============================================================================
699
700/// Reverses the order of elements along specified dimensions.
701pub fn flip<T: Numeric>(x: &Tensor<T>, dims: &[usize]) -> Result<Tensor<T>> {
702    let shape = x.shape();
703    let data = x.to_vec();
704    let ndim = shape.len();
705
706    for &d in dims {
707        if d >= ndim {
708            return Err(axonml_core::error::Error::invalid_operation(format!(
709                "Dimension {} out of range for tensor with {} dimensions",
710                d, ndim
711            )));
712        }
713    }
714
715    if shape.is_empty() {
716        return Ok(x.clone());
717    }
718
719    // Calculate strides
720    let mut strides = vec![1usize; ndim];
721    for i in (0..ndim - 1).rev() {
722        strides[i] = strides[i + 1] * shape[i + 1];
723    }
724
725    let mut result = vec![T::zero(); data.len()];
726
727    for src_linear in 0..data.len() {
728        // Convert to n-dimensional index
729        let mut nd_idx = vec![0usize; ndim];
730        let mut remaining = src_linear;
731        for d in 0..ndim {
732            nd_idx[d] = remaining / strides[d];
733            remaining %= strides[d];
734        }
735
736        // Flip specified dimensions
737        for &flip_dim in dims {
738            nd_idx[flip_dim] = shape[flip_dim] - 1 - nd_idx[flip_dim];
739        }
740
741        // Convert back to linear index
742        let mut dst_linear = 0;
743        for d in 0..ndim {
744            dst_linear += nd_idx[d] * strides[d];
745        }
746
747        result[dst_linear] = data[src_linear];
748    }
749
750    Tensor::from_vec(result, shape)
751}
752
753// =============================================================================
754// Roll Operation
755// =============================================================================
756
757/// Rolls tensor elements along specified dimensions.
758pub fn roll<T: Numeric>(x: &Tensor<T>, shifts: &[i64], dims: &[usize]) -> Result<Tensor<T>> {
759    if shifts.len() != dims.len() {
760        return Err(axonml_core::error::Error::invalid_operation(
761            "shifts and dims must have the same length".to_string(),
762        ));
763    }
764
765    let shape = x.shape();
766    let data = x.to_vec();
767    let ndim = shape.len();
768
769    for &d in dims {
770        if d >= ndim {
771            return Err(axonml_core::error::Error::invalid_operation(format!(
772                "Dimension {} out of range",
773                d
774            )));
775        }
776    }
777
778    if shape.is_empty() {
779        return Ok(x.clone());
780    }
781
782    // Calculate strides
783    let mut strides = vec![1usize; ndim];
784    for i in (0..ndim - 1).rev() {
785        strides[i] = strides[i + 1] * shape[i + 1];
786    }
787
788    let mut result = vec![T::zero(); data.len()];
789
790    for src_linear in 0..data.len() {
791        // Convert to n-dimensional index
792        let mut nd_idx = vec![0usize; ndim];
793        let mut remaining = src_linear;
794        for d in 0..ndim {
795            nd_idx[d] = remaining / strides[d];
796            remaining %= strides[d];
797        }
798
799        // Apply shifts
800        for (shift, &dim) in shifts.iter().zip(dims.iter()) {
801            let dim_size = shape[dim] as i64;
802            let new_idx = ((nd_idx[dim] as i64 + shift) % dim_size + dim_size) % dim_size;
803            nd_idx[dim] = new_idx as usize;
804        }
805
806        // Convert back to linear index
807        let mut dst_linear = 0;
808        for d in 0..ndim {
809            dst_linear += nd_idx[d] * strides[d];
810        }
811
812        result[dst_linear] = data[src_linear];
813    }
814
815    Tensor::from_vec(result, shape)
816}
817
818// =============================================================================
819// Tests
820// =============================================================================
821
822#[cfg(test)]
823mod tests {
824    use super::*;
825
826    #[test]
827    fn test_softmax() {
828        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
829        let s = softmax(&t, -1).unwrap();
830
831        let sum: f32 = s.to_vec().iter().sum();
832        assert!((sum - 1.0).abs() < 1e-5);
833    }
834
835    #[test]
836    fn test_clamp() {
837        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
838        let c = clamp(&t, 0.0, 1.0);
839        assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
840    }
841
842    #[test]
843    fn test_leaky_relu() {
844        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
845        let r = leaky_relu(&t, 0.01);
846        assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
847    }
848
849    #[test]
850    fn test_comparison() {
851        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
852        let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
853
854        assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
855        assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
856        assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
857    }
858
859    #[test]
860    fn test_topk() {
861        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
862        let result = topk(&t, 3, -1, true, true).unwrap();
863
864        assert_eq!(result.values.shape(), &[3]);
865        assert_eq!(result.values.to_vec(), vec![9.0, 5.0, 4.0]);
866        assert_eq!(result.indices.to_vec(), vec![5, 4, 2]);
867    }
868
869    #[test]
870    fn test_topk_smallest() {
871        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[6]).unwrap();
872        let result = topk(&t, 2, -1, false, true).unwrap();
873
874        assert_eq!(result.values.to_vec(), vec![1.0, 1.0]);
875    }
876
877    #[test]
878    fn test_sort() {
879        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0, 1.0, 5.0], &[5]).unwrap();
880        let result = sort(&t, -1, false).unwrap();
881
882        assert_eq!(result.values.to_vec(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
883    }
884
885    #[test]
886    fn test_sort_descending() {
887        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 4.0], &[3]).unwrap();
888        let result = sort(&t, -1, true).unwrap();
889
890        assert_eq!(result.values.to_vec(), vec![4.0, 3.0, 1.0]);
891    }
892
893    #[test]
894    fn test_argsort() {
895        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 2.0], &[3]).unwrap();
896        let indices = argsort(&t, -1, false).unwrap();
897
898        assert_eq!(indices.to_vec(), vec![1, 2, 0]);
899    }
900
901    #[test]
902    fn test_nonzero() {
903        let t = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 2.0, 3.0, 0.0], &[6]).unwrap();
904        let result = nonzero(&t);
905
906        assert_eq!(result.shape(), &[3, 1]);
907        assert_eq!(result.to_vec(), vec![1, 3, 4]);
908    }
909
910    #[test]
911    fn test_nonzero_2d() {
912        let t = Tensor::<f32>::from_vec(vec![1.0, 0.0, 0.0, 2.0], &[2, 2]).unwrap();
913        let result = nonzero(&t);
914
915        assert_eq!(result.shape(), &[2, 2]);
916        // (0,0) and (1,1) are non-zero
917        assert_eq!(result.to_vec(), vec![0, 0, 1, 1]);
918    }
919
920    #[test]
921    fn test_unique() {
922        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0], &[6]).unwrap();
923        let result = unique(&t, true, true, true);
924
925        assert_eq!(result.values.to_vec(), vec![1.0, 2.0, 3.0]);
926        assert_eq!(
927            result.inverse_indices.unwrap().to_vec(),
928            vec![0, 1, 0, 2, 1, 0]
929        );
930        assert_eq!(result.counts.unwrap().to_vec(), vec![3, 2, 1]);
931    }
932
933    #[test]
934    fn test_unique_unsorted() {
935        let t = Tensor::<f32>::from_vec(vec![3.0, 1.0, 3.0], &[3]).unwrap();
936        let result = unique(&t, false, false, false);
937
938        // Preserves insertion order
939        assert_eq!(result.values.to_vec(), vec![3.0, 1.0]);
940    }
941
942    #[test]
943    fn test_flip() {
944        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
945        let flipped = flip(&t, &[0]).unwrap();
946
947        assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
948    }
949
950    #[test]
951    fn test_flip_2d() {
952        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
953        let flipped = flip(&t, &[0]).unwrap();
954
955        // Flip along dim 0: [[3,4], [1,2]]
956        assert_eq!(flipped.to_vec(), vec![3.0, 4.0, 1.0, 2.0]);
957    }
958
959    #[test]
960    fn test_roll() {
961        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
962        let rolled = roll(&t, &[1], &[0]).unwrap();
963
964        assert_eq!(rolled.to_vec(), vec![4.0, 1.0, 2.0, 3.0]);
965    }
966
967    #[test]
968    fn test_roll_negative() {
969        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).unwrap();
970        let rolled = roll(&t, &[-1], &[0]).unwrap();
971
972        assert_eq!(rolled.to_vec(), vec![2.0, 3.0, 4.0, 1.0]);
973    }
974
975    #[test]
976    fn test_scatter() {
977        let dst = Tensor::<f32>::zeros(&[3]);
978        let index = Tensor::from_vec(vec![0_i64, 2], &[2]).expect("tensor creation failed");
979        let src = Tensor::from_vec(vec![1.0, 2.0], &[2]).expect("tensor creation failed");
980
981        let result = scatter(&dst, 0, &index, &src).unwrap();
982        assert_eq!(result.to_vec(), vec![1.0, 0.0, 2.0]);
983    }
984
985    // =========================================================================
986    // where_cond
987    // =========================================================================
988
989    #[test]
990    fn test_where_cond_basic() {
991        let cond = vec![true, false, true, false];
992        let a = Tensor::from_vec(vec![10.0f32, 20.0, 30.0, 40.0], &[4]).unwrap();
993        let b = Tensor::from_vec(vec![-1.0f32, -2.0, -3.0, -4.0], &[4]).unwrap();
994        let result = where_cond(&cond, &a, &b).unwrap();
995        assert_eq!(result.to_vec(), vec![10.0, -2.0, 30.0, -4.0]);
996    }
997
998    // =========================================================================
999    // scatter — additional
1000    // =========================================================================
1001
1002    #[test]
1003    fn test_scatter_overwrites() {
1004        let dst = Tensor::from_vec(vec![0.0f32; 4], &[4]).unwrap();
1005        let index = Tensor::from_vec(vec![1_i64, 1], &[2]).unwrap();
1006        let src = Tensor::from_vec(vec![5.0f32, 10.0], &[2]).unwrap();
1007
1008        let result = scatter(&dst, 0, &index, &src).unwrap();
1009        assert!(result.to_vec()[1] > 0.0, "Scatter should write to index 1");
1010    }
1011
1012    // =========================================================================
1013    // unique — edge cases
1014    // =========================================================================
1015
1016    #[test]
1017    fn test_unique_all_same() {
1018        let t = Tensor::from_vec(vec![5.0f32, 5.0, 5.0, 5.0], &[4]).unwrap();
1019        let result = unique(&t, true, false, false);
1020        assert_eq!(result.values.to_vec().len(), 1);
1021        assert!((result.values.to_vec()[0] - 5.0).abs() < 1e-5);
1022    }
1023
1024    #[test]
1025    fn test_unique_already_unique() {
1026        let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], &[4]).unwrap();
1027        let result = unique(&t, true, false, false);
1028        assert_eq!(result.values.to_vec().len(), 4);
1029    }
1030
1031    // =========================================================================
1032    // flip — multi-dimensional edge cases
1033    // =========================================================================
1034
1035    #[test]
1036    fn test_flip_both_dims() {
1037        let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1038        let flipped = flip(&t, &[0, 1]).unwrap();
1039        assert_eq!(flipped.to_vec(), vec![4.0, 3.0, 2.0, 1.0]);
1040    }
1041
1042    #[test]
1043    fn test_flip_col_only() {
1044        let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1045        let flipped = flip(&t, &[1]).unwrap();
1046        assert_eq!(flipped.to_vec(), vec![2.0, 1.0, 4.0, 3.0]);
1047    }
1048
1049    // =========================================================================
1050    // roll — multi-dimensional
1051    // =========================================================================
1052
1053    #[test]
1054    fn test_roll_2d() {
1055        let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
1056        let rolled = roll(&t, &[1], &[1]).unwrap();
1057        assert_eq!(rolled.to_vec(), vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0]);
1058    }
1059
1060    #[test]
1061    fn test_roll_full_cycle() {
1062        let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], &[3]).unwrap();
1063        let rolled = roll(&t, &[3], &[0]).unwrap();
1064        assert_eq!(rolled.to_vec(), vec![1.0, 2.0, 3.0]);
1065    }
1066
1067    // =========================================================================
1068    // nonzero — edge cases
1069    // =========================================================================
1070
1071    #[test]
1072    fn test_nonzero_all_zeros() {
1073        let t = Tensor::from_vec(vec![0.0f32, 0.0, 0.0], &[3]).unwrap();
1074        let result = nonzero(&t);
1075        assert_eq!(result.to_vec().len(), 0);
1076    }
1077
1078    #[test]
1079    fn test_nonzero_all_nonzero() {
1080        let t = Tensor::from_vec(vec![1.0f32, -2.0, 0.5], &[3]).unwrap();
1081        let result = nonzero(&t);
1082        assert_eq!(result.to_vec().len(), 3);
1083    }
1084
1085    // =========================================================================
1086    // Numerical stability of softmax
1087    // =========================================================================
1088
1089    #[test]
1090    fn test_softmax_large_values() {
1091        let t = Tensor::from_vec(vec![1000.0f32, 1001.0, 999.0], &[3]).unwrap();
1092        let result = softmax(&t, 0).unwrap();
1093        let rv = result.to_vec();
1094        assert!(
1095            rv.iter().all(|v: &f32| v.is_finite()),
1096            "Softmax should handle large values"
1097        );
1098        let sum: f32 = rv.iter().sum();
1099        assert!(
1100            (sum - 1.0).abs() < 1e-5,
1101            "Softmax should sum to 1.0, got {}",
1102            sum
1103        );
1104    }
1105
1106    #[test]
1107    fn test_softmax_negative_values() {
1108        let t = Tensor::from_vec(vec![-100.0f32, -200.0, -150.0], &[3]).unwrap();
1109        let result = softmax(&t, 0).unwrap();
1110        let rv = result.to_vec();
1111        assert!(rv.iter().all(|v: &f32| v.is_finite() && *v >= 0.0));
1112        let sum: f32 = rv.iter().sum();
1113        assert!((sum - 1.0).abs() < 1e-5);
1114    }
1115
1116    // =========================================================================
1117    // clamp edge cases
1118    // =========================================================================
1119
1120    #[test]
1121    fn test_clamp_no_op() {
1122        let t = Tensor::from_vec(vec![0.5f32, 0.3, 0.8], &[3]).unwrap();
1123        let clamped = clamp_min(&t, 0.0);
1124        assert_eq!(clamped.to_vec(), vec![0.5, 0.3, 0.8]);
1125    }
1126
1127    #[test]
1128    fn test_clamp_min_all_negative() {
1129        let t = Tensor::from_vec(vec![-5.0f32, -3.0, -1.0], &[3]).unwrap();
1130        let clamped = clamp_min(&t, 0.0);
1131        assert_eq!(clamped.to_vec(), vec![0.0, 0.0, 0.0]);
1132    }
1133}