burn_std/tensor/
mod.rs

1pub mod dtype;
2pub mod indexing;
3pub mod quantization;
4pub mod shape;
5pub mod slice;
6
7pub use dtype::*;
8pub use indexing::*;
9pub use quantization::*;
10pub use shape::*;
11pub use slice::*;
12
13use alloc::{vec, vec::Vec};
14
15/// Check if the current tensor is contiguous.
16///
17/// A tensor is considered contiguous if its elements are stored in memory
18/// such that the stride at position `k` is equal to the product of the shapes
19/// of all dimensions greater than `k`.
20///
21/// This means that strides increase as you move from the rightmost to the leftmost dimension.
22pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
23    if shape.is_empty() {
24        return true;
25    }
26
27    for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
28        if expected != stride {
29            return false;
30        }
31    }
32
33    true
34}
35
36/// Computes the strides for a contiguous tensor with the given shape.
37///
38/// In a contiguous row-major tensor, the stride for each dimension
39/// equals the product of all dimension sizes to its right.
40pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
41    let mut strides = Vec::with_capacity(shape.len());
42    let mut current = 1;
43
44    for &dim in shape.iter().rev() {
45        strides.push(current);
46        current *= dim;
47    }
48
49    strides.reverse();
50    strides
51}
52
53/// The action to take for a reshape operation.
54#[derive(Debug)]
55pub enum ReshapeAction {
56    /// Updating the strides is sufficient to handle the reshape.
57    UpdateStrides {
58        /// The new strides.
59        strides: Vec<usize>,
60    },
61    /// The strides are not compatible, we should recompute the buffer.
62    Recompute,
63    /// The strides are already correct.
64    NoChange,
65}
66
67/// The reshape kind.
68#[derive(Debug)]
69pub enum ReshapeAnalysis {
70    /// Original tensor is contiguous, can update the strides.
71    IsContiguous,
72    /// Original tensor is highly permutated, can't update the strides.
73    HighlyPermutated,
74    /// Only batch dimensions are added, can update the strides.
75    Broadcasted,
76    /// Original tensor is bigger than output shape.
77    SmallerRank,
78    /// New shape is the same.
79    NoChange,
80}
81
82impl ReshapeAnalysis {
83    /// Returns the proper action to take for the current analysis.
84    fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
85        match self {
86            ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
87                strides: contiguous_strides(shape_new),
88            },
89            ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
90            ReshapeAnalysis::HighlyPermutated | ReshapeAnalysis::SmallerRank => {
91                ReshapeAction::Recompute
92            }
93            ReshapeAnalysis::Broadcasted => {
94                let shape_rank = shape.len();
95                let shape_new_rank = shape_new.len();
96                let n_new_batch = shape_new_rank - shape_rank;
97                let num_elems = shape.iter().product::<usize>();
98                let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
99
100                ReshapeAction::UpdateStrides {
101                    strides: strides_new,
102                }
103            }
104        }
105    }
106}
107
108/// Returns the proper action to take when reshaping a tensor.
109pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
110    reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
111}
112
113/// Calculate the new strides given added batch dimensions.
114pub fn broadcast_strides(
115    n_new_batch: usize,
116    rank_prev: usize,
117    num_elems: usize,
118    strides: &[usize],
119) -> Vec<usize> {
120    let mut strides_new = vec![num_elems; rank_prev + n_new_batch];
121
122    for (i, s) in strides.iter().enumerate() {
123        strides_new[i + n_new_batch] = *s;
124    }
125
126    strides_new
127}
128
129/// Returns the analysis of a reshape operation.
130pub fn reshape_analysis(
131    shape: &[usize],
132    strides: Option<&[usize]>,
133    shape_new: &[usize],
134) -> ReshapeAnalysis {
135    let shape_rank = shape.len();
136    let shape_new_rank = shape_new.len();
137
138    if shape_new_rank < shape_rank {
139        let is_contiguous = match strides {
140            Some(strides) => is_contiguous(shape, strides),
141            None => false,
142        };
143        return match is_contiguous {
144            true => ReshapeAnalysis::IsContiguous,
145            false => ReshapeAnalysis::SmallerRank,
146        };
147    }
148
149    let n_new_batch = shape_new_rank - shape_rank;
150
151    match n_new_batch > 0 {
152        true => {
153            if shape == &shape_new[n_new_batch..shape_new_rank]
154                && shape_new[0..n_new_batch] == vec![1; n_new_batch]
155            {
156                return ReshapeAnalysis::Broadcasted;
157            }
158        }
159
160        false => {
161            if shape == shape_new {
162                return ReshapeAnalysis::NoChange;
163            } else {
164                let is_contiguous = match strides {
165                    Some(strides) => is_contiguous(shape, strides),
166                    None => false,
167                };
168                if is_contiguous {
169                    return ReshapeAnalysis::IsContiguous;
170                }
171            }
172        }
173    };
174
175    ReshapeAnalysis::HighlyPermutated
176}