matten/ops/tensor_ops.rs
1//! Element-wise binary operators for borrowed tensor pairs (RFC-006).
2//!
3//! `*` is element-wise multiplication; matrix multiplication is explicit and
4//! arrives in RFC-010 / M6.
5
6use crate::Tensor;
7use crate::ops::broadcast::apply_binary;
8use std::ops::{Add, Div, Mul, Sub};
9
10impl Add for &Tensor {
11 type Output = Tensor;
12 /// Element-wise addition with NumPy-style broadcasting.
13 ///
14 /// # Panics
15 ///
16 /// Panics with `"matten broadcast error in add: ..."` if the shapes are
17 /// incompatible.
18 ///
19 /// ```
20 /// use matten::Tensor;
21 /// let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
22 /// let b = Tensor::ones(&[2, 2]);
23 /// let c = &a + &b;
24 /// assert_eq!(c.as_slice(), &[2.0, 3.0, 4.0, 5.0]);
25 /// ```
26 fn add(self, rhs: &Tensor) -> Tensor {
27 apply_binary(self, rhs, "add", |a, b| a + b)
28 }
29}
30
31impl Sub for &Tensor {
32 type Output = Tensor;
33 /// Element-wise subtraction with broadcasting.
34 ///
35 /// # Panics
36 ///
37 /// Panics on incompatible shapes.
38 ///
39 /// ```
40 /// use matten::Tensor;
41 /// let a = Tensor::new(vec![5.0, 4.0, 3.0, 2.0], &[2, 2]);
42 /// let b = Tensor::ones(&[2, 2]);
43 /// let c = &a - &b;
44 /// assert_eq!(c.as_slice(), &[4.0, 3.0, 2.0, 1.0]);
45 /// ```
46 fn sub(self, rhs: &Tensor) -> Tensor {
47 apply_binary(self, rhs, "sub", |a, b| a - b)
48 }
49}
50
51impl Mul for &Tensor {
52 type Output = Tensor;
53 /// Element-wise multiplication with broadcasting (`*` is **not** matrix
54 /// multiply; use `matmul` for that).
55 ///
56 /// # Panics
57 ///
58 /// Panics on incompatible shapes.
59 ///
60 /// ```
61 /// use matten::Tensor;
62 /// let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
63 /// let b = Tensor::full(&[2, 2], 2.0);
64 /// let c = &a * &b;
65 /// assert_eq!(c.as_slice(), &[2.0, 4.0, 6.0, 8.0]);
66 /// ```
67 fn mul(self, rhs: &Tensor) -> Tensor {
68 apply_binary(self, rhs, "mul", |a, b| a * b)
69 }
70}
71
72impl Div for &Tensor {
73 type Output = Tensor;
74 /// Element-wise division with broadcasting. Division by zero follows IEEE 754
75 /// `f64` behavior (yields `inf`, `-inf`, or `NaN`); no error is produced.
76 ///
77 /// # Panics
78 ///
79 /// Panics on incompatible shapes.
80 ///
81 /// ```
82 /// use matten::Tensor;
83 /// let a = Tensor::new(vec![4.0, 9.0], &[2]);
84 /// let b = Tensor::new(vec![2.0, 3.0], &[2]);
85 /// let c = &a / &b;
86 /// assert_eq!(c.as_slice(), &[2.0, 3.0]);
87 /// ```
88 fn div(self, rhs: &Tensor) -> Tensor {
89 apply_binary(self, rhs, "div", |a, b| a / b)
90 }
91}