burn_contracts/
lib.rs

1#![deprecated(since = "0.1.0", note = "Use bimm-contracts instead")]
2
3pub mod shapes;
4#[cfg(any(test, feature = "testing"))]
5pub mod testing;
6
7use crate::shapes::ShapePatternError;
8use burn::prelude::{Backend, Float};
9use burn::tensor::{BasicOps, Tensor};
10use shapes::ShapePattern;
11
12/// A wrapper around a Tensor that provides additional assertions.
13#[derive(Clone, Debug)]
14pub struct TensorWrapper<'a, B, const D: usize, K = Float>
15where
16    B: Backend,
17    K: BasicOps<B>,
18{
19    inner: &'a Tensor<B, D, K>,
20}
21
22/// Wrap a Tensor for test assertions.
23pub fn assert_tensor<B, const D: usize, K>(tensor: &Tensor<B, D, K>) -> TensorWrapper<'_, B, D, K>
24where
25    B: Backend,
26    K: BasicOps<B>,
27{
28    TensorWrapper { inner: tensor }
29}
30
31impl<B, const D: usize, K> TensorWrapper<'_, B, D, K>
32where
33    B: Backend,
34    K: BasicOps<B>,
35{
36    /// Assert that the wrapped tensor has the expected dimensions.
37    ///
38    /// ## Parameters
39    ///
40    /// - `dims`: The expected dimensions of the tensor.
41    ///
42    /// ## Panics
43    ///
44    /// Panics if the tensor does not have the expected dimensions.
45    ///
46    /// ## Example:
47    /// ```
48    /// use burn::backend::NdArray;
49    /// use burn::tensor::Tensor;
50    /// use burn_contracts::assert_tensor;
51    ///
52    /// let device = Default::default();
53    /// let tensor = Tensor::<NdArray, 2>::from_data([[2., 3.], [4., 5.]], &device);
54    ///
55    /// assert_tensor(&tensor).has_dims([2, 2]);
56    /// ```
57    #[allow(clippy::must_use_candidate)]
58    pub fn has_dims(
59        &self,
60        dims: [usize; D],
61    ) -> &Self {
62        // Example assertion
63        assert_eq!(
64            self.inner.dims(),
65            dims,
66            "Expected tensor to have dimensions {:?}, but got {:?}",
67            dims,
68            self.inner.dims()
69        );
70        self
71    }
72
73    /// Unpacks components of the shape of the tensor according to a pattern.
74    ///
75    /// ## Parameters
76    ///
77    /// - `keys`: The keys to select from the unpacked shape.
78    /// - `pattern`: The pattern to unpack the shape.
79    /// - `bindings`: The bindings to use for the unpacking.
80    ///
81    /// ## Returns
82    ///
83    /// The unpacked shape.
84    ///
85    /// ## Errors
86    ///
87    /// Returns an error if the pattern is invalid or the bindings are not found,
88    /// or do not match the shape.
89    ///
90    /// ## Example
91    ///
92    /// ```rust
93    /// #[cfg(test)]
94    /// mod tests {
95    ///    use burn::backend::NdArray;
96    ///    use burn::tensor::Tensor;
97    ///    use burn_contracts::assert_tensor;
98    ///
99    ///    #[test]
100    ///    fn example() -> Result<(), Box<dyn std::error::Error>> {
101    ///        let device = Default::default();
102    ///        let tensor = Tensor::<NdArray, 6>::zeros([2, 2, 2, 5 * 4, 4 * 4, 3], &device);
103    ///
104    ///        let [b, h, w] = assert_tensor(&tensor).unpacks_shape(
105    ///           ["b", "h", "w"],
106    ///           "b ... (h p) (w p) c",
107    ///           &[("p", 4), ("c", 3)],
108    ///        )?;
109    ///
110    ///        assert_eq!(b, 2);
111    ///        assert_eq!(h, 5);
112    ///        assert_eq!(w, 4);
113    ///
114    ///        Ok(())
115    ///    }
116    /// }
117    /// ```
118    pub fn unpacks_shape<const S: usize, C: shapes::ShapeBindingSource>(
119        &self,
120        keys: [&str; S],
121        pattern: &str,
122        bindings: C,
123    ) -> Result<[usize; S], ShapePatternError> {
124        Ok(ShapePattern::cached_parse(pattern)?
125            .match_bindings(&self.inner.dims(), bindings)?
126            .select(keys))
127    }
128
129    /// Assert that the wrapped tensor has the expected named dimensions.
130    ///
131    /// ## Parameters
132    ///
133    /// - `dims`: The expected named dimensions of the tensor.
134    ///
135    /// ## Panics
136    ///
137    /// Panics if the tensor does not have the expected named dimensions.
138    ///
139    /// ## Example:
140    /// ```
141    /// use burn::backend::NdArray;
142    /// use burn::tensor::Tensor;
143    /// use burn_contracts::assert_tensor;
144    ///
145    /// let device = Default::default();
146    /// let tensor = Tensor::<NdArray, 2>::from_data([[2., 3.], [4., 5.]], &device);
147    ///
148    /// assert_tensor(&tensor).has_named_dims([("rows", 2), ("cols", 2)]);
149    /// ```
150    #[allow(clippy::must_use_candidate)]
151    pub fn has_named_dims(
152        &self,
153        dims: [(&str, usize); D],
154    ) -> &Self {
155        if self
156            .inner
157            .dims()
158            .iter()
159            .zip(dims.iter())
160            .all(|(&a, &(_, b))| a == b)
161        {
162            return self;
163        }
164
165        let actual = self
166            .inner
167            .dims()
168            .iter()
169            .zip(dims.iter())
170            .map(|(&d, &(n, _))| format!("{n}={d}"))
171            .collect::<Vec<String>>()
172            .join(", ");
173
174        let expected = dims
175            .iter()
176            .map(|&(n, d)| format!("{n}={d}"))
177            .collect::<Vec<String>>()
178            .join(", ");
179
180        panic!("Expected dims [{expected}], found [{actual}]")
181    }
182}
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use burn::backend::NdArray;
187    use burn::prelude::Backend;
188    use burn::tensor::Tensor;
189    use std::error::Error;
190
191    #[test]
192    fn test_unpacks_shape() {
193        impl_test_unpacks_shape::<NdArray>().unwrap();
194    }
195
196    #[allow(clippy::many_single_char_names)]
197    fn impl_test_unpacks_shape<B: Backend>() -> Result<(), Box<dyn Error>> {
198        let device = Default::default();
199        let tensor = Tensor::<B, 6>::zeros([2, 2, 2, 5 * 4, 4 * 4, 3], &device);
200
201        let [b, h, w] = assert_tensor(&tensor).unpacks_shape(
202            ["b", "h", "w"],
203            "b ... (h p) (w p) c",
204            &[("p", 4), ("c", 3)],
205        )?;
206
207        assert_eq!(b, 2);
208        assert_eq!(h, 5);
209        assert_eq!(w, 4);
210
211        Ok(())
212    }
213
214    #[test]
215    fn test_has_dims_passing() {
216        impl_has_dims_passing::<NdArray>();
217    }
218
219    fn impl_has_dims_passing<B: Backend>() {
220        let device = Default::default();
221        let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
222
223        assert_tensor(&tensor).has_dims([2, 1]);
224    }
225
226    #[test]
227    #[should_panic(expected = "Expected tensor to have dimensions [1, 2], but got [2, 1]")]
228    fn test_has_dims_failing() {
229        impl_has_dims_failing::<NdArray>();
230    }
231
232    fn impl_has_dims_failing<B: Backend>() {
233        let device = Default::default();
234        let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
235
236        assert_tensor(&tensor).has_dims([1, 2]);
237    }
238
239    #[test]
240    fn test_has_named_dims_passing() {
241        impl_has_named_dims_passing::<NdArray>();
242    }
243
244    fn impl_has_named_dims_passing<B: Backend>() {
245        let device = Default::default();
246        let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
247
248        assert_tensor(&tensor).has_named_dims([("rows", 2), ("cols", 1)]);
249    }
250
251    #[test]
252    #[should_panic(expected = "Expected dims [rows=1, cols=2], found [rows=2, cols=1]")]
253    fn test_has_named_dims_failing() {
254        impl_has_named_dims_failing::<NdArray>();
255    }
256
257    fn impl_has_named_dims_failing<B: Backend>() {
258        let device = Default::default();
259        let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
260
261        assert_tensor(&tensor).has_named_dims([("rows", 1), ("cols", 2)]);
262    }
263}