1pub mod dtype;
2pub mod index_conversion;
3pub mod indexing;
4pub mod quantization;
5pub mod shape;
6pub mod slice;
7
8pub use dtype::*;
9pub use index_conversion::*;
10pub use indexing::*;
11pub use quantization::*;
12pub use shape::*;
13pub use slice::*;
14
15use alloc::{vec, vec::Vec};
16
17pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
25 if shape.is_empty() {
26 return true;
27 }
28
29 for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
30 if expected != stride {
31 return false;
32 }
33 }
34
35 true
36}
37
38pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
43 let mut strides = Vec::with_capacity(shape.len());
44 let mut current = 1;
45
46 for &dim in shape.iter().rev() {
47 strides.push(current);
48 current *= dim;
49 }
50
51 strides.reverse();
52 strides
53}
54
55#[derive(Debug)]
57pub enum ReshapeAction {
58 UpdateStrides {
60 strides: Vec<usize>,
62 },
63 Recompute,
65 NoChange,
67}
68
69#[derive(Debug)]
71pub enum ReshapeAnalysis {
72 IsContiguous,
74 HighlyPermuted,
76 Broadcasted,
78 Split,
80 SmallerRank,
82 NoChange,
84}
85
86impl ReshapeAnalysis {
87 fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
89 match self {
90 ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
91 strides: contiguous_strides(shape_new),
92 },
93 ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
94 ReshapeAnalysis::HighlyPermuted | ReshapeAnalysis::SmallerRank => {
95 ReshapeAction::Recompute
96 }
97 ReshapeAnalysis::Broadcasted => {
98 let shape_rank = shape.len();
99 let shape_new_rank = shape_new.len();
100 let n_new_batch = shape_new_rank - shape_rank;
101 let num_elems = shape.iter().product::<usize>();
102 let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
103
104 ReshapeAction::UpdateStrides {
105 strides: strides_new,
106 }
107 }
108 ReshapeAnalysis::Split => {
109 let strides_new = split_strides(shape, strides, shape_new);
110
111 ReshapeAction::UpdateStrides {
112 strides: strides_new,
113 }
114 }
115 }
116 }
117}
118
119pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
121 reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
122}
123
124pub fn broadcast_strides(
126 n_new_batch: usize,
127 rank_prev: usize,
128 num_elems: usize,
129 strides: &[usize],
130) -> Vec<usize> {
131 let mut strides_new = vec![num_elems; rank_prev + n_new_batch];
132
133 for (i, s) in strides.iter().enumerate() {
134 strides_new[i + n_new_batch] = *s;
135 }
136
137 strides_new
138}
139
140pub fn split_strides(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> Vec<usize> {
142 let mut strides_new = vec![1; shape_new.len()];
143
144 let mut old_idx = shape.len() - 1;
145 let mut current_stride = strides[old_idx];
146 let mut dim_prod = 1;
147
148 for (i, dim) in shape_new.iter().enumerate().rev() {
149 dim_prod *= *dim;
150 strides_new[i] = current_stride;
151 if *dim == 1 {
152 continue;
153 } else if dim_prod == shape[old_idx] {
154 old_idx = old_idx.saturating_sub(1);
155 current_stride = strides[old_idx];
156 dim_prod = 1;
157 } else {
158 current_stride *= *dim;
159 }
160 }
161
162 strides_new
163}
164
165pub fn reshape_analysis(
167 shape: &[usize],
168 strides: Option<&[usize]>,
169 shape_new: &[usize],
170) -> ReshapeAnalysis {
171 let shape_rank = shape.len();
172 let shape_new_rank = shape_new.len();
173
174 let is_contiguous = match strides {
175 Some(strides) => is_contiguous(shape, strides),
176 None => false,
177 };
178
179 if is_contiguous {
180 return ReshapeAnalysis::IsContiguous;
181 }
182
183 if shape_new_rank < shape_rank {
184 return ReshapeAnalysis::SmallerRank;
185 }
186
187 let n_new_batch = shape_new_rank - shape_rank;
188
189 match n_new_batch > 0 {
190 true => {
191 if shape == &shape_new[n_new_batch..shape_new_rank]
192 && shape_new[0..n_new_batch] == vec![1; n_new_batch]
193 {
194 return ReshapeAnalysis::Broadcasted;
195 } else {
196 let mut dim_prod = 1;
197 let mut old_idx = 0;
198 for dim in shape_new {
199 dim_prod *= *dim;
200
201 if *dim == 1 {
205 continue;
206 } else if dim_prod == shape[old_idx] {
207 dim_prod = 1;
208 old_idx += 1;
209 } else if dim_prod > shape[old_idx] {
210 return ReshapeAnalysis::HighlyPermuted;
211 }
212 }
213 return ReshapeAnalysis::Split;
214 }
215 }
216
217 false => {
218 if shape == shape_new {
219 return ReshapeAnalysis::NoChange;
220 }
221 }
222 };
223
224 ReshapeAnalysis::HighlyPermuted
225}