rstsr_common/layout/
reshape.rs

1//! Auxiliary function for reshaping a tensor.
2
3use crate::prelude_dev::*;
4
5/// Check `-1` in shape and substitute it with the correct value.
6///
7/// # Arguments
8/// * `shape_out` - The shape of the tensor to be reshaped.
9/// * `size_known` - The size of the original tensor.
10pub fn reshape_substitute_negatives(shape_out: &[isize], size_in: usize) -> Result<Vec<usize>> {
11    let mut shape = shape_out.to_vec();
12
13    // check negative indexes
14    let mut idx_neg1: Option<usize> = None;
15    for (i, &v) in shape.iter().enumerate() {
16        match v {
17            -1 => match idx_neg1 {
18                Some(_) => rstsr_raise!(InvalidValue, "Only one -1 is allowed in shape.")?,
19                None => idx_neg1 = Some(i),
20            },
21            ..-1 => {
22                rstsr_raise!(InvalidValue, "Negative index must be -1.")?;
23            },
24            _ => (),
25        }
26    }
27
28    // substitute negative index
29    if let Some(idx_neg1) = idx_neg1 {
30        let size_in = size_in as isize;
31        let size_neg = shape.iter().fold(1, |acc, &v| if v == -1 { acc } else { acc * v });
32        rstsr_assert!(
33            size_in % size_neg == 0,
34            InvalidValue,
35            "Shape '-1' in {:?} could not be determined to original tensor size {:?}",
36            shape,
37            size_in
38        )?;
39        shape[idx_neg1] = size_in / size_neg;
40    }
41    return Ok(shape.iter().map(|&v| v as usize).collect::<Vec<usize>>());
42}
43
44/// A quick check for reshaping a tensor.
45///
46/// - check if size is the same (raise if failed)
47/// - check if exactly same shape (return if true)
48/// - check if contiguous (return if true)
49///
50/// For more complex reshaping, return `None`, and other functions should handle
51/// this kind of situation.
52///
53/// For order option, row-major and col-major behaves differently.
54fn quick_check(shape_out: &Vec<usize>, layout_in: &Layout<IxD>, order: FlagOrder) -> Result<Option<Layout<IxD>>> {
55    // check if size is the same
56    let size_in = layout_in.size();
57    let size_out = shape_out.iter().product();
58    rstsr_assert_eq!(size_in, size_out, InvalidValue, "Size mismatch between input tensor and output tensor.",)?;
59
60    // if size is zero or one, return immediately
61    // currently, we use broadcast way to handle this case
62    // strides will be set to zeros, which should not affect computation
63    if size_in == 0 || size_in == 1 {
64        let strides = vec![0; shape_out.len()];
65        return Ok(Some(Layout::<IxD>::new(shape_out.clone(), strides, layout_in.offset())?));
66    }
67
68    // check if exactly same shape
69    if shape_out == layout_in.shape() {
70        return Ok(Some(layout_in.clone()));
71    }
72
73    // check if contiguous
74    match order {
75        RowMajor => {
76            if layout_in.c_contig() {
77                return Ok(Some(shape_out.new_c_contig(Some(layout_in.offset()))));
78            }
79        },
80        ColMajor => {
81            if layout_in.f_contig() {
82                return Ok(Some(shape_out.new_f_contig(Some(layout_in.offset()))));
83            }
84        },
85    };
86
87    // all easy cases checked, return None for further reshaping
88    return Ok(None);
89}
90
91/// Internal function that pop input layout.
92///
93/// This function is for c-prefer (row-major) only.
94///
95/// # Returns
96/// * `Vec<usize>` - The size of partly contiguous (with a minimum stride) batch of input tensor.
97/// * `Vec<isize>` - The minimum stride of the current batch.
98fn pop_layout_in(shape_in: &mut Vec<usize>, stride_in: &mut Vec<isize>) -> (usize, isize) {
99    rstsr_assert_eq!(shape_in.len(), stride_in.len(), RuntimeError).unwrap();
100    rstsr_assert!(!shape_in.is_empty(), RuntimeError).unwrap();
101
102    let mut stride_min = stride_in.pop().unwrap();
103    let mut size = shape_in.pop().unwrap();
104
105    // determine if current batch is broadcasted
106    if size == 1 || stride_min == 0 {
107        // broadcasted, reset stride_min to 0
108        stride_min = 0;
109        while stride_in.last().is_some_and(|&v| v == 0) || shape_in.last().is_some_and(|&v| v == 1) {
110            stride_in.pop();
111            size *= shape_in.pop().unwrap();
112        }
113        return (size, stride_min);
114    } else {
115        // general case
116        while stride_in.last().is_some_and(|&v| v == size as isize * stride_min) {
117            stride_in.pop();
118            size *= shape_in.pop().unwrap();
119        }
120        return (size, stride_min);
121    }
122}
123
124/// Internal function that pop output shape, and inject output strides.
125///
126/// This function is for c-prefer (row-major) only.
127/// However, note that `stride_out` is in reverse order.
128///
129/// This function will return `true/false` depending on compatibility of shape.
130fn pop_shape_out(
131    shape_out: &mut Vec<usize>,
132    stride_out: &mut Vec<isize>,
133    mut size: usize,
134    mut stride_min: isize,
135) -> bool {
136    rstsr_assert!(!shape_out.is_empty(), RuntimeError).unwrap();
137
138    while size != 1 || shape_out.last().is_some_and(|&v| v == 1) {
139        let s_out = shape_out.pop().unwrap();
140        if size % s_out != 0 {
141            return false;
142        }
143        size /= s_out;
144        stride_out.push(stride_min);
145        stride_min *= s_out as isize;
146    }
147
148    return true;
149}
150
151/// Internal function for reshaping a tensor in any cases.
152fn complicated_reshape(shape_out: &[usize], layout_in: &Layout<IxD>, order: FlagOrder) -> Option<Layout<IxD>> {
153    let shape_out_ref = shape_out; // the original shape_out not modified
154    let mut shape_out = shape_out.to_vec(); // the shape_out to be destroyed in iteration
155    let mut stride_out = Vec::new();
156    let mut shape_in = layout_in.shape().to_vec();
157    let mut stride_in = layout_in.stride().to_vec();
158    let offset = layout_in.offset();
159
160    // f-prefer handled by reversing everything
161    if order == FlagOrder::F {
162        shape_in.reverse();
163        stride_in.reverse();
164        shape_out.reverse();
165    }
166
167    while !shape_in.is_empty() {
168        let (size_in, stride_in_min) = pop_layout_in(&mut shape_in, &mut stride_in);
169        if !pop_shape_out(&mut shape_out, &mut stride_out, size_in, stride_in_min) {
170            return None;
171        }
172    }
173    rstsr_assert!(shape_out.is_empty(), RuntimeError).unwrap();
174    rstsr_assert_eq!(stride_out.len(), shape_out_ref.len(), RuntimeError).unwrap();
175    // note that stride_out is in reverse order in c-prefer
176    // as contrary, shape_out is in reverse order in f-prefer
177    match order {
178        RowMajor => stride_out.reverse(),
179        ColMajor => shape_out.reverse(),
180    };
181
182    let layout_out = unsafe { Layout::<IxD>::new_unchecked(shape_out_ref.to_vec(), stride_out, offset) };
183    return Some(layout_out);
184}
185
186/// Check if a tensor can be reshaped to a new shape without explicitly copying
187/// underlying data.
188///
189/// - If shape not match, this function will raise error.
190/// - If shape match but data need to be copied, return `Ok(None)`.
191/// - If everything is fine, return `Ok(Some(layout_out))`.
192///
193/// For order, row-major and col-major behaves differently.
194pub fn layout_reshapeable(
195    layout_in: &Layout<IxD>,
196    shape_out: &Vec<usize>,
197    order: FlagOrder,
198) -> Result<Option<Layout<IxD>>> {
199    if let Some(layout_out) = quick_check(shape_out, layout_in, order)? {
200        return Ok(Some(layout_out));
201    }
202    return Ok(complicated_reshape(shape_out, layout_in, order));
203}