1use crate::tensor::Tensor;
7use anyhow::{Result, anyhow};
8
9pub trait ArithmeticOps {
11 fn add(&self, other: &Tensor) -> Result<Tensor>;
13
14 fn sub(&self, other: &Tensor) -> Result<Tensor>;
16
17 fn mul(&self, other: &Tensor) -> Result<Tensor>;
19
20 fn div(&self, other: &Tensor) -> Result<Tensor>;
22
23 fn add_scalar(&self, scalar: f32) -> Result<Tensor>;
25
26 fn sub_scalar(&self, scalar: f32) -> Result<Tensor>;
28
29 fn mul_scalar(&self, scalar: f32) -> Result<Tensor>;
31
32 fn div_scalar(&self, scalar: f32) -> Result<Tensor>;
34
35 fn neg(&self) -> Result<Tensor>;
37
38 fn abs(&self) -> Result<Tensor>;
40
41 fn pow(&self, exponent: f32) -> Result<Tensor>;
43
44 fn sqrt(&self) -> Result<Tensor>;
46
47 fn exp(&self) -> Result<Tensor>;
49
50 fn log(&self) -> Result<Tensor>;
52}
53
54impl ArithmeticOps for Tensor {
55 fn add(&self, other: &Tensor) -> Result<Tensor> {
56 if !self.is_broadcastable_with(other) {
58 return Err(anyhow!(
59 "Cannot broadcast tensors with shapes {:?} and {:?}",
60 self.shape(),
61 other.shape()
62 ));
63 }
64
65 let result_candle = self.candle_tensor().broadcast_add(other.candle_tensor())?;
66
67 Ok(Tensor::from_candle(
68 result_candle,
69 self.dtype(),
70 self.layout(),
71 ))
72 }
73
74 fn sub(&self, other: &Tensor) -> Result<Tensor> {
75 if !self.is_broadcastable_with(other) {
76 return Err(anyhow!(
77 "Cannot broadcast tensors with shapes {:?} and {:?}",
78 self.shape(),
79 other.shape()
80 ));
81 }
82
83 let result_candle = self.candle_tensor().broadcast_sub(other.candle_tensor())?;
84
85 Ok(Tensor::from_candle(
86 result_candle,
87 self.dtype(),
88 self.layout(),
89 ))
90 }
91
92 fn mul(&self, other: &Tensor) -> Result<Tensor> {
93 if !self.is_broadcastable_with(other) {
94 return Err(anyhow!(
95 "Cannot broadcast tensors with shapes {:?} and {:?}",
96 self.shape(),
97 other.shape()
98 ));
99 }
100
101 let result_candle = self.candle_tensor().broadcast_mul(other.candle_tensor())?;
102
103 Ok(Tensor::from_candle(
104 result_candle,
105 self.dtype(),
106 self.layout(),
107 ))
108 }
109
110 fn div(&self, other: &Tensor) -> Result<Tensor> {
111 if !self.is_broadcastable_with(other) {
112 return Err(anyhow!(
113 "Cannot broadcast tensors with shapes {:?} and {:?}",
114 self.shape(),
115 other.shape()
116 ));
117 }
118
119 let result_candle = self.candle_tensor().broadcast_div(other.candle_tensor())?;
120
121 Ok(Tensor::from_candle(
122 result_candle,
123 self.dtype(),
124 self.layout(),
125 ))
126 }
127
128 fn add_scalar(&self, scalar: f32) -> Result<Tensor> {
129 let result_candle = (self.candle_tensor() + scalar as f64)?;
130
131 Ok(Tensor::from_candle(
132 result_candle,
133 self.dtype(),
134 self.layout(),
135 ))
136 }
137
138 fn sub_scalar(&self, scalar: f32) -> Result<Tensor> {
139 let result_candle = (self.candle_tensor() - scalar as f64)?;
140
141 Ok(Tensor::from_candle(
142 result_candle,
143 self.dtype(),
144 self.layout(),
145 ))
146 }
147
148 fn mul_scalar(&self, scalar: f32) -> Result<Tensor> {
149 let result_candle = (self.candle_tensor() * scalar as f64)?;
150
151 Ok(Tensor::from_candle(
152 result_candle,
153 self.dtype(),
154 self.layout(),
155 ))
156 }
157
158 fn div_scalar(&self, scalar: f32) -> Result<Tensor> {
159 if scalar == 0.0 {
160 return Err(anyhow!("Division by zero"));
161 }
162
163 let result_candle = (self.candle_tensor() / scalar as f64)?;
164
165 Ok(Tensor::from_candle(
166 result_candle,
167 self.dtype(),
168 self.layout(),
169 ))
170 }
171
172 fn neg(&self) -> Result<Tensor> {
173 let result_candle = self.candle_tensor().neg()?;
174
175 Ok(Tensor::from_candle(
176 result_candle,
177 self.dtype(),
178 self.layout(),
179 ))
180 }
181
182 fn abs(&self) -> Result<Tensor> {
183 let result_candle = self.candle_tensor().abs()?;
184
185 Ok(Tensor::from_candle(
186 result_candle,
187 self.dtype(),
188 self.layout(),
189 ))
190 }
191
192 fn pow(&self, exponent: f32) -> Result<Tensor> {
193 let result_candle = self.candle_tensor().powf(exponent as f64)?;
194
195 Ok(Tensor::from_candle(
196 result_candle,
197 self.dtype(),
198 self.layout(),
199 ))
200 }
201
202 fn sqrt(&self) -> Result<Tensor> {
203 let result_candle = self.candle_tensor().sqrt()?;
204
205 Ok(Tensor::from_candle(
206 result_candle,
207 self.dtype(),
208 self.layout(),
209 ))
210 }
211
212 fn exp(&self) -> Result<Tensor> {
213 let result_candle = self.candle_tensor().exp()?;
214
215 Ok(Tensor::from_candle(
216 result_candle,
217 self.dtype(),
218 self.layout(),
219 ))
220 }
221
222 fn log(&self) -> Result<Tensor> {
223 let result_candle = self.candle_tensor().log()?;
224
225 Ok(Tensor::from_candle(
226 result_candle,
227 self.dtype(),
228 self.layout(),
229 ))
230 }
231}
232
233impl Tensor {
235 pub fn clamp(&self, min: f32, max: f32) -> Result<Tensor> {
237 if min > max {
238 return Err(anyhow!(
239 "Min value {} is greater than max value {}",
240 min,
241 max
242 ));
243 }
244
245 let result_candle = self.candle_tensor().clamp(min as f64, max as f64)?;
246
247 Ok(Tensor::from_candle(
248 result_candle,
249 self.dtype(),
250 self.layout(),
251 ))
252 }
253
254 pub fn relu(&self) -> Result<Tensor> {
256 self.clamp(0.0, f32::INFINITY)
257 }
258
259 pub fn sigmoid(&self) -> Result<Tensor> {
261 let neg_x = self.neg()?;
263 let exp_neg_x = neg_x.exp()?;
264 let one = Tensor::ones(vec![1], self.dtype(), self.layout())?;
265 let one_plus_exp = one.add(&exp_neg_x)?;
266 one.div(&one_plus_exp)
267 }
268
269 pub fn tanh(&self) -> Result<Tensor> {
271 let result_candle = self.candle_tensor().tanh()?;
272
273 Ok(Tensor::from_candle(
274 result_candle,
275 self.dtype(),
276 self.layout(),
277 ))
278 }
279
280 pub fn gelu(&self) -> Result<Tensor> {
282 let x = self;
285 let x_cubed = x.pow(3.0)?;
286 let term1 = x_cubed.mul_scalar(0.044715)?;
287 let term2 = x.add(&term1)?;
288 let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
289 let term3 = term2.mul_scalar(sqrt_2_over_pi)?;
290 let tanh_term = term3.tanh()?;
291 let one = Tensor::ones(vec![1], self.dtype(), self.layout())?;
292 let one_plus_tanh = one.add(&tanh_term)?;
293 let half = Tensor::from_data(vec![0.5], vec![1], self.dtype(), self.layout())?;
294 let result = x.mul(&half)?.mul(&one_plus_tanh)?;
295 Ok(result)
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::types::{DataType, TensorLayout};
303
304 #[test]
305 fn test_arithmetic_operations() -> Result<()> {
306 let a = Tensor::from_data(
307 vec![1.0, 2.0, 3.0, 4.0],
308 vec![2, 2],
309 DataType::F32,
310 TensorLayout::RowMajor,
311 )?;
312 let b = Tensor::from_data(
313 vec![2.0, 1.0, 1.0, 2.0],
314 vec![2, 2],
315 DataType::F32,
316 TensorLayout::RowMajor,
317 )?;
318
319 let sum = a.add(&b)?;
321 let sum_data = sum.to_vec()?;
322 assert_eq!(sum_data, vec![3.0, 3.0, 4.0, 6.0]);
323
324 let diff = a.sub(&b)?;
326 let diff_data = diff.to_vec()?;
327 assert_eq!(diff_data, vec![-1.0, 1.0, 2.0, 2.0]);
328
329 let product = a.mul(&b)?;
331 let product_data = product.to_vec()?;
332 assert_eq!(product_data, vec![2.0, 2.0, 3.0, 8.0]);
333
334 let quotient = a.div(&b)?;
336 let quotient_data = quotient.to_vec()?;
337 assert_eq!(quotient_data, vec![0.5, 2.0, 3.0, 2.0]);
338
339 Ok(())
340 }
341
342 #[test]
343 fn test_scalar_operations() -> Result<()> {
344 let a = Tensor::from_data(
345 vec![1.0, 2.0, 3.0, 4.0],
346 vec![2, 2],
347 DataType::F32,
348 TensorLayout::RowMajor,
349 )?;
350
351 let sum = a.add_scalar(5.0)?;
353 let sum_data = sum.to_vec()?;
354 assert_eq!(sum_data, vec![6.0, 7.0, 8.0, 9.0]);
355
356 let product = a.mul_scalar(2.0)?;
358 let product_data = product.to_vec()?;
359 assert_eq!(product_data, vec![2.0, 4.0, 6.0, 8.0]);
360
361 Ok(())
362 }
363
364 #[test]
365 fn test_broadcasting() -> Result<()> {
366 let a = Tensor::from_data(
367 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
368 vec![2, 3],
369 DataType::F32,
370 TensorLayout::RowMajor,
371 )?;
372 let b = Tensor::from_data(
373 vec![10.0, 20.0, 30.0],
374 vec![3],
375 DataType::F32,
376 TensorLayout::RowMajor,
377 )?;
378
379 let sum = a.add(&b)?;
380 let sum_data = sum.to_vec()?;
381 assert_eq!(sum_data, vec![11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
382
383 Ok(())
384 }
385
386 #[test]
387 fn test_activation_functions() -> Result<()> {
388 let a = Tensor::from_data(
389 vec![-1.0, 0.0, 1.0, 2.0],
390 vec![4],
391 DataType::F32,
392 TensorLayout::RowMajor,
393 )?;
394
395 let relu_result = a.relu()?;
397 let relu_data = relu_result.to_vec()?;
398 assert_eq!(relu_data, vec![0.0, 0.0, 1.0, 2.0]);
399
400 let abs_result = a.abs()?;
402 let abs_data = abs_result.to_vec()?;
403 assert_eq!(abs_data, vec![1.0, 0.0, 1.0, 2.0]);
404
405 let neg_result = a.neg()?;
407 let neg_data = neg_result.to_vec()?;
408 assert_eq!(neg_data, vec![1.0, 0.0, -1.0, -2.0]);
409
410 Ok(())
411 }
412
413 #[test]
414 fn test_sigmoid() -> Result<()> {
415 let x = Tensor::from_data(vec![0.0], vec![1], DataType::F32, TensorLayout::RowMajor)?;
416 let sigmoid_result = x.sigmoid()?;
417 let sigmoid_data = sigmoid_result.to_vec()?;
418
419 assert!((sigmoid_data[0] - 0.5).abs() < 1e-6);
421
422 Ok(())
423 }
424
425 #[test]
426 fn test_error_handling() {
427 let a = Tensor::from_data(
428 vec![1.0, 2.0],
429 vec![2],
430 DataType::F32,
431 TensorLayout::RowMajor,
432 )
433 .unwrap();
434 let b = Tensor::from_data(
435 vec![1.0, 2.0, 3.0],
436 vec![3],
437 DataType::F32,
438 TensorLayout::RowMajor,
439 )
440 .unwrap();
441
442 assert!(a.add(&b).is_err());
444
445 assert!(a.div_scalar(0.0).is_err());
447
448 assert!(a.clamp(5.0, 1.0).is_err());
450 }
451}