1#![deprecated(since = "0.1.0", note = "Use bimm-contracts instead")]
2
3pub mod shapes;
4#[cfg(any(test, feature = "testing"))]
5pub mod testing;
6
7use crate::shapes::ShapePatternError;
8use burn::prelude::{Backend, Float};
9use burn::tensor::{BasicOps, Tensor};
10use shapes::ShapePattern;
11
12#[derive(Clone, Debug)]
14pub struct TensorWrapper<'a, B, const D: usize, K = Float>
15where
16 B: Backend,
17 K: BasicOps<B>,
18{
19 inner: &'a Tensor<B, D, K>,
20}
21
22pub fn assert_tensor<B, const D: usize, K>(tensor: &Tensor<B, D, K>) -> TensorWrapper<'_, B, D, K>
24where
25 B: Backend,
26 K: BasicOps<B>,
27{
28 TensorWrapper { inner: tensor }
29}
30
31impl<B, const D: usize, K> TensorWrapper<'_, B, D, K>
32where
33 B: Backend,
34 K: BasicOps<B>,
35{
36 #[allow(clippy::must_use_candidate)]
58 pub fn has_dims(
59 &self,
60 dims: [usize; D],
61 ) -> &Self {
62 assert_eq!(
64 self.inner.dims(),
65 dims,
66 "Expected tensor to have dimensions {:?}, but got {:?}",
67 dims,
68 self.inner.dims()
69 );
70 self
71 }
72
73 pub fn unpacks_shape<const S: usize, C: shapes::ShapeBindingSource>(
119 &self,
120 keys: [&str; S],
121 pattern: &str,
122 bindings: C,
123 ) -> Result<[usize; S], ShapePatternError> {
124 Ok(ShapePattern::cached_parse(pattern)?
125 .match_bindings(&self.inner.dims(), bindings)?
126 .select(keys))
127 }
128
129 #[allow(clippy::must_use_candidate)]
151 pub fn has_named_dims(
152 &self,
153 dims: [(&str, usize); D],
154 ) -> &Self {
155 if self
156 .inner
157 .dims()
158 .iter()
159 .zip(dims.iter())
160 .all(|(&a, &(_, b))| a == b)
161 {
162 return self;
163 }
164
165 let actual = self
166 .inner
167 .dims()
168 .iter()
169 .zip(dims.iter())
170 .map(|(&d, &(n, _))| format!("{n}={d}"))
171 .collect::<Vec<String>>()
172 .join(", ");
173
174 let expected = dims
175 .iter()
176 .map(|&(n, d)| format!("{n}={d}"))
177 .collect::<Vec<String>>()
178 .join(", ");
179
180 panic!("Expected dims [{expected}], found [{actual}]")
181 }
182}
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use burn::backend::NdArray;
187 use burn::prelude::Backend;
188 use burn::tensor::Tensor;
189 use std::error::Error;
190
191 #[test]
192 fn test_unpacks_shape() {
193 impl_test_unpacks_shape::<NdArray>().unwrap();
194 }
195
196 #[allow(clippy::many_single_char_names)]
197 fn impl_test_unpacks_shape<B: Backend>() -> Result<(), Box<dyn Error>> {
198 let device = Default::default();
199 let tensor = Tensor::<B, 6>::zeros([2, 2, 2, 5 * 4, 4 * 4, 3], &device);
200
201 let [b, h, w] = assert_tensor(&tensor).unpacks_shape(
202 ["b", "h", "w"],
203 "b ... (h p) (w p) c",
204 &[("p", 4), ("c", 3)],
205 )?;
206
207 assert_eq!(b, 2);
208 assert_eq!(h, 5);
209 assert_eq!(w, 4);
210
211 Ok(())
212 }
213
214 #[test]
215 fn test_has_dims_passing() {
216 impl_has_dims_passing::<NdArray>();
217 }
218
219 fn impl_has_dims_passing<B: Backend>() {
220 let device = Default::default();
221 let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
222
223 assert_tensor(&tensor).has_dims([2, 1]);
224 }
225
226 #[test]
227 #[should_panic(expected = "Expected tensor to have dimensions [1, 2], but got [2, 1]")]
228 fn test_has_dims_failing() {
229 impl_has_dims_failing::<NdArray>();
230 }
231
232 fn impl_has_dims_failing<B: Backend>() {
233 let device = Default::default();
234 let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
235
236 assert_tensor(&tensor).has_dims([1, 2]);
237 }
238
239 #[test]
240 fn test_has_named_dims_passing() {
241 impl_has_named_dims_passing::<NdArray>();
242 }
243
244 fn impl_has_named_dims_passing<B: Backend>() {
245 let device = Default::default();
246 let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
247
248 assert_tensor(&tensor).has_named_dims([("rows", 2), ("cols", 1)]);
249 }
250
251 #[test]
252 #[should_panic(expected = "Expected dims [rows=1, cols=2], found [rows=2, cols=1]")]
253 fn test_has_named_dims_failing() {
254 impl_has_named_dims_failing::<NdArray>();
255 }
256
257 fn impl_has_named_dims_failing<B: Backend>() {
258 let device = Default::default();
259 let tensor = Tensor::<B, 2>::from_data([[2.], [3.]], &device);
260
261 assert_tensor(&tensor).has_named_dims([("rows", 1), ("cols", 2)]);
262 }
263}