1use crate::tensor::Tensor;
4use crate::error::Result;
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8macro_rules! map_elements {
10 ($data:expr, $op:expr) => {{
11 #[cfg(feature = "rayon")]
12 { $data.par_iter().map($op).collect() }
13 #[cfg(not(feature = "rayon"))]
14 { $data.iter().map($op).collect() }
15 }};
16}
17
18impl Tensor {
19 pub fn add(&self, other: &Tensor) -> Result<Tensor> {
21 let a = self.data_f32();
22 let b = other.data_f32();
23
24 let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x + y)?;
26 Tensor::from_slice(&result, &shape)
27 }
28
29 pub fn sub(&self, other: &Tensor) -> Result<Tensor> {
31 let a = self.data_f32();
32 let b = other.data_f32();
33
34 let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x - y)?;
35 Tensor::from_slice(&result, &shape)
36 }
37
38 pub fn mul(&self, other: &Tensor) -> Result<Tensor> {
40 let a = self.data_f32();
41 let b = other.data_f32();
42
43 let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x * y)?;
44 Tensor::from_slice(&result, &shape)
45 }
46
47 pub fn div(&self, other: &Tensor) -> Result<Tensor> {
49 let a = self.data_f32();
50 let b = other.data_f32();
51
52 let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x / y)?;
53 Tensor::from_slice(&result, &shape)
54 }
55
56 pub fn add_scalar(&self, scalar: f32) -> Tensor {
58 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x + scalar);
59 Tensor::from_slice(&data, self.dims()).unwrap()
60 }
61
62 pub fn sub_scalar(&self, scalar: f32) -> Tensor {
64 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x - scalar);
65 Tensor::from_slice(&data, self.dims()).unwrap()
66 }
67
68 pub fn mul_scalar(&self, scalar: f32) -> Tensor {
70 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x * scalar);
71 Tensor::from_slice(&data, self.dims()).unwrap()
72 }
73
74 pub fn div_scalar(&self, scalar: f32) -> Tensor {
76 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x / scalar);
77 Tensor::from_slice(&data, self.dims()).unwrap()
78 }
79
80 pub fn neg(&self) -> Tensor {
82 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| -x);
83 Tensor::from_slice(&data, self.dims()).unwrap()
84 }
85
86 pub fn abs(&self) -> Tensor {
88 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.abs());
89 Tensor::from_slice(&data, self.dims()).unwrap()
90 }
91
92 pub fn pow(&self, exp: f32) -> Tensor {
94 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.powf(exp));
95 Tensor::from_slice(&data, self.dims()).unwrap()
96 }
97
98 pub fn sqrt(&self) -> Tensor {
100 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.sqrt());
101 Tensor::from_slice(&data, self.dims()).unwrap()
102 }
103
104 pub fn exp(&self) -> Tensor {
106 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.exp());
107 Tensor::from_slice(&data, self.dims()).unwrap()
108 }
109
110 pub fn log(&self) -> Tensor {
112 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.ln());
113 Tensor::from_slice(&data, self.dims()).unwrap()
114 }
115
116 pub fn clamp(&self, min: f32, max: f32) -> Tensor {
118 let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.clamp(min, max));
119 Tensor::from_slice(&data, self.dims()).unwrap()
120 }
121}
122
123fn broadcast_binary_op<F>(
125 a: &[f32],
126 a_shape: &[usize],
127 b: &[f32],
128 b_shape: &[usize],
129 op: F,
130) -> Result<(Vec<f32>, Vec<usize>)>
131where
132 F: Fn(f32, f32) -> f32 + Sync,
133{
134 use crate::shape::Shape;
135
136 let shape_a = Shape::new(a_shape);
137 let shape_b = Shape::new(b_shape);
138 let result_shape = shape_a.broadcast_with(&shape_b)?;
139 let result_dims = result_shape.dims().to_vec();
140 let numel = result_shape.numel();
141
142 if a_shape == b_shape {
144 #[cfg(feature = "rayon")]
145 let result: Vec<f32> = a.par_iter()
146 .zip(b.par_iter())
147 .map(|(&x, &y)| op(x, y))
148 .collect();
149 #[cfg(not(feature = "rayon"))]
150 let result: Vec<f32> = a.iter()
151 .zip(b.iter())
152 .map(|(&x, &y)| op(x, y))
153 .collect();
154 return Ok((result, result_dims));
155 }
156
157 let a_strides = compute_broadcast_strides(a_shape, &result_dims);
159 let b_strides = compute_broadcast_strides(b_shape, &result_dims);
160
161 #[cfg(feature = "rayon")]
162 let result: Vec<f32> = (0..numel)
163 .into_par_iter()
164 .map(|i| {
165 let a_idx = compute_broadcast_index(i, &result_dims, &a_strides);
166 let b_idx = compute_broadcast_index(i, &result_dims, &b_strides);
167 op(a[a_idx], b[b_idx])
168 })
169 .collect();
170 #[cfg(not(feature = "rayon"))]
171 let result: Vec<f32> = (0..numel)
172 .map(|i| {
173 let a_idx = compute_broadcast_index(i, &result_dims, &a_strides);
174 let b_idx = compute_broadcast_index(i, &result_dims, &b_strides);
175 op(a[a_idx], b[b_idx])
176 })
177 .collect();
178
179 Ok((result, result_dims))
180}
181
182fn compute_broadcast_strides(shape: &[usize], target_shape: &[usize]) -> Vec<usize> {
184 let ndim = target_shape.len();
185 let offset = ndim - shape.len();
186
187 let mut strides = vec![0usize; ndim];
188 let mut stride = 1usize;
189
190 for i in (0..shape.len()).rev() {
191 if shape[i] == target_shape[i + offset] {
192 strides[i + offset] = stride;
193 stride *= shape[i];
194 } else {
195 strides[i + offset] = 0;
197 }
198 }
199
200 strides
201}
202
203fn compute_broadcast_index(flat_idx: usize, shape: &[usize], strides: &[usize]) -> usize {
205 let mut idx = 0;
206 let mut remaining = flat_idx;
207
208 for i in (0..shape.len()).rev() {
209 let coord = remaining % shape[i];
210 remaining /= shape[i];
211 idx += coord * strides[i];
212 }
213
214 idx
215}
216
217impl std::ops::Add for &Tensor {
219 type Output = Tensor;
220 fn add(self, other: &Tensor) -> Tensor {
221 self.add(other).unwrap()
222 }
223}
224
225impl std::ops::Sub for &Tensor {
226 type Output = Tensor;
227 fn sub(self, other: &Tensor) -> Tensor {
228 self.sub(other).unwrap()
229 }
230}
231
232impl std::ops::Mul for &Tensor {
233 type Output = Tensor;
234 fn mul(self, other: &Tensor) -> Tensor {
235 self.mul(other).unwrap()
236 }
237}
238
239impl std::ops::Div for &Tensor {
240 type Output = Tensor;
241 fn div(self, other: &Tensor) -> Tensor {
242 self.div(other).unwrap()
243 }
244}
245
246impl std::ops::Neg for &Tensor {
247 type Output = Tensor;
248 fn neg(self) -> Tensor {
249 self.neg()
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_add() {
259 let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
260 let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0], &[3]).unwrap();
261 let c = a.add(&b).unwrap();
262 assert_eq!(c.data_f32(), vec![5.0, 7.0, 9.0]);
263 }
264
265 #[test]
266 fn test_broadcast_add() {
267 let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3, 1]).unwrap();
268 let b = Tensor::from_slice(&[10.0f32, 20.0], &[1, 2]).unwrap();
269 let c = a.add(&b).unwrap();
270 assert_eq!(c.dims(), &[3, 2]);
271 }
272
273 #[test]
274 fn test_scalar_ops() {
275 let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
276 let b = a.mul_scalar(2.0);
277 assert_eq!(b.data_f32(), vec![2.0, 4.0, 6.0]);
278 }
279}