rstsr_common/layout/
broadcast.rs

1//! Layout broadcasting.
2//!
3//! We refer to documentation of Python array API: [broadcasting](https://data-apis.org/array-api/2024.12/API_specification/broadcasting.html).
4
5// use super::DimMaxAPI;
6use crate::prelude_dev::*;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum BroadcastType {
10    Upcast,
11    Expand,
12    Preserve,
13    Undefined,
14}
15
16/// Shape broadcasting.
17///
18/// # See also
19///
20/// [broadcasting](https://data-apis.org/array-api/2024.12/API_specification/broadcasting.html)
21pub fn broadcast_shape<D1, D2, D>(
22    shape1: &D1,
23    shape2: &D2,
24    order: FlagOrder,
25) -> Result<(D, Vec<BroadcastType>, Vec<BroadcastType>)>
26where
27    D1: DimBaseAPI + DimMaxAPI<D2, Max = D>,
28    D2: DimBaseAPI,
29    D: DimBaseAPI,
30{
31    // order: flip if col-major
32    let mut shape1: Vec<usize> = shape1.clone().into();
33    let mut shape2: Vec<usize> = shape2.clone().into();
34    if order == ColMajor {
35        shape1.reverse();
36        shape2.reverse();
37    };
38    // step 1-6: determine maximum shape
39    let (n1, n2) = (shape1.ndim(), shape2.ndim());
40    let n = usize::max(n1, n2);
41    // step 7: declare result shape and corresponding broadcast type
42    let mut shape = vec![0; n];
43    let mut tp1 = vec![BroadcastType::Undefined; n];
44    let mut tp2 = vec![BroadcastType::Undefined; n];
45    // step 8-10: iterate over the shape
46    for i in (0..n).rev() {
47        let in1 = (n1 + i) as isize - n as isize;
48        let in2 = (n2 + i) as isize - n as isize;
49
50        let d1 = if in1 >= 0 { shape1[in1 as usize] } else { 1 };
51        let d2 = if in2 >= 0 { shape2[in2 as usize] } else { 1 };
52
53        match (d1 == 1, d2 == 1) {
54            (true, true) => {
55                tp1[i] = BroadcastType::Preserve;
56                tp2[i] = BroadcastType::Preserve;
57                shape[i] = 1;
58            },
59            (false, true) => {
60                tp1[i] = BroadcastType::Preserve;
61                tp2[i] = BroadcastType::Upcast;
62                shape[i] = d1;
63            },
64            (true, false) => {
65                tp1[i] = BroadcastType::Upcast;
66                tp2[i] = BroadcastType::Preserve;
67                shape[i] = d2;
68            },
69            (false, false) => {
70                rstsr_assert_eq!(d1, d2, InvalidLayout, "Broadcasting failed.")?;
71                tp1[i] = BroadcastType::Preserve;
72                tp2[i] = BroadcastType::Preserve;
73                shape[i] = d1;
74            },
75        }
76
77        if in1 < 0 {
78            tp1[i] = BroadcastType::Expand;
79        }
80        if in2 < 0 {
81            tp2[i] = BroadcastType::Expand;
82        }
83    }
84    // flip back if col-major
85    if order == ColMajor {
86        shape.reverse();
87        tp1.reverse();
88        tp2.reverse();
89    }
90    // convert to the final shape
91    let shape = TryInto::<D>::try_into(shape);
92    let shape = shape.map_err(|_| rstsr_error!(InvalidLayout, "Type cast error."))?;
93
94    return Ok((shape, tp1, tp2));
95}
96
97pub trait DimBroadcastableAPI: DimBaseAPI {
98    /// Check whether second shape can be broadcasted to first shape.
99    ///
100    /// Order of the two parameters depends.
101    fn broadcastable_from<D2>(&self, other: &D2) -> bool
102    where
103        D2: DimBaseAPI,
104    {
105        let (shape1, shape2) = (self.as_ref(), other.as_ref());
106        let (n1, n2) = (shape1.len(), shape2.len());
107        let n = usize::max(n1, n2);
108        if n != n1 {
109            return false;
110        }
111        for i in (0..n).rev() {
112            let in1 = (n1 + i) as isize - n as isize;
113            let in2 = (n2 + i) as isize - n as isize;
114
115            let d1 = if in1 >= 0 { shape1[in1 as usize] } else { 1 };
116            let d2 = if in2 >= 0 { shape2[in2 as usize] } else { 1 };
117
118            if d1 != d2 && d2 != 1 {
119                return false;
120            }
121        }
122        return true;
123    }
124
125    /// Check whether first shape can be broadcasted to second shape.
126    ///
127    /// Order of the two parameters depends.
128    fn broadcastable_to<D2>(&self, other: &D2) -> bool
129    where
130        D2: DimBaseAPI,
131    {
132        let (shape1, shape2) = (self.as_ref(), other.as_ref());
133        let (n1, n2) = (shape1.len(), shape2.len());
134        let n = usize::max(n1, n2);
135        if n != n2 {
136            return false;
137        }
138        for i in (0..n).rev() {
139            let in1 = (n1 + i) as isize - n as isize;
140            let in2 = (n2 + i) as isize - n as isize;
141
142            let d1 = if in1 >= 0 { shape1[in1 as usize] } else { 1 };
143            let d2 = if in2 >= 0 { shape2[in2 as usize] } else { 1 };
144
145            if d1 != d2 && d1 != 1 {
146                return false;
147            }
148        }
149        return true;
150    }
151}
152
153impl<D> DimBroadcastableAPI for D where D: DimAPI {}
154
155/// Layout broadcasting.
156///
157/// Dimensions that to be upcasted or expanded will have stride length of zero.
158///
159/// Note that zero stride length is generally not accepted, since different
160/// indices will point to the same memory, which is not expected in most cases
161/// for this library. But this will be convenient when we need to broadcast.
162///
163/// # See also
164///
165/// [broadcasting](https://data-apis.org/array-api/2024.12/API_specification/broadcasting.html)
166pub fn broadcast_layout<D1, D2, D>(
167    layout1: &Layout<D1>,
168    layout2: &Layout<D2>,
169    order: FlagOrder,
170) -> Result<(Layout<D>, Layout<D>)>
171where
172    D1: DimDevAPI + DimMaxAPI<D2, Max = D>,
173    D2: DimDevAPI,
174    D: DimDevAPI,
175{
176    let shape1 = layout1.shape();
177    let shape2 = layout2.shape();
178    let (shape, tp1, tp2) = broadcast_shape(shape1, shape2, order)?;
179    let layout1 = update_layout_by_shape(layout1, &shape, &tp1, order)?;
180    let layout2 = update_layout_by_shape(layout2, &shape, &tp2, order)?;
181    return Ok((layout1, layout2));
182}
183
184/// Layout broadcasting.
185///
186/// This function will broadcast the layout to the first layout.
187///
188/// # See also
189///
190/// [`broadcast_layout`]
191pub fn broadcast_layout_to_first<D1, D2, D>(
192    layout1: &Layout<D1>,
193    layout2: &Layout<D2>,
194    order: FlagOrder,
195) -> Result<(Layout<D1>, Layout<D1>)>
196where
197    D1: DimDevAPI + DimMaxAPI<D2, Max = D>,
198    D2: DimDevAPI,
199    D: DimIntoAPI<D1> + DimDevAPI,
200{
201    let (layout1, layout2) = broadcast_layout(layout1, layout2, order)?;
202    let layout1 = layout1.into_dim::<D1>()?;
203    let layout2 = layout2.into_dim::<D1>()?;
204    return Ok((layout1, layout2));
205}
206
207pub fn update_layout_by_shape<D, DMax>(
208    layout: &Layout<D>,
209    shape: &DMax,
210    broadcast_type: &[BroadcastType],
211    order: FlagOrder,
212) -> Result<Layout<DMax>>
213where
214    D: DimDevAPI,
215    DMax: DimDevAPI,
216{
217    // handle col-major
218    if order == ColMajor {
219        let mut shape: IxD = shape.clone().into();
220        shape.reverse();
221        let shape: DMax = unsafe { shape.try_into().unwrap_unchecked() };
222        let mut broadcast_type = broadcast_type.to_vec();
223        broadcast_type.reverse();
224        let layout = layout.reverse_axes();
225        let result = update_layout_by_shape(&layout, &shape, &broadcast_type, RowMajor);
226        return result.map(|layout| layout.reverse_axes());
227    }
228    assert_eq!(order, RowMajor);
229    let n_old = layout.ndim();
230    let stride_old = layout.stride();
231    let n = shape.ndim();
232    let mut stride = vec![0; n];
233    stride[n - n_old..n].copy_from_slice(stride_old.as_ref());
234    for i in 0..n {
235        match broadcast_type[i] {
236            BroadcastType::Expand | BroadcastType::Upcast => {
237                stride[i] = 0;
238            },
239            _ => {},
240        }
241    }
242    let stride = stride.try_into();
243    let stride = stride.map_err(|_| rstsr_error!(InvalidLayout, "Type cast error."))?;
244    unsafe { Ok(Layout::new_unchecked(shape.clone(), stride, layout.offset())) }
245}
246
247impl<D> Layout<D>
248where
249    D: DimBaseAPI,
250{
251    /// Get the size of the non-broadcasted part.
252    ///
253    /// Equivalent to `size()` if there is no broadcast (setting axis size = 1
254    /// where stride = 0).
255    pub fn size_non_broadcast(&self) -> usize {
256        if self.size() == 0 {
257            return 0;
258        }
259        let mut size = 1;
260        for i in 0..self.ndim() {
261            if self.stride[i] != 0 {
262                size *= self.shape[i];
263            }
264        }
265        return size;
266    }
267
268    /// Check whether current layout has been broadcasted.
269    ///
270    /// This check is done by checking whether any stride of axis is zero.
271    pub fn is_broadcasted(&self) -> bool {
272        self.stride().as_ref().contains(&0)
273    }
274}
275
276#[cfg(test)]
277mod test {
278    use super::*;
279    use BroadcastType::*;
280
281    #[test]
282    fn test_broadcast_shape() {
283        // A      (4d array):  8 x 1 x 6 x 1
284        // B      (3d array):      7 x 1 x 5
285        // ---------------------------------
286        // Result (4d array):  8 x 7 x 6 x 5
287        let shape1 = [8, 1, 6, 1];
288        let shape2 = [7, 1, 5];
289        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
290        assert!(!shape1.broadcastable_from(&shape2));
291        assert!(!shape1.broadcastable_to(&shape2));
292        assert_eq!(broadcast.0, [8, 7, 6, 5]);
293        assert_eq!(broadcast.1, [Preserve, Upcast, Preserve, Upcast]);
294        assert_eq!(broadcast.2, [Expand, Preserve, Upcast, Preserve]);
295        // A      (2d array):  5 x 4
296        // B      (1d array):      1
297        // -------------------------
298        // Result (2d array):  5 x 4
299        let shape1 = [5, 4];
300        let shape2 = [1];
301        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
302        assert!(shape1.broadcastable_from(&shape2));
303        assert!(!shape1.broadcastable_to(&shape2));
304        assert_eq!(broadcast.0, [5, 4]);
305        assert_eq!(broadcast.1, [Preserve, Preserve]);
306        assert_eq!(broadcast.2, [Expand, Upcast]);
307        // A      (2d array):  5 x 4
308        // B      (1d array):      4
309        // -------------------------
310        // Result (2d array):  5 x 4
311        let shape1 = [5, 4];
312        let shape2 = [4];
313        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
314        assert!(shape1.broadcastable_from(&shape2));
315        assert!(!shape1.broadcastable_to(&shape2));
316        assert_eq!(broadcast.0, [5, 4]);
317        assert_eq!(broadcast.1, [Preserve, Preserve]);
318        assert_eq!(broadcast.2, [Expand, Preserve]);
319        // A      (3d array):  15 x 3 x 5
320        // B      (3d array):  15 x 1 x 5
321        // ------------------------------
322        // Result (3d array):  15 x 3 x 5
323        let shape1 = [15, 3, 5];
324        let shape2 = [15, 1, 5];
325        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
326        assert!(shape1.broadcastable_from(&shape2));
327        assert!(!shape1.broadcastable_to(&shape2));
328        assert_eq!(broadcast.0, [15, 3, 5]);
329        assert_eq!(broadcast.1, [Preserve, Preserve, Preserve]);
330        assert_eq!(broadcast.2, [Preserve, Upcast, Preserve]);
331        // A      (3d array):  15 x 3 x 5
332        // B      (2d array):       3 x 5
333        // ------------------------------
334        // Result (3d array):  15 x 3 x 5
335        let shape1 = [15, 3, 5];
336        let shape2 = [3, 5];
337        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
338        assert!(shape1.broadcastable_from(&shape2));
339        assert!(!shape1.broadcastable_to(&shape2));
340        assert_eq!(broadcast.0, [15, 3, 5]);
341        assert_eq!(broadcast.1, [Preserve, Preserve, Preserve]);
342        assert_eq!(broadcast.2, [Expand, Preserve, Preserve]);
343        // A      (3d array):  15 x 3 x 5
344        // B      (2d array):       3 x 1
345        // ------------------------------
346        // Result (3d array):  15 x 3 x 5
347        let shape1 = [15, 3, 5];
348        let shape2 = [3, 1];
349        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
350        assert!(shape1.broadcastable_from(&shape2));
351        assert!(!shape1.broadcastable_to(&shape2));
352        assert_eq!(broadcast.0, [15, 3, 5]);
353        assert_eq!(broadcast.1, [Preserve, Preserve, Preserve]);
354        assert_eq!(broadcast.2, [Expand, Preserve, Upcast]);
355
356        // other test cases
357        let shape1 = [1, 1, 2];
358        let shape2 = [1, 2];
359        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
360        assert!(shape1.broadcastable_from(&shape2));
361        assert!(!shape1.broadcastable_to(&shape2));
362        assert_eq!(broadcast.0, [1, 1, 2]);
363        assert_eq!(broadcast.1, [Preserve, Preserve, Preserve]);
364        assert_eq!(broadcast.2, [Expand, Preserve, Preserve]);
365
366        // other test cases
367        let shape1 = [1, 2];
368        let shape2 = [1, 1, 2];
369        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor).unwrap();
370        assert!(!shape1.broadcastable_from(&shape2));
371        assert!(shape1.broadcastable_to(&shape2));
372        assert_eq!(broadcast.0, [1, 1, 2]);
373        assert_eq!(broadcast.1, [Expand, Preserve, Preserve]);
374        assert_eq!(broadcast.2, [Preserve, Preserve, Preserve]);
375    }
376
377    #[test]
378    fn test_broadcast_shape_fail() {
379        // A      (1d array):  3
380        // B      (1d array):  4           # dimension does not match
381        let shape1 = [3];
382        let shape2 = [4];
383        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor);
384        assert!(broadcast.is_err());
385        // A      (2d array):      2 x 1
386        // B      (3d array):  8 x 4 x 3   # second dimension does not match
387        let shape1 = [2, 1];
388        let shape2 = [8, 4, 3];
389        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor);
390        assert!(broadcast.is_err());
391        // A      (3d array):  15 x 3 x 5
392        // B      (2d array):  15 x 3
393        // # singleton dimensions can only be prepended, not appended
394        let shape1 = [15, 3, 5];
395        let shape2 = [15, 3];
396        let broadcast = broadcast_shape(&shape1, &shape2, RowMajor);
397        assert!(broadcast.is_err());
398    }
399
400    #[test]
401    fn test_broadcast_layout() {
402        // A      (4d array):  8 x 1 x 6 x 3 x 1
403        // B      (3d array):      7 x 1 x 3 x 5
404        // -------------------------------------
405        // Result (4d array):  8 x 7 x 6 x 3 x 5
406        let shape1 = [8, 1, 6, 3, 1];
407        let shape2 = [7, 1, 3, 5];
408        let layout1 = shape1.c();
409        let layout2 = shape2.f();
410        let (layout1, layout2) = broadcast_layout(&layout1, &layout2, RowMajor).unwrap();
411        assert_eq!(layout1.shape(), &[8, 7, 6, 3, 5]);
412        assert_eq!(layout2.shape(), &[8, 7, 6, 3, 5]);
413        assert_eq!(layout1.stride(), &[18, 0, 3, 1, 0]);
414        assert_eq!(layout2.stride(), &[0, 1, 0, 7, 21]);
415    }
416}