Skip to main content

burn_std/tensor/
mod.rs

1pub mod dtype;
2pub mod quantization;
3pub mod shape;
4pub mod slice;
5
6pub use dtype::*;
7pub use quantization::*;
8pub use shape::*;
9pub use slice::*;
10
11pub use cubecl_zspace::indexing::{self, *};
12pub use cubecl_zspace::{Strides, metadata::Metadata, strides};
13
14/// Check if the current tensor is contiguous.
15///
16/// A tensor is considered contiguous if its elements are stored in memory
17/// such that the stride at position `k` is equal to the product of the shapes
18/// of all dimensions greater than `k`.
19///
20/// This means that strides increase as you move from the rightmost to the leftmost dimension.
21pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
22    if shape.is_empty() {
23        return true;
24    }
25
26    for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides) {
27        if expected != stride {
28            return false;
29        }
30    }
31
32    true
33}
34
35/// Computes the strides for a contiguous tensor with the given shape.
36///
37/// In a contiguous row-major tensor, the stride for each dimension
38/// equals the product of all dimension sizes to its right.
39pub fn contiguous_strides(shape: &[usize]) -> Strides {
40    let mut strides = strides![0; shape.len()];
41    let mut current = 1;
42
43    for (i, &dim) in shape.iter().enumerate().rev() {
44        strides[i] = current;
45        current *= dim;
46    }
47
48    strides
49}
50
51/// The action to take for a reshape operation.
52#[derive(Debug)]
53pub enum ReshapeAction {
54    /// Updating the strides is sufficient to handle the reshape.
55    UpdateStrides {
56        /// The new strides.
57        strides: Strides,
58    },
59    /// The strides are not compatible, we should recompute the buffer.
60    Recompute,
61    /// The strides are already correct.
62    NoChange,
63}
64
65/// The reshape kind.
66#[derive(Debug)]
67pub enum ReshapeAnalysis {
68    /// Original tensor is contiguous, can update the strides.
69    IsContiguous,
70    /// Original tensor is highly permutated, can't update the strides.
71    HighlyPermuted,
72    /// Only batch dimensions are added, can update the strides.
73    Broadcasted,
74    /// Dimensions are only split, can update the strides.
75    Split,
76    /// Original tensor is bigger than output shape.
77    SmallerRank,
78    /// New shape is the same.
79    NoChange,
80}
81
82impl ReshapeAnalysis {
83    /// Returns the proper action to take for the current analysis.
84    fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
85        match self {
86            ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
87                strides: contiguous_strides(shape_new),
88            },
89            ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
90            ReshapeAnalysis::HighlyPermuted | ReshapeAnalysis::SmallerRank => {
91                ReshapeAction::Recompute
92            }
93            ReshapeAnalysis::Broadcasted => {
94                let shape_rank = shape.len();
95                let shape_new_rank = shape_new.len();
96                let n_new_batch = shape_new_rank - shape_rank;
97                let num_elems = shape.iter().product::<usize>();
98                let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
99
100                ReshapeAction::UpdateStrides {
101                    strides: strides_new,
102                }
103            }
104            ReshapeAnalysis::Split => {
105                let strides_new = split_strides(shape, strides, shape_new);
106
107                ReshapeAction::UpdateStrides {
108                    strides: strides_new,
109                }
110            }
111        }
112    }
113}
114
115/// Returns the proper action to take when reshaping a tensor.
116pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
117    reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
118}
119
120/// Calculate the new strides given added batch dimensions.
121pub fn broadcast_strides(
122    n_new_batch: usize,
123    rank_prev: usize,
124    num_elems: usize,
125    strides: &[usize],
126) -> Strides {
127    let mut strides_new = strides![num_elems; rank_prev + n_new_batch];
128
129    for (i, s) in strides.iter().enumerate() {
130        strides_new[i + n_new_batch] = *s;
131    }
132
133    strides_new
134}
135
136/// Calculate the new strides given added split dimensions.
137pub fn split_strides(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> Strides {
138    let mut strides_new = strides![1; shape_new.len()];
139
140    let mut old_idx = shape.len() - 1;
141    let mut current_stride = strides[old_idx];
142    let mut dim_prod = 1;
143
144    for (i, dim) in shape_new.iter().enumerate().rev() {
145        dim_prod *= *dim;
146        strides_new[i] = current_stride;
147        if *dim == 1 {
148            continue;
149        } else if dim_prod == shape[old_idx] {
150            old_idx = old_idx.saturating_sub(1);
151            current_stride = strides[old_idx];
152            dim_prod = 1;
153        } else {
154            current_stride *= *dim;
155        }
156    }
157
158    strides_new
159}
160
161/// Returns the analysis of a reshape operation.
162pub fn reshape_analysis(
163    shape: &[usize],
164    strides: Option<&[usize]>,
165    shape_new: &[usize],
166) -> ReshapeAnalysis {
167    let shape_rank = shape.len();
168    let shape_new_rank = shape_new.len();
169
170    let is_contiguous = match strides {
171        Some(strides) => is_contiguous(shape, strides),
172        None => false,
173    };
174
175    if is_contiguous {
176        return ReshapeAnalysis::IsContiguous;
177    }
178
179    if shape_new_rank < shape_rank {
180        return ReshapeAnalysis::SmallerRank;
181    }
182
183    let n_new_batch = shape_new_rank - shape_rank;
184
185    match n_new_batch > 0 {
186        true => {
187            if shape == &shape_new[n_new_batch..shape_new_rank]
188                && shape_new[0..n_new_batch].iter().all(|it| *it == 1)
189            {
190                return ReshapeAnalysis::Broadcasted;
191            } else {
192                let mut dim_prod = 1;
193                let mut old_idx = 0;
194                for dim in shape_new {
195                    dim_prod *= *dim;
196
197                    // We need to ignore unit dims because they don't affect analysis and break
198                    // things because they match the default `dim_prod`. If we don't do this,
199                    // reshapes like [2, 3] to [2, 3, 1] will panic from out of bounds access.
200                    if *dim == 1 {
201                        continue;
202                    } else if dim_prod == shape[old_idx] {
203                        dim_prod = 1;
204                        old_idx += 1;
205                    } else if dim_prod > shape[old_idx] {
206                        return ReshapeAnalysis::HighlyPermuted;
207                    }
208                }
209                return ReshapeAnalysis::Split;
210            }
211        }
212
213        false => {
214            if shape == shape_new {
215                return ReshapeAnalysis::NoChange;
216            }
217        }
218    };
219
220    ReshapeAnalysis::HighlyPermuted
221}