burn_contracts/shapes/
exp.rs

1use crate::shapes::bindings::{ShapeBindingSource, collect_binding_map, lookup_binding};
2use crate::shapes::parser::{cached_parse_shape_pattern, parse_shape_pattern};
3use std::collections::HashMap;
4use std::fmt::Display;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub struct ShapePattern {
8    ellipsis_pos: Option<usize>,
9    components: Vec<PatternComponent>,
10}
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub enum PatternComponent {
14    Dim(String),
15    Ellipsis,
16    Composite(Vec<String>),
17}
18
19#[derive(thiserror::Error, Debug, PartialEq, Eq, Hash)]
20pub enum ShapePatternError {
21    #[error("Parse error for \"{pattern}\"")]
22    ParseError { pattern: String },
23
24    #[error("Invalid pattern \"{pattern}\": {message}")]
25    InvalidPattern { pattern: String, message: String },
26
27    #[error("Shape \"{shape:?}\" !~= \"{pattern}\" with {bindings:?}: {message}")]
28    MatchError {
29        shape: Vec<usize>,
30        pattern: String,
31        bindings: Vec<(String, usize)>,
32        message: String,
33    },
34}
35
36#[derive(Debug, Clone)]
37pub struct ShapeMatch {
38    pub shape: Vec<usize>,
39    pub bindings: HashMap<String, usize>,
40    pub ellipsis_range: Option<std::ops::Range<usize>>,
41}
42
43impl ShapeMatch {
44    /// Select a subset of the bindings.
45    ///
46    /// ## Parameters
47    ///
48    /// - `keys`: The keys to select.
49    ///
50    /// ## Returns
51    ///
52    /// Returns the selected bindings.
53    ///
54    /// ## Panics
55    ///
56    /// Panics if a key is not found in the bindings.
57    #[must_use]
58    pub fn select<const D: usize>(
59        &self,
60        keys: [&str; D],
61    ) -> [usize; D] {
62        let mut result = [0; D];
63        for (i, key) in keys.iter().enumerate() {
64            result[i] = lookup_binding(&self.bindings, key).unwrap();
65        }
66        result
67    }
68}
69
70impl Display for ShapePattern {
71    fn fmt(
72        &self,
73        f: &mut std::fmt::Formatter<'_>,
74    ) -> std::fmt::Result {
75        for (idx, comp) in self.components.iter().enumerate() {
76            if idx > 0 {
77                write!(f, " ")?;
78            }
79            write!(f, "{comp}")?;
80        }
81        Ok(())
82    }
83}
84
85impl Display for PatternComponent {
86    fn fmt(
87        &self,
88        f: &mut std::fmt::Formatter<'_>,
89    ) -> std::fmt::Result {
90        match self {
91            PatternComponent::Dim(id) => write!(f, "{id}"),
92            PatternComponent::Ellipsis => write!(f, "..."),
93            PatternComponent::Composite(ids) => {
94                write!(f, "(")?;
95                for (idx, id) in ids.iter().enumerate() {
96                    if idx > 0 {
97                        write!(f, " ")?;
98                    }
99                    write!(f, "{id}")?;
100                }
101                write!(f, ")")
102            }
103        }
104    }
105}
106
107fn check_ellipsis_pos(components: &[PatternComponent]) -> Result<Option<usize>, ShapePatternError> {
108    let mut ellipsis_pos = None;
109    for (i, component) in components.iter().enumerate() {
110        if let PatternComponent::Ellipsis = component {
111            if ellipsis_pos.is_some() {
112                return Err(ShapePatternError::InvalidPattern {
113                    pattern: components
114                        .iter()
115                        .map(std::string::ToString::to_string)
116                        .collect(),
117                    message: "Only one ellipsis is allowed".to_string(),
118                });
119            }
120            ellipsis_pos = Some(i);
121        }
122    }
123    Ok(ellipsis_pos)
124}
125
126impl ShapePattern {
127    /// Create a new `ShapePattern` from a list of `DimPatterns`
128    ///
129    /// ## Parameters
130    ///
131    /// - `components`: A list of `DimPatterns`
132    ///
133    /// ## Errors
134    ///
135    /// Returns an error if there are too many ellipses
136    pub fn new(components: Vec<PatternComponent>) -> Result<Self, ShapePatternError> {
137        Ok(Self {
138            ellipsis_pos: check_ellipsis_pos(components.as_slice())?,
139            components,
140        })
141    }
142
143    /// Parse a `ShapePattern` from a string
144    ///
145    /// ## Parameters
146    ///
147    /// - `input`: A string representation of the `ShapePattern`
148    ///
149    /// ## Errors
150    ///
151    /// Returns an error if the input string cannot be parsed;
152    /// or the pattern is invalid.
153    pub fn parse(input: &str) -> Result<Self, ShapePatternError> {
154        parse_shape_pattern(input)
155    }
156
157    /// Parse a `ShapePattern` from a string, using a cache
158    ///
159    /// ## Parameters
160    ///
161    /// - `input`: A string representation of the `ShapePattern`
162    ///
163    /// ## Errors
164    ///
165    /// Returns an error if the input string cannot be parsed;
166    /// or the pattern is invalid.
167    pub fn cached_parse(input: &str) -> Result<Self, ShapePatternError> {
168        cached_parse_shape_pattern(input)
169    }
170
171    /// Get the components of the `ShapePattern`.
172    #[must_use]
173    pub fn components(&self) -> &[PatternComponent] {
174        &self.components
175    }
176
177    /// Get the position of the ellipsis in the `ShapePattern`; if any.
178    #[must_use]
179    pub fn ellipsis_pos(&self) -> Option<usize> {
180        self.components
181            .iter()
182            .position(|c| matches!(c, PatternComponent::Ellipsis))
183    }
184
185    /// Check if the `ShapePattern` has an ellipsis.
186    #[must_use]
187    pub fn has_ellipsis(&self) -> bool {
188        self.ellipsis_pos().is_some()
189    }
190
191    /// Assert that the `ShapeEx` matches a given shape.
192    ///
193    /// ## Parameters
194    ///
195    /// - `shape`: The shape to match against.
196    /// - `bindings`: The bindings to use for matching.
197    ///
198    /// ## Errors
199    ///
200    /// Returns an error if the shape does not match the pattern.
201    ///
202    /// ## Returns
203    ///
204    /// Returns a `ShapeMatch` if the shape matches the pattern.
205    #[allow(clippy::missing_panics_doc)]
206    pub fn match_bindings<B: ShapeBindingSource>(
207        &self,
208        shape: &[usize],
209        bindings: B,
210    ) -> Result<ShapeMatch, ShapePatternError> {
211        // FIXME: Reconsider result contents.
212        // - We can skip returning the source shape.
213        // - returned bindings should be an assoc vec OR fixed array?
214        //   - alloc size vs speed considerations
215        // - return ellipsis dims, locations; both?
216        // - multi-pass to resolve composite bindings?
217
218        let bindings: HashMap<String, usize> = collect_binding_map(bindings);
219
220        let dims = shape.len();
221        let ellipsis_pos = self.ellipsis_pos();
222        let non_e_comps = match ellipsis_pos {
223            Some(_) => self.components.len() - 1,
224            None => self.components.len(),
225        };
226        if non_e_comps > dims {
227            return Err(ShapePatternError::MatchError {
228                shape: shape.to_vec(),
229                pattern: self.to_string(),
230                bindings: bindings.iter().map(|(k, v)| (k.clone(), *v)).collect(),
231                message: "Too few dimensions".to_string(),
232            });
233        }
234        let ellipsis_range = ellipsis_pos.map(|pos| pos..pos + dims - non_e_comps);
235
236        let mut export = HashMap::new();
237
238        fn readthrough_lookup(
239            bindings: &HashMap<String, usize>,
240            target: &mut HashMap<String, usize>,
241            id: &str,
242        ) -> Option<usize> {
243            match target.get(id) {
244                Some(value) => Some(*value),
245                None => match bindings.get(id) {
246                    Some(value) => {
247                        target.insert(id.to_string(), *value);
248                        Some(*value)
249                    }
250                    None => None,
251                },
252            }
253        }
254
255        let mut i = 0;
256        for component in &self.components {
257            let dim_shape = shape[i];
258            match component {
259                PatternComponent::Ellipsis => {
260                    i = ellipsis_range.clone().unwrap().end;
261                }
262                PatternComponent::Dim(id) => {
263                    match readthrough_lookup(&bindings, &mut export, id) {
264                        Some(bound_value) => {
265                            if bound_value != dim_shape {
266                                let message = format!(
267                                    "Constraint Mismatch @{id}: {bound_value} != {dim_shape}"
268                                );
269
270                                return Err(ShapePatternError::MatchError {
271                                    shape: shape.to_vec(),
272                                    pattern: self.to_string(),
273                                    bindings: bindings
274                                        .iter()
275                                        .map(|(k, v)| (k.clone(), *v))
276                                        .collect(),
277                                    message,
278                                });
279                            }
280                        }
281                        None => {
282                            export.insert(id.clone(), dim_shape);
283                        }
284                    }
285                    i += 1;
286                }
287                PatternComponent::Composite(ids) => {
288                    let mut acc = 1;
289                    let mut unbound: Option<String> = None;
290                    for factor in ids {
291                        if let Some(value) = readthrough_lookup(&bindings, &mut export, factor) {
292                            acc *= value;
293                        } else {
294                            if unbound.is_some() {
295                                return Err(ShapePatternError::MatchError {
296                                    shape: shape.to_vec(),
297                                    pattern: self.to_string(),
298                                    bindings: bindings
299                                        .iter()
300                                        .map(|(k, v)| (k.clone(), *v))
301                                        .collect(),
302                                    message: "Multiple unbound factors in composite".to_string(),
303                                });
304                            }
305                            unbound = Some(factor.clone());
306                        }
307                    }
308                    if let Some(factor) = unbound {
309                        if dim_shape % acc != 0 {
310                            return Err(ShapePatternError::MatchError {
311                                shape: shape.to_vec(),
312                                pattern: self.to_string(),
313                                bindings: bindings.iter().map(|(k, v)| (k.clone(), *v)).collect(),
314                                message: format!(
315                                    "Composite factor \"{factor}\" * {acc} != shape {dim_shape}",
316                                ),
317                            });
318                        }
319                        export.insert(factor, dim_shape / acc);
320                    }
321                    i += 1;
322                }
323            }
324        }
325
326        Ok(ShapeMatch {
327            shape: shape.to_vec(),
328            bindings: export,
329            ellipsis_range,
330        })
331    }
332}
333
334#[cfg(test)]
335mod test {
336    use super::*;
337    use std::error::Error;
338
339    #[test]
340    fn test_display_pattern() {
341        let pattern = ShapePattern::new(vec![
342            PatternComponent::Dim("b".to_string()),
343            PatternComponent::Ellipsis,
344            PatternComponent::Composite(vec!["h".to_string(), "w".to_string()]),
345            PatternComponent::Dim("c".to_string()),
346        ])
347        .unwrap();
348
349        assert_eq!(pattern.to_string(), "b ... (h w) c");
350    }
351
352    #[test]
353    #[allow(clippy::many_single_char_names)]
354    fn test_parser_example() -> Result<(), Box<dyn Error>> {
355        let shape = [2, 9, 9, 20 * 4, 10 * 4, 3];
356
357        let [b, h, w, c] = ShapePattern::cached_parse("b ... (h p) (w p) c")?
358            .match_bindings(&shape, &[("b", 2), ("p", 4)])?
359            .select(["b", "h", "w", "c"]);
360
361        assert_eq!(b, 2);
362        assert_eq!(h, 20);
363        assert_eq!(w, 10);
364        assert_eq!(c, 3);
365
366        Ok(())
367    }
368
369    #[test]
370    #[allow(clippy::many_single_char_names)]
371    fn test_assert() -> Result<(), Box<dyn Error>> {
372        let b = 2;
373        let h = 3;
374        let w = 4;
375        let p = 2;
376        let c = 3;
377
378        let extra = 7;
379
380        let shape = [b, 9, 9, h * p, w * p, c];
381
382        let mut bindings = HashMap::new();
383        bindings.insert("b".to_string(), b);
384        bindings.insert("p".to_string(), p);
385        bindings.insert("extra".to_string(), extra);
386
387        let m = ShapePattern::cached_parse("b ... (h p) (w p) c")?
388            .match_bindings(shape.as_ref(), &bindings)?;
389
390        assert_eq!(m.shape, shape);
391        assert_eq!(m.ellipsis_range, Some(1..3));
392        assert_eq!(m.bindings["b"], b);
393        assert_eq!(m.bindings["h"], h);
394        assert_eq!(m.bindings["w"], w);
395        assert_eq!(m.bindings["p"], p);
396        assert_eq!(m.bindings["c"], c);
397
398        let [sel_b, sel_h, sel_w] = m.select(["b", "h", "w"]);
399        assert_eq!(sel_b, b);
400        assert_eq!(sel_h, h);
401        assert_eq!(sel_w, w);
402
403        Ok(())
404    }
405}