1use crate::shapes::bindings::{ShapeBindingSource, collect_binding_map, lookup_binding};
2use crate::shapes::parser::{cached_parse_shape_pattern, parse_shape_pattern};
3use std::collections::HashMap;
4use std::fmt::Display;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub struct ShapePattern {
8 ellipsis_pos: Option<usize>,
9 components: Vec<PatternComponent>,
10}
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub enum PatternComponent {
14 Dim(String),
15 Ellipsis,
16 Composite(Vec<String>),
17}
18
19#[derive(thiserror::Error, Debug, PartialEq, Eq, Hash)]
20pub enum ShapePatternError {
21 #[error("Parse error for \"{pattern}\"")]
22 ParseError { pattern: String },
23
24 #[error("Invalid pattern \"{pattern}\": {message}")]
25 InvalidPattern { pattern: String, message: String },
26
27 #[error("Shape \"{shape:?}\" !~= \"{pattern}\" with {bindings:?}: {message}")]
28 MatchError {
29 shape: Vec<usize>,
30 pattern: String,
31 bindings: Vec<(String, usize)>,
32 message: String,
33 },
34}
35
36#[derive(Debug, Clone)]
37pub struct ShapeMatch {
38 pub shape: Vec<usize>,
39 pub bindings: HashMap<String, usize>,
40 pub ellipsis_range: Option<std::ops::Range<usize>>,
41}
42
43impl ShapeMatch {
44 #[must_use]
58 pub fn select<const D: usize>(
59 &self,
60 keys: [&str; D],
61 ) -> [usize; D] {
62 let mut result = [0; D];
63 for (i, key) in keys.iter().enumerate() {
64 result[i] = lookup_binding(&self.bindings, key).unwrap();
65 }
66 result
67 }
68}
69
70impl Display for ShapePattern {
71 fn fmt(
72 &self,
73 f: &mut std::fmt::Formatter<'_>,
74 ) -> std::fmt::Result {
75 for (idx, comp) in self.components.iter().enumerate() {
76 if idx > 0 {
77 write!(f, " ")?;
78 }
79 write!(f, "{comp}")?;
80 }
81 Ok(())
82 }
83}
84
85impl Display for PatternComponent {
86 fn fmt(
87 &self,
88 f: &mut std::fmt::Formatter<'_>,
89 ) -> std::fmt::Result {
90 match self {
91 PatternComponent::Dim(id) => write!(f, "{id}"),
92 PatternComponent::Ellipsis => write!(f, "..."),
93 PatternComponent::Composite(ids) => {
94 write!(f, "(")?;
95 for (idx, id) in ids.iter().enumerate() {
96 if idx > 0 {
97 write!(f, " ")?;
98 }
99 write!(f, "{id}")?;
100 }
101 write!(f, ")")
102 }
103 }
104 }
105}
106
107fn check_ellipsis_pos(components: &[PatternComponent]) -> Result<Option<usize>, ShapePatternError> {
108 let mut ellipsis_pos = None;
109 for (i, component) in components.iter().enumerate() {
110 if let PatternComponent::Ellipsis = component {
111 if ellipsis_pos.is_some() {
112 return Err(ShapePatternError::InvalidPattern {
113 pattern: components
114 .iter()
115 .map(std::string::ToString::to_string)
116 .collect(),
117 message: "Only one ellipsis is allowed".to_string(),
118 });
119 }
120 ellipsis_pos = Some(i);
121 }
122 }
123 Ok(ellipsis_pos)
124}
125
126impl ShapePattern {
127 pub fn new(components: Vec<PatternComponent>) -> Result<Self, ShapePatternError> {
137 Ok(Self {
138 ellipsis_pos: check_ellipsis_pos(components.as_slice())?,
139 components,
140 })
141 }
142
143 pub fn parse(input: &str) -> Result<Self, ShapePatternError> {
154 parse_shape_pattern(input)
155 }
156
157 pub fn cached_parse(input: &str) -> Result<Self, ShapePatternError> {
168 cached_parse_shape_pattern(input)
169 }
170
171 #[must_use]
173 pub fn components(&self) -> &[PatternComponent] {
174 &self.components
175 }
176
177 #[must_use]
179 pub fn ellipsis_pos(&self) -> Option<usize> {
180 self.components
181 .iter()
182 .position(|c| matches!(c, PatternComponent::Ellipsis))
183 }
184
185 #[must_use]
187 pub fn has_ellipsis(&self) -> bool {
188 self.ellipsis_pos().is_some()
189 }
190
191 #[allow(clippy::missing_panics_doc)]
206 pub fn match_bindings<B: ShapeBindingSource>(
207 &self,
208 shape: &[usize],
209 bindings: B,
210 ) -> Result<ShapeMatch, ShapePatternError> {
211 let bindings: HashMap<String, usize> = collect_binding_map(bindings);
219
220 let dims = shape.len();
221 let ellipsis_pos = self.ellipsis_pos();
222 let non_e_comps = match ellipsis_pos {
223 Some(_) => self.components.len() - 1,
224 None => self.components.len(),
225 };
226 if non_e_comps > dims {
227 return Err(ShapePatternError::MatchError {
228 shape: shape.to_vec(),
229 pattern: self.to_string(),
230 bindings: bindings.iter().map(|(k, v)| (k.clone(), *v)).collect(),
231 message: "Too few dimensions".to_string(),
232 });
233 }
234 let ellipsis_range = ellipsis_pos.map(|pos| pos..pos + dims - non_e_comps);
235
236 let mut export = HashMap::new();
237
238 fn readthrough_lookup(
239 bindings: &HashMap<String, usize>,
240 target: &mut HashMap<String, usize>,
241 id: &str,
242 ) -> Option<usize> {
243 match target.get(id) {
244 Some(value) => Some(*value),
245 None => match bindings.get(id) {
246 Some(value) => {
247 target.insert(id.to_string(), *value);
248 Some(*value)
249 }
250 None => None,
251 },
252 }
253 }
254
255 let mut i = 0;
256 for component in &self.components {
257 let dim_shape = shape[i];
258 match component {
259 PatternComponent::Ellipsis => {
260 i = ellipsis_range.clone().unwrap().end;
261 }
262 PatternComponent::Dim(id) => {
263 match readthrough_lookup(&bindings, &mut export, id) {
264 Some(bound_value) => {
265 if bound_value != dim_shape {
266 let message = format!(
267 "Constraint Mismatch @{id}: {bound_value} != {dim_shape}"
268 );
269
270 return Err(ShapePatternError::MatchError {
271 shape: shape.to_vec(),
272 pattern: self.to_string(),
273 bindings: bindings
274 .iter()
275 .map(|(k, v)| (k.clone(), *v))
276 .collect(),
277 message,
278 });
279 }
280 }
281 None => {
282 export.insert(id.clone(), dim_shape);
283 }
284 }
285 i += 1;
286 }
287 PatternComponent::Composite(ids) => {
288 let mut acc = 1;
289 let mut unbound: Option<String> = None;
290 for factor in ids {
291 if let Some(value) = readthrough_lookup(&bindings, &mut export, factor) {
292 acc *= value;
293 } else {
294 if unbound.is_some() {
295 return Err(ShapePatternError::MatchError {
296 shape: shape.to_vec(),
297 pattern: self.to_string(),
298 bindings: bindings
299 .iter()
300 .map(|(k, v)| (k.clone(), *v))
301 .collect(),
302 message: "Multiple unbound factors in composite".to_string(),
303 });
304 }
305 unbound = Some(factor.clone());
306 }
307 }
308 if let Some(factor) = unbound {
309 if dim_shape % acc != 0 {
310 return Err(ShapePatternError::MatchError {
311 shape: shape.to_vec(),
312 pattern: self.to_string(),
313 bindings: bindings.iter().map(|(k, v)| (k.clone(), *v)).collect(),
314 message: format!(
315 "Composite factor \"{factor}\" * {acc} != shape {dim_shape}",
316 ),
317 });
318 }
319 export.insert(factor, dim_shape / acc);
320 }
321 i += 1;
322 }
323 }
324 }
325
326 Ok(ShapeMatch {
327 shape: shape.to_vec(),
328 bindings: export,
329 ellipsis_range,
330 })
331 }
332}
333
334#[cfg(test)]
335mod test {
336 use super::*;
337 use std::error::Error;
338
339 #[test]
340 fn test_display_pattern() {
341 let pattern = ShapePattern::new(vec![
342 PatternComponent::Dim("b".to_string()),
343 PatternComponent::Ellipsis,
344 PatternComponent::Composite(vec!["h".to_string(), "w".to_string()]),
345 PatternComponent::Dim("c".to_string()),
346 ])
347 .unwrap();
348
349 assert_eq!(pattern.to_string(), "b ... (h w) c");
350 }
351
352 #[test]
353 #[allow(clippy::many_single_char_names)]
354 fn test_parser_example() -> Result<(), Box<dyn Error>> {
355 let shape = [2, 9, 9, 20 * 4, 10 * 4, 3];
356
357 let [b, h, w, c] = ShapePattern::cached_parse("b ... (h p) (w p) c")?
358 .match_bindings(&shape, &[("b", 2), ("p", 4)])?
359 .select(["b", "h", "w", "c"]);
360
361 assert_eq!(b, 2);
362 assert_eq!(h, 20);
363 assert_eq!(w, 10);
364 assert_eq!(c, 3);
365
366 Ok(())
367 }
368
369 #[test]
370 #[allow(clippy::many_single_char_names)]
371 fn test_assert() -> Result<(), Box<dyn Error>> {
372 let b = 2;
373 let h = 3;
374 let w = 4;
375 let p = 2;
376 let c = 3;
377
378 let extra = 7;
379
380 let shape = [b, 9, 9, h * p, w * p, c];
381
382 let mut bindings = HashMap::new();
383 bindings.insert("b".to_string(), b);
384 bindings.insert("p".to_string(), p);
385 bindings.insert("extra".to_string(), extra);
386
387 let m = ShapePattern::cached_parse("b ... (h p) (w p) c")?
388 .match_bindings(shape.as_ref(), &bindings)?;
389
390 assert_eq!(m.shape, shape);
391 assert_eq!(m.ellipsis_range, Some(1..3));
392 assert_eq!(m.bindings["b"], b);
393 assert_eq!(m.bindings["h"], h);
394 assert_eq!(m.bindings["w"], w);
395 assert_eq!(m.bindings["p"], p);
396 assert_eq!(m.bindings["c"], c);
397
398 let [sel_b, sel_h, sel_w] = m.select(["b", "h", "w"]);
399 assert_eq!(sel_b, b);
400 assert_eq!(sel_h, h);
401 assert_eq!(sel_w, w);
402
403 Ok(())
404 }
405}