1use axonml_core::dtype::{Float, Numeric, Scalar};
20use axonml_core::error::Result;
21
22use crate::tensor::Tensor;
23
24pub fn eq<T: Numeric + PartialEq>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
30 if a.shape() != b.shape() {
31 return Err(axonml_core::error::Error::shape_mismatch(
32 a.shape(),
33 b.shape(),
34 ));
35 }
36
37 let a_data = a.to_vec();
38 let b_data = b.to_vec();
39
40 Ok(a_data
41 .iter()
42 .zip(b_data.iter())
43 .map(|(x, y)| x == y)
44 .collect())
45}
46
47pub fn lt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
49 if a.shape() != b.shape() {
50 return Err(axonml_core::error::Error::shape_mismatch(
51 a.shape(),
52 b.shape(),
53 ));
54 }
55
56 let a_data = a.to_vec();
57 let b_data = b.to_vec();
58
59 Ok(a_data
60 .iter()
61 .zip(b_data.iter())
62 .map(|(x, y)| x < y)
63 .collect())
64}
65
66pub fn gt<T: Numeric>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Vec<bool>> {
68 if a.shape() != b.shape() {
69 return Err(axonml_core::error::Error::shape_mismatch(
70 a.shape(),
71 b.shape(),
72 ));
73 }
74
75 let a_data = a.to_vec();
76 let b_data = b.to_vec();
77
78 Ok(a_data
79 .iter()
80 .zip(b_data.iter())
81 .map(|(x, y)| x > y)
82 .collect())
83}
84
85pub fn softmax<T: Float>(x: &Tensor<T>, _dim: i64) -> Result<Tensor<T>> {
91 let data = x.to_vec();
93 let shape = x.shape();
94
95 if shape.is_empty() {
96 return Ok(Tensor::scalar(T::one()));
97 }
98
99 let max_val = data
101 .iter()
102 .fold(T::neg_infinity(), |a, &b| if b > a { b } else { a });
103
104 let exp_data: Vec<T> = data.iter().map(|&v| (v - max_val).exp_value()).collect();
106
107 let sum: T = exp_data.iter().fold(T::zero(), |a, &b| a + b);
109
110 let result: Vec<T> = exp_data.iter().map(|&v| v / sum).collect();
112
113 Tensor::from_vec(result, shape)
114}
115
116pub fn log_softmax<T: Float>(x: &Tensor<T>, dim: i64) -> Result<Tensor<T>> {
118 let sm = softmax(x, dim)?;
119 Ok(sm.ln())
120}
121
122#[must_use] pub fn gelu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
124 let data = x.to_vec();
125 let sqrt_2_over_pi = T::from(0.7978845608028654).unwrap();
126 let coeff = T::from(0.044715).unwrap();
127
128 let result: Vec<T> = data
129 .iter()
130 .map(|&v| {
131 let inner = sqrt_2_over_pi * (v + coeff * v * v * v);
132 v * T::from(0.5).unwrap() * (T::one() + inner.tanh_value())
133 })
134 .collect();
135
136 Tensor::from_vec(result, x.shape()).unwrap()
137}
138
139pub fn leaky_relu<T: Float>(x: &Tensor<T>, negative_slope: T) -> Tensor<T> {
141 let data = x.to_vec();
142 let result: Vec<T> = data
143 .iter()
144 .map(|&v| if v > T::zero() { v } else { negative_slope * v })
145 .collect();
146
147 Tensor::from_vec(result, x.shape()).unwrap()
148}
149
150pub fn elu<T: Float>(x: &Tensor<T>, alpha: T) -> Tensor<T> {
152 let data = x.to_vec();
153 let result: Vec<T> = data
154 .iter()
155 .map(|&v| {
156 if v > T::zero() {
157 v
158 } else {
159 alpha * (v.exp_value() - T::one())
160 }
161 })
162 .collect();
163
164 Tensor::from_vec(result, x.shape()).unwrap()
165}
166
167#[must_use] pub fn silu<T: Float>(x: &Tensor<T>) -> Tensor<T> {
169 let sig = x.sigmoid();
170 x.mul(&sig).unwrap()
171}
172
173pub fn clamp<T: Numeric>(x: &Tensor<T>, min: T, max: T) -> Tensor<T> {
179 let data = x.to_vec();
180 let result: Vec<T> = data
181 .iter()
182 .map(|&v| {
183 if v < min {
184 min
185 } else if v > max {
186 max
187 } else {
188 v
189 }
190 })
191 .collect();
192
193 Tensor::from_vec(result, x.shape()).unwrap()
194}
195
196pub fn clamp_min<T: Numeric>(x: &Tensor<T>, min: T) -> Tensor<T> {
198 let data = x.to_vec();
199 let result: Vec<T> = data
200 .iter()
201 .map(|&v| if v < min { min } else { v })
202 .collect();
203
204 Tensor::from_vec(result, x.shape()).unwrap()
205}
206
207pub fn clamp_max<T: Numeric>(x: &Tensor<T>, max: T) -> Tensor<T> {
209 let data = x.to_vec();
210 let result: Vec<T> = data
211 .iter()
212 .map(|&v| if v > max { max } else { v })
213 .collect();
214
215 Tensor::from_vec(result, x.shape()).unwrap()
216}
217
218pub fn where_cond<T: Scalar>(
224 condition: &[bool],
225 x: &Tensor<T>,
226 y: &Tensor<T>,
227) -> Result<Tensor<T>> {
228 if x.shape() != y.shape() {
229 return Err(axonml_core::error::Error::shape_mismatch(
230 x.shape(),
231 y.shape(),
232 ));
233 }
234
235 if condition.len() != x.numel() {
236 return Err(axonml_core::error::Error::shape_mismatch(
237 &[condition.len()],
238 &[x.numel()],
239 ));
240 }
241
242 let x_data = x.to_vec();
243 let y_data = y.to_vec();
244
245 let result: Vec<T> = condition
246 .iter()
247 .zip(x_data.iter().zip(y_data.iter()))
248 .map(|(&c, (&xv, &yv))| if c { xv } else { yv })
249 .collect();
250
251 Tensor::from_vec(result, x.shape())
252}
253
254#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_softmax() {
264 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
265 let s = softmax(&t, -1).unwrap();
266
267 let sum: f32 = s.to_vec().iter().sum();
268 assert!((sum - 1.0).abs() < 1e-5);
269 }
270
271 #[test]
272 fn test_clamp() {
273 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.5, 2.0], &[3]).unwrap();
274 let c = clamp(&t, 0.0, 1.0);
275 assert_eq!(c.to_vec(), vec![0.0, 0.5, 1.0]);
276 }
277
278 #[test]
279 fn test_leaky_relu() {
280 let t = Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap();
281 let r = leaky_relu(&t, 0.01);
282 assert_eq!(r.to_vec(), vec![-0.01, 0.0, 1.0]);
283 }
284
285 #[test]
286 fn test_comparison() {
287 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
288 let b = Tensor::<f32>::from_vec(vec![1.0, 3.0, 2.0], &[3]).unwrap();
289
290 assert_eq!(eq(&a, &b).unwrap(), vec![true, false, false]);
291 assert_eq!(lt(&a, &b).unwrap(), vec![false, true, false]);
292 assert_eq!(gt(&a, &b).unwrap(), vec![false, false, true]);
293 }
294}