1use crate::prelude_dev::*;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum BroadcastType {
10 Upcast,
11 Expand,
12 Preserve,
13 Undefined,
14}
15
16pub fn broadcast_shape<D1, D2, D>(
22 shape1: &D1,
23 shape2: &D2,
24 order: FlagOrder,
25) -> Result<(D, Vec<BroadcastType>, Vec<BroadcastType>)>
26where
27 D1: DimBaseAPI + DimMaxAPI<D2, Max = D>,
28 D2: DimBaseAPI,
29 D: DimBaseAPI,
30{
31 let mut shape1: Vec<usize> = shape1.clone().into();
33 let mut shape2: Vec<usize> = shape2.clone().into();
34 if order == ColMajor {
35 shape1.reverse();
36 shape2.reverse();
37 };
38 let (n1, n2) = (shape1.ndim(), shape2.ndim());
40 let n = usize::max(n1, n2);
41 let mut shape = vec![0; n];
43 let mut tp1 = vec![BroadcastType::Undefined; n];
44 let mut tp2 = vec![BroadcastType::Undefined; n];
45 for i in (0..n).rev() {
47 let in1 = (n1 + i) as isize - n as isize;
48 let in2 = (n2 + i) as isize - n as isize;
49
50 let d1 = if in1 >= 0 { shape1[in1 as usize] } else { 1 };
51 let d2 = if in2 >= 0 { shape2[in2 as usize] } else { 1 };
52
53 match (d1 == 1, d2 == 1) {
54 (true, true) => {
55 tp1[i] = BroadcastType::Preserve;
56 tp2[i] = BroadcastType::Preserve;
57 shape[i] = 1;
58 },
59 (false, true) => {
60 tp1[i] = BroadcastType::Preserve;
61 tp2[i] = BroadcastType::Upcast;
62 shape[i] = d1;
63 },
64 (true, false) => {
65 tp1[i] = BroadcastType::Upcast;
66 tp2[i] = BroadcastType::Preserve;
67 shape[i] = d2;
68 },
69 (false, false) => {
70 rstsr_assert_eq!(d1, d2, InvalidLayout, "Broadcasting failed.")?;
71 tp1[i] = BroadcastType::Preserve;
72 tp2[i] = BroadcastType::Preserve;
73 shape[i] = d1;
74 },
75 }
76
77 if in1 < 0 {
78 tp1[i] = BroadcastType::Expand;
79 }
80 if in2 < 0 {
81 tp2[i] = BroadcastType::Expand;
82 }
83 }
84 if order == ColMajor {
86 shape.reverse();
87 tp1.reverse();
88 tp2.reverse();
89 }
90 let shape = TryInto::<D>::try_into(shape);
92 let shape = shape.map_err(|_| rstsr_error!(InvalidLayout, "Type cast error."))?;
93
94 return Ok((shape, tp1, tp2));
95}
96
97pub trait DimBroadcastableAPI: DimBaseAPI {
98 fn broadcastable_from<D2>(&self, other: &D2) -> bool
102 where
103 D2: DimBaseAPI,
104 {
105 let (shape1, shape2) = (self.as_ref(), other.as_ref());
106 let (n1, n2) = (shape1.len(), shape2.len());
107 let n = usize::max(n1, n2);
108 if n != n1 {
109 return false;
110 }
111 for i in (0..n).rev() {
112 let in1 = (n1 + i) as isize - n as isize;
113 let in2 = (n2 + i) as isize - n as isize;
114
115 let d1 = if in1 >= 0 { shape1[in1 as usize] } else { 1 };
116 let d2 = if in2 >= 0 { shape2[in2 as usize] } else { 1 };
117
118 if d1 != d2 && d2 != 1 {
119 return false;
120 }
121 }
122 return true;
123 }
124
125 fn broadcastable_to<D2>(&self, other: &D2) -> bool
129 where
130 D2: DimBaseAPI,
131 {
132 let (shape1, shape2) = (self.as_ref(), other.as_ref());
133 let (n1, n2) = (shape1.len(), shape2.len());
134 let n = usize::max(n1, n2);
135 if n != n2 {
136 return false;
137 }
138 for i in (0..n).rev() {
139 let in1 = (n1 + i) as isize - n as isize;
140 let in2 = (n2 + i) as isize - n as isize;
141
142 let d1 = if in1 >= 0 { shape1[in1 as usize] } else { 1 };
143 let d2 = if in2 >= 0 { shape2[in2 as usize] } else { 1 };
144
145 if d1 != d2 && d1 != 1 {
146 return false;
147 }
148 }
149 return true;
150 }
151}
152
153impl<D> DimBroadcastableAPI for D where D: DimAPI {}
154
155pub fn broadcast_layout<D1, D2, D>(
167 layout1: &Layout<D1>,
168 layout2: &Layout<D2>,
169 order: FlagOrder,
170) -> Result<(Layout<D>, Layout<D>)>
171where
172 D1: DimDevAPI + DimMaxAPI<D2, Max = D>,
173 D2: DimDevAPI,
174 D: DimDevAPI,
175{
176 let shape1 = layout1.shape();
177 let shape2 = layout2.shape();
178 let (shape, tp1, tp2) = broadcast_shape(shape1, shape2, order)?;
179 let layout1 = update_layout_by_shape(layout1, &shape, &tp1, order)?;
180 let layout2 = update_layout_by_shape(layout2, &shape, &tp2, order)?;
181 return Ok((layout1, layout2));
182}
183
184pub fn broadcast_layout_to_first<D1, D2, D>(
192 layout1: &Layout<D1>,
193 layout2: &Layout<D2>,
194 order: FlagOrder,
195) -> Result<(Layout<D1>, Layout<D1>)>
196where
197 D1: DimDevAPI + DimMaxAPI<D2, Max = D>,
198 D2: DimDevAPI,
199 D: DimIntoAPI<D1> + DimDevAPI,
200{
201 let (layout1, layout2) = broadcast_layout(layout1, layout2, order)?;
202 let layout1 = layout1.into_dim::<D1>()?;
203 let layout2 = layout2.into_dim::<D1>()?;
204 return Ok((layout1, layout2));
205}
206
207pub fn update_layout_by_shape<D, DMax>(
208 layout: &Layout<D>,
209 shape: &DMax,
210 broadcast_type: &[BroadcastType],
211 order: FlagOrder,
212) -> Result<Layout<DMax>>
213where
214 D: DimDevAPI,
215 DMax: DimDevAPI,
216{
217 if order == ColMajor {
219 let mut shape: IxD = shape.clone().into();
220 shape.reverse();
221 let shape: DMax = unsafe { shape.try_into().unwrap_unchecked() };
222 let mut broadcast_type = broadcast_type.to_vec();
223 broadcast_type.reverse();
224 let layout = layout.reverse_axes();
225 let result = update_layout_by_shape(&layout, &shape, &broadcast_type, RowMajor);
226 return result.map(|layout| layout.reverse_axes());
227 }
228 assert_eq!(order, RowMajor);
229 let n_old = layout.ndim();
230 let stride_old = layout.stride();
231 let n = shape.ndim();
232 let mut stride = vec![0; n];
233 stride[n - n_old..n].copy_from_slice(stride_old.as_ref());
234 for i in 0..n {
235 match broadcast_type[i] {
236 BroadcastType::Expand | BroadcastType::Upcast => {
237 stride[i] = 0;
238 },
239 _ => {},
240 }
241 }
242 let stride = stride.try_into();
243 let stride = stride.map_err(|_| rstsr_error!(InvalidLayout, "Type cast error."))?;
244 unsafe { Ok(Layout::new_unchecked(shape.clone(), stride, layout.offset())) }
245}
246
247impl<D> Layout<D>
248where
249 D: DimBaseAPI,
250{
251 pub fn size_non_broadcast(&self) -> usize {
256 if self.size() == 0 {
257 return 0;
258 }
259 let mut size = 1;
260 for i in 0..self.ndim() {
261 if self.stride[i] != 0 {
262 size *= self.shape[i];
263 }
264 }
265 return size;
266 }
267
268 pub fn is_broadcasted(&self) -> bool {
272 self.stride().as_ref().contains(&0)
273 }
274}
275
276#[cfg(test)]
277mod test {
278 use super::*;
279 use BroadcastType::*;
280
281 #[test]
282 fn test_broadcast_shape() {
283 let shape1 = [8, 1, 6, 1];
288 let shape2 = [7, 1, 5];
289 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
290 assert!(!shape1.broadcastable_from(&shape2));
291 assert!(!shape1.broadcastable_to(&shape2));
292 assert_eq!(broadcast.0, [8, 7, 6, 5]);
293 assert_eq!(broadcast.1, [Preserve, Upcast, Preserve, Upcast]);
294 assert_eq!(broadcast.2, [Expand, Preserve, Upcast, Preserve]);
295 let shape1 = [5, 4];
300 let shape2 = [1];
301 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
302 assert!(shape1.broadcastable_from(&shape2));
303 assert!(!shape1.broadcastable_to(&shape2));
304 assert_eq!(broadcast.0, [5, 4]);
305 assert_eq!(broadcast.1, [Preserve, Preserve]);
306 assert_eq!(broadcast.2, [Expand, Upcast]);
307 let shape1 = [5, 4];
312 let shape2 = [4];
313 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
314 assert!(shape1.broadcastable_from(&shape2));
315 assert!(!shape1.broadcastable_to(&shape2));
316 assert_eq!(broadcast.0, [5, 4]);
317 assert_eq!(broadcast.1, [Preserve, Preserve]);
318 assert_eq!(broadcast.2, [Expand, Preserve]);
319 let shape1 = [15, 3, 5];
324 let shape2 = [15, 1, 5];
325 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
326 assert!(shape1.broadcastable_from(&shape2));
327 assert!(!shape1.broadcastable_to(&shape2));
328 assert_eq!(broadcast.0, [15, 3, 5]);
329 assert_eq!(broadcast.1, [Preserve, Preserve, Preserve]);
330 assert_eq!(broadcast.2, [Preserve, Upcast, Preserve]);
331 let shape1 = [15, 3, 5];
336 let shape2 = [3, 5];
337 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
338 assert!(shape1.broadcastable_from(&shape2));
339 assert!(!shape1.broadcastable_to(&shape2));
340 assert_eq!(broadcast.0, [15, 3, 5]);
341 assert_eq!(broadcast.1, [Preserve, Preserve, Preserve]);
342 assert_eq!(broadcast.2, [Expand, Preserve, Preserve]);
343 let shape1 = [15, 3, 5];
348 let shape2 = [3, 1];
349 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
350 assert!(shape1.broadcastable_from(&shape2));
351 assert!(!shape1.broadcastable_to(&shape2));
352 assert_eq!(broadcast.0, [15, 3, 5]);
353 assert_eq!(broadcast.1, [Preserve, Preserve, Preserve]);
354 assert_eq!(broadcast.2, [Expand, Preserve, Upcast]);
355
356 let shape1 = [1, 1, 2];
358 let shape2 = [1, 2];
359 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
360 assert!(shape1.broadcastable_from(&shape2));
361 assert!(!shape1.broadcastable_to(&shape2));
362 assert_eq!(broadcast.0, [1, 1, 2]);
363 assert_eq!(broadcast.1, [Preserve, Preserve, Preserve]);
364 assert_eq!(broadcast.2, [Expand, Preserve, Preserve]);
365
366 let shape1 = [1, 2];
368 let shape2 = [1, 1, 2];
369 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
370 assert!(!shape1.broadcastable_from(&shape2));
371 assert!(shape1.broadcastable_to(&shape2));
372 assert_eq!(broadcast.0, [1, 1, 2]);
373 assert_eq!(broadcast.1, [Expand, Preserve, Preserve]);
374 assert_eq!(broadcast.2, [Preserve, Preserve, Preserve]);
375 }
376
377 #[test]
378 fn test_broadcast_shape_fail() {
379 let shape1 = [3];
382 let shape2 = [4];
383 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor);
384 assert!(broadcast.is_err());
385 let shape1 = [2, 1];
388 let shape2 = [8, 4, 3];
389 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor);
390 assert!(broadcast.is_err());
391 let shape1 = [15, 3, 5];
395 let shape2 = [15, 3];
396 let broadcast = broadcast_shape(&shape1, &shape2, RowMajor);
397 assert!(broadcast.is_err());
398 }
399
400 #[test]
401 fn test_broadcast_layout() {
402 let shape1 = [8, 1, 6, 3, 1];
407 let shape2 = [7, 1, 3, 5];
408 let layout1 = shape1.c();
409 let layout2 = shape2.f();
410 let (layout1, layout2) = broadcast_layout(&layout1, &layout2, RowMajor).unwrap();
411 assert_eq!(layout1.shape(), &[8, 7, 6, 3, 5]);
412 assert_eq!(layout2.shape(), &[8, 7, 6, 3, 5]);
413 assert_eq!(layout1.stride(), &[18, 0, 3, 1, 0]);
414 assert_eq!(layout2.stride(), &[0, 1, 0, 7, 21]);
415 }
416}