1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5pub mod id;
11
12pub use cubecl_common::*;
13
14#[cfg(feature = "rayon")]
15pub use rayon;
16
17extern crate alloc;
18
19#[cfg(feature = "network")]
21pub mod network;
22
23pub mod tensor {
25 use alloc::vec;
26 use alloc::vec::Vec;
27
28 pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
36 if shape.is_empty() {
37 return true;
38 }
39
40 for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
41 if expected != stride {
42 return false;
43 }
44 }
45
46 true
47 }
48
49 pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
54 let mut strides = Vec::with_capacity(shape.len());
55 let mut current = 1;
56
57 for &dim in shape.iter().rev() {
58 strides.push(current);
59 current *= dim;
60 }
61
62 strides.reverse();
63 strides
64 }
65
66 #[derive(Debug)]
68 pub enum ReshapeAction {
69 UpdateStrides {
71 strides: Vec<usize>,
73 },
74 Recompute,
76 NoChange,
78 }
79
80 #[derive(Debug)]
82 pub enum ReshapeAnalysis {
83 IsContiguous,
85 HighlyPermutated,
87 Broadcasted,
89 SmallerRank,
91 NoChange,
93 }
94
95 impl ReshapeAnalysis {
96 fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
98 match self {
99 ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
100 strides: contiguous_strides(shape_new),
101 },
102 ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
103 ReshapeAnalysis::HighlyPermutated | ReshapeAnalysis::SmallerRank => {
104 ReshapeAction::Recompute
105 }
106 ReshapeAnalysis::Broadcasted => {
107 let shape_rank = shape.len();
108 let shape_new_rank = shape_new.len();
109 let n_new_batch = shape_new_rank - shape_rank;
110 let num_elems = shape.iter().product::<usize>();
111 let strides_new =
112 broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
113
114 ReshapeAction::UpdateStrides {
115 strides: strides_new,
116 }
117 }
118 }
119 }
120 }
121
122 pub fn reshape_action(
124 shape: &[usize],
125 strides: &[usize],
126 shape_new: &[usize],
127 ) -> ReshapeAction {
128 reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
129 }
130
131 pub fn broadcast_strides(
133 n_new_batch: usize,
134 rank_prev: usize,
135 num_elems: usize,
136 strides: &[usize],
137 ) -> Vec<usize> {
138 let mut strides_new = vec![num_elems; rank_prev + n_new_batch];
139
140 for (i, s) in strides.iter().enumerate() {
141 strides_new[i + n_new_batch] = *s;
142 }
143
144 strides_new
145 }
146
147 pub fn reshape_analysis(
149 shape: &[usize],
150 strides: Option<&[usize]>,
151 shape_new: &[usize],
152 ) -> ReshapeAnalysis {
153 let shape_rank = shape.len();
154 let shape_new_rank = shape_new.len();
155
156 if shape_new_rank < shape_rank {
157 let is_contiguous = match strides {
158 Some(strides) => is_contiguous(shape, strides),
159 None => false,
160 };
161 return match is_contiguous {
162 true => ReshapeAnalysis::IsContiguous,
163 false => ReshapeAnalysis::SmallerRank,
164 };
165 }
166
167 let n_new_batch = shape_new_rank - shape_rank;
168
169 match n_new_batch > 0 {
170 true => {
171 if shape == &shape_new[n_new_batch..shape_new_rank]
172 && shape_new[0..n_new_batch] == vec![1; n_new_batch]
173 {
174 return ReshapeAnalysis::Broadcasted;
175 }
176 }
177
178 false => {
179 if shape == shape_new {
180 return ReshapeAnalysis::NoChange;
181 } else {
182 let is_contiguous = match strides {
183 Some(strides) => is_contiguous(shape, strides),
184 None => false,
185 };
186 if is_contiguous {
187 return ReshapeAnalysis::IsContiguous;
188 }
189 }
190 }
191 };
192
193 ReshapeAnalysis::HighlyPermutated
194 }
195}