1pub mod shapes;
2#[cfg(any(test, feature = "testing"))]
3pub mod testing;
4
5use crate::shapes::ShapePatternError;
6use burn::prelude::{Backend, Float};
7use burn::tensor::{BasicOps, Tensor};
8use shapes::ShapePattern;
9
10#[derive(Clone, Debug)]
12pub struct TensorWrapper<'a, B, const D: usize, K = Float>
13where
14 B: Backend,
15 K: BasicOps<B>,
16{
17 inner: &'a Tensor<B, D, K>,
18}
19
20pub fn assert_tensor<B, const D: usize, K>(tensor: &Tensor<B, D, K>) -> TensorWrapper<B, D, K>
22where
23 B: Backend,
24 K: BasicOps<B>,
25{
26 TensorWrapper { inner: tensor }
27}
28
29impl<B, const D: usize, K> TensorWrapper<'_, B, D, K>
30where
31 B: Backend,
32 K: BasicOps<B>,
33{
34 #[allow(clippy::must_use_candidate)]
56 pub fn has_dims(
57 &self,
58 dims: [usize; D],
59 ) -> &Self {
60 assert_eq!(
62 self.inner.dims(),
63 dims,
64 "Expected tensor to have dimensions {:?}, but got {:?}",
65 dims,
66 self.inner.dims()
67 );
68 self
69 }
70
71 pub fn unpacks_shape<const S: usize, C: shapes::ShapeBindingSource>(
117 &self,
118 keys: [&str; S],
119 pattern: &str,
120 bindings: C,
121 ) -> Result<[usize; S], ShapePatternError> {
122 Ok(ShapePattern::cached_parse(pattern)?
123 .match_bindings(&self.inner.dims(), bindings)?
124 .select(keys))
125 }
126
127 #[allow(clippy::must_use_candidate)]
149 pub fn has_named_dims(
150 &self,
151 dims: [(&str, usize); D],
152 ) -> &Self {
153 if self
154 .inner
155 .dims()
156 .iter()
157 .zip(dims.iter())
158 .all(|(&a, &(_, b))| a == b)
159 {
160 return self;
161 }
162
163 let actual = self
164 .inner
165 .dims()
166 .iter()
167 .zip(dims.iter())
168 .map(|(&d, &(n, _))| format!("{n}={d}"))
169 .collect::<Vec<String>>()
170 .join(", ");
171
172 let expected = dims
173 .iter()
174 .map(|&(n, d)| format!("{n}={d}"))
175 .collect::<Vec<String>>()
176 .join(", ");
177
178 panic!("Expected dims [{expected}], found [{actual}]")
179 }
180}
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use burn::backend::NdArray;
185 use burn::prelude::Backend;
186 use burn::tensor::Tensor;
187 use std::error::Error;
188
189 #[test]
190 fn test_unpacks_shape() {
191 impl_test_unpacks_shape::<NdArray>().unwrap();
192 }
193
194 #[allow(clippy::many_single_char_names)]
195 fn impl_test_unpacks_shape<B: Backend>() -> Result<(), Box<dyn Error>> {
196 let device = Default::default();
197 let tensor = Tensor::<B, 6>::zeros([2, 2, 2, 5 * 4, 4 * 4, 3], &device);
198
199 let [b, h, w] = assert_tensor(&tensor).unpacks_shape(
200 ["b", "h", "w"],
201 "b ... (h p) (w p) c",
202 &[("p", 4), ("c", 3)],
203 )?;
204
205 assert_eq!(b, 2);
206 assert_eq!(h, 5);
207 assert_eq!(w, 4);
208
209 Ok(())
210 }
211
212 #[test]
213 fn test_has_dims_passing() {
214 impl_has_dims_passing::<NdArray>();
215 }
216
217 fn impl_has_dims_passing<B: Backend>() {
218 let device = Default::default();
219 let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
220
221 assert_tensor(&tensor).has_dims([2, 1]);
222 }
223
224 #[test]
225 #[should_panic(expected = "Expected tensor to have dimensions [1, 2], but got [2, 1]")]
226 fn test_has_dims_failing() {
227 impl_has_dims_failing::<NdArray>();
228 }
229
230 fn impl_has_dims_failing<B: Backend>() {
231 let device = Default::default();
232 let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
233
234 assert_tensor(&tensor).has_dims([1, 2]);
235 }
236
237 #[test]
238 fn test_has_named_dims_passing() {
239 impl_has_named_dims_passing::<NdArray>();
240 }
241
242 fn impl_has_named_dims_passing<B: Backend>() {
243 let device = Default::default();
244 let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
245
246 assert_tensor(&tensor).has_named_dims([("rows", 2), ("cols", 1)]);
247 }
248
249 #[test]
250 #[should_panic(expected = "Expected dims [rows=1, cols=2], found [rows=2, cols=1]")]
251 fn test_has_named_dims_failing() {
252 impl_has_named_dims_failing::<NdArray>();
253 }
254
255 fn impl_has_named_dims_failing<B: Backend>() {
256 let device = Default::default();
257 let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
258
259 assert_tensor(&tensor).has_named_dims([("rows", 1), ("cols", 2)]);
260 }
261}