rstsr_common/layout/
reshape.rs

1//! Auxiliary function for reshaping a tensor.
2
3use crate::prelude_dev::*;
4
5/// Check `-1` in shape and substitute it with the correct value.
6///
7/// # Arguments
8/// * `shape_out` - The shape of the tensor to be reshaped.
9/// * `size_known` - The size of the original tensor.
10pub fn reshape_substitute_negatives(shape_out: &[isize], size_in: usize) -> Result<Vec<usize>> {
11    let mut shape = shape_out.to_vec();
12
13    // check negative indexes
14    let mut idx_neg1: Option<usize> = None;
15    for (i, &v) in shape.iter().enumerate() {
16        match v {
17            -1 => match idx_neg1 {
18                Some(_) => rstsr_raise!(InvalidValue, "Only one -1 is allowed in shape.")?,
19                None => idx_neg1 = Some(i),
20            },
21            ..-1 => {
22                rstsr_raise!(InvalidValue, "Negative index must be -1.")?;
23            },
24            _ => (),
25        }
26    }
27
28    // substitute negative index
29    if let Some(idx_neg1) = idx_neg1 {
30        let size_in = size_in as isize;
31        let size_neg = shape.iter().fold(1, |acc, &v| if v == -1 { acc } else { acc * v });
32        rstsr_assert!(
33            size_in % size_neg == 0,
34            InvalidValue,
35            "Shape '-1' in {:?} could not be determined to original tensor size {:?}",
36            shape,
37            size_in
38        )?;
39        shape[idx_neg1] = size_in / size_neg;
40    }
41    return Ok(shape.iter().map(|&v| v as usize).collect::<Vec<usize>>());
42}
43
44/// A quick check for reshaping a tensor.
45///
46/// - check if size is the same (raise if failed)
47/// - check if exactly same shape (return if true)
48/// - check if contiguous (return if true)
49///
50/// For more complex reshaping, return `None`, and other functions should handle
51/// this kind of situation.
52///
53/// For order option, row-major and col-major behaves differently.
54fn quick_check(
55    shape_out: &Vec<usize>,
56    layout_in: &Layout<IxD>,
57    order: FlagOrder,
58) -> Result<Option<Layout<IxD>>> {
59    // check if size is the same
60    let size_in = layout_in.size();
61    let size_out = shape_out.iter().product();
62    rstsr_assert_eq!(
63        size_in,
64        size_out,
65        InvalidValue,
66        "Size mismatch between input tensor and output tensor.",
67    )?;
68
69    // if size is zero or one, return immediately
70    // currently, we use broadcast way to handle this case
71    // strides will be set to zeros, which should not affect computation
72    if size_in == 0 || size_in == 1 {
73        let strides = vec![0; shape_out.len()];
74        return Ok(Some(Layout::<IxD>::new(shape_out.clone(), strides, layout_in.offset())?));
75    }
76
77    // check if exactly same shape
78    if shape_out == layout_in.shape() {
79        return Ok(Some(layout_in.clone()));
80    }
81
82    // check if contiguous
83    match order {
84        RowMajor => {
85            if layout_in.c_contig() {
86                return Ok(Some(shape_out.new_c_contig(Some(layout_in.offset()))));
87            }
88        },
89        ColMajor => {
90            if layout_in.f_contig() {
91                return Ok(Some(shape_out.new_f_contig(Some(layout_in.offset()))));
92            }
93        },
94    };
95
96    // all easy cases checked, return None for further reshaping
97    return Ok(None);
98}
99
100/// Internal function that pop input layout.
101///
102/// This function is for c-prefer (row-major) only.
103///
104/// # Returns
105/// * `Vec<usize>` - The size of partly contiguous (with a minimum stride) batch
106///   of input tensor.
107/// * `Vec<isize>` - The minimum stride of the current batch.
108fn pop_layout_in(shape_in: &mut Vec<usize>, stride_in: &mut Vec<isize>) -> (usize, isize) {
109    rstsr_assert_eq!(shape_in.len(), stride_in.len(), RuntimeError).unwrap();
110    rstsr_assert!(!shape_in.is_empty(), RuntimeError).unwrap();
111
112    let mut stride_min = stride_in.pop().unwrap();
113    let mut size = shape_in.pop().unwrap();
114
115    // determine if current batch is broadcasted
116    if size == 1 || stride_min == 0 {
117        // broadcasted, reset stride_min to 0
118        stride_min = 0;
119        while stride_in.last().is_some_and(|&v| v == 0) || shape_in.last().is_some_and(|&v| v == 1)
120        {
121            stride_in.pop();
122            size *= shape_in.pop().unwrap();
123        }
124        return (size, stride_min);
125    } else {
126        // general case
127        while stride_in.last().is_some_and(|&v| v == size as isize * stride_min) {
128            stride_in.pop();
129            size *= shape_in.pop().unwrap();
130        }
131        return (size, stride_min);
132    }
133}
134
135/// Internal function that pop output shape, and inject output strides.
136///
137/// This function is for c-prefer (row-major) only.
138/// However, note that `stride_out` is in reverse order.
139///
140/// This function will return `true/false` depending on compatibility of shape.
141fn pop_shape_out(
142    shape_out: &mut Vec<usize>,
143    stride_out: &mut Vec<isize>,
144    mut size: usize,
145    mut stride_min: isize,
146) -> bool {
147    rstsr_assert!(!shape_out.is_empty(), RuntimeError).unwrap();
148
149    while size != 1 || shape_out.last().is_some_and(|&v| v == 1) {
150        let s_out = shape_out.pop().unwrap();
151        if size % s_out != 0 {
152            return false;
153        }
154        size /= s_out;
155        stride_out.push(stride_min);
156        stride_min *= s_out as isize;
157    }
158
159    return true;
160}
161
162/// Internal function for reshaping a tensor in any cases.
163fn complicated_reshape(
164    shape_out: &[usize],
165    layout_in: &Layout<IxD>,
166    order: FlagOrder,
167) -> Option<Layout<IxD>> {
168    let shape_out_ref = shape_out; // the original shape_out not modified
169    let mut shape_out = shape_out.to_vec(); // the shape_out to be destroyed in iteration
170    let mut stride_out = Vec::new();
171    let mut shape_in = layout_in.shape().to_vec();
172    let mut stride_in = layout_in.stride().to_vec();
173    let offset = layout_in.offset();
174
175    // f-prefer handled by reversing everything
176    if order == FlagOrder::F {
177        shape_in.reverse();
178        stride_in.reverse();
179        shape_out.reverse();
180    }
181
182    while !shape_in.is_empty() {
183        let (size_in, stride_in_min) = pop_layout_in(&mut shape_in, &mut stride_in);
184        if !pop_shape_out(&mut shape_out, &mut stride_out, size_in, stride_in_min) {
185            return None;
186        }
187    }
188    rstsr_assert!(shape_out.is_empty(), RuntimeError).unwrap();
189    rstsr_assert_eq!(stride_out.len(), shape_out_ref.len(), RuntimeError).unwrap();
190    // note that stride_out is in reverse order in c-prefer
191    // as contrary, shape_out is in reverse order in f-prefer
192    match order {
193        RowMajor => stride_out.reverse(),
194        ColMajor => shape_out.reverse(),
195    };
196
197    let layout_out =
198        unsafe { Layout::<IxD>::new_unchecked(shape_out_ref.to_vec(), stride_out, offset) };
199    return Some(layout_out);
200}
201
202/// Check if a tensor can be reshaped to a new shape without explicitly copying
203/// underlying data.
204///
205/// - If shape not match, this function will raise error.
206/// - If shape match but data need to be copied, return `Ok(None)`.
207/// - If everything is fine, return `Ok(Some(layout_out))`.
208///
209/// For order, row-major and col-major behaves differently.
210pub fn layout_reshapeable(
211    layout_in: &Layout<IxD>,
212    shape_out: &Vec<usize>,
213    order: FlagOrder,
214) -> Result<Option<Layout<IxD>>> {
215    if let Some(layout_out) = quick_check(shape_out, layout_in, order)? {
216        return Ok(Some(layout_out));
217    }
218    return Ok(complicated_reshape(shape_out, layout_in, order));
219}