bimm_contracts/
contracts.rs

1//! # Shape Contracts.
2//!
3//! `bimm-contracts` is built around the [`ShapeContract`] interface.
4//! - A [`ShapeContract`] is a sequence of [`DimMatcher`]s.
5//! - A [`DimMatcher`] matches one or more dimensions of a shape:
6//!   - [`DimMatcher::Any`] matches any dimension size.
7//!   - [`DimMatcher::Ellipsis`] matches a variable number of dimensions (ellipsis).
8//!   - [`DimMatcher::Expr`] matches a dimension expression that must match a specific value.
9//!
10//! A [`ShapeContract`] should usually be constructed using the [`crate::shape_contract`] macro.
11//!
12//! ## Example
13//!
14//! ```rust
15//! use bimm_contracts::{shape_contract, ShapeContract};
16//!
17//! static CONTRACT : ShapeContract = shape_contract![
18//!    ...,
19//!    "height" = "h_wins" * "window",
20//!    "width" = "w_wins" * "window",
21//!    "channels",
22//! ];
23//!
24//! let shape = [1, 2, 3, 2 * 8, 3 * 8, 4];
25//!
26//! // Assert the shape, given the bindings.
27//! let [h_wins, w_wins] = CONTRACT.unpack_shape(
28//!     &shape,
29//!     &["h_wins", "w_wins"],
30//!     &[("window", 8)]
31//! );
32//! assert_eq!(h_wins, 2);
33//! assert_eq!(w_wins, 3);
34//! ```
35
36use crate::StackEnvironment;
37use crate::expressions::{DimExpr, ExprDisplayAdapter, MatchResult};
38use crate::shape_argument::ShapeArgument;
39use alloc::string::{String, ToString};
40use alloc::vec::Vec;
41use alloc::{format, vec};
42use core::fmt::{Display, Formatter};
43use core::panic::Location;
44
45/// A term in a shape pattern.
46///
47/// Users should generally use [`crate::shape_contract`] to construct patterns.
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum DimMatcher<'a> {
50    /// Matches any dimension size.
51    Any {
52        /// An optional label for the matcher.
53        label_id: Option<usize>,
54    },
55
56    /// Matches a variable number of dimensions (ellipsis).
57    Ellipsis {
58        /// An optional label for the matcher.
59        label_id: Option<usize>,
60    },
61
62    /// A dimension size expression that must match a specific value.
63    Expr {
64        /// An optional label for the matcher.
65        label_id: Option<usize>,
66
67        /// The dimension expression that must match a specific value.
68        expr: DimExpr<'a>,
69    },
70}
71
72impl<'a> DimMatcher<'a> {
73    /// Create a new `DimMatcher` that matches any dimension size.
74    pub const fn any() -> Self {
75        DimMatcher::Any { label_id: None }
76    }
77
78    /// Create a new `DimMatcher` that matches a variable number of dimensions (ellipsis).
79    pub const fn ellipsis() -> Self {
80        DimMatcher::Ellipsis { label_id: None }
81    }
82
83    /// Create a new `DimMatcher` from a dimension expression.
84    ///
85    /// ## Arguments
86    ///
87    /// - `expr`: a dimension expression that must match a specific value.
88    ///
89    /// ## Returns
90    ///
91    /// A new `DimMatcher` that matches the given expression.
92    pub const fn expr(expr: DimExpr<'a>) -> Self {
93        DimMatcher::Expr {
94            label_id: None,
95            expr,
96        }
97    }
98
99    /// Get the label of the matcher, if any.
100    pub const fn label_id(&self) -> Option<usize> {
101        match self {
102            DimMatcher::Any { label_id } => *label_id,
103            DimMatcher::Ellipsis { label_id } => *label_id,
104            DimMatcher::Expr { label_id, .. } => *label_id,
105        }
106    }
107
108    /// Attach a label to the matcher.
109    ///
110    /// ## Arguments
111    ///
112    /// - `label_id`: an optional label to attach to the matcher.
113    ///
114    /// ## Returns
115    ///
116    /// A new `DimMatcher` with the label attached.
117    pub const fn with_label_id(self, label_id: Option<usize>) -> Self {
118        match self {
119            DimMatcher::Any { .. } => DimMatcher::Any { label_id },
120            DimMatcher::Ellipsis { .. } => DimMatcher::Ellipsis { label_id },
121            DimMatcher::Expr { expr, .. } => DimMatcher::Expr { label_id, expr },
122        }
123    }
124}
125
126/// Display Adapter to format `DimMatchers` with a `Index`.
127pub struct MatcherDisplayAdapter<'a> {
128    index: &'a [&'a str],
129    matcher: &'a DimMatcher<'a>,
130}
131
132impl<'a> Display for MatcherDisplayAdapter<'a> {
133    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
134        if let Some(label_id) = self.matcher.label_id() {
135            write!(f, "{}=", self.index[label_id])?;
136        }
137        match self.matcher {
138            DimMatcher::Any { .. } => write!(f, "_"),
139            DimMatcher::Ellipsis { .. } => write!(f, "..."),
140            DimMatcher::Expr { expr, .. } => write!(
141                f,
142                "{}",
143                ExprDisplayAdapter {
144                    index: self.index,
145                    expr
146                }
147            ),
148        }
149    }
150}
151
152/// A shape pattern, which is a sequence of terms that can match a shape.
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub struct ShapeContract<'a> {
155    /// The slot index of the contract.
156    pub index: &'a [&'a str],
157
158    /// The terms in the pattern.
159    pub terms: &'a [DimMatcher<'a>],
160
161    /// The position of the ellipsis in the pattern, if any.
162    pub ellipsis_pos: Option<usize>,
163}
164
165impl Display for ShapeContract<'_> {
166    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
167        write!(f, "[")?;
168        for (idx, term) in self.terms.iter().enumerate() {
169            if idx > 0 {
170                write!(f, ", ")?;
171            }
172            write!(
173                f,
174                "{}",
175                MatcherDisplayAdapter {
176                    index: self.index,
177                    matcher: term
178                }
179            )?;
180        }
181        write!(f, "]")
182    }
183}
184
185impl<'a> ShapeContract<'a> {
186    /// Create a new shape pattern from a slice of terms.
187    ///
188    /// ## Arguments
189    ///
190    /// - `terms`: a slice of `ShapePatternTerm` that defines the pattern.
191    ///
192    /// ## Returns
193    ///
194    /// A new `ShapePattern` instance.
195    ///
196    /// ## Macro Support
197    ///
198    /// Consider using the [`crate::shape_contract`] macro instead.
199    ///
200    /// ```
201    /// use bimm_contracts::{shape_contract, ShapeContract};
202    ///
203    /// static CONTRACT : ShapeContract = shape_contract![
204    ///    ...,
205    ///    "height" = "h_wins" * "window",
206    ///    "width" = "w_wins" * "window",
207    ///    "channels",
208    /// ];
209    /// ```
210    pub const fn new(index: &'a [&'a str], terms: &'a [DimMatcher<'a>]) -> Self {
211        let mut i = 0;
212        let mut ellipsis_pos: Option<usize> = None;
213
214        while i < terms.len() {
215            if matches!(terms[i], DimMatcher::Ellipsis { .. }) {
216                match ellipsis_pos {
217                    Some(_) => panic!("Multiple ellipses in pattern"),
218                    None => ellipsis_pos = Some(i),
219                }
220            }
221            i += 1;
222        }
223
224        ShapeContract {
225            index,
226            terms,
227            ellipsis_pos,
228        }
229    }
230
231    /// Convert a key to an index.
232    pub fn maybe_key_to_index(&self, key: &str) -> Option<usize> {
233        self.index.iter().position(|&s| s == key)
234    }
235
236    /// Assert that the shape matches the pattern.
237    ///
238    /// ## Arguments
239    ///
240    /// - `shape`: the shape to match.
241    /// - `env`: the params which are already bound.
242    ///
243    /// ## Panics
244    ///
245    /// If the shape does not match the pattern, or if there is a conflict in the bindings.
246    ///
247    /// ## Example
248    ///
249    /// ```rust
250    /// use bimm_contracts::{shape_contract, run_periodically, ShapeContract};
251    ///
252    /// let shape = [1, 2, 3, 2 * 8, 3 * 8, 4];
253    ///
254    /// // Run under backoff amortization.
255    /// run_periodically! {{
256    ///     // Statically allocated contract.
257    ///     static CONTRACT : ShapeContract = shape_contract![
258    ///        ...,
259    ///        "height" = "h_wins" * "window",
260    ///        "width" = "w_wins" * "window",
261    ///        "channels",
262    ///     ];
263    ///
264    ///     // Assert the shape, given the bindings.
265    ///     CONTRACT.assert_shape(
266    ///         &shape,
267    ///         &[("h_wins", 2), ("w_wins", 3), ("channels", 4)]
268    ///     );
269    /// }}
270    /// ```
271    #[track_caller]
272    pub fn assert_shape<S>(&'a self, shape: S, env: StackEnvironment<'a>)
273    where
274        S: ShapeArgument,
275    {
276        match self._loc_try_assert_shape(shape, env, Location::caller()) {
277            Ok(()) => (),
278            Err(msg) => panic!("{}", msg),
279        }
280    }
281
282    /// Assert that the shape matches the pattern.
283    ///
284    /// ## Arguments
285    ///
286    /// - `shape`: the shape to match.
287    /// - `env`: the params which are already bound.
288    ///
289    /// ## Returns
290    ///
291    /// - `Ok(())`: if the shape matches the pattern.
292    /// - `Err(String)`: if the shape does not match the pattern, with an error message.
293    ///
294    /// ## Example
295    ///
296    /// ```rust
297    /// use bimm_contracts::{shape_contract, run_periodically, ShapeContract};
298    ///
299    /// let shape = [1, 2, 3, 2 * 8, 3 * 8, 4];
300    ///
301    /// // Statically allocated contract.
302    /// static CONTRACT : ShapeContract = shape_contract![
303    ///    ...,
304    ///    "height" = "h_wins" * "window",
305    ///    "width" = "w_wins" * "window",
306    ///    "channels",
307    /// ];
308    ///
309    /// // Assert the shape, given the bindings; or throw.
310    /// CONTRACT.try_assert_shape(
311    ///     &shape,
312    ///     &[("h_wins", 2), ("w_wins", 3), ("channels", 4)]
313    /// ).unwrap();
314    /// ```
315    #[track_caller]
316    pub fn try_assert_shape<S>(&'a self, shape: S, env: StackEnvironment<'a>) -> Result<(), String>
317    where
318        S: ShapeArgument,
319    {
320        self._loc_try_assert_shape(shape, env, Location::caller())
321    }
322
323    fn _loc_try_assert_shape<S>(
324        &'a self,
325        shape: S,
326        env: StackEnvironment<'a>,
327        loc: &Location<'a>,
328    ) -> Result<(), String>
329    where
330        S: ShapeArgument,
331    {
332        let mut scratch: Vec<Option<isize>> = vec![None; self.index.len()];
333        for (k, v) in env.iter() {
334            let v = *v as isize;
335            match self.maybe_key_to_index(k) {
336                Some(param_id) => scratch[param_id] = Some(v),
337                None => {
338                    return Err(
339                        format!("The key \"{k}\" is not indexed in the contract:\n{self}")
340                            .to_string(),
341                    );
342                }
343            }
344        }
345
346        self.format_resolve(shape, scratch.as_mut_slice(), loc)
347    }
348
349    /// Match and unpack `K` keys from a shape pattern.
350    ///
351    /// Wraps `try_unpack_shape` and panics if the shape does not match.
352    ///
353    /// ## Generics
354    ///
355    /// - `K`: the length of the `keys` array.
356    ///
357    /// ## Arguments
358    ///
359    /// - `shape`: the shape to match.
360    /// - `keys`: the bound keys to export.
361    /// - `env`: the params which are already bound.
362    ///
363    /// ## Returns
364    ///
365    /// An `[usize; K]` of the unpacked `keys` values.
366    ///
367    /// ## Panics
368    ///
369    /// If the shape does not match the pattern, or if there is a conflict in the bindings.
370    ///
371    /// ## Example
372    ///
373    /// ```rust
374    /// use bimm_contracts::{shape_contract, run_periodically, ShapeContract};
375    ///
376    /// let shape = [1, 2, 3, 2 * 8, 3 * 8, 4];
377    ///
378    /// // Statically allocated contract.
379    /// static CONTRACT : ShapeContract = shape_contract![
380    ///    ...,
381    ///    "height" = "h_wins" * "window",
382    ///    "width" = "w_wins" * "window",
383    ///    "channels",
384    /// ];
385    ///
386    /// // Unpack the shape, given the bindings.
387    /// let [h, w, c] = CONTRACT.unpack_shape(
388    ///     &shape,
389    ///     &["h_wins", "w_wins", "channels"],
390    ///     &[("window", 8)]
391    /// );
392    /// assert_eq!(h, 2);
393    /// assert_eq!(w, 3);
394    /// assert_eq!(c, 4);
395    /// ```
396    #[must_use]
397    #[track_caller]
398    pub fn unpack_shape<S, const K: usize>(
399        &'a self,
400        shape: S,
401        keys: &[&'a str; K],
402        env: StackEnvironment<'a>,
403    ) -> [usize; K]
404    where
405        S: ShapeArgument,
406    {
407        self._loc_unpack_shape(shape, keys, env, Location::caller())
408    }
409
410    fn _loc_unpack_shape<S, const K: usize>(
411        &'a self,
412        shape: S,
413        keys: &[&'a str; K],
414        env: StackEnvironment<'a>,
415        loc: &Location<'a>,
416    ) -> [usize; K]
417    where
418        S: ShapeArgument,
419    {
420        match self._loc_try_unpack_shape(shape, keys, env, loc) {
421            Ok(values) => values,
422            Err(msg) => panic!("{msg}"),
423        }
424    }
425
426    /// Try and match and unpack `K` keys from a shape pattern.
427    ///
428    /// ## Generics
429    ///
430    /// - `K`: the length of the `keys` array.
431    ///
432    /// ## Arguments
433    ///
434    /// - `shape`: the shape to match.
435    /// - `keys`: the bound keys to export.
436    /// - `env`: the params which are already bound.
437    ///
438    /// ## Returns
439    ///
440    /// A `Result<[usize; K], String>` of the unpacked `keys` values.
441    ///
442    /// ## Example
443    ///
444    /// ```rust
445    /// use bimm_contracts::{shape_contract, run_periodically, ShapeContract};
446    ///
447    /// let shape = [1, 2, 3, 2 * 8, 3 * 8, 4];
448    ///
449    /// // Statically allocated contract.
450    /// static CONTRACT : ShapeContract = shape_contract![
451    ///    ...,
452    ///    "height" = "h_wins" * "window",
453    ///    "width" = "w_wins" * "window",
454    ///    "channels",
455    /// ];
456    ///
457    /// // Unpack the shape, given the bindings; or throw.
458    /// let [h, w, c] = CONTRACT.try_unpack_shape(
459    ///     &shape,
460    ///     &["h_wins", "w_wins", "channels"],
461    ///     &[("window", 8)]
462    /// ).unwrap();
463    /// assert_eq!(h, 2);
464    /// assert_eq!(w, 3);
465    /// assert_eq!(c, 4);
466    /// ```
467    #[track_caller]
468    pub fn try_unpack_shape<S, const K: usize>(
469        &'a self,
470        shape: S,
471        keys: &[&'a str; K],
472        env: StackEnvironment<'a>,
473    ) -> Result<[usize; K], String>
474    where
475        S: ShapeArgument,
476    {
477        self._loc_try_unpack_shape(shape, keys, env, Location::caller())
478    }
479
480    fn _loc_try_unpack_shape<S, const K: usize>(
481        &'a self,
482        shape: S,
483        keys: &[&'a str; K],
484        env: StackEnvironment<'a>,
485        loc: &Location<'a>,
486    ) -> Result<[usize; K], String>
487    where
488        S: ShapeArgument,
489    {
490        let selection = self.expect_keys_to_selection(keys);
491
492        let mut scratch: Vec<Option<isize>> = vec![None; self.index.len()];
493        for (k, v) in env.iter() {
494            let v = *v as isize;
495            match self.maybe_key_to_index(k) {
496                Some(param_id) => scratch[param_id] = Some(v),
497                None => {
498                    return Err(
499                        format!("The key \"{k}\" is not indexed in the contract:\n{self}")
500                            .to_string(),
501                    );
502                }
503            }
504        }
505
506        let selected: [isize; K] =
507            self._loc_try_select(shape, &selection, scratch.as_mut_slice(), loc)?;
508
509        let result: [usize; K] = selected
510            .into_iter()
511            .map(|v| v as usize)
512            .collect::<Vec<usize>>()
513            .try_into()
514            .unwrap();
515
516        Ok(result)
517    }
518
519    /// Convert a list of keys to a selection.
520    pub fn expect_keys_to_selection<const D: usize>(&'a self, keys: &[&'a str; D]) -> [usize; D] {
521        let mut selection = [0; D];
522        for (i, key) in keys.iter().enumerate() {
523            match self.maybe_key_to_index(key) {
524                Some(param_id) => selection[i] = param_id,
525                None => panic!("The key \"{key}\" is not indexed in the contract:\n{self}"),
526            }
527        }
528        selection
529    }
530
531    fn _loc_try_select<S, const K: usize>(
532        &'a self,
533        shape: S,
534        selection: &[usize; K],
535        env: &mut [Option<isize>],
536        loc: &Location<'a>,
537    ) -> Result<[isize; K], String>
538    where
539        S: ShapeArgument,
540    {
541        let num_slots = self.index.len();
542        assert_eq!(env.len(), num_slots);
543
544        self.format_resolve(shape, env, loc)?;
545
546        let mut out = [0; K];
547        for (i, &k) in selection.iter().enumerate() {
548            out[i] = env[k].unwrap();
549        }
550        Ok(out)
551    }
552
553    /// Resolve the match for the shape against the pattern.
554    ///
555    /// ## Arguments
556    ///
557    /// - `shape`: the shape to match.
558    /// - `env`: the mutable environment to bind parameters.
559    /// - `location`: the location reference from ``#[track_caller]``.
560    ///
561    /// ## Returns
562    ///
563    /// - `Ok(())`: if the shape matches the pattern; will update the `env`.
564    /// - `Err(&str)`: if the shape does not match the pattern, with an error message.
565    pub(crate) fn format_resolve<S>(
566        &'a self,
567        shape: S,
568        env: &mut [Option<isize>],
569        location: &Location,
570    ) -> Result<(), String>
571    where
572        S: ShapeArgument,
573    {
574        let shape = shape.get_shape_vec();
575        match self._resolve(&shape, env) {
576            Ok(()) => Ok(()),
577            Err(msg) => Err(format!(
578                "at {}:{}: Shape Error\n  {msg}\nActual:\n  {shape:?}\nContract:\n  {self}\nBindings:\n  {{{}}}",
579                location.file(),
580                location.line(),
581                self.index
582                    .iter()
583                    .zip(env.iter())
584                    .filter(|(_, v)| v.is_some())
585                    .map(|(k, v)| format!("\"{}\": {}", *k, v.unwrap()))
586                    .collect::<Vec<_>>()
587                    .join(", ")
588            )),
589        }
590    }
591
592    /// Low-level resolver.
593    pub fn _resolve(&'a self, shape: &[usize], env: &mut [Option<isize>]) -> Result<(), String> {
594        let rank = shape.len();
595
596        let fail_at = |shape_idx: usize, term_idx: usize, msg: &str| -> String {
597            format!(
598                "{} !~ {} :: {msg}",
599                shape[shape_idx],
600                MatcherDisplayAdapter {
601                    index: self.index,
602                    matcher: &self.terms[term_idx]
603                }
604            )
605        };
606
607        let (e_start, e_size) = match self.try_ellipsis_split(rank) {
608            Ok((e_start, e_size)) => (e_start, e_size),
609            Err(msg) => return Err(msg),
610        };
611
612        for (shape_idx, &dim_size) in shape.iter().enumerate() {
613            let dim_size = dim_size as isize;
614
615            let term_idx = if shape_idx < e_start {
616                shape_idx
617            } else if shape_idx < (e_start + e_size) {
618                continue;
619            } else {
620                shape_idx + 1 - e_size
621            };
622
623            let matcher = &self.terms[term_idx];
624            if let Some(label_id) = matcher.label_id() {
625                match env[label_id] {
626                    Some(value) => {
627                        if value != dim_size {
628                            return Err(fail_at(shape_idx, term_idx, "Value MissMatch."));
629                        }
630                    }
631                    None => {
632                        env[label_id] = Some(dim_size);
633                    }
634                }
635            }
636
637            let expr = match matcher {
638                DimMatcher::Any { .. } => continue,
639                DimMatcher::Expr { expr, .. } => expr,
640                DimMatcher::Ellipsis { .. } => {
641                    unreachable!("Ellipsis should have been handled before")
642                }
643            };
644
645            match expr.try_match(dim_size, env) {
646                Ok(MatchResult::Match) => continue,
647                Ok(MatchResult::Conflict) => {
648                    return Err(fail_at(shape_idx, term_idx, "Value MissMatch."));
649                }
650                Ok(MatchResult::ParamConstraint { id, value }) => {
651                    env[id] = Some(value);
652                }
653                Err(msg) => return Err(fail_at(shape_idx, term_idx, msg)),
654            }
655        }
656
657        Ok(())
658    }
659
660    /// Check if the pattern has an ellipsis.
661    ///
662    /// ## Arguments
663    ///
664    /// - `rank`: the number of dims of the shape to match.
665    ///
666    /// ## Returns
667    ///
668    /// - `Ok((usize, usize))`: the position of the ellipsis and the number of dimensions it matches.
669    /// - `Err(String)`: an error message if the pattern does not match the expected size.
670    fn try_ellipsis_split(&self, rank: usize) -> Result<(usize, usize), String> {
671        let k = self.terms.len();
672        match self.ellipsis_pos {
673            None => {
674                if rank != k {
675                    Err(format!("Shape rank {rank} != pattern dim count {k}",))
676                } else {
677                    Ok((k, 0))
678                }
679            }
680            Some(pos) => {
681                let non_ellipsis_terms = k - 1;
682                if rank < non_ellipsis_terms {
683                    return Err(format!(
684                        "Shape rank {rank} < non-ellipsis pattern term count {non_ellipsis_terms}",
685                    ));
686                }
687                Ok((pos, rank - non_ellipsis_terms))
688            }
689        }
690    }
691}
692
693#[cfg(test)]
694mod tests {
695    use super::*;
696    use crate::expressions::DimExpr;
697
698    #[test]
699    fn test_unpack_shape() {
700        static CONTRACT: ShapeContract = ShapeContract::new(
701            &["b", "h", "w", "p", "z", "c"],
702            &[
703                DimMatcher::any(),
704                DimMatcher::expr(DimExpr::Param { id: 0 }),
705                DimMatcher::ellipsis(),
706                DimMatcher::expr(DimExpr::Prod {
707                    children: &[DimExpr::Param { id: 1 }, DimExpr::Param { id: 3 }],
708                }),
709                DimMatcher::expr(DimExpr::Prod {
710                    children: &[DimExpr::Param { id: 2 }, DimExpr::Param { id: 3 }],
711                }),
712                DimMatcher::expr(DimExpr::Pow {
713                    base: &DimExpr::Param { id: 4 },
714                    exp: 3,
715                }),
716                DimMatcher::expr(DimExpr::Param { id: 5 }),
717            ],
718        );
719
720        let b = 2;
721        let h = 3;
722        let w = 2;
723        let p = 4;
724        let c = 5;
725        let z = 4;
726
727        let shape = [12, b, 1, 2, 3, h * p, w * p, z * z * z, c];
728        let env = [("p", p), ("c", c)];
729
730        CONTRACT.assert_shape(&shape, &env);
731
732        let [u_b, u_h, u_w, u_z] = CONTRACT.unpack_shape(&shape, &["b", "h", "w", "z"], &env);
733
734        assert_eq!(u_b, b);
735        assert_eq!(u_h, h);
736        assert_eq!(u_w, w);
737        assert_eq!(u_z, z);
738    }
739}