1pub mod dtype;
2pub mod indexing;
3pub mod quantization;
4pub mod shape;
5pub mod slice;
6
7pub use dtype::*;
8pub use indexing::*;
9pub use quantization::*;
10pub use shape::*;
11pub use slice::*;
12
13use alloc::{vec, vec::Vec};
14
15pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
23 if shape.is_empty() {
24 return true;
25 }
26
27 for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
28 if expected != stride {
29 return false;
30 }
31 }
32
33 true
34}
35
36pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
41 let mut strides = Vec::with_capacity(shape.len());
42 let mut current = 1;
43
44 for &dim in shape.iter().rev() {
45 strides.push(current);
46 current *= dim;
47 }
48
49 strides.reverse();
50 strides
51}
52
53#[derive(Debug)]
55pub enum ReshapeAction {
56 UpdateStrides {
58 strides: Vec<usize>,
60 },
61 Recompute,
63 NoChange,
65}
66
67#[derive(Debug)]
69pub enum ReshapeAnalysis {
70 IsContiguous,
72 HighlyPermutated,
74 Broadcasted,
76 SmallerRank,
78 NoChange,
80}
81
82impl ReshapeAnalysis {
83 fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
85 match self {
86 ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
87 strides: contiguous_strides(shape_new),
88 },
89 ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
90 ReshapeAnalysis::HighlyPermutated | ReshapeAnalysis::SmallerRank => {
91 ReshapeAction::Recompute
92 }
93 ReshapeAnalysis::Broadcasted => {
94 let shape_rank = shape.len();
95 let shape_new_rank = shape_new.len();
96 let n_new_batch = shape_new_rank - shape_rank;
97 let num_elems = shape.iter().product::<usize>();
98 let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
99
100 ReshapeAction::UpdateStrides {
101 strides: strides_new,
102 }
103 }
104 }
105 }
106}
107
108pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
110 reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
111}
112
113pub fn broadcast_strides(
115 n_new_batch: usize,
116 rank_prev: usize,
117 num_elems: usize,
118 strides: &[usize],
119) -> Vec<usize> {
120 let mut strides_new = vec![num_elems; rank_prev + n_new_batch];
121
122 for (i, s) in strides.iter().enumerate() {
123 strides_new[i + n_new_batch] = *s;
124 }
125
126 strides_new
127}
128
129pub fn reshape_analysis(
131 shape: &[usize],
132 strides: Option<&[usize]>,
133 shape_new: &[usize],
134) -> ReshapeAnalysis {
135 let shape_rank = shape.len();
136 let shape_new_rank = shape_new.len();
137
138 if shape_new_rank < shape_rank {
139 let is_contiguous = match strides {
140 Some(strides) => is_contiguous(shape, strides),
141 None => false,
142 };
143 return match is_contiguous {
144 true => ReshapeAnalysis::IsContiguous,
145 false => ReshapeAnalysis::SmallerRank,
146 };
147 }
148
149 let n_new_batch = shape_new_rank - shape_rank;
150
151 match n_new_batch > 0 {
152 true => {
153 if shape == &shape_new[n_new_batch..shape_new_rank]
154 && shape_new[0..n_new_batch] == vec![1; n_new_batch]
155 {
156 return ReshapeAnalysis::Broadcasted;
157 }
158 }
159
160 false => {
161 if shape == shape_new {
162 return ReshapeAnalysis::NoChange;
163 } else {
164 let is_contiguous = match strides {
165 Some(strides) => is_contiguous(shape, strides),
166 None => false,
167 };
168 if is_contiguous {
169 return ReshapeAnalysis::IsContiguous;
170 }
171 }
172 }
173 };
174
175 ReshapeAnalysis::HighlyPermutated
176}