Skip to main content

lumen_core/tensor/
broadcast.rs

1use std::sync::Arc;
2use crate::{AutogradMetaT, NumDType, Result, Shape, WithDType};
3use super::{Tensor, TensorId, TensorImpl};
4
5impl<T: WithDType> Tensor<T> {
6    /// Broadcast the input tensor to the target shape. This returns an error if the input shape is
7    /// not compatible with the target shape.
8    ///
9    /// If the input shape is `i_1, i_2, ... i_k`, the target shape has to have `k` dimensions or
10    /// more and shape `j_1, ..., j_l, t_1, t_2, ..., t_k`. The dimensions `j_1` to `j_l` can have
11    /// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
12    /// `i_a` is equal to 1, any value can be used.
13    pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
14        let ndarry_ = TensorImpl {
15            id: TensorId::new(),
16            storage: self.0.storage.clone(),
17            layout: self.layout().broadcast_as(shape)?,
18            meta: T::AutogradMeta::on_broadcast_op(self)
19        };
20        Ok(Tensor(Arc::new(ndarry_)))
21    }
22}
23
24macro_rules! broadcast_binary_op {
25    ($fn_name:ident, $inner_fn_name:ident) => {
26        pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
27            if self.shape() == rhs.shape() {
28                return self.$inner_fn_name(rhs);
29            }
30            let lhs = self;
31            let shape = lhs
32                .shape()
33                .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
34            let l_broadcast = shape != *lhs.shape();
35            let r_broadcast = shape != *rhs.shape();
36            match (l_broadcast, r_broadcast) {
37                (true, true) => lhs
38                    .broadcast_as(&shape)?
39                    .$inner_fn_name(&rhs.broadcast_as(&shape)?),
40                (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
41                (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
42                (false, false) => lhs.$inner_fn_name(rhs),
43            }
44        }
45    };
46}
47
48macro_rules! broadcast_cmp_op {
49    ($fn_name:ident, $inner_fn_name:ident) => {
50        pub fn $fn_name(&self, rhs: &Self) -> Result<Tensor<bool>> {
51            if self.shape() == rhs.shape() {
52                return self.$inner_fn_name(rhs);
53            }
54            let lhs = self;
55            let shape = lhs
56                .shape()
57                .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
58            let l_broadcast = shape != *lhs.shape();
59            let r_broadcast = shape != *rhs.shape();
60            match (l_broadcast, r_broadcast) {
61                (true, true) => lhs
62                    .broadcast_as(&shape)?
63                    .$inner_fn_name(&rhs.broadcast_as(&shape)?),
64                (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
65                (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
66                (false, false) => lhs.$inner_fn_name(rhs),
67            }
68        }
69    };
70}
71
72impl<T: NumDType> Tensor<T> {
73    broadcast_binary_op!(broadcast_add, add);
74    broadcast_binary_op!(broadcast_mul, mul);
75    broadcast_binary_op!(broadcast_sub, sub);
76    broadcast_binary_op!(broadcast_div, div);
77    broadcast_binary_op!(broadcast_maximum, maximum);
78    broadcast_binary_op!(broadcast_minimum, minimum);
79    broadcast_cmp_op!(broadcast_eq, eq);
80    broadcast_cmp_op!(broadcast_ne, ne);
81    broadcast_cmp_op!(broadcast_lt, lt);
82    broadcast_cmp_op!(broadcast_le, le);
83    broadcast_cmp_op!(broadcast_gt, gt);
84    broadcast_cmp_op!(broadcast_ge, ge);
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_broadcast_add_scalar() {
93        let a = Tensor::new(&[1., 2., 3.]).unwrap();
94        let b = Tensor::new(&[10.]).unwrap(); // 标量 (1,)
95        let res = a.broadcast_add(&b).unwrap();
96
97        let expected = Tensor::new(&[11., 12., 13.]).unwrap();
98        assert!(res.allclose(&expected, 1e-6, 1e-6).unwrap());
99    }
100
101    #[test]
102    fn test_broadcast_add_vector_to_matrix() {
103        let a = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]]).unwrap(); // (2,3)
104        let b = Tensor::new(&[10., 20., 30.]).unwrap(); // (3,)
105        let res = a.broadcast_add(&b).unwrap();
106
107        let expected = Tensor::new(&[[11., 22., 33.], [14., 25., 36.]]).unwrap();
108        assert!(res.allclose(&expected, 1e-6, 1e-6).unwrap());
109    }
110
111    #[test]
112    fn test_broadcast_mul() {
113        let a = Tensor::new(&[[1., 2.], [3., 4.]]).unwrap(); // (2,2)
114        let b = Tensor::new(&[10.]).unwrap(); // (1,)
115        let res = a.broadcast_mul(&b).unwrap();
116
117        let expected = Tensor::new(&[[10., 20.], [30., 40.]]).unwrap();
118        assert!(res.allclose(&expected, 1e-6, 1e-6).unwrap());
119    }
120
121    #[test]
122    fn test_broadcast_maximum_minimum() {
123        let a = Tensor::new(&[1., 5., 3.]).unwrap();
124        let b = Tensor::new(&[2., 2., 2.]).unwrap();
125
126        let max_res = a.broadcast_maximum(&b).unwrap();
127        let min_res = a.broadcast_minimum(&b).unwrap();
128
129        let expected_max = Tensor::new(&[2., 5., 3.]).unwrap();
130        let expected_min = Tensor::new(&[1., 2., 2.]).unwrap();
131
132        assert!(max_res.allclose(&expected_max, 1e-6, 1e-6).unwrap());
133        assert!(min_res.allclose(&expected_min, 1e-6, 1e-6).unwrap());
134    }
135
136    #[test]
137    fn test_broadcast_comparisons() {
138        let a = Tensor::new(&[1., 2., 3.]).unwrap();
139        let b = Tensor::new(&[2.]).unwrap();
140
141        let eq = a.broadcast_eq(&b).unwrap();
142        let lt = a.broadcast_lt(&b).unwrap();
143        let le = a.broadcast_le(&b).unwrap();
144        let gt = a.broadcast_gt(&b).unwrap();
145        let ge = a.broadcast_ge(&b).unwrap();
146        let ne = a.broadcast_ne(&b).unwrap();
147
148        assert_eq!(eq.to_vec().unwrap(), [false, true, false]);
149        assert_eq!(lt.to_vec().unwrap(), [true, false, false]);
150        assert_eq!(le.to_vec().unwrap(), [true, true, false]);
151        assert_eq!(gt.to_vec().unwrap(), [false, false, true]);
152        assert_eq!(ge.to_vec().unwrap(), [false, true, true]);
153        assert_eq!(ne.to_vec().unwrap(), [true, false, true]);
154    }
155
156    #[test]
157    fn test_broadcast_div() {
158        let a = Tensor::new(&[10., 20., 30.]).unwrap();
159        let b = Tensor::new(&[10.]).unwrap();
160        let res = a.broadcast_div(&b).unwrap();
161
162        let expected = Tensor::new(&[1., 2., 3.]).unwrap();
163        assert!(res.allclose(&expected, 1e-6, 1e-6).unwrap());
164    }
165}