bimm_contracts/contracts.rs
1//! # Shape Contracts.
2//!
3//! `bimm-contracts` is built around the [`ShapeContract`] interface.
4//! - A [`ShapeContract`] is a sequence of [`DimMatcher`]s.
5//! - A [`DimMatcher`] matches one or more dimensions of a shape:
6//! - [`DimMatcher::Any`] matches any dimension size.
7//! - [`DimMatcher::Ellipsis`] matches a variable number of dimensions (ellipsis).
8//! - [`DimMatcher::Expr`] matches a dimension expression that must match a specific value.
9//!
10//! A [`ShapeContract`] should usually be constructed using the [`crate::shape_contract`] macro.
11//!
12//! ## Example
13//!
14//! ```rust
15//! use bimm_contracts::{shape_contract, ShapeContract};
16//!
17//! static CONTRACT : ShapeContract = shape_contract![
18//! ...,
19//! "height" = "h_wins" * "window",
20//! "width" = "w_wins" * "window",
21//! "channels",
22//! ];
23//!
24//! let shape = [1, 2, 3, 2 * 8, 3 * 8, 4];
25//!
26//! // Assert the shape, given the bindings.
27//! let [h_wins, w_wins] = CONTRACT.unpack_shape(
28//! &shape,
29//! &["h_wins", "w_wins"],
30//! &[("window", 8)]
31//! );
32//! assert_eq!(h_wins, 2);
33//! assert_eq!(w_wins, 3);
34//! ```
35
36use crate::StackEnvironment;
37use crate::expressions::{DimExpr, ExprDisplayAdapter, MatchResult};
38use crate::shape_argument::ShapeArgument;
39use alloc::string::{String, ToString};
40use alloc::vec::Vec;
41use alloc::{format, vec};
42use core::fmt::{Display, Formatter};
43use core::panic::Location;
44
45/// A term in a shape pattern.
46///
47/// Users should generally use [`crate::shape_contract`] to construct patterns.
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum DimMatcher<'a> {
50 /// Matches any dimension size.
51 Any {
52 /// An optional label for the matcher.
53 label_id: Option<usize>,
54 },
55
56 /// Matches a variable number of dimensions (ellipsis).
57 Ellipsis {
58 /// An optional label for the matcher.
59 label_id: Option<usize>,
60 },
61
62 /// A dimension size expression that must match a specific value.
63 Expr {
64 /// An optional label for the matcher.
65 label_id: Option<usize>,
66
67 /// The dimension expression that must match a specific value.
68 expr: DimExpr<'a>,
69 },
70}
71
72impl<'a> DimMatcher<'a> {
73 /// Create a new `DimMatcher` that matches any dimension size.
74 pub const fn any() -> Self {
75 DimMatcher::Any { label_id: None }
76 }
77
78 /// Create a new `DimMatcher` that matches a variable number of dimensions (ellipsis).
79 pub const fn ellipsis() -> Self {
80 DimMatcher::Ellipsis { label_id: None }
81 }
82
83 /// Create a new `DimMatcher` from a dimension expression.
84 ///
85 /// ## Arguments
86 ///
87 /// - `expr`: a dimension expression that must match a specific value.
88 ///
89 /// ## Returns
90 ///
91 /// A new `DimMatcher` that matches the given expression.
92 pub const fn expr(expr: DimExpr<'a>) -> Self {
93 DimMatcher::Expr {
94 label_id: None,
95 expr,
96 }
97 }
98
99 /// Get the label of the matcher, if any.
100 pub const fn label_id(&self) -> Option<usize> {
101 match self {
102 DimMatcher::Any { label_id } => *label_id,
103 DimMatcher::Ellipsis { label_id } => *label_id,
104 DimMatcher::Expr { label_id, .. } => *label_id,
105 }
106 }
107
108 /// Attach a label to the matcher.
109 ///
110 /// ## Arguments
111 ///
112 /// - `label_id`: an optional label to attach to the matcher.
113 ///
114 /// ## Returns
115 ///
116 /// A new `DimMatcher` with the label attached.
117 pub const fn with_label_id(self, label_id: Option<usize>) -> Self {
118 match self {
119 DimMatcher::Any { .. } => DimMatcher::Any { label_id },
120 DimMatcher::Ellipsis { .. } => DimMatcher::Ellipsis { label_id },
121 DimMatcher::Expr { expr, .. } => DimMatcher::Expr { label_id, expr },
122 }
123 }
124}
125
126/// Display Adapter to format `DimMatchers` with a `Index`.
127pub struct MatcherDisplayAdapter<'a> {
128 index: &'a [&'a str],
129 matcher: &'a DimMatcher<'a>,
130}
131
132impl<'a> Display for MatcherDisplayAdapter<'a> {
133 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
134 if let Some(label_id) = self.matcher.label_id() {
135 write!(f, "{}=", self.index[label_id])?;
136 }
137 match self.matcher {
138 DimMatcher::Any { .. } => write!(f, "_"),
139 DimMatcher::Ellipsis { .. } => write!(f, "..."),
140 DimMatcher::Expr { expr, .. } => write!(
141 f,
142 "{}",
143 ExprDisplayAdapter {
144 index: self.index,
145 expr
146 }
147 ),
148 }
149 }
150}
151
152/// A shape pattern, which is a sequence of terms that can match a shape.
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub struct ShapeContract<'a> {
155 /// The slot index of the contract.
156 pub index: &'a [&'a str],
157
158 /// The terms in the pattern.
159 pub terms: &'a [DimMatcher<'a>],
160
161 /// The position of the ellipsis in the pattern, if any.
162 pub ellipsis_pos: Option<usize>,
163}
164
165impl Display for ShapeContract<'_> {
166 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
167 write!(f, "[")?;
168 for (idx, term) in self.terms.iter().enumerate() {
169 if idx > 0 {
170 write!(f, ", ")?;
171 }
172 write!(
173 f,
174 "{}",
175 MatcherDisplayAdapter {
176 index: self.index,
177 matcher: term
178 }
179 )?;
180 }
181 write!(f, "]")
182 }
183}
184
185impl<'a> ShapeContract<'a> {
186 /// Create a new shape pattern from a slice of terms.
187 ///
188 /// ## Arguments
189 ///
190 /// - `terms`: a slice of `ShapePatternTerm` that defines the pattern.
191 ///
192 /// ## Returns
193 ///
194 /// A new `ShapePattern` instance.
195 ///
196 /// ## Macro Support
197 ///
198 /// Consider using the [`crate::shape_contract`] macro instead.
199 ///
200 /// ```
201 /// use bimm_contracts::{shape_contract, ShapeContract};
202 ///
203 /// static CONTRACT : ShapeContract = shape_contract![
204 /// ...,
205 /// "height" = "h_wins" * "window",
206 /// "width" = "w_wins" * "window",
207 /// "channels",
208 /// ];
209 /// ```
210 pub const fn new(index: &'a [&'a str], terms: &'a [DimMatcher<'a>]) -> Self {
211 let mut i = 0;
212 let mut ellipsis_pos: Option<usize> = None;
213
214 while i < terms.len() {
215 if matches!(terms[i], DimMatcher::Ellipsis { .. }) {
216 match ellipsis_pos {
217 Some(_) => panic!("Multiple ellipses in pattern"),
218 None => ellipsis_pos = Some(i),
219 }
220 }
221 i += 1;
222 }
223
224 ShapeContract {
225 index,
226 terms,
227 ellipsis_pos,
228 }
229 }
230
231 /// Convert a key to an index.
232 pub fn maybe_key_to_index(&self, key: &str) -> Option<usize> {
233 self.index.iter().position(|&s| s == key)
234 }
235
236 /// Assert that the shape matches the pattern.
237 ///
238 /// ## Arguments
239 ///
240 /// - `shape`: the shape to match.
241 /// - `env`: the params which are already bound.
242 ///
243 /// ## Panics
244 ///
245 /// If the shape does not match the pattern, or if there is a conflict in the bindings.
246 ///
247 /// ## Example
248 ///
249 /// ```rust
250 /// use bimm_contracts::{shape_contract, run_periodically, ShapeContract};
251 ///
252 /// let shape = [1, 2, 3, 2 * 8, 3 * 8, 4];
253 ///
254 /// // Run under backoff amortization.
255 /// run_periodically! {{
256 /// // Statically allocated contract.
257 /// static CONTRACT : ShapeContract = shape_contract![
258 /// ...,
259 /// "height" = "h_wins" * "window",
260 /// "width" = "w_wins" * "window",
261 /// "channels",
262 /// ];
263 ///
264 /// // Assert the shape, given the bindings.
265 /// CONTRACT.assert_shape(
266 /// &shape,
267 /// &[("h_wins", 2), ("w_wins", 3), ("channels", 4)]
268 /// );
269 /// }}
270 /// ```
271 #[track_caller]
272 pub fn assert_shape<S>(&'a self, shape: S, env: StackEnvironment<'a>)
273 where
274 S: ShapeArgument,
275 {
276 match self._loc_try_assert_shape(shape, env, Location::caller()) {
277 Ok(()) => (),
278 Err(msg) => panic!("{}", msg),
279 }
280 }
281
282 /// Assert that the shape matches the pattern.
283 ///
284 /// ## Arguments
285 ///
286 /// - `shape`: the shape to match.
287 /// - `env`: the params which are already bound.
288 ///
289 /// ## Returns
290 ///
291 /// - `Ok(())`: if the shape matches the pattern.
292 /// - `Err(String)`: if the shape does not match the pattern, with an error message.
293 ///
294 /// ## Example
295 ///
296 /// ```rust
297 /// use bimm_contracts::{shape_contract, run_periodically, ShapeContract};
298 ///
299 /// let shape = [1, 2, 3, 2 * 8, 3 * 8, 4];
300 ///
301 /// // Statically allocated contract.
302 /// static CONTRACT : ShapeContract = shape_contract![
303 /// ...,
304 /// "height" = "h_wins" * "window",
305 /// "width" = "w_wins" * "window",
306 /// "channels",
307 /// ];
308 ///
309 /// // Assert the shape, given the bindings; or throw.
310 /// CONTRACT.try_assert_shape(
311 /// &shape,
312 /// &[("h_wins", 2), ("w_wins", 3), ("channels", 4)]
313 /// ).unwrap();
314 /// ```
315 #[track_caller]
316 pub fn try_assert_shape<S>(&'a self, shape: S, env: StackEnvironment<'a>) -> Result<(), String>
317 where
318 S: ShapeArgument,
319 {
320 self._loc_try_assert_shape(shape, env, Location::caller())
321 }
322
323 fn _loc_try_assert_shape<S>(
324 &'a self,
325 shape: S,
326 env: StackEnvironment<'a>,
327 loc: &Location<'a>,
328 ) -> Result<(), String>
329 where
330 S: ShapeArgument,
331 {
332 let mut scratch: Vec<Option<isize>> = vec![None; self.index.len()];
333 for (k, v) in env.iter() {
334 let v = *v as isize;
335 match self.maybe_key_to_index(k) {
336 Some(param_id) => scratch[param_id] = Some(v),
337 None => {
338 return Err(
339 format!("The key \"{k}\" is not indexed in the contract:\n{self}")
340 .to_string(),
341 );
342 }
343 }
344 }
345
346 self.format_resolve(shape, scratch.as_mut_slice(), loc)
347 }
348
349 /// Match and unpack `K` keys from a shape pattern.
350 ///
351 /// Wraps `try_unpack_shape` and panics if the shape does not match.
352 ///
353 /// ## Generics
354 ///
355 /// - `K`: the length of the `keys` array.
356 ///
357 /// ## Arguments
358 ///
359 /// - `shape`: the shape to match.
360 /// - `keys`: the bound keys to export.
361 /// - `env`: the params which are already bound.
362 ///
363 /// ## Returns
364 ///
365 /// An `[usize; K]` of the unpacked `keys` values.
366 ///
367 /// ## Panics
368 ///
369 /// If the shape does not match the pattern, or if there is a conflict in the bindings.
370 ///
371 /// ## Example
372 ///
373 /// ```rust
374 /// use bimm_contracts::{shape_contract, run_periodically, ShapeContract};
375 ///
376 /// let shape = [1, 2, 3, 2 * 8, 3 * 8, 4];
377 ///
378 /// // Statically allocated contract.
379 /// static CONTRACT : ShapeContract = shape_contract![
380 /// ...,
381 /// "height" = "h_wins" * "window",
382 /// "width" = "w_wins" * "window",
383 /// "channels",
384 /// ];
385 ///
386 /// // Unpack the shape, given the bindings.
387 /// let [h, w, c] = CONTRACT.unpack_shape(
388 /// &shape,
389 /// &["h_wins", "w_wins", "channels"],
390 /// &[("window", 8)]
391 /// );
392 /// assert_eq!(h, 2);
393 /// assert_eq!(w, 3);
394 /// assert_eq!(c, 4);
395 /// ```
396 #[must_use]
397 #[track_caller]
398 pub fn unpack_shape<S, const K: usize>(
399 &'a self,
400 shape: S,
401 keys: &[&'a str; K],
402 env: StackEnvironment<'a>,
403 ) -> [usize; K]
404 where
405 S: ShapeArgument,
406 {
407 self._loc_unpack_shape(shape, keys, env, Location::caller())
408 }
409
410 fn _loc_unpack_shape<S, const K: usize>(
411 &'a self,
412 shape: S,
413 keys: &[&'a str; K],
414 env: StackEnvironment<'a>,
415 loc: &Location<'a>,
416 ) -> [usize; K]
417 where
418 S: ShapeArgument,
419 {
420 match self._loc_try_unpack_shape(shape, keys, env, loc) {
421 Ok(values) => values,
422 Err(msg) => panic!("{msg}"),
423 }
424 }
425
426 /// Try and match and unpack `K` keys from a shape pattern.
427 ///
428 /// ## Generics
429 ///
430 /// - `K`: the length of the `keys` array.
431 ///
432 /// ## Arguments
433 ///
434 /// - `shape`: the shape to match.
435 /// - `keys`: the bound keys to export.
436 /// - `env`: the params which are already bound.
437 ///
438 /// ## Returns
439 ///
440 /// A `Result<[usize; K], String>` of the unpacked `keys` values.
441 ///
442 /// ## Example
443 ///
444 /// ```rust
445 /// use bimm_contracts::{shape_contract, run_periodically, ShapeContract};
446 ///
447 /// let shape = [1, 2, 3, 2 * 8, 3 * 8, 4];
448 ///
449 /// // Statically allocated contract.
450 /// static CONTRACT : ShapeContract = shape_contract![
451 /// ...,
452 /// "height" = "h_wins" * "window",
453 /// "width" = "w_wins" * "window",
454 /// "channels",
455 /// ];
456 ///
457 /// // Unpack the shape, given the bindings; or throw.
458 /// let [h, w, c] = CONTRACT.try_unpack_shape(
459 /// &shape,
460 /// &["h_wins", "w_wins", "channels"],
461 /// &[("window", 8)]
462 /// ).unwrap();
463 /// assert_eq!(h, 2);
464 /// assert_eq!(w, 3);
465 /// assert_eq!(c, 4);
466 /// ```
467 #[track_caller]
468 pub fn try_unpack_shape<S, const K: usize>(
469 &'a self,
470 shape: S,
471 keys: &[&'a str; K],
472 env: StackEnvironment<'a>,
473 ) -> Result<[usize; K], String>
474 where
475 S: ShapeArgument,
476 {
477 self._loc_try_unpack_shape(shape, keys, env, Location::caller())
478 }
479
480 fn _loc_try_unpack_shape<S, const K: usize>(
481 &'a self,
482 shape: S,
483 keys: &[&'a str; K],
484 env: StackEnvironment<'a>,
485 loc: &Location<'a>,
486 ) -> Result<[usize; K], String>
487 where
488 S: ShapeArgument,
489 {
490 let selection = self.expect_keys_to_selection(keys);
491
492 let mut scratch: Vec<Option<isize>> = vec![None; self.index.len()];
493 for (k, v) in env.iter() {
494 let v = *v as isize;
495 match self.maybe_key_to_index(k) {
496 Some(param_id) => scratch[param_id] = Some(v),
497 None => {
498 return Err(
499 format!("The key \"{k}\" is not indexed in the contract:\n{self}")
500 .to_string(),
501 );
502 }
503 }
504 }
505
506 let selected: [isize; K] =
507 self._loc_try_select(shape, &selection, scratch.as_mut_slice(), loc)?;
508
509 let result: [usize; K] = selected
510 .into_iter()
511 .map(|v| v as usize)
512 .collect::<Vec<usize>>()
513 .try_into()
514 .unwrap();
515
516 Ok(result)
517 }
518
519 /// Convert a list of keys to a selection.
520 pub fn expect_keys_to_selection<const D: usize>(&'a self, keys: &[&'a str; D]) -> [usize; D] {
521 let mut selection = [0; D];
522 for (i, key) in keys.iter().enumerate() {
523 match self.maybe_key_to_index(key) {
524 Some(param_id) => selection[i] = param_id,
525 None => panic!("The key \"{key}\" is not indexed in the contract:\n{self}"),
526 }
527 }
528 selection
529 }
530
531 fn _loc_try_select<S, const K: usize>(
532 &'a self,
533 shape: S,
534 selection: &[usize; K],
535 env: &mut [Option<isize>],
536 loc: &Location<'a>,
537 ) -> Result<[isize; K], String>
538 where
539 S: ShapeArgument,
540 {
541 let num_slots = self.index.len();
542 assert_eq!(env.len(), num_slots);
543
544 self.format_resolve(shape, env, loc)?;
545
546 let mut out = [0; K];
547 for (i, &k) in selection.iter().enumerate() {
548 out[i] = env[k].unwrap();
549 }
550 Ok(out)
551 }
552
553 /// Resolve the match for the shape against the pattern.
554 ///
555 /// ## Arguments
556 ///
557 /// - `shape`: the shape to match.
558 /// - `env`: the mutable environment to bind parameters.
559 /// - `location`: the location reference from ``#[track_caller]``.
560 ///
561 /// ## Returns
562 ///
563 /// - `Ok(())`: if the shape matches the pattern; will update the `env`.
564 /// - `Err(&str)`: if the shape does not match the pattern, with an error message.
565 pub(crate) fn format_resolve<S>(
566 &'a self,
567 shape: S,
568 env: &mut [Option<isize>],
569 location: &Location,
570 ) -> Result<(), String>
571 where
572 S: ShapeArgument,
573 {
574 let shape = shape.get_shape_vec();
575 match self._resolve(&shape, env) {
576 Ok(()) => Ok(()),
577 Err(msg) => Err(format!(
578 "at {}:{}: Shape Error\n {msg}\nActual:\n {shape:?}\nContract:\n {self}\nBindings:\n {{{}}}",
579 location.file(),
580 location.line(),
581 self.index
582 .iter()
583 .zip(env.iter())
584 .filter(|(_, v)| v.is_some())
585 .map(|(k, v)| format!("\"{}\": {}", *k, v.unwrap()))
586 .collect::<Vec<_>>()
587 .join(", ")
588 )),
589 }
590 }
591
592 /// Low-level resolver.
593 pub fn _resolve(&'a self, shape: &[usize], env: &mut [Option<isize>]) -> Result<(), String> {
594 let rank = shape.len();
595
596 let fail_at = |shape_idx: usize, term_idx: usize, msg: &str| -> String {
597 format!(
598 "{} !~ {} :: {msg}",
599 shape[shape_idx],
600 MatcherDisplayAdapter {
601 index: self.index,
602 matcher: &self.terms[term_idx]
603 }
604 )
605 };
606
607 let (e_start, e_size) = match self.try_ellipsis_split(rank) {
608 Ok((e_start, e_size)) => (e_start, e_size),
609 Err(msg) => return Err(msg),
610 };
611
612 for (shape_idx, &dim_size) in shape.iter().enumerate() {
613 let dim_size = dim_size as isize;
614
615 let term_idx = if shape_idx < e_start {
616 shape_idx
617 } else if shape_idx < (e_start + e_size) {
618 continue;
619 } else {
620 shape_idx + 1 - e_size
621 };
622
623 let matcher = &self.terms[term_idx];
624 if let Some(label_id) = matcher.label_id() {
625 match env[label_id] {
626 Some(value) => {
627 if value != dim_size {
628 return Err(fail_at(shape_idx, term_idx, "Value MissMatch."));
629 }
630 }
631 None => {
632 env[label_id] = Some(dim_size);
633 }
634 }
635 }
636
637 let expr = match matcher {
638 DimMatcher::Any { .. } => continue,
639 DimMatcher::Expr { expr, .. } => expr,
640 DimMatcher::Ellipsis { .. } => {
641 unreachable!("Ellipsis should have been handled before")
642 }
643 };
644
645 match expr.try_match(dim_size, env) {
646 Ok(MatchResult::Match) => continue,
647 Ok(MatchResult::Conflict) => {
648 return Err(fail_at(shape_idx, term_idx, "Value MissMatch."));
649 }
650 Ok(MatchResult::ParamConstraint { id, value }) => {
651 env[id] = Some(value);
652 }
653 Err(msg) => return Err(fail_at(shape_idx, term_idx, msg)),
654 }
655 }
656
657 Ok(())
658 }
659
660 /// Check if the pattern has an ellipsis.
661 ///
662 /// ## Arguments
663 ///
664 /// - `rank`: the number of dims of the shape to match.
665 ///
666 /// ## Returns
667 ///
668 /// - `Ok((usize, usize))`: the position of the ellipsis and the number of dimensions it matches.
669 /// - `Err(String)`: an error message if the pattern does not match the expected size.
670 fn try_ellipsis_split(&self, rank: usize) -> Result<(usize, usize), String> {
671 let k = self.terms.len();
672 match self.ellipsis_pos {
673 None => {
674 if rank != k {
675 Err(format!("Shape rank {rank} != pattern dim count {k}",))
676 } else {
677 Ok((k, 0))
678 }
679 }
680 Some(pos) => {
681 let non_ellipsis_terms = k - 1;
682 if rank < non_ellipsis_terms {
683 return Err(format!(
684 "Shape rank {rank} < non-ellipsis pattern term count {non_ellipsis_terms}",
685 ));
686 }
687 Ok((pos, rank - non_ellipsis_terms))
688 }
689 }
690 }
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696 use crate::expressions::DimExpr;
697
698 #[test]
699 fn test_unpack_shape() {
700 static CONTRACT: ShapeContract = ShapeContract::new(
701 &["b", "h", "w", "p", "z", "c"],
702 &[
703 DimMatcher::any(),
704 DimMatcher::expr(DimExpr::Param { id: 0 }),
705 DimMatcher::ellipsis(),
706 DimMatcher::expr(DimExpr::Prod {
707 children: &[DimExpr::Param { id: 1 }, DimExpr::Param { id: 3 }],
708 }),
709 DimMatcher::expr(DimExpr::Prod {
710 children: &[DimExpr::Param { id: 2 }, DimExpr::Param { id: 3 }],
711 }),
712 DimMatcher::expr(DimExpr::Pow {
713 base: &DimExpr::Param { id: 4 },
714 exp: 3,
715 }),
716 DimMatcher::expr(DimExpr::Param { id: 5 }),
717 ],
718 );
719
720 let b = 2;
721 let h = 3;
722 let w = 2;
723 let p = 4;
724 let c = 5;
725 let z = 4;
726
727 let shape = [12, b, 1, 2, 3, h * p, w * p, z * z * z, c];
728 let env = [("p", p), ("c", c)];
729
730 CONTRACT.assert_shape(&shape, &env);
731
732 let [u_b, u_h, u_w, u_z] = CONTRACT.unpack_shape(&shape, &["b", "h", "w", "z"], &env);
733
734 assert_eq!(u_b, b);
735 assert_eq!(u_h, h);
736 assert_eq!(u_w, w);
737 assert_eq!(u_z, z);
738 }
739}