axonml_tensor/ops/
mod.rs

1//! Tensor Operations - Mathematical and Structural Operations
2//!
3//! This module re-exports all tensor operations for convenient access.
4//! Operations are organized into submodules by category.
5//!
6//! # Categories
7//! - Arithmetic: +, -, *, /, power
8//! - Comparison: ==, <, >, <=, >=
9//! - Reduction: sum, mean, max, min
10//! - Matrix: matmul, transpose, inverse
11//! - Activation: relu, sigmoid, tanh, softmax
12//!
13//! @version 0.1.0
14//! @author `AutomataNexus` Development Team
15
16// Operations are implemented directly on Tensor in tensor.rs
17// This module provides additional standalone functions
18
19use axonml_core::dtype::{Float, Numeric, Scalar};
20use axonml_core::error::Result;
21
22use crate::tensor::Tensor;
23
24// =============================================================================
25// Comparison Operations
26// =============================================================================
27
28/// Element-wise equality comparison.
29pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
30    if a.shape() != b.shape() {
31        return Err(axonml_core::error::Error::shape_mismatch(
32            a.shape(),
33            b.shape(),
34        ));
35    }
36
37    let a_data = a.to_vec();
38    let b_data = b.to_vec();
39
40    Ok(a_data
41        .iter()
42        .zip(b_data.iter())
43        .map(|(x, y)| x == y)
44        .collect())
45}
46
47/// Element-wise less-than comparison.
48pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
49    if a.shape() != b.shape() {
50        return Err(axonml_core::error::Error::shape_mismatch(
51            a.shape(),
52            b.shape(),
53        ));
54    }
55
56    let a_data = a.to_vec();
57    let b_data = b.to_vec();
58
59    Ok(a_data
60        .iter()
61        .zip(b_data.iter())
62        .map(|(x, y)| x < y)
63        .collect())
64}
65
66/// Element-wise greater-than comparison.
67pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
68    if a.shape() != b.shape() {
69        return Err(axonml_core::error::Error::shape_mismatch(
70            a.shape(),
71            b.shape(),
72        ));
73    }
74
75    let a_data = a.to_vec();
76    let b_data = b.to_vec();
77
78    Ok(a_data
79        .iter()
80        .zip(b_data.iter())
81        .map(|(x, y)| x > y)
82        .collect())
83}
84
85// =============================================================================
86// Advanced Activation Functions
87// =============================================================================
88
89/// Applies softmax along the specified dimension.
90pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
91    // For simplicity, this handles the last dimension case
92    let data = x.to_vec();
93    let shape = x.shape();
94
95    if shape.is_empty() {
96        return Ok(Tensor::scalar(T::one()));
97    }
98
99    // Find max for numerical stability
100    let max_val = data
101        .iter()
102        .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
103
104    // Compute exp(x - max)
105    let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
106
107    // Compute sum
108    let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
109
110    // Normalize
111    let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
112
113    Tensor::from_vec(result, shape)
114}
115
116/// Applies log-softmax along the specified dimension.
117pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
118    let sm = softmax(x, dim)?;
119    Ok(sm.ln())
120}
121
122/// Applies GELU (Gaussian Error Linear Unit) activation.
123#[must_use] pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
124    let data = x.to_vec();
125    let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
126    let coeff = T::from(0.044715).unwrap();
127
128    let result: Vec<T> = data
129        .iter()
130        .map(|&v| {
131            let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
132            v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
133        })
134        .collect();
135
136    Tensor::from_vec(result, x.shape()).unwrap()
137}
138
139/// Applies Leaky `ReLU` activation.
140pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
141    let data = x.to_vec();
142    let result: Vec<T> = data
143        .iter()
144        .map(|&v| if v > T::zero() { v } else { negative_slope * v })
145        .collect();
146
147    Tensor::from_vec(result, x.shape()).unwrap()
148}
149
150/// Applies ELU (Exponential Linear Unit) activation.
151pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
152    let data = x.to_vec();
153    let result: Vec<T> = data
154        .iter()
155        .map(|&v| {
156            if v > T::zero() {
157                v
158            } else {
159                alpha * (v.exp_value() - T::one())
160            }
161        })
162        .collect();
163
164    Tensor::from_vec(result, x.shape()).unwrap()
165}
166
167/// Applies `SiLU` (Sigmoid Linear Unit) / Swish activation.
168#[must_use] pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
169    let sig = x.sigmoid();
170    x.mul(&sig).unwrap()
171}
172
173// =============================================================================
174// Clipping Operations
175// =============================================================================
176
177/// Clamps all elements to the range [min, max].
178pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
179    let data = x.to_vec();
180    let result: Vec<T> = data
181        .iter()
182        .map(|&v| {
183            if v < min {
184                min
185            } else if v > max {
186                max
187            } else {
188                v
189            }
190        })
191        .collect();
192
193    Tensor::from_vec(result, x.shape()).unwrap()
194}
195
196/// Clamps all elements to be at least min.
197pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
198    let data = x.to_vec();
199    let result: Vec<T> = data
200        .iter()
201        .map(|&v| if v < min { min } else { v })
202        .collect();
203
204    Tensor::from_vec(result, x.shape()).unwrap()
205}
206
207/// Clamps all elements to be at most max.
208pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
209    let data = x.to_vec();
210    let result: Vec<T> = data
211        .iter()
212        .map(|&v| if v > max { max } else { v })
213        .collect();
214
215    Tensor::from_vec(result, x.shape()).unwrap()
216}
217
218// =============================================================================
219// Where Operation
220// =============================================================================
221
222/// Selects elements from x or y based on condition.
223pub fn where_cond<T: Scalar>(
224    condition: &[bool],
225    x: &Tensor<T>,
226    y: &Tensor<T>,
227) -> Result<Tensor<T>> {
228    if x.shape() != y.shape() {
229        return Err(axonml_core::error::Error::shape_mismatch(
230            x.shape(),
231            y.shape(),
232        ));
233    }
234
235    if condition.len() != x.numel() {
236        return Err(axonml_core::error::Error::shape_mismatch(
237            &[condition.len()],
238            &[x.numel()],
239        ));
240    }
241
242    let x_data = x.to_vec();
243    let y_data = y.to_vec();
244
245    let result: Vec<T> = condition
246        .iter()
247        .zip(x_data.iter().zip(y_data.iter()))
248        .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
249        .collect();
250
251    Tensor::from_vec(result, x.shape())
252}
253
254// =============================================================================
255// Tests
256// =============================================================================
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_softmax() {
264        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
265        let s = softmax(&t, -1).unwrap();
266
267        let sum: f32 = s.to_vec().iter().sum();
268        assert!((sum - 1.0).abs() < 1e-5);
269    }
270
271    #[test]
272    fn test_clamp() {
273        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
274        let c = clamp(&t, 0.0, 1.0);
275        assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
276    }
277
278    #[test]
279    fn test_leaky_relu() {
280        let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
281        let r = leaky_relu(&t, 0.01);
282        assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
283    }
284
285    #[test]
286    fn test_comparison() {
287        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
288        let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
289
290        assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
291        assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
292        assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
293    }
294}