burn_contracts/
lib.rs

1pub mod shapes;
2#[cfg(any(test, feature = "testing"))]
3pub mod testing;
4
5use crate::shapes::ShapePatternError;
6use burn::prelude::{Backend, Float};
7use burn::tensor::{BasicOps, Tensor};
8use shapes::ShapePattern;
9
10/// A wrapper around a Tensor that provides additional assertions.
11#[derive(Clone, Debug)]
12pub struct TensorWrapper<'a, B, const D: usize, K = Float>
13where
14    B: Backend,
15    K: BasicOps<B>,
16{
17    inner: &'a Tensor<B, D, K>,
18}
19
20/// Wrap a Tensor for test assertions.
21pub fn assert_tensor<B, const D: usize, K>(tensor: &Tensor<B, D, K>) -> TensorWrapper<B, D, K>
22where
23    B: Backend,
24    K: BasicOps<B>,
25{
26    TensorWrapper { inner: tensor }
27}
28
29impl<B, const D: usize, K> TensorWrapper<'_, B, D, K>
30where
31    B: Backend,
32    K: BasicOps<B>,
33{
34    /// Assert that the wrapped tensor has the expected dimensions.
35    ///
36    /// ## Parameters
37    ///
38    /// - `dims`: The expected dimensions of the tensor.
39    ///
40    /// ## Panics
41    ///
42    /// Panics if the tensor does not have the expected dimensions.
43    ///
44    /// ## Example:
45    /// ```
46    /// use burn::backend::NdArray;
47    /// use burn::tensor::Tensor;
48    /// use burn_contracts::assert_tensor;
49    ///
50    /// let device = Default::default();
51    /// let tensor = Tensor::<NdArray, 2>::from_data([[2., 3.], [4., 5.]], &device);
52    ///
53    /// assert_tensor(&tensor).has_dims([2, 2]);
54    /// ```
55    #[allow(clippy::must_use_candidate)]
56    pub fn has_dims(
57        &self,
58        dims: [usize; D],
59    ) -> &Self {
60        // Example assertion
61        assert_eq!(
62            self.inner.dims(),
63            dims,
64            "Expected tensor to have dimensions {:?}, but got {:?}",
65            dims,
66            self.inner.dims()
67        );
68        self
69    }
70
71    /// Unpacks components of the shape of the tensor according to a pattern.
72    ///
73    /// ## Parameters
74    ///
75    /// - `keys`: The keys to select from the unpacked shape.
76    /// - `pattern`: The pattern to unpack the shape.
77    /// - `bindings`: The bindings to use for the unpacking.
78    ///
79    /// ## Returns
80    ///
81    /// The unpacked shape.
82    ///
83    /// ## Errors
84    ///
85    /// Returns an error if the pattern is invalid or the bindings are not found,
86    /// or do not match the shape.
87    ///
88    /// ## Example
89    ///
90    /// ```rust
91    /// #[cfg(test)]
92    /// mod tests {
93    ///    use burn::backend::NdArray;
94    ///    use burn::tensor::Tensor;
95    ///    use burn_contracts::assert_tensor;
96    ///
97    ///    #[test]
98    ///    fn example() -> Result<(), Box<dyn std::error::Error>> {
99    ///        let device = Default::default();
100    ///        let tensor = Tensor::<NdArray, 6>::zeros([2, 2, 2, 5 * 4, 4 * 4, 3], &device);
101    ///
102    ///        let [b, h, w] = assert_tensor(&tensor).unpacks_shape(
103    ///           ["b", "h", "w"],
104    ///           "b ... (h p) (w p) c",
105    ///           &[("p", 4), ("c", 3)],
106    ///        )?;
107    ///
108    ///        assert_eq!(b, 2);
109    ///        assert_eq!(h, 5);
110    ///        assert_eq!(w, 4);
111    ///
112    ///        Ok(())
113    ///    }
114    /// }
115    /// ```
116    pub fn unpacks_shape<const S: usize, C: shapes::ShapeBindingSource>(
117        &self,
118        keys: [&str; S],
119        pattern: &str,
120        bindings: C,
121    ) -> Result<[usize; S], ShapePatternError> {
122        Ok(ShapePattern::cached_parse(pattern)?
123            .match_bindings(&self.inner.dims(), bindings)?
124            .select(keys))
125    }
126
127    /// Assert that the wrapped tensor has the expected named dimensions.
128    ///
129    /// ## Parameters
130    ///
131    /// - `dims`: The expected named dimensions of the tensor.
132    ///
133    /// ## Panics
134    ///
135    /// Panics if the tensor does not have the expected named dimensions.
136    ///
137    /// ## Example:
138    /// ```
139    /// use burn::backend::NdArray;
140    /// use burn::tensor::Tensor;
141    /// use burn_contracts::assert_tensor;
142    ///
143    /// let device = Default::default();
144    /// let tensor = Tensor::<NdArray, 2>::from_data([[2., 3.], [4., 5.]], &device);
145    ///
146    /// assert_tensor(&tensor).has_named_dims([("rows", 2), ("cols", 2)]);
147    /// ```
148    #[allow(clippy::must_use_candidate)]
149    pub fn has_named_dims(
150        &self,
151        dims: [(&str, usize); D],
152    ) -> &Self {
153        if self
154            .inner
155            .dims()
156            .iter()
157            .zip(dims.iter())
158            .all(|(&a, &(_, b))| a == b)
159        {
160            return self;
161        }
162
163        let actual = self
164            .inner
165            .dims()
166            .iter()
167            .zip(dims.iter())
168            .map(|(&d, &(n, _))| format!("{n}={d}"))
169            .collect::<Vec<String>>()
170            .join(", ");
171
172        let expected = dims
173            .iter()
174            .map(|&(n, d)| format!("{n}={d}"))
175            .collect::<Vec<String>>()
176            .join(", ");
177
178        panic!("Expected dims [{expected}], found [{actual}]")
179    }
180}
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use burn::backend::NdArray;
185    use burn::prelude::Backend;
186    use burn::tensor::Tensor;
187    use std::error::Error;
188
189    #[test]
190    fn test_unpacks_shape() {
191        impl_test_unpacks_shape::<NdArray>().unwrap();
192    }
193
194    #[allow(clippy::many_single_char_names)]
195    fn impl_test_unpacks_shape<B: Backend>() -> Result<(), Box<dyn Error>> {
196        let device = Default::default();
197        let tensor = Tensor::<B, 6>::zeros([2, 2, 2, 5 * 4, 4 * 4, 3], &device);
198
199        let [b, h, w] = assert_tensor(&tensor).unpacks_shape(
200            ["b", "h", "w"],
201            "b ... (h p) (w p) c",
202            &[("p", 4), ("c", 3)],
203        )?;
204
205        assert_eq!(b, 2);
206        assert_eq!(h, 5);
207        assert_eq!(w, 4);
208
209        Ok(())
210    }
211
212    #[test]
213    fn test_has_dims_passing() {
214        impl_has_dims_passing::<NdArray>();
215    }
216
217    fn impl_has_dims_passing<B: Backend>() {
218        let device = Default::default();
219        let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
220
221        assert_tensor(&tensor).has_dims([2, 1]);
222    }
223
224    #[test]
225    #[should_panic(expected = "Expected tensor to have dimensions [1, 2], but got [2, 1]")]
226    fn test_has_dims_failing() {
227        impl_has_dims_failing::<NdArray>();
228    }
229
230    fn impl_has_dims_failing<B: Backend>() {
231        let device = Default::default();
232        let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
233
234        assert_tensor(&tensor).has_dims([1, 2]);
235    }
236
237    #[test]
238    fn test_has_named_dims_passing() {
239        impl_has_named_dims_passing::<NdArray>();
240    }
241
242    fn impl_has_named_dims_passing<B: Backend>() {
243        let device = Default::default();
244        let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
245
246        assert_tensor(&tensor).has_named_dims([("rows", 2), ("cols", 1)]);
247    }
248
249    #[test]
250    #[should_panic(expected = "Expected dims [rows=1, cols=2], found [rows=2, cols=1]")]
251    fn test_has_named_dims_failing() {
252        impl_has_named_dims_failing::<NdArray>();
253    }
254
255    fn impl_has_named_dims_failing<B: Backend>() {
256        let device = Default::default();
257        let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
258
259        assert_tensor(&tensor).has_named_dims([("rows", 1), ("cols", 2)]);
260    }
261}