rstsr_common/layout/
reshape.rs1use crate::prelude_dev::*;
4
5pub fn reshape_substitute_negatives(shape_out: &[isize], size_in: usize) -> Result<Vec<usize>> {
11 let mut shape = shape_out.to_vec();
12
13 let mut idx_neg1: Option<usize> = None;
15 for (i, &v) in shape.iter().enumerate() {
16 match v {
17 -1 => match idx_neg1 {
18 Some(_) => rstsr_raise!(InvalidValue, "Only one -1 is allowed in shape.")?,
19 None => idx_neg1 = Some(i),
20 },
21 ..-1 => {
22 rstsr_raise!(InvalidValue, "Negative index must be -1.")?;
23 },
24 _ => (),
25 }
26 }
27
28 if let Some(idx_neg1) = idx_neg1 {
30 let size_in = size_in as isize;
31 let size_neg = shape.iter().fold(1, |acc, &v| if v == -1 { acc } else { acc * v });
32 rstsr_assert!(
33 size_in % size_neg == 0,
34 InvalidValue,
35 "Shape '-1' in {:?} could not be determined to original tensor size {:?}",
36 shape,
37 size_in
38 )?;
39 shape[idx_neg1] = size_in / size_neg;
40 }
41 return Ok(shape.iter().map(|&v| v as usize).collect::<Vec<usize>>());
42}
43
44fn quick_check(
55 shape_out: &Vec<usize>,
56 layout_in: &Layout<IxD>,
57 order: FlagOrder,
58) -> Result<Option<Layout<IxD>>> {
59 let size_in = layout_in.size();
61 let size_out = shape_out.iter().product();
62 rstsr_assert_eq!(
63 size_in,
64 size_out,
65 InvalidValue,
66 "Size mismatch between input tensor and output tensor.",
67 )?;
68
69 if size_in == 0 || size_in == 1 {
73 let strides = vec![0; shape_out.len()];
74 return Ok(Some(Layout::<IxD>::new(shape_out.clone(), strides, layout_in.offset())?));
75 }
76
77 if shape_out == layout_in.shape() {
79 return Ok(Some(layout_in.clone()));
80 }
81
82 match order {
84 RowMajor => {
85 if layout_in.c_contig() {
86 return Ok(Some(shape_out.new_c_contig(Some(layout_in.offset()))));
87 }
88 },
89 ColMajor => {
90 if layout_in.f_contig() {
91 return Ok(Some(shape_out.new_f_contig(Some(layout_in.offset()))));
92 }
93 },
94 };
95
96 return Ok(None);
98}
99
100fn pop_layout_in(shape_in: &mut Vec<usize>, stride_in: &mut Vec<isize>) -> (usize, isize) {
109 rstsr_assert_eq!(shape_in.len(), stride_in.len(), RuntimeError).unwrap();
110 rstsr_assert!(!shape_in.is_empty(), RuntimeError).unwrap();
111
112 let mut stride_min = stride_in.pop().unwrap();
113 let mut size = shape_in.pop().unwrap();
114
115 if size == 1 || stride_min == 0 {
117 stride_min = 0;
119 while stride_in.last().is_some_and(|&v| v == 0) || shape_in.last().is_some_and(|&v| v == 1)
120 {
121 stride_in.pop();
122 size *= shape_in.pop().unwrap();
123 }
124 return (size, stride_min);
125 } else {
126 while stride_in.last().is_some_and(|&v| v == size as isize * stride_min) {
128 stride_in.pop();
129 size *= shape_in.pop().unwrap();
130 }
131 return (size, stride_min);
132 }
133}
134
135fn pop_shape_out(
142 shape_out: &mut Vec<usize>,
143 stride_out: &mut Vec<isize>,
144 mut size: usize,
145 mut stride_min: isize,
146) -> bool {
147 rstsr_assert!(!shape_out.is_empty(), RuntimeError).unwrap();
148
149 while size != 1 || shape_out.last().is_some_and(|&v| v == 1) {
150 let s_out = shape_out.pop().unwrap();
151 if size % s_out != 0 {
152 return false;
153 }
154 size /= s_out;
155 stride_out.push(stride_min);
156 stride_min *= s_out as isize;
157 }
158
159 return true;
160}
161
162fn complicated_reshape(
164 shape_out: &[usize],
165 layout_in: &Layout<IxD>,
166 order: FlagOrder,
167) -> Option<Layout<IxD>> {
168 let shape_out_ref = shape_out; let mut shape_out = shape_out.to_vec(); let mut stride_out = Vec::new();
171 let mut shape_in = layout_in.shape().to_vec();
172 let mut stride_in = layout_in.stride().to_vec();
173 let offset = layout_in.offset();
174
175 if order == FlagOrder::F {
177 shape_in.reverse();
178 stride_in.reverse();
179 shape_out.reverse();
180 }
181
182 while !shape_in.is_empty() {
183 let (size_in, stride_in_min) = pop_layout_in(&mut shape_in, &mut stride_in);
184 if !pop_shape_out(&mut shape_out, &mut stride_out, size_in, stride_in_min) {
185 return None;
186 }
187 }
188 rstsr_assert!(shape_out.is_empty(), RuntimeError).unwrap();
189 rstsr_assert_eq!(stride_out.len(), shape_out_ref.len(), RuntimeError).unwrap();
190 match order {
193 RowMajor => stride_out.reverse(),
194 ColMajor => shape_out.reverse(),
195 };
196
197 let layout_out =
198 unsafe { Layout::<IxD>::new_unchecked(shape_out_ref.to_vec(), stride_out, offset) };
199 return Some(layout_out);
200}
201
202pub fn layout_reshapeable(
211 layout_in: &Layout<IxD>,
212 shape_out: &Vec<usize>,
213 order: FlagOrder,
214) -> Result<Option<Layout<IxD>>> {
215 if let Some(layout_out) = quick_check(shape_out, layout_in, order)? {
216 return Ok(Some(layout_out));
217 }
218 return Ok(complicated_reshape(shape_out, layout_in, order));
219}