burn_common/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5//! # Burn Common Library
6//!
7//! This library contains common types used by other Burn crates that must be shared.
8
9/// Id module contains types for unique identifiers.
10pub mod id;
11
12pub use cubecl_common::*;
13
14#[cfg(feature = "rayon")]
15pub use rayon;
16
17extern crate alloc;
18
19/// Network utilities.
20#[cfg(feature = "network")]
21pub mod network;
22
23/// Tensor utilities.
24pub mod tensor {
25    use alloc::vec;
26    use alloc::vec::Vec;
27
28    /// Check if the current tensor is contiguous.
29    ///
30    /// A tensor is considered contiguous if its elements are stored in memory
31    /// such that the stride at position `k` is equal to the product of the shapes
32    /// of all dimensions greater than `k`.
33    ///
34    /// This means that strides increase as you move from the rightmost to the leftmost dimension.
35    pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
36        if shape.is_empty() {
37            return true;
38        }
39
40        for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
41            if expected != stride {
42                return false;
43            }
44        }
45
46        true
47    }
48
49    /// Computes the strides for a contiguous tensor with the given shape.
50    ///
51    /// In a contiguous row-major tensor, the stride for each dimension
52    /// equals the product of all dimension sizes to its right.
53    pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
54        let mut strides = Vec::with_capacity(shape.len());
55        let mut current = 1;
56
57        for &dim in shape.iter().rev() {
58            strides.push(current);
59            current *= dim;
60        }
61
62        strides.reverse();
63        strides
64    }
65
66    /// The action to take for a reshape operation.
67    #[derive(Debug)]
68    pub enum ReshapeAction {
69        /// Updating the strides is sufficient to handle the reshape.
70        UpdateStrides {
71            /// The new strides.
72            strides: Vec<usize>,
73        },
74        /// The strides are not compatible, we should recompute the buffer.
75        Recompute,
76        /// The strides are already correct.
77        NoChange,
78    }
79
80    /// The reshape kind.
81    #[derive(Debug)]
82    pub enum ReshapeAnalysis {
83        /// Original tensor is contiguous, can update the strides.
84        IsContiguous,
85        /// Original tensor is highly permutated, can't update the strides.
86        HighlyPermutated,
87        /// Only batch dimensions are added, can update the strides.
88        Broadcasted,
89        /// Original tensor is bigger than output shape.
90        SmallerRank,
91        /// New shape is the same.
92        NoChange,
93    }
94
95    impl ReshapeAnalysis {
96        /// Returns the proper action to take for the current analysis.
97        fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
98            match self {
99                ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
100                    strides: contiguous_strides(shape_new),
101                },
102                ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
103                ReshapeAnalysis::HighlyPermutated | ReshapeAnalysis::SmallerRank => {
104                    ReshapeAction::Recompute
105                }
106                ReshapeAnalysis::Broadcasted => {
107                    let shape_rank = shape.len();
108                    let shape_new_rank = shape_new.len();
109                    let n_new_batch = shape_new_rank - shape_rank;
110                    let num_elems = shape.iter().product::<usize>();
111                    let strides_new =
112                        broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
113
114                    ReshapeAction::UpdateStrides {
115                        strides: strides_new,
116                    }
117                }
118            }
119        }
120    }
121
122    /// Returns the proper action to take when reshaping a tensor.
123    pub fn reshape_action(
124        shape: &[usize],
125        strides: &[usize],
126        shape_new: &[usize],
127    ) -> ReshapeAction {
128        reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
129    }
130
131    /// Calculate the new strides given added batch dimensions.
132    pub fn broadcast_strides(
133        n_new_batch: usize,
134        rank_prev: usize,
135        num_elems: usize,
136        strides: &[usize],
137    ) -> Vec<usize> {
138        let mut strides_new = vec![num_elems; rank_prev + n_new_batch];
139
140        for (i, s) in strides.iter().enumerate() {
141            strides_new[i + n_new_batch] = *s;
142        }
143
144        strides_new
145    }
146
147    /// Returns the analysis of a reshape operation.
148    pub fn reshape_analysis(
149        shape: &[usize],
150        strides: Option<&[usize]>,
151        shape_new: &[usize],
152    ) -> ReshapeAnalysis {
153        let shape_rank = shape.len();
154        let shape_new_rank = shape_new.len();
155
156        if shape_new_rank < shape_rank {
157            let is_contiguous = match strides {
158                Some(strides) => is_contiguous(shape, strides),
159                None => false,
160            };
161            return match is_contiguous {
162                true => ReshapeAnalysis::IsContiguous,
163                false => ReshapeAnalysis::SmallerRank,
164            };
165        }
166
167        let n_new_batch = shape_new_rank - shape_rank;
168
169        match n_new_batch > 0 {
170            true => {
171                if shape == &shape_new[n_new_batch..shape_new_rank]
172                    && shape_new[0..n_new_batch] == vec![1; n_new_batch]
173                {
174                    return ReshapeAnalysis::Broadcasted;
175                }
176            }
177
178            false => {
179                if shape == shape_new {
180                    return ReshapeAnalysis::NoChange;
181                } else {
182                    let is_contiguous = match strides {
183                        Some(strides) => is_contiguous(shape, strides),
184                        None => false,
185                    };
186                    if is_contiguous {
187                        return ReshapeAnalysis::IsContiguous;
188                    }
189                }
190            }
191        };
192
193        ReshapeAnalysis::HighlyPermutated
194    }
195}