1use crate::tensor::Tensor;
4use crate::error::Result;
5use rayon::prelude::*;
6
7impl Tensor {
8 pub fn add(&self, other: &Tensor) -> Result<Tensor> {
10 let a = self.data_f32();
11 let b = other.data_f32();
12
13 let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x + y)?;
15 Tensor::from_slice(&result, &shape)
16 }
17
18 pub fn sub(&self, other: &Tensor) -> Result<Tensor> {
20 let a = self.data_f32();
21 let b = other.data_f32();
22
23 let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x - y)?;
24 Tensor::from_slice(&result, &shape)
25 }
26
27 pub fn mul(&self, other: &Tensor) -> Result<Tensor> {
29 let a = self.data_f32();
30 let b = other.data_f32();
31
32 let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x * y)?;
33 Tensor::from_slice(&result, &shape)
34 }
35
36 pub fn div(&self, other: &Tensor) -> Result<Tensor> {
38 let a = self.data_f32();
39 let b = other.data_f32();
40
41 let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x / y)?;
42 Tensor::from_slice(&result, &shape)
43 }
44
45 pub fn add_scalar(&self, scalar: f32) -> Tensor {
47 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x + scalar).collect();
48 Tensor::from_slice(&data, self.dims()).unwrap()
49 }
50
51 pub fn sub_scalar(&self, scalar: f32) -> Tensor {
53 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x - scalar).collect();
54 Tensor::from_slice(&data, self.dims()).unwrap()
55 }
56
57 pub fn mul_scalar(&self, scalar: f32) -> Tensor {
59 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x * scalar).collect();
60 Tensor::from_slice(&data, self.dims()).unwrap()
61 }
62
63 pub fn div_scalar(&self, scalar: f32) -> Tensor {
65 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x / scalar).collect();
66 Tensor::from_slice(&data, self.dims()).unwrap()
67 }
68
69 pub fn neg(&self) -> Tensor {
71 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| -x).collect();
72 Tensor::from_slice(&data, self.dims()).unwrap()
73 }
74
75 pub fn abs(&self) -> Tensor {
77 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.abs()).collect();
78 Tensor::from_slice(&data, self.dims()).unwrap()
79 }
80
81 pub fn pow(&self, exp: f32) -> Tensor {
83 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.powf(exp)).collect();
84 Tensor::from_slice(&data, self.dims()).unwrap()
85 }
86
87 pub fn sqrt(&self) -> Tensor {
89 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.sqrt()).collect();
90 Tensor::from_slice(&data, self.dims()).unwrap()
91 }
92
93 pub fn exp(&self) -> Tensor {
95 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.exp()).collect();
96 Tensor::from_slice(&data, self.dims()).unwrap()
97 }
98
99 pub fn log(&self) -> Tensor {
101 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.ln()).collect();
102 Tensor::from_slice(&data, self.dims()).unwrap()
103 }
104
105 pub fn clamp(&self, min: f32, max: f32) -> Tensor {
107 let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.clamp(min, max)).collect();
108 Tensor::from_slice(&data, self.dims()).unwrap()
109 }
110}
111
112fn broadcast_binary_op<F>(
114 a: &[f32],
115 a_shape: &[usize],
116 b: &[f32],
117 b_shape: &[usize],
118 op: F,
119) -> Result<(Vec<f32>, Vec<usize>)>
120where
121 F: Fn(f32, f32) -> f32 + Sync,
122{
123 use crate::shape::Shape;
124
125 let shape_a = Shape::new(a_shape);
126 let shape_b = Shape::new(b_shape);
127 let result_shape = shape_a.broadcast_with(&shape_b)?;
128 let result_dims = result_shape.dims().to_vec();
129 let numel = result_shape.numel();
130
131 if a_shape == b_shape {
133 let result: Vec<f32> = a.par_iter()
134 .zip(b.par_iter())
135 .map(|(&x, &y)| op(x, y))
136 .collect();
137 return Ok((result, result_dims));
138 }
139
140 let a_strides = compute_broadcast_strides(a_shape, &result_dims);
142 let b_strides = compute_broadcast_strides(b_shape, &result_dims);
143
144 let result: Vec<f32> = (0..numel)
145 .into_par_iter()
146 .map(|i| {
147 let a_idx = compute_broadcast_index(i, &result_dims, &a_strides);
148 let b_idx = compute_broadcast_index(i, &result_dims, &b_strides);
149 op(a[a_idx], b[b_idx])
150 })
151 .collect();
152
153 Ok((result, result_dims))
154}
155
156fn compute_broadcast_strides(shape: &[usize], target_shape: &[usize]) -> Vec<usize> {
158 let ndim = target_shape.len();
159 let offset = ndim - shape.len();
160
161 let mut strides = vec![0usize; ndim];
162 let mut stride = 1usize;
163
164 for i in (0..shape.len()).rev() {
165 if shape[i] == target_shape[i + offset] {
166 strides[i + offset] = stride;
167 stride *= shape[i];
168 } else {
169 strides[i + offset] = 0;
171 }
172 }
173
174 strides
175}
176
177fn compute_broadcast_index(flat_idx: usize, shape: &[usize], strides: &[usize]) -> usize {
179 let mut idx = 0;
180 let mut remaining = flat_idx;
181
182 for i in (0..shape.len()).rev() {
183 let coord = remaining % shape[i];
184 remaining /= shape[i];
185 idx += coord * strides[i];
186 }
187
188 idx
189}
190
191impl std::ops::Add for &Tensor {
193 type Output = Tensor;
194 fn add(self, other: &Tensor) -> Tensor {
195 self.add(other).unwrap()
196 }
197}
198
199impl std::ops::Sub for &Tensor {
200 type Output = Tensor;
201 fn sub(self, other: &Tensor) -> Tensor {
202 self.sub(other).unwrap()
203 }
204}
205
206impl std::ops::Mul for &Tensor {
207 type Output = Tensor;
208 fn mul(self, other: &Tensor) -> Tensor {
209 self.mul(other).unwrap()
210 }
211}
212
213impl std::ops::Div for &Tensor {
214 type Output = Tensor;
215 fn div(self, other: &Tensor) -> Tensor {
216 self.div(other).unwrap()
217 }
218}
219
220impl std::ops::Neg for &Tensor {
221 type Output = Tensor;
222 fn neg(self) -> Tensor {
223 self.neg()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_add() {
233 let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
234 let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0], &[3]).unwrap();
235 let c = a.add(&b).unwrap();
236 assert_eq!(c.data_f32(), vec![5.0, 7.0, 9.0]);
237 }
238
239 #[test]
240 fn test_broadcast_add() {
241 let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3, 1]).unwrap();
242 let b = Tensor::from_slice(&[10.0f32, 20.0], &[1, 2]).unwrap();
243 let c = a.add(&b).unwrap();
244 assert_eq!(c.dims(), &[3, 2]);
245 }
246
247 #[test]
248 fn test_scalar_ops() {
249 let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
250 let b = a.mul_scalar(2.0);
251 assert_eq!(b.data_f32(), vec![2.0, 4.0, 6.0]);
252 }
253}