1pub mod dtype;
2pub mod quantization;
3pub mod shape;
4pub mod slice;
5
6pub use dtype::*;
7pub use quantization::*;
8pub use shape::*;
9pub use slice::*;
10
11pub use cubecl_zspace::indexing::{self, *};
12pub use cubecl_zspace::{Strides, metadata::Metadata, strides};
13
14pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
22 if shape.is_empty() {
23 return true;
24 }
25
26 for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides) {
27 if expected != stride {
28 return false;
29 }
30 }
31
32 true
33}
34
35pub fn contiguous_strides(shape: &[usize]) -> Strides {
40 let mut strides = strides![0; shape.len()];
41 let mut current = 1;
42
43 for (i, &dim) in shape.iter().enumerate().rev() {
44 strides[i] = current;
45 current *= dim;
46 }
47
48 strides
49}
50
51#[derive(Debug)]
53pub enum ReshapeAction {
54 UpdateStrides {
56 strides: Strides,
58 },
59 Recompute,
61 NoChange,
63}
64
65#[derive(Debug)]
67pub enum ReshapeAnalysis {
68 IsContiguous,
70 HighlyPermuted,
72 Broadcasted,
74 Split,
76 SmallerRank,
78 NoChange,
80}
81
82impl ReshapeAnalysis {
83 fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
85 match self {
86 ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
87 strides: contiguous_strides(shape_new),
88 },
89 ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
90 ReshapeAnalysis::HighlyPermuted | ReshapeAnalysis::SmallerRank => {
91 ReshapeAction::Recompute
92 }
93 ReshapeAnalysis::Broadcasted => {
94 let shape_rank = shape.len();
95 let shape_new_rank = shape_new.len();
96 let n_new_batch = shape_new_rank - shape_rank;
97 let num_elems = shape.iter().product::<usize>();
98 let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
99
100 ReshapeAction::UpdateStrides {
101 strides: strides_new,
102 }
103 }
104 ReshapeAnalysis::Split => {
105 let strides_new = split_strides(shape, strides, shape_new);
106
107 ReshapeAction::UpdateStrides {
108 strides: strides_new,
109 }
110 }
111 }
112 }
113}
114
115pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
117 reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
118}
119
120pub fn broadcast_strides(
122 n_new_batch: usize,
123 rank_prev: usize,
124 num_elems: usize,
125 strides: &[usize],
126) -> Strides {
127 let mut strides_new = strides![num_elems; rank_prev + n_new_batch];
128
129 for (i, s) in strides.iter().enumerate() {
130 strides_new[i + n_new_batch] = *s;
131 }
132
133 strides_new
134}
135
136pub fn split_strides(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> Strides {
138 let mut strides_new = strides![1; shape_new.len()];
139
140 let mut old_idx = shape.len() - 1;
141 let mut current_stride = strides[old_idx];
142 let mut dim_prod = 1;
143
144 for (i, dim) in shape_new.iter().enumerate().rev() {
145 dim_prod *= *dim;
146 strides_new[i] = current_stride;
147 if *dim == 1 {
148 continue;
149 } else if dim_prod == shape[old_idx] {
150 old_idx = old_idx.saturating_sub(1);
151 current_stride = strides[old_idx];
152 dim_prod = 1;
153 } else {
154 current_stride *= *dim;
155 }
156 }
157
158 strides_new
159}
160
161pub fn reshape_analysis(
163 shape: &[usize],
164 strides: Option<&[usize]>,
165 shape_new: &[usize],
166) -> ReshapeAnalysis {
167 let shape_rank = shape.len();
168 let shape_new_rank = shape_new.len();
169
170 let is_contiguous = match strides {
171 Some(strides) => is_contiguous(shape, strides),
172 None => false,
173 };
174
175 if is_contiguous {
176 return ReshapeAnalysis::IsContiguous;
177 }
178
179 if shape_new_rank < shape_rank {
180 return ReshapeAnalysis::SmallerRank;
181 }
182
183 let n_new_batch = shape_new_rank - shape_rank;
184
185 match n_new_batch > 0 {
186 true => {
187 if shape == &shape_new[n_new_batch..shape_new_rank]
188 && shape_new[0..n_new_batch].iter().all(|it| *it == 1)
189 {
190 return ReshapeAnalysis::Broadcasted;
191 } else {
192 let mut dim_prod = 1;
193 let mut old_idx = 0;
194 for dim in shape_new {
195 dim_prod *= *dim;
196
197 if *dim == 1 {
201 continue;
202 } else if dim_prod == shape[old_idx] {
203 dim_prod = 1;
204 old_idx += 1;
205 } else if dim_prod > shape[old_idx] {
206 return ReshapeAnalysis::HighlyPermuted;
207 }
208 }
209 return ReshapeAnalysis::Split;
210 }
211 }
212
213 false => {
214 if shape == shape_new {
215 return ReshapeAnalysis::NoChange;
216 }
217 }
218 };
219
220 ReshapeAnalysis::HighlyPermuted
221}