runmat_runtime/builtins/common/
broadcast.rs1pub fn broadcast_shapes(
9 fn_name: &str,
10 left: &[usize],
11 right: &[usize],
12) -> Result<Vec<usize>, String> {
13 let rank = left.len().max(right.len());
17 let mut left_ext = Vec::with_capacity(rank);
18 left_ext.extend(std::iter::repeat_n(1, rank.saturating_sub(left.len())));
19 left_ext.extend_from_slice(left);
20 let mut right_ext = Vec::with_capacity(rank);
21 right_ext.extend(std::iter::repeat_n(1, rank.saturating_sub(right.len())));
22 right_ext.extend_from_slice(right);
23
24 let mut shape = Vec::with_capacity(rank);
25 for dim in 0..rank {
26 let a = left_ext[dim];
27 let b = right_ext[dim];
28 if a == b {
29 shape.push(a);
30 } else if a == 1 {
31 shape.push(b);
32 } else if b == 1 {
33 shape.push(a);
34 } else if a == 0 || b == 0 {
35 shape.push(0);
36 } else {
37 return Err(format!(
38 "{fn_name}: size mismatch between inputs (dimension {} has lengths {} and {})",
39 dim + 1,
40 a,
41 b
42 ));
43 }
44 }
45 Ok(shape)
46}
47
48pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
50 let mut strides = Vec::with_capacity(shape.len());
51 let mut stride = 1usize;
52 for &extent in shape {
53 strides.push(stride);
54 stride = stride.saturating_mul(extent.max(1));
55 }
56 strides
57}
58
59pub fn broadcast_index(
61 mut linear: usize,
62 out_shape: &[usize],
63 in_shape: &[usize],
64 strides: &[usize],
65) -> usize {
66 if in_shape.is_empty() {
67 return 0;
68 }
69 let mut offset = 0usize;
70 for dim in 0..out_shape.len() {
71 let out_extent = out_shape[dim];
72 let coord = if out_extent == 0 {
73 0
74 } else {
75 linear % out_extent
76 };
77 if out_extent != 0 {
78 linear /= out_extent;
79 }
80 let in_extent = in_shape.get(dim).copied().unwrap_or(1);
81 let mapped = if in_extent == 1 || out_extent == 0 {
82 0
83 } else {
84 coord
85 };
86 if dim < strides.len() {
87 offset += mapped * strides[dim];
88 }
89 }
90 offset
91}
92
93#[derive(Debug, Clone)]
95pub struct BroadcastPlan {
96 output_shape: Vec<usize>,
97 len: usize,
98 advance_a: Vec<usize>,
99 advance_b: Vec<usize>,
100}
101
102impl BroadcastPlan {
103 pub fn new(shape_a: &[usize], shape_b: &[usize]) -> Result<Self, String> {
106 let ndims = shape_a.len().max(shape_b.len());
107
108 let mut ext_a = Vec::with_capacity(ndims);
110 ext_a.extend(std::iter::repeat_n(1, ndims.saturating_sub(shape_a.len())));
111 ext_a.extend_from_slice(shape_a);
112
113 let mut ext_b = Vec::with_capacity(ndims);
114 ext_b.extend(std::iter::repeat_n(1, ndims.saturating_sub(shape_b.len())));
115 ext_b.extend_from_slice(shape_b);
116
117 let mut output_shape = Vec::with_capacity(ndims);
118 for i in 0..ndims {
119 let da = ext_a[i];
120 let db = ext_b[i];
121 if da == db {
122 output_shape.push(da);
123 } else if da == 1 {
124 output_shape.push(db);
125 } else if db == 1 {
126 output_shape.push(da);
127 } else {
128 return Err(format!(
129 "broadcast: non-singleton dimension mismatch (dimension {}: {} vs {})",
130 i + 1,
131 da,
132 db
133 ));
134 }
135 }
136
137 let len = output_shape.iter().copied().product();
138 let strides_a = compute_strides(&ext_a);
139 let strides_b = compute_strides(&ext_b);
140
141 let advance_a = ext_a
142 .iter()
143 .enumerate()
144 .map(|(dim, &size)| if size <= 1 { 0 } else { strides_a[dim] })
145 .collect::<Vec<_>>();
146 let advance_b = ext_b
147 .iter()
148 .enumerate()
149 .map(|(dim, &size)| if size <= 1 { 0 } else { strides_b[dim] })
150 .collect::<Vec<_>>();
151
152 Ok(Self {
153 output_shape,
154 len,
155 advance_a,
156 advance_b,
157 })
158 }
159
160 pub fn len(&self) -> usize {
162 self.len
163 }
164
165 pub fn is_empty(&self) -> bool {
167 self.len == 0
168 }
169
170 pub fn output_shape(&self) -> &[usize] {
172 &self.output_shape
173 }
174
175 pub fn iter(&self) -> BroadcastIter<'_> {
177 BroadcastIter {
178 plan: self,
179 offset: 0,
180 index_a: 0,
181 index_b: 0,
182 coords: vec![0usize; self.output_shape.len()],
183 }
184 }
185}
186
187pub struct BroadcastIter<'a> {
189 plan: &'a BroadcastPlan,
190 offset: usize,
191 index_a: usize,
192 index_b: usize,
193 coords: Vec<usize>,
194}
195
196impl<'a> Iterator for BroadcastIter<'a> {
197 type Item = (usize, usize, usize);
198
199 fn next(&mut self) -> Option<Self::Item> {
200 if self.offset >= self.plan.len {
201 return None;
202 }
203 let current = (self.offset, self.index_a, self.index_b);
204 self.offset += 1;
205 if self.offset == self.plan.len {
206 return Some(current);
207 }
208 for dim in 0..self.plan.output_shape.len() {
209 if self.plan.output_shape[dim] == 0 {
210 continue;
211 }
212 self.coords[dim] += 1;
213 if self.coords[dim] < self.plan.output_shape[dim] {
214 self.index_a += self.plan.advance_a[dim];
215 self.index_b += self.plan.advance_b[dim];
216 break;
217 }
218 self.coords[dim] = 0;
219 let rewind = self.plan.output_shape[dim].saturating_sub(1);
220 let rewind_a = self.plan.advance_a[dim] * rewind;
221 let rewind_b = self.plan.advance_b[dim] * rewind;
222 if rewind_a != 0 {
223 self.index_a = self.index_a.saturating_sub(rewind_a);
224 }
225 if rewind_b != 0 {
226 self.index_b = self.index_b.saturating_sub(rewind_b);
227 }
228 }
229 Some(current)
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn broadcast_equal_shapes() {
239 let out = broadcast_shapes("test", &[2, 3], &[2, 3]).unwrap();
240 assert_eq!(out, vec![2, 3]);
241 }
242
243 #[test]
244 fn broadcast_scalar() {
245 let out = broadcast_shapes("test", &[1, 1], &[4, 5]).unwrap();
246 assert_eq!(out, vec![4, 5]);
247 }
248
249 #[test]
250 fn broadcast_mismatched_dimension_errors() {
251 let err = broadcast_shapes("test", &[2, 3], &[4, 3]).unwrap_err();
252 assert!(err.contains("dimension 1"));
253 }
254
255 #[test]
256 fn compute_strides_column_major() {
257 let strides = compute_strides(&[2, 3, 4]);
258 assert_eq!(strides, vec![1, 2, 6]);
259 }
260
261 #[test]
262 fn broadcast_index_maps_scalar_inputs() {
263 let strides = compute_strides(&[1, 1]);
264 let idx = broadcast_index(5, &[2, 3], &[1, 1], &strides);
265 assert_eq!(idx, 0);
266 }
267
268 #[test]
269 fn broadcast_same_shape() {
270 let plan = BroadcastPlan::new(&[2, 3], &[2, 3]).unwrap();
271 assert_eq!(plan.output_shape(), &[2, 3]);
272 assert_eq!(plan.len(), 6);
273 let indices: Vec<(usize, usize, usize)> = plan.iter().collect();
274 assert_eq!(
275 indices,
276 vec![
277 (0, 0, 0),
278 (1, 1, 1),
279 (2, 2, 2),
280 (3, 3, 3),
281 (4, 4, 4),
282 (5, 5, 5)
283 ]
284 );
285 }
286
287 #[test]
288 fn broadcast_scalar_expansion() {
289 let plan = BroadcastPlan::new(&[1, 3], &[1, 1]).unwrap();
290 assert_eq!(plan.output_shape(), &[1, 3]);
291 assert_eq!(plan.len(), 3);
292 let indices: Vec<(usize, usize, usize)> = plan.iter().collect();
293 assert_eq!(indices, vec![(0, 0, 0), (1, 1, 0), (2, 2, 0)]);
294 }
295
296 #[test]
297 fn broadcast_zero_sized_dimension() {
298 let plan = BroadcastPlan::new(&[0, 3], &[1, 3]).unwrap();
299 assert_eq!(plan.output_shape(), &[0, 3]);
300 assert_eq!(plan.len(), 0);
301 assert_eq!(plan.iter().next(), None);
302 }
303}