runmat_runtime/builtins/common/
broadcast.rs

1//! Broadcasting utilities shared across builtin implementations.
2//!
3//! The helpers in this module mirror MATLAB's implicit expansion rules and
4//! operate on column-major shapes expressed as `[usize]` vectors.
5
6/// Compute the broadcasted shape for two operands using MATLAB implicit
7/// expansion rules.
8pub fn broadcast_shapes(
9    fn_name: &str,
10    left: &[usize],
11    right: &[usize],
12) -> Result<Vec<usize>, String> {
13    // MATLAB implicit expansion aligns trailing dimensions. To achieve this with
14    // column-major `[usize]` shape vectors, pad the shorter shape on the FRONT
15    // with ones so that the last dimensions line up.
16    let rank = left.len().max(right.len());
17    let mut left_ext = Vec::with_capacity(rank);
18    left_ext.extend(std::iter::repeat_n(1, rank.saturating_sub(left.len())));
19    left_ext.extend_from_slice(left);
20    let mut right_ext = Vec::with_capacity(rank);
21    right_ext.extend(std::iter::repeat_n(1, rank.saturating_sub(right.len())));
22    right_ext.extend_from_slice(right);
23
24    let mut shape = Vec::with_capacity(rank);
25    for dim in 0..rank {
26        let a = left_ext[dim];
27        let b = right_ext[dim];
28        if a == b {
29            shape.push(a);
30        } else if a == 1 {
31            shape.push(b);
32        } else if b == 1 {
33            shape.push(a);
34        } else if a == 0 || b == 0 {
35            shape.push(0);
36        } else {
37            return Err(format!(
38                "{fn_name}: size mismatch between inputs (dimension {} has lengths {} and {})",
39                dim + 1,
40                a,
41                b
42            ));
43        }
44    }
45    Ok(shape)
46}
47
48/// Compute column-major strides for a given shape.
49pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
50    let mut strides = Vec::with_capacity(shape.len());
51    let mut stride = 1usize;
52    for &extent in shape {
53        strides.push(stride);
54        stride = stride.saturating_mul(extent.max(1));
55    }
56    strides
57}
58
59/// Map a linear index in the broadcasted result back to a source operand.
60pub fn broadcast_index(
61    mut linear: usize,
62    out_shape: &[usize],
63    in_shape: &[usize],
64    strides: &[usize],
65) -> usize {
66    if in_shape.is_empty() {
67        return 0;
68    }
69    let mut offset = 0usize;
70    for dim in 0..out_shape.len() {
71        let out_extent = out_shape[dim];
72        let coord = if out_extent == 0 {
73            0
74        } else {
75            linear % out_extent
76        };
77        if out_extent != 0 {
78            linear /= out_extent;
79        }
80        let in_extent = in_shape.get(dim).copied().unwrap_or(1);
81        let mapped = if in_extent == 1 || out_extent == 0 {
82            0
83        } else {
84            coord
85        };
86        if dim < strides.len() {
87            offset += mapped * strides[dim];
88        }
89    }
90    offset
91}
92
93/// Broadcast plan describing how two tensors can be implicitly expanded.
94#[derive(Debug, Clone)]
95pub struct BroadcastPlan {
96    output_shape: Vec<usize>,
97    len: usize,
98    advance_a: Vec<usize>,
99    advance_b: Vec<usize>,
100}
101
102impl BroadcastPlan {
103    /// Construct a broadcast plan for two shapes, returning an error when they
104    /// cannot be implicitly expanded under MATLAB rules.
105    pub fn new(shape_a: &[usize], shape_b: &[usize]) -> Result<Self, String> {
106        let ndims = shape_a.len().max(shape_b.len());
107
108        // Pad on the FRONT to align trailing dimensions per MATLAB rules.
109        let mut ext_a = Vec::with_capacity(ndims);
110        ext_a.extend(std::iter::repeat_n(1, ndims.saturating_sub(shape_a.len())));
111        ext_a.extend_from_slice(shape_a);
112
113        let mut ext_b = Vec::with_capacity(ndims);
114        ext_b.extend(std::iter::repeat_n(1, ndims.saturating_sub(shape_b.len())));
115        ext_b.extend_from_slice(shape_b);
116
117        let mut output_shape = Vec::with_capacity(ndims);
118        for i in 0..ndims {
119            let da = ext_a[i];
120            let db = ext_b[i];
121            if da == db {
122                output_shape.push(da);
123            } else if da == 1 {
124                output_shape.push(db);
125            } else if db == 1 {
126                output_shape.push(da);
127            } else {
128                return Err(format!(
129                    "broadcast: non-singleton dimension mismatch (dimension {}: {} vs {})",
130                    i + 1,
131                    da,
132                    db
133                ));
134            }
135        }
136
137        let len = output_shape.iter().copied().product();
138        let strides_a = compute_strides(&ext_a);
139        let strides_b = compute_strides(&ext_b);
140
141        let advance_a = ext_a
142            .iter()
143            .enumerate()
144            .map(|(dim, &size)| if size <= 1 { 0 } else { strides_a[dim] })
145            .collect::<Vec<_>>();
146        let advance_b = ext_b
147            .iter()
148            .enumerate()
149            .map(|(dim, &size)| if size <= 1 { 0 } else { strides_b[dim] })
150            .collect::<Vec<_>>();
151
152        Ok(Self {
153            output_shape,
154            len,
155            advance_a,
156            advance_b,
157        })
158    }
159
160    /// Total number of elements produced by the broadcast.
161    pub fn len(&self) -> usize {
162        self.len
163    }
164
165    /// Returns true if the broadcast produces no elements.
166    pub fn is_empty(&self) -> bool {
167        self.len == 0
168    }
169
170    /// Output shape after broadcasting both operands.
171    pub fn output_shape(&self) -> &[usize] {
172        &self.output_shape
173    }
174
175    /// Iterator yielding `(output_index, index_a, index_b)` triples for each element.
176    pub fn iter(&self) -> BroadcastIter<'_> {
177        BroadcastIter {
178            plan: self,
179            offset: 0,
180            index_a: 0,
181            index_b: 0,
182            coords: vec![0usize; self.output_shape.len()],
183        }
184    }
185}
186
187/// Iterator over broadcast indices.
188pub struct BroadcastIter<'a> {
189    plan: &'a BroadcastPlan,
190    offset: usize,
191    index_a: usize,
192    index_b: usize,
193    coords: Vec<usize>,
194}
195
196impl<'a> Iterator for BroadcastIter<'a> {
197    type Item = (usize, usize, usize);
198
199    fn next(&mut self) -> Option<Self::Item> {
200        if self.offset >= self.plan.len {
201            return None;
202        }
203        let current = (self.offset, self.index_a, self.index_b);
204        self.offset += 1;
205        if self.offset == self.plan.len {
206            return Some(current);
207        }
208        for dim in 0..self.plan.output_shape.len() {
209            if self.plan.output_shape[dim] == 0 {
210                continue;
211            }
212            self.coords[dim] += 1;
213            if self.coords[dim] < self.plan.output_shape[dim] {
214                self.index_a += self.plan.advance_a[dim];
215                self.index_b += self.plan.advance_b[dim];
216                break;
217            }
218            self.coords[dim] = 0;
219            let rewind = self.plan.output_shape[dim].saturating_sub(1);
220            let rewind_a = self.plan.advance_a[dim] * rewind;
221            let rewind_b = self.plan.advance_b[dim] * rewind;
222            if rewind_a != 0 {
223                self.index_a = self.index_a.saturating_sub(rewind_a);
224            }
225            if rewind_b != 0 {
226                self.index_b = self.index_b.saturating_sub(rewind_b);
227            }
228        }
229        Some(current)
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn broadcast_equal_shapes() {
239        let out = broadcast_shapes("test", &[2, 3], &[2, 3]).unwrap();
240        assert_eq!(out, vec![2, 3]);
241    }
242
243    #[test]
244    fn broadcast_scalar() {
245        let out = broadcast_shapes("test", &[1, 1], &[4, 5]).unwrap();
246        assert_eq!(out, vec![4, 5]);
247    }
248
249    #[test]
250    fn broadcast_mismatched_dimension_errors() {
251        let err = broadcast_shapes("test", &[2, 3], &[4, 3]).unwrap_err();
252        assert!(err.contains("dimension 1"));
253    }
254
255    #[test]
256    fn compute_strides_column_major() {
257        let strides = compute_strides(&[2, 3, 4]);
258        assert_eq!(strides, vec![1, 2, 6]);
259    }
260
261    #[test]
262    fn broadcast_index_maps_scalar_inputs() {
263        let strides = compute_strides(&[1, 1]);
264        let idx = broadcast_index(5, &[2, 3], &[1, 1], &strides);
265        assert_eq!(idx, 0);
266    }
267
268    #[test]
269    fn broadcast_same_shape() {
270        let plan = BroadcastPlan::new(&[2, 3], &[2, 3]).unwrap();
271        assert_eq!(plan.output_shape(), &[2, 3]);
272        assert_eq!(plan.len(), 6);
273        let indices: Vec<(usize, usize, usize)> = plan.iter().collect();
274        assert_eq!(
275            indices,
276            vec![
277                (0, 0, 0),
278                (1, 1, 1),
279                (2, 2, 2),
280                (3, 3, 3),
281                (4, 4, 4),
282                (5, 5, 5)
283            ]
284        );
285    }
286
287    #[test]
288    fn broadcast_scalar_expansion() {
289        let plan = BroadcastPlan::new(&[1, 3], &[1, 1]).unwrap();
290        assert_eq!(plan.output_shape(), &[1, 3]);
291        assert_eq!(plan.len(), 3);
292        let indices: Vec<(usize, usize, usize)> = plan.iter().collect();
293        assert_eq!(indices, vec![(0, 0, 0), (1, 1, 0), (2, 2, 0)]);
294    }
295
296    #[test]
297    fn broadcast_zero_sized_dimension() {
298        let plan = BroadcastPlan::new(&[0, 3], &[1, 3]).unwrap();
299        assert_eq!(plan.output_shape(), &[0, 3]);
300        assert_eq!(plan.len(), 0);
301        assert_eq!(plan.iter().next(), None);
302    }
303}