runmat_runtime/builtins/common/
broadcast.rs1pub fn broadcast_shapes(
9 fn_name: &str,
10 left: &[usize],
11 right: &[usize],
12) -> Result<Vec<usize>, String> {
13 let rank = left.len().max(right.len());
17 let mut left_ext = Vec::with_capacity(rank);
18 left_ext.extend(std::iter::repeat_n(1, rank.saturating_sub(left.len())));
19 left_ext.extend_from_slice(left);
20 let mut right_ext = Vec::with_capacity(rank);
21 right_ext.extend(std::iter::repeat_n(1, rank.saturating_sub(right.len())));
22 right_ext.extend_from_slice(right);
23
24 let mut shape = Vec::with_capacity(rank);
25 for dim in 0..rank {
26 let a = left_ext[dim];
27 let b = right_ext[dim];
28 if a == b {
29 shape.push(a);
30 } else if a == 1 {
31 shape.push(b);
32 } else if b == 1 {
33 shape.push(a);
34 } else if a == 0 || b == 0 {
35 shape.push(0);
36 } else {
37 return Err(format!(
38 "{fn_name}: size mismatch between inputs (dimension {} has lengths {} and {})",
39 dim + 1,
40 a,
41 b
42 ));
43 }
44 }
45 Ok(shape)
46}
47
48pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
50 let mut strides = Vec::with_capacity(shape.len());
51 let mut stride = 1usize;
52 for &extent in shape {
53 strides.push(stride);
54 stride = stride.saturating_mul(extent.max(1));
55 }
56 strides
57}
58
59pub fn broadcast_index(
61 mut linear: usize,
62 out_shape: &[usize],
63 in_shape: &[usize],
64 strides: &[usize],
65) -> usize {
66 if in_shape.is_empty() {
67 return 0;
68 }
69 let mut offset = 0usize;
70 for dim in 0..out_shape.len() {
71 let out_extent = out_shape[dim];
72 let coord = if out_extent == 0 {
73 0
74 } else {
75 linear % out_extent
76 };
77 if out_extent != 0 {
78 linear /= out_extent;
79 }
80 let in_extent = in_shape.get(dim).copied().unwrap_or(1);
81 let mapped = if in_extent == 1 || out_extent == 0 {
82 0
83 } else {
84 coord
85 };
86 if dim < strides.len() {
87 offset += mapped * strides[dim];
88 }
89 }
90 offset
91}
92
93#[derive(Debug, Clone)]
95pub struct BroadcastPlan {
96 output_shape: Vec<usize>,
97 len: usize,
98 advance_a: Vec<usize>,
99 advance_b: Vec<usize>,
100}
101
102impl BroadcastPlan {
103 pub fn new(shape_a: &[usize], shape_b: &[usize]) -> Result<Self, String> {
106 let ndims = shape_a.len().max(shape_b.len());
107
108 let mut ext_a = Vec::with_capacity(ndims);
110 ext_a.extend(std::iter::repeat_n(1, ndims.saturating_sub(shape_a.len())));
111 ext_a.extend_from_slice(shape_a);
112
113 let mut ext_b = Vec::with_capacity(ndims);
114 ext_b.extend(std::iter::repeat_n(1, ndims.saturating_sub(shape_b.len())));
115 ext_b.extend_from_slice(shape_b);
116
117 let mut output_shape = Vec::with_capacity(ndims);
118 for i in 0..ndims {
119 let da = ext_a[i];
120 let db = ext_b[i];
121 if da == db {
122 output_shape.push(da);
123 } else if da == 1 {
124 output_shape.push(db);
125 } else if db == 1 {
126 output_shape.push(da);
127 } else {
128 return Err(format!(
129 "broadcast: non-singleton dimension mismatch (dimension {}: {} vs {})",
130 i + 1,
131 da,
132 db
133 ));
134 }
135 }
136
137 let len = output_shape.iter().copied().product();
138 let strides_a = compute_strides(&ext_a);
139 let strides_b = compute_strides(&ext_b);
140
141 let advance_a = ext_a
142 .iter()
143 .enumerate()
144 .map(|(dim, &size)| if size <= 1 { 0 } else { strides_a[dim] })
145 .collect::<Vec<_>>();
146 let advance_b = ext_b
147 .iter()
148 .enumerate()
149 .map(|(dim, &size)| if size <= 1 { 0 } else { strides_b[dim] })
150 .collect::<Vec<_>>();
151
152 Ok(Self {
153 output_shape,
154 len,
155 advance_a,
156 advance_b,
157 })
158 }
159
160 pub fn len(&self) -> usize {
162 self.len
163 }
164
165 pub fn is_empty(&self) -> bool {
167 self.len == 0
168 }
169
170 pub fn output_shape(&self) -> &[usize] {
172 &self.output_shape
173 }
174
175 pub fn iter(&self) -> BroadcastIter<'_> {
177 BroadcastIter {
178 plan: self,
179 offset: 0,
180 index_a: 0,
181 index_b: 0,
182 coords: vec![0usize; self.output_shape.len()],
183 }
184 }
185}
186
187pub struct BroadcastIter<'a> {
189 plan: &'a BroadcastPlan,
190 offset: usize,
191 index_a: usize,
192 index_b: usize,
193 coords: Vec<usize>,
194}
195
196impl<'a> Iterator for BroadcastIter<'a> {
197 type Item = (usize, usize, usize);
198
199 fn next(&mut self) -> Option<Self::Item> {
200 if self.offset >= self.plan.len {
201 return None;
202 }
203 let current = (self.offset, self.index_a, self.index_b);
204 self.offset += 1;
205 if self.offset == self.plan.len {
206 return Some(current);
207 }
208 for dim in 0..self.plan.output_shape.len() {
209 if self.plan.output_shape[dim] == 0 {
210 continue;
211 }
212 self.coords[dim] += 1;
213 if self.coords[dim] < self.plan.output_shape[dim] {
214 self.index_a += self.plan.advance_a[dim];
215 self.index_b += self.plan.advance_b[dim];
216 break;
217 }
218 self.coords[dim] = 0;
219 let rewind = self.plan.output_shape[dim].saturating_sub(1);
220 let rewind_a = self.plan.advance_a[dim] * rewind;
221 let rewind_b = self.plan.advance_b[dim] * rewind;
222 if rewind_a != 0 {
223 self.index_a = self.index_a.saturating_sub(rewind_a);
224 }
225 if rewind_b != 0 {
226 self.index_b = self.index_b.saturating_sub(rewind_b);
227 }
228 }
229 Some(current)
230 }
231}
232
233#[cfg(test)]
234pub(crate) mod tests {
235 use super::*;
236
237 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
238 #[test]
239 fn broadcast_equal_shapes() {
240 let out = broadcast_shapes("test", &[2, 3], &[2, 3]).unwrap();
241 assert_eq!(out, vec![2, 3]);
242 }
243
244 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
245 #[test]
246 fn broadcast_scalar() {
247 let out = broadcast_shapes("test", &[1, 1], &[4, 5]).unwrap();
248 assert_eq!(out, vec![4, 5]);
249 }
250
251 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
252 #[test]
253 fn broadcast_mismatched_dimension_errors() {
254 let err = broadcast_shapes("test", &[2, 3], &[4, 3]).unwrap_err();
255 assert!(err.contains("dimension 1"));
256 }
257
258 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
259 #[test]
260 fn compute_strides_column_major() {
261 let strides = compute_strides(&[2, 3, 4]);
262 assert_eq!(strides, vec![1, 2, 6]);
263 }
264
265 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
266 #[test]
267 fn broadcast_index_maps_scalar_inputs() {
268 let strides = compute_strides(&[1, 1]);
269 let idx = broadcast_index(5, &[2, 3], &[1, 1], &strides);
270 assert_eq!(idx, 0);
271 }
272
273 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
274 #[test]
275 fn broadcast_same_shape() {
276 let plan = BroadcastPlan::new(&[2, 3], &[2, 3]).unwrap();
277 assert_eq!(plan.output_shape(), &[2, 3]);
278 assert_eq!(plan.len(), 6);
279 let indices: Vec<(usize, usize, usize)> = plan.iter().collect();
280 assert_eq!(
281 indices,
282 vec![
283 (0, 0, 0),
284 (1, 1, 1),
285 (2, 2, 2),
286 (3, 3, 3),
287 (4, 4, 4),
288 (5, 5, 5),
289 ]
290 );
291 }
292
293 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
294 #[test]
295 fn broadcast_scalar_expansion() {
296 let plan = BroadcastPlan::new(&[1, 3], &[1, 1]).unwrap();
297 assert_eq!(plan.output_shape(), &[1, 3]);
298 assert_eq!(plan.len(), 3);
299 let indices: Vec<(usize, usize, usize)> = plan.iter().collect();
300 assert_eq!(indices, vec![(0, 0, 0), (1, 1, 0), (2, 2, 0)]);
301 }
302
303 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
304 #[test]
305 fn broadcast_zero_sized_dimension() {
306 let plan = BroadcastPlan::new(&[0, 3], &[1, 3]).unwrap();
307 assert_eq!(plan.output_shape(), &[0, 3]);
308 assert_eq!(plan.len(), 0);
309 assert_eq!(plan.iter().next(), None);
310 }
311}