1use std::sync::Arc;
2use crate::{AutogradMetaT, NumDType, Result, Shape, WithDType};
3use super::{Tensor, TensorId, TensorImpl};
4
5impl<T: WithDType> Tensor<T> {
6 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
14 let ndarry_ = TensorImpl {
15 id: TensorId::new(),
16 storage: self.0.storage.clone(),
17 layout: self.layout().broadcast_as(shape)?,
18 meta: T::AutogradMeta::on_broadcast_op(self)
19 };
20 Ok(Tensor(Arc::new(ndarry_)))
21 }
22}
23
24macro_rules! broadcast_binary_op {
25 ($fn_name:ident, $inner_fn_name:ident) => {
26 pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
27 if self.shape() == rhs.shape() {
28 return self.$inner_fn_name(rhs);
29 }
30 let lhs = self;
31 let shape = lhs
32 .shape()
33 .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
34 let l_broadcast = shape != *lhs.shape();
35 let r_broadcast = shape != *rhs.shape();
36 match (l_broadcast, r_broadcast) {
37 (true, true) => lhs
38 .broadcast_as(&shape)?
39 .$inner_fn_name(&rhs.broadcast_as(&shape)?),
40 (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
41 (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
42 (false, false) => lhs.$inner_fn_name(rhs),
43 }
44 }
45 };
46}
47
48macro_rules! broadcast_cmp_op {
49 ($fn_name:ident, $inner_fn_name:ident) => {
50 pub fn $fn_name(&self, rhs: &Self) -> Result<Tensor<bool>> {
51 if self.shape() == rhs.shape() {
52 return self.$inner_fn_name(rhs);
53 }
54 let lhs = self;
55 let shape = lhs
56 .shape()
57 .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
58 let l_broadcast = shape != *lhs.shape();
59 let r_broadcast = shape != *rhs.shape();
60 match (l_broadcast, r_broadcast) {
61 (true, true) => lhs
62 .broadcast_as(&shape)?
63 .$inner_fn_name(&rhs.broadcast_as(&shape)?),
64 (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
65 (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
66 (false, false) => lhs.$inner_fn_name(rhs),
67 }
68 }
69 };
70}
71
72impl<T: NumDType> Tensor<T> {
73 broadcast_binary_op!(broadcast_add, add);
74 broadcast_binary_op!(broadcast_mul, mul);
75 broadcast_binary_op!(broadcast_sub, sub);
76 broadcast_binary_op!(broadcast_div, div);
77 broadcast_binary_op!(broadcast_maximum, maximum);
78 broadcast_binary_op!(broadcast_minimum, minimum);
79 broadcast_cmp_op!(broadcast_eq, eq);
80 broadcast_cmp_op!(broadcast_ne, ne);
81 broadcast_cmp_op!(broadcast_lt, lt);
82 broadcast_cmp_op!(broadcast_le, le);
83 broadcast_cmp_op!(broadcast_gt, gt);
84 broadcast_cmp_op!(broadcast_ge, ge);
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_broadcast_add_scalar() {
93 let a = Tensor::new(&[1., 2., 3.]).unwrap();
94 let b = Tensor::new(&[10.]).unwrap(); let res = a.broadcast_add(&b).unwrap();
96
97 let expected = Tensor::new(&[11., 12., 13.]).unwrap();
98 assert!(res.allclose(&expected, 1e-6, 1e-6).unwrap());
99 }
100
101 #[test]
102 fn test_broadcast_add_vector_to_matrix() {
103 let a = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]]).unwrap(); let b = Tensor::new(&[10., 20., 30.]).unwrap(); let res = a.broadcast_add(&b).unwrap();
106
107 let expected = Tensor::new(&[[11., 22., 33.], [14., 25., 36.]]).unwrap();
108 assert!(res.allclose(&expected, 1e-6, 1e-6).unwrap());
109 }
110
111 #[test]
112 fn test_broadcast_mul() {
113 let a = Tensor::new(&[[1., 2.], [3., 4.]]).unwrap(); let b = Tensor::new(&[10.]).unwrap(); let res = a.broadcast_mul(&b).unwrap();
116
117 let expected = Tensor::new(&[[10., 20.], [30., 40.]]).unwrap();
118 assert!(res.allclose(&expected, 1e-6, 1e-6).unwrap());
119 }
120
121 #[test]
122 fn test_broadcast_maximum_minimum() {
123 let a = Tensor::new(&[1., 5., 3.]).unwrap();
124 let b = Tensor::new(&[2., 2., 2.]).unwrap();
125
126 let max_res = a.broadcast_maximum(&b).unwrap();
127 let min_res = a.broadcast_minimum(&b).unwrap();
128
129 let expected_max = Tensor::new(&[2., 5., 3.]).unwrap();
130 let expected_min = Tensor::new(&[1., 2., 2.]).unwrap();
131
132 assert!(max_res.allclose(&expected_max, 1e-6, 1e-6).unwrap());
133 assert!(min_res.allclose(&expected_min, 1e-6, 1e-6).unwrap());
134 }
135
136 #[test]
137 fn test_broadcast_comparisons() {
138 let a = Tensor::new(&[1., 2., 3.]).unwrap();
139 let b = Tensor::new(&[2.]).unwrap();
140
141 let eq = a.broadcast_eq(&b).unwrap();
142 let lt = a.broadcast_lt(&b).unwrap();
143 let le = a.broadcast_le(&b).unwrap();
144 let gt = a.broadcast_gt(&b).unwrap();
145 let ge = a.broadcast_ge(&b).unwrap();
146 let ne = a.broadcast_ne(&b).unwrap();
147
148 assert_eq!(eq.to_vec().unwrap(), [false, true, false]);
149 assert_eq!(lt.to_vec().unwrap(), [true, false, false]);
150 assert_eq!(le.to_vec().unwrap(), [true, true, false]);
151 assert_eq!(gt.to_vec().unwrap(), [false, false, true]);
152 assert_eq!(ge.to_vec().unwrap(), [false, true, true]);
153 assert_eq!(ne.to_vec().unwrap(), [true, false, true]);
154 }
155
156 #[test]
157 fn test_broadcast_div() {
158 let a = Tensor::new(&[10., 20., 30.]).unwrap();
159 let b = Tensor::new(&[10.]).unwrap();
160 let res = a.broadcast_div(&b).unwrap();
161
162 let expected = Tensor::new(&[1., 2., 3.]).unwrap();
163 assert!(res.allclose(&expected, 1e-6, 1e-6).unwrap());
164 }
165}