burn_std/tensor/
mod.rs

1pub mod dtype;
2pub mod index_conversion;
3pub mod indexing;
4pub mod quantization;
5pub mod shape;
6pub mod slice;
7
8pub use dtype::*;
9pub use index_conversion::*;
10pub use indexing::*;
11pub use quantization::*;
12pub use shape::*;
13pub use slice::*;
14
15use alloc::{vec, vec::Vec};
16
17/// Check if the current tensor is contiguous.
18///
19/// A tensor is considered contiguous if its elements are stored in memory
20/// such that the stride at position `k` is equal to the product of the shapes
21/// of all dimensions greater than `k`.
22///
23/// This means that strides increase as you move from the rightmost to the leftmost dimension.
24pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
25    if shape.is_empty() {
26        return true;
27    }
28
29    for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
30        if expected != stride {
31            return false;
32        }
33    }
34
35    true
36}
37
38/// Computes the strides for a contiguous tensor with the given shape.
39///
40/// In a contiguous row-major tensor, the stride for each dimension
41/// equals the product of all dimension sizes to its right.
42pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
43    let mut strides = Vec::with_capacity(shape.len());
44    let mut current = 1;
45
46    for &dim in shape.iter().rev() {
47        strides.push(current);
48        current *= dim;
49    }
50
51    strides.reverse();
52    strides
53}
54
55/// The action to take for a reshape operation.
56#[derive(Debug)]
57pub enum ReshapeAction {
58    /// Updating the strides is sufficient to handle the reshape.
59    UpdateStrides {
60        /// The new strides.
61        strides: Vec<usize>,
62    },
63    /// The strides are not compatible, we should recompute the buffer.
64    Recompute,
65    /// The strides are already correct.
66    NoChange,
67}
68
69/// The reshape kind.
70#[derive(Debug)]
71pub enum ReshapeAnalysis {
72    /// Original tensor is contiguous, can update the strides.
73    IsContiguous,
74    /// Original tensor is highly permutated, can't update the strides.
75    HighlyPermutated,
76    /// Only batch dimensions are added, can update the strides.
77    Broadcasted,
78    /// Original tensor is bigger than output shape.
79    SmallerRank,
80    /// New shape is the same.
81    NoChange,
82}
83
84impl ReshapeAnalysis {
85    /// Returns the proper action to take for the current analysis.
86    fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
87        match self {
88            ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
89                strides: contiguous_strides(shape_new),
90            },
91            ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
92            ReshapeAnalysis::HighlyPermutated | ReshapeAnalysis::SmallerRank => {
93                ReshapeAction::Recompute
94            }
95            ReshapeAnalysis::Broadcasted => {
96                let shape_rank = shape.len();
97                let shape_new_rank = shape_new.len();
98                let n_new_batch = shape_new_rank - shape_rank;
99                let num_elems = shape.iter().product::<usize>();
100                let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
101
102                ReshapeAction::UpdateStrides {
103                    strides: strides_new,
104                }
105            }
106        }
107    }
108}
109
110/// Returns the proper action to take when reshaping a tensor.
111pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
112    reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
113}
114
115/// Calculate the new strides given added batch dimensions.
116pub fn broadcast_strides(
117    n_new_batch: usize,
118    rank_prev: usize,
119    num_elems: usize,
120    strides: &[usize],
121) -> Vec<usize> {
122    let mut strides_new = vec![num_elems; rank_prev + n_new_batch];
123
124    for (i, s) in strides.iter().enumerate() {
125        strides_new[i + n_new_batch] = *s;
126    }
127
128    strides_new
129}
130
131/// Returns the analysis of a reshape operation.
132pub fn reshape_analysis(
133    shape: &[usize],
134    strides: Option<&[usize]>,
135    shape_new: &[usize],
136) -> ReshapeAnalysis {
137    let shape_rank = shape.len();
138    let shape_new_rank = shape_new.len();
139
140    if shape_new_rank < shape_rank {
141        let is_contiguous = match strides {
142            Some(strides) => is_contiguous(shape, strides),
143            None => false,
144        };
145        return match is_contiguous {
146            true => ReshapeAnalysis::IsContiguous,
147            false => ReshapeAnalysis::SmallerRank,
148        };
149    }
150
151    let n_new_batch = shape_new_rank - shape_rank;
152
153    match n_new_batch > 0 {
154        true => {
155            if shape == &shape_new[n_new_batch..shape_new_rank]
156                && shape_new[0..n_new_batch] == vec![1; n_new_batch]
157            {
158                return ReshapeAnalysis::Broadcasted;
159            }
160        }
161
162        false => {
163            if shape == shape_new {
164                return ReshapeAnalysis::NoChange;
165            } else {
166                let is_contiguous = match strides {
167                    Some(strides) => is_contiguous(shape, strides),
168                    None => false,
169                };
170                if is_contiguous {
171                    return ReshapeAnalysis::IsContiguous;
172                }
173            }
174        }
175    };
176
177    ReshapeAnalysis::HighlyPermutated
178}