1pub mod dtype;
2pub mod index_conversion;
3pub mod indexing;
4pub mod quantization;
5pub mod shape;
6pub mod slice;
7
8pub use dtype::*;
9pub use index_conversion::*;
10pub use indexing::*;
11pub use quantization::*;
12pub use shape::*;
13pub use slice::*;
14
15use alloc::{vec, vec::Vec};
16
17pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
25 if shape.is_empty() {
26 return true;
27 }
28
29 for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
30 if expected != stride {
31 return false;
32 }
33 }
34
35 true
36}
37
38pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
43 let mut strides = Vec::with_capacity(shape.len());
44 let mut current = 1;
45
46 for &dim in shape.iter().rev() {
47 strides.push(current);
48 current *= dim;
49 }
50
51 strides.reverse();
52 strides
53}
54
55#[derive(Debug)]
57pub enum ReshapeAction {
58 UpdateStrides {
60 strides: Vec<usize>,
62 },
63 Recompute,
65 NoChange,
67}
68
69#[derive(Debug)]
71pub enum ReshapeAnalysis {
72 IsContiguous,
74 HighlyPermutated,
76 Broadcasted,
78 SmallerRank,
80 NoChange,
82}
83
84impl ReshapeAnalysis {
85 fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
87 match self {
88 ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
89 strides: contiguous_strides(shape_new),
90 },
91 ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
92 ReshapeAnalysis::HighlyPermutated | ReshapeAnalysis::SmallerRank => {
93 ReshapeAction::Recompute
94 }
95 ReshapeAnalysis::Broadcasted => {
96 let shape_rank = shape.len();
97 let shape_new_rank = shape_new.len();
98 let n_new_batch = shape_new_rank - shape_rank;
99 let num_elems = shape.iter().product::<usize>();
100 let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
101
102 ReshapeAction::UpdateStrides {
103 strides: strides_new,
104 }
105 }
106 }
107 }
108}
109
110pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
112 reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
113}
114
115pub fn broadcast_strides(
117 n_new_batch: usize,
118 rank_prev: usize,
119 num_elems: usize,
120 strides: &[usize],
121) -> Vec<usize> {
122 let mut strides_new = vec![num_elems; rank_prev + n_new_batch];
123
124 for (i, s) in strides.iter().enumerate() {
125 strides_new[i + n_new_batch] = *s;
126 }
127
128 strides_new
129}
130
131pub fn reshape_analysis(
133 shape: &[usize],
134 strides: Option<&[usize]>,
135 shape_new: &[usize],
136) -> ReshapeAnalysis {
137 let shape_rank = shape.len();
138 let shape_new_rank = shape_new.len();
139
140 if shape_new_rank < shape_rank {
141 let is_contiguous = match strides {
142 Some(strides) => is_contiguous(shape, strides),
143 None => false,
144 };
145 return match is_contiguous {
146 true => ReshapeAnalysis::IsContiguous,
147 false => ReshapeAnalysis::SmallerRank,
148 };
149 }
150
151 let n_new_batch = shape_new_rank - shape_rank;
152
153 match n_new_batch > 0 {
154 true => {
155 if shape == &shape_new[n_new_batch..shape_new_rank]
156 && shape_new[0..n_new_batch] == vec![1; n_new_batch]
157 {
158 return ReshapeAnalysis::Broadcasted;
159 }
160 }
161
162 false => {
163 if shape == shape_new {
164 return ReshapeAnalysis::NoChange;
165 } else {
166 let is_contiguous = match strides {
167 Some(strides) => is_contiguous(shape, strides),
168 None => false,
169 };
170 if is_contiguous {
171 return ReshapeAnalysis::IsContiguous;
172 }
173 }
174 }
175 };
176
177 ReshapeAnalysis::HighlyPermutated
178}