Skip to main content

burn_std/tensor/
mod.rs

1pub mod dtype;
2pub mod index_conversion;
3pub mod indexing;
4pub mod quantization;
5pub mod shape;
6pub mod slice;
7
8pub use dtype::*;
9pub use index_conversion::*;
10pub use indexing::*;
11pub use quantization::*;
12pub use shape::*;
13pub use slice::*;
14
15use alloc::{vec, vec::Vec};
16
17/// Check if the current tensor is contiguous.
18///
19/// A tensor is considered contiguous if its elements are stored in memory
20/// such that the stride at position `k` is equal to the product of the shapes
21/// of all dimensions greater than `k`.
22///
23/// This means that strides increase as you move from the rightmost to the leftmost dimension.
24pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
25    if shape.is_empty() {
26        return true;
27    }
28
29    for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
30        if expected != stride {
31            return false;
32        }
33    }
34
35    true
36}
37
38/// Computes the strides for a contiguous tensor with the given shape.
39///
40/// In a contiguous row-major tensor, the stride for each dimension
41/// equals the product of all dimension sizes to its right.
42pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
43    let mut strides = Vec::with_capacity(shape.len());
44    let mut current = 1;
45
46    for &dim in shape.iter().rev() {
47        strides.push(current);
48        current *= dim;
49    }
50
51    strides.reverse();
52    strides
53}
54
55/// The action to take for a reshape operation.
56#[derive(Debug)]
57pub enum ReshapeAction {
58    /// Updating the strides is sufficient to handle the reshape.
59    UpdateStrides {
60        /// The new strides.
61        strides: Vec<usize>,
62    },
63    /// The strides are not compatible, we should recompute the buffer.
64    Recompute,
65    /// The strides are already correct.
66    NoChange,
67}
68
69/// The reshape kind.
70#[derive(Debug)]
71pub enum ReshapeAnalysis {
72    /// Original tensor is contiguous, can update the strides.
73    IsContiguous,
74    /// Original tensor is highly permutated, can't update the strides.
75    HighlyPermuted,
76    /// Only batch dimensions are added, can update the strides.
77    Broadcasted,
78    /// Dimensions are only split, can update the strides.
79    Split,
80    /// Original tensor is bigger than output shape.
81    SmallerRank,
82    /// New shape is the same.
83    NoChange,
84}
85
86impl ReshapeAnalysis {
87    /// Returns the proper action to take for the current analysis.
88    fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
89        match self {
90            ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
91                strides: contiguous_strides(shape_new),
92            },
93            ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
94            ReshapeAnalysis::HighlyPermuted | ReshapeAnalysis::SmallerRank => {
95                ReshapeAction::Recompute
96            }
97            ReshapeAnalysis::Broadcasted => {
98                let shape_rank = shape.len();
99                let shape_new_rank = shape_new.len();
100                let n_new_batch = shape_new_rank - shape_rank;
101                let num_elems = shape.iter().product::<usize>();
102                let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
103
104                ReshapeAction::UpdateStrides {
105                    strides: strides_new,
106                }
107            }
108            ReshapeAnalysis::Split => {
109                let strides_new = split_strides(shape, strides, shape_new);
110
111                ReshapeAction::UpdateStrides {
112                    strides: strides_new,
113                }
114            }
115        }
116    }
117}
118
119/// Returns the proper action to take when reshaping a tensor.
120pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
121    reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
122}
123
124/// Calculate the new strides given added batch dimensions.
125pub fn broadcast_strides(
126    n_new_batch: usize,
127    rank_prev: usize,
128    num_elems: usize,
129    strides: &[usize],
130) -> Vec<usize> {
131    let mut strides_new = vec![num_elems; rank_prev + n_new_batch];
132
133    for (i, s) in strides.iter().enumerate() {
134        strides_new[i + n_new_batch] = *s;
135    }
136
137    strides_new
138}
139
140/// Calculate the new strides given added split dimensions.
141pub fn split_strides(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> Vec<usize> {
142    let mut strides_new = vec![1; shape_new.len()];
143
144    let mut old_idx = shape.len() - 1;
145    let mut current_stride = strides[old_idx];
146    let mut dim_prod = 1;
147
148    for (i, dim) in shape_new.iter().enumerate().rev() {
149        dim_prod *= *dim;
150        strides_new[i] = current_stride;
151        if *dim == 1 {
152            continue;
153        } else if dim_prod == shape[old_idx] {
154            old_idx = old_idx.saturating_sub(1);
155            current_stride = strides[old_idx];
156            dim_prod = 1;
157        } else {
158            current_stride *= *dim;
159        }
160    }
161
162    strides_new
163}
164
165/// Returns the analysis of a reshape operation.
166pub fn reshape_analysis(
167    shape: &[usize],
168    strides: Option<&[usize]>,
169    shape_new: &[usize],
170) -> ReshapeAnalysis {
171    let shape_rank = shape.len();
172    let shape_new_rank = shape_new.len();
173
174    let is_contiguous = match strides {
175        Some(strides) => is_contiguous(shape, strides),
176        None => false,
177    };
178
179    if is_contiguous {
180        return ReshapeAnalysis::IsContiguous;
181    }
182
183    if shape_new_rank < shape_rank {
184        return ReshapeAnalysis::SmallerRank;
185    }
186
187    let n_new_batch = shape_new_rank - shape_rank;
188
189    match n_new_batch > 0 {
190        true => {
191            if shape == &shape_new[n_new_batch..shape_new_rank]
192                && shape_new[0..n_new_batch] == vec![1; n_new_batch]
193            {
194                return ReshapeAnalysis::Broadcasted;
195            } else {
196                let mut dim_prod = 1;
197                let mut old_idx = 0;
198                for dim in shape_new {
199                    dim_prod *= *dim;
200
201                    // We need to ignore unit dims because they don't affect analysis and break
202                    // things because they match the default `dim_prod`. If we don't do this,
203                    // reshapes like [2, 3] to [2, 3, 1] will panic from out of bounds access.
204                    if *dim == 1 {
205                        continue;
206                    } else if dim_prod == shape[old_idx] {
207                        dim_prod = 1;
208                        old_idx += 1;
209                    } else if dim_prod > shape[old_idx] {
210                        return ReshapeAnalysis::HighlyPermuted;
211                    }
212                }
213                return ReshapeAnalysis::Split;
214            }
215        }
216
217        false => {
218            if shape == shape_new {
219                return ReshapeAnalysis::NoChange;
220            }
221        }
222    };
223
224    ReshapeAnalysis::HighlyPermuted
225}