rstsr_common/layout/
reshape.rs1use crate::prelude_dev::*;
4
5pub fn reshape_substitute_negatives(shape_out: &[isize], size_in: usize) -> Result<Vec<usize>> {
11 let mut shape = shape_out.to_vec();
12
13 let mut idx_neg1: Option<usize> = None;
15 for (i, &v) in shape.iter().enumerate() {
16 match v {
17 -1 => match idx_neg1 {
18 Some(_) => rstsr_raise!(InvalidValue, "Only one -1 is allowed in shape.")?,
19 None => idx_neg1 = Some(i),
20 },
21 ..-1 => {
22 rstsr_raise!(InvalidValue, "Negative index must be -1.")?;
23 },
24 _ => (),
25 }
26 }
27
28 if let Some(idx_neg1) = idx_neg1 {
30 let size_in = size_in as isize;
31 let size_neg = shape.iter().fold(1, |acc, &v| if v == -1 { acc } else { acc * v });
32 rstsr_assert!(
33 size_in % size_neg == 0,
34 InvalidValue,
35 "Shape '-1' in {:?} could not be determined to original tensor size {:?}",
36 shape,
37 size_in
38 )?;
39 shape[idx_neg1] = size_in / size_neg;
40 }
41 return Ok(shape.iter().map(|&v| v as usize).collect::<Vec<usize>>());
42}
43
44fn quick_check(shape_out: &Vec<usize>, layout_in: &Layout<IxD>, order: FlagOrder) -> Result<Option<Layout<IxD>>> {
55 let size_in = layout_in.size();
57 let size_out = shape_out.iter().product();
58 rstsr_assert_eq!(size_in, size_out, InvalidValue, "Size mismatch between input tensor and output tensor.",)?;
59
60 if size_in == 0 || size_in == 1 {
64 let strides = vec![0; shape_out.len()];
65 return Ok(Some(Layout::<IxD>::new(shape_out.clone(), strides, layout_in.offset())?));
66 }
67
68 if shape_out == layout_in.shape() {
70 return Ok(Some(layout_in.clone()));
71 }
72
73 match order {
75 RowMajor => {
76 if layout_in.c_contig() {
77 return Ok(Some(shape_out.new_c_contig(Some(layout_in.offset()))));
78 }
79 },
80 ColMajor => {
81 if layout_in.f_contig() {
82 return Ok(Some(shape_out.new_f_contig(Some(layout_in.offset()))));
83 }
84 },
85 };
86
87 return Ok(None);
89}
90
91fn pop_layout_in(shape_in: &mut Vec<usize>, stride_in: &mut Vec<isize>) -> (usize, isize) {
99 rstsr_assert_eq!(shape_in.len(), stride_in.len(), RuntimeError).unwrap();
100 rstsr_assert!(!shape_in.is_empty(), RuntimeError).unwrap();
101
102 let mut stride_min = stride_in.pop().unwrap();
103 let mut size = shape_in.pop().unwrap();
104
105 if size == 1 || stride_min == 0 {
107 stride_min = 0;
109 while stride_in.last().is_some_and(|&v| v == 0) || shape_in.last().is_some_and(|&v| v == 1) {
110 stride_in.pop();
111 size *= shape_in.pop().unwrap();
112 }
113 return (size, stride_min);
114 } else {
115 while stride_in.last().is_some_and(|&v| v == size as isize * stride_min) {
117 stride_in.pop();
118 size *= shape_in.pop().unwrap();
119 }
120 return (size, stride_min);
121 }
122}
123
124fn pop_shape_out(
131 shape_out: &mut Vec<usize>,
132 stride_out: &mut Vec<isize>,
133 mut size: usize,
134 mut stride_min: isize,
135) -> bool {
136 rstsr_assert!(!shape_out.is_empty(), RuntimeError).unwrap();
137
138 while size != 1 || shape_out.last().is_some_and(|&v| v == 1) {
139 let s_out = shape_out.pop().unwrap();
140 if size % s_out != 0 {
141 return false;
142 }
143 size /= s_out;
144 stride_out.push(stride_min);
145 stride_min *= s_out as isize;
146 }
147
148 return true;
149}
150
151fn complicated_reshape(shape_out: &[usize], layout_in: &Layout<IxD>, order: FlagOrder) -> Option<Layout<IxD>> {
153 let shape_out_ref = shape_out; let mut shape_out = shape_out.to_vec(); let mut stride_out = Vec::new();
156 let mut shape_in = layout_in.shape().to_vec();
157 let mut stride_in = layout_in.stride().to_vec();
158 let offset = layout_in.offset();
159
160 if order == FlagOrder::F {
162 shape_in.reverse();
163 stride_in.reverse();
164 shape_out.reverse();
165 }
166
167 while !shape_in.is_empty() {
168 let (size_in, stride_in_min) = pop_layout_in(&mut shape_in, &mut stride_in);
169 if !pop_shape_out(&mut shape_out, &mut stride_out, size_in, stride_in_min) {
170 return None;
171 }
172 }
173 rstsr_assert!(shape_out.is_empty(), RuntimeError).unwrap();
174 rstsr_assert_eq!(stride_out.len(), shape_out_ref.len(), RuntimeError).unwrap();
175 match order {
178 RowMajor => stride_out.reverse(),
179 ColMajor => shape_out.reverse(),
180 };
181
182 let layout_out = unsafe { Layout::<IxD>::new_unchecked(shape_out_ref.to_vec(), stride_out, offset) };
183 return Some(layout_out);
184}
185
186pub fn layout_reshapeable(
195 layout_in: &Layout<IxD>,
196 shape_out: &Vec<usize>,
197 order: FlagOrder,
198) -> Result<Option<Layout<IxD>>> {
199 if let Some(layout_out) = quick_check(shape_out, layout_in, order)? {
200 return Ok(Some(layout_out));
201 }
202 return Ok(complicated_reshape(shape_out, layout_in, order));
203}