Skip to main content

sim_shape/algebra/
collection.rs

1//! Collection shapes: `TableShape` with its per-field specs and extra-field
2//! policy, and `RepeatShape` for matching repeated occurrences of a shape.
3
4use std::sync::Arc;
5
6use sim_kernel::{
7    Cx, Diagnostic, Expr, Result, Symbol, Table, Value, force_list_to_vec, shape_is_subshape_of,
8};
9
10use crate::{
11    algebra::{capture_symbol, number_expr, number_value, symbol_list_expr, symbol_list_value},
12    base::{Bindings, MatchScore, Shape, ShapeDoc, ShapeMatch},
13};
14
15/// Shape for table values or map expressions with named field constraints.
16///
17/// Required fields must be present and accepted by their field shapes. Extra
18/// fields are controlled by [`TableExtraPolicy`].
19///
20/// ```rust
21/// # use std::sync::Arc;
22/// # use sim_kernel::{Cx, DefaultFactory, Expr, NoopEvalPolicy, Symbol};
23/// # use sim_shape::{ExprKind, ExprKindShape, Shape, TableShape};
24/// # let mut cx = Cx::new(Arc::new(NoopEvalPolicy), Arc::new(DefaultFactory));
25/// let shape = TableShape::single(
26///     Symbol::new("ok"),
27///     Arc::new(ExprKindShape::new(ExprKind::Bool)),
28/// );
29/// let expr = Expr::Map(vec![(Expr::Symbol(Symbol::new("ok")), Expr::Bool(true))]);
30///
31/// assert!(shape.check_expr(&mut cx, &expr).unwrap().accepted);
32/// ```
33#[derive(Clone)]
34pub struct TableShape {
35    fields: Vec<TableFieldSpec>,
36    extra: TableExtraPolicy,
37}
38
39/// One field constraint inside a [`TableShape`].
40#[derive(Clone)]
41pub struct TableFieldSpec {
42    /// Symbol key to look up in the table or map expression.
43    pub key: Symbol,
44    /// Shape that must accept the field value or expression.
45    pub shape: Arc<dyn Shape>,
46    /// Whether the field must be present.
47    pub required: bool,
48}
49
50/// Policy for keys not listed in a [`TableShape`].
51#[derive(Clone)]
52pub enum TableExtraPolicy {
53    /// Accept extra keys without checking their values.
54    Allow,
55    /// Reject any extra key.
56    Reject,
57    /// Check each extra value with the supplied shape.
58    Shape(Arc<dyn Shape>),
59}
60
61impl TableShape {
62    /// Build a table shape requiring a single named field, allowing extras.
63    pub fn single(key: Symbol, shape: Arc<dyn Shape>) -> Self {
64        Self::new(
65            vec![TableFieldSpec {
66                key,
67                shape,
68                required: true,
69            }],
70            TableExtraPolicy::Allow,
71        )
72    }
73
74    /// Build a table shape from explicit field specs and an extra-key policy.
75    pub fn new(fields: Vec<TableFieldSpec>, extra: TableExtraPolicy) -> Self {
76        Self { fields, extra }
77    }
78
79    /// Return the field constraints in declaration order.
80    pub fn fields(&self) -> &[TableFieldSpec] {
81        &self.fields
82    }
83
84    /// Return the policy applied to keys not listed in the field specs.
85    pub fn extra(&self) -> &TableExtraPolicy {
86        &self.extra
87    }
88}
89
90impl Shape for TableShape {
91    fn is_total(&self) -> bool {
92        self.fields.is_empty() && matches!(self.extra, TableExtraPolicy::Allow)
93    }
94
95    fn is_effectful(&self) -> bool {
96        self.fields.iter().any(|field| field.shape.is_effectful())
97            || matches!(&self.extra, TableExtraPolicy::Shape(shape) if shape.is_effectful())
98    }
99
100    fn is_subshape_of(&self, cx: &mut Cx, parent: &dyn Shape) -> Result<Option<bool>> {
101        let Some(parent) = parent.as_any().downcast_ref::<Self>() else {
102            return Ok(None);
103        };
104
105        for parent_field in parent.fields() {
106            if !parent_field.required {
107                continue;
108            }
109            let Some(field) = self
110                .fields
111                .iter()
112                .find(|candidate| candidate.key == parent_field.key)
113            else {
114                return Ok(None);
115            };
116            if !field.required {
117                return Ok(None);
118            }
119            if !shape_is_subshape_of(cx, field.shape.as_ref(), parent_field.shape.as_ref())? {
120                return Ok(None);
121            }
122        }
123
124        if extra_policy_at_least_as_strict(cx, &self.extra, &parent.extra)? {
125            Ok(Some(true))
126        } else {
127            Ok(None)
128        }
129    }
130
131    fn check_value(&self, cx: &mut Cx, value: Value) -> Result<ShapeMatch> {
132        if let Some(table) = value.object().as_table_impl() {
133            return self.check_table_value(cx, table);
134        }
135
136        let table_value = value.object().as_table(cx)?;
137        let Some(table) = table_value.object().as_table_impl() else {
138            return Ok(ShapeMatch::reject("shape-table: expected table"));
139        };
140        self.check_table_value(cx, table)
141    }
142
143    fn check_expr(&self, cx: &mut Cx, expr: &Expr) -> Result<ShapeMatch> {
144        let Expr::Map(entries) = expr else {
145            return Ok(ShapeMatch::reject("shape-table: expected map expression"));
146        };
147
148        let mut parsed = Vec::with_capacity(entries.len());
149        for (key, value) in entries {
150            let Expr::Symbol(key) = key else {
151                return Ok(ShapeMatch::reject("shape-table: map key must be symbol"));
152            };
153            parsed.push((key.clone(), value.clone()));
154        }
155        self.check_map_expr(cx, &parsed)
156    }
157
158    fn describe(&self, cx: &mut Cx) -> Result<ShapeDoc> {
159        let mut doc = ShapeDoc::new("table shape");
160        for field in &self.fields {
161            doc = doc.with_detail(format!("{}: {}", field.key, field.shape.describe(cx)?.name));
162        }
163        Ok(doc)
164    }
165}
166
167/// Shape for homogeneous list-like values and collection expressions.
168///
169/// `RepeatShape` checks every item with the body shape and can enforce minimum
170/// and maximum item counts.
171///
172/// ```rust
173/// # use std::sync::Arc;
174/// # use sim_kernel::{Cx, DefaultFactory, Expr, NoopEvalPolicy};
175/// # use sim_shape::{ExprKind, ExprKindShape, RepeatShape, Shape};
176/// # let mut cx = Cx::new(Arc::new(NoopEvalPolicy), Arc::new(DefaultFactory));
177/// let shape = RepeatShape::with_bounds(
178///     Arc::new(ExprKindShape::new(ExprKind::Bool)),
179///     1,
180///     Some(2),
181/// );
182///
183/// assert!(shape
184///     .check_expr(&mut cx, &Expr::Vector(vec![Expr::Bool(true)]))
185///     .unwrap()
186///     .accepted);
187/// ```
188pub struct RepeatShape {
189    body: Arc<dyn Shape>,
190    min: usize,
191    max: Option<usize>,
192}
193
194impl RepeatShape {
195    /// Build an unbounded repeat over the given body shape.
196    pub fn new(body: Arc<dyn Shape>) -> Self {
197        Self::with_bounds(body, 0, None)
198    }
199
200    /// Build a repeat with a minimum and optional maximum item count.
201    pub fn with_bounds(body: Arc<dyn Shape>, min: usize, max: Option<usize>) -> Self {
202        Self { body, min, max }
203    }
204
205    /// Return the shape applied to each item.
206    pub fn body(&self) -> &Arc<dyn Shape> {
207        &self.body
208    }
209
210    /// Return the minimum required item count.
211    pub fn min(&self) -> usize {
212        self.min
213    }
214
215    /// Return the maximum allowed item count, if bounded.
216    pub fn max(&self) -> Option<usize> {
217        self.max
218    }
219}
220
221impl Shape for RepeatShape {
222    fn is_effectful(&self) -> bool {
223        self.body.is_effectful()
224    }
225
226    fn is_subshape_of(&self, cx: &mut Cx, parent: &dyn Shape) -> Result<Option<bool>> {
227        let Some(parent) = parent.as_any().downcast_ref::<Self>() else {
228            return Ok(None);
229        };
230        if self.min < parent.min {
231            return Ok(None);
232        }
233        if !max_at_most(self.max, parent.max) {
234            return Ok(None);
235        }
236        shape_is_subshape_of(cx, self.body.as_ref(), parent.body.as_ref()).map(Some)
237    }
238
239    fn check_value(&self, cx: &mut Cx, value: Value) -> Result<ShapeMatch> {
240        let Some(list) = value.object().as_list() else {
241            let expr = value.object().as_expr(cx)?;
242            return self.check_expr(cx, &expr);
243        };
244        let items = force_list_to_vec(cx, list, "shape-repeat")?;
245        self.check_values(cx, &items)
246    }
247
248    fn check_expr(&self, cx: &mut Cx, expr: &Expr) -> Result<ShapeMatch> {
249        let items = match expr {
250            Expr::List(items) | Expr::Vector(items) | Expr::Set(items) => items,
251            _ => return Ok(ShapeMatch::reject("shape-repeat: expected list expression")),
252        };
253        let mut out = ShapeMatch::accept(MatchScore::exact(20));
254        for item in items {
255            let mut matched = self.body.check_expr(cx, item)?;
256            if !matched.accepted {
257                matched
258                    .diagnostics
259                    .insert(0, Diagnostic::error("shape-repeat: item rejected"));
260                return Ok(matched);
261            }
262            out.captures.extend(matched.captures);
263            out.score += matched.score;
264        }
265        self.finish_expr(out, items.len())
266    }
267
268    fn describe(&self, cx: &mut Cx) -> Result<ShapeDoc> {
269        let max = self
270            .max
271            .map(|max| max.to_string())
272            .unwrap_or_else(|| "unbounded".to_owned());
273        Ok(ShapeDoc::new("repeat shape")
274            .with_detail(self.body.describe(cx)?.name)
275            .with_detail(format!("min {}", self.min))
276            .with_detail(format!("max {max}")))
277    }
278}
279
280impl TableShape {
281    fn check_table_value(&self, cx: &mut Cx, table: &dyn Table) -> Result<ShapeMatch> {
282        let entries = table.entries(cx)?;
283        self.check_value_entries(cx, &entries)
284    }
285
286    fn check_value_entries(&self, cx: &mut Cx, entries: &[(Symbol, Value)]) -> Result<ShapeMatch> {
287        let mut out = ShapeMatch::accept(MatchScore::exact(20));
288        let mut matched_keys = Vec::new();
289        let mut missing_keys = Vec::new();
290
291        for field in &self.fields {
292            let Some((_, value)) = entries.iter().find(|(key, _)| *key == field.key) else {
293                if field.required {
294                    missing_keys.push(field.key.clone());
295                }
296                continue;
297            };
298            let mut matched = field.shape.check_value(cx, value.clone())?;
299            if !matched.accepted {
300                matched
301                    .diagnostics
302                    .insert(0, Diagnostic::error("shape-table: field rejected"));
303                return Ok(matched);
304            }
305            out.captures.extend(matched.captures);
306            out.score += matched.score;
307            matched_keys.push(field.key.clone());
308        }
309
310        if !missing_keys.is_empty() {
311            let mut captures = Bindings::new();
312            captures.bind_value(
313                capture_symbol("missing-keys"),
314                symbol_list_value(cx, &missing_keys)?,
315            );
316            return Ok(ShapeMatch {
317                accepted: false,
318                captures,
319                score: MatchScore::reject(),
320                diagnostics: vec![Diagnostic::error("shape-table: missing keys")],
321            });
322        }
323
324        let field_keys = self
325            .fields
326            .iter()
327            .map(|field| field.key.clone())
328            .collect::<Vec<_>>();
329        for (key, value) in entries {
330            if field_keys.contains(key) {
331                continue;
332            }
333            match &self.extra {
334                TableExtraPolicy::Allow => {}
335                TableExtraPolicy::Reject => {
336                    return Ok(ShapeMatch::reject(format!("shape-table: extra key {key}")));
337                }
338                TableExtraPolicy::Shape(shape) => {
339                    let mut matched = shape.check_value(cx, value.clone())?;
340                    if !matched.accepted {
341                        matched
342                            .diagnostics
343                            .insert(0, Diagnostic::error("shape-table: extra value rejected"));
344                        return Ok(matched);
345                    }
346                    out.captures.extend(matched.captures);
347                    out.score += matched.score;
348                }
349            }
350        }
351
352        out.captures.bind_value(
353            capture_symbol("matched-keys"),
354            symbol_list_value(cx, &matched_keys)?,
355        );
356        Ok(out)
357    }
358
359    fn check_map_expr(&self, cx: &mut Cx, entries: &[(Symbol, Expr)]) -> Result<ShapeMatch> {
360        let mut out = ShapeMatch::accept(MatchScore::exact(20));
361        let mut matched_keys = Vec::new();
362        let mut missing_keys = Vec::new();
363
364        for field in &self.fields {
365            let Some((_, value)) = entries.iter().find(|(key, _)| *key == field.key) else {
366                if field.required {
367                    missing_keys.push(field.key.clone());
368                }
369                continue;
370            };
371            let mut matched = field.shape.check_expr(cx, value)?;
372            if !matched.accepted {
373                matched
374                    .diagnostics
375                    .insert(0, Diagnostic::error("shape-table: field rejected"));
376                return Ok(matched);
377            }
378            out.captures.extend(matched.captures);
379            out.score += matched.score;
380            matched_keys.push(field.key.clone());
381        }
382
383        if !missing_keys.is_empty() {
384            let mut captures = Bindings::new();
385            captures.bind_expr(
386                capture_symbol("missing-keys"),
387                symbol_list_expr(&missing_keys),
388            );
389            return Ok(ShapeMatch {
390                accepted: false,
391                captures,
392                score: MatchScore::reject(),
393                diagnostics: vec![Diagnostic::error("shape-table: missing keys")],
394            });
395        }
396
397        let field_keys = self
398            .fields
399            .iter()
400            .map(|field| field.key.clone())
401            .collect::<Vec<_>>();
402        for (key, value) in entries {
403            if field_keys.contains(key) {
404                continue;
405            }
406            match &self.extra {
407                TableExtraPolicy::Allow => {}
408                TableExtraPolicy::Reject => {
409                    return Ok(ShapeMatch::reject(format!("shape-table: extra key {key}")));
410                }
411                TableExtraPolicy::Shape(shape) => {
412                    let mut matched = shape.check_expr(cx, value)?;
413                    if !matched.accepted {
414                        matched
415                            .diagnostics
416                            .insert(0, Diagnostic::error("shape-table: extra value rejected"));
417                        return Ok(matched);
418                    }
419                    out.captures.extend(matched.captures);
420                    out.score += matched.score;
421                }
422            }
423        }
424
425        out.captures.bind_expr(
426            capture_symbol("matched-keys"),
427            symbol_list_expr(&matched_keys),
428        );
429        Ok(out)
430    }
431}
432
433impl RepeatShape {
434    fn check_values(&self, cx: &mut Cx, items: &[Value]) -> Result<ShapeMatch> {
435        let mut out = ShapeMatch::accept(MatchScore::exact(20));
436        for item in items {
437            let mut matched = self.body.check_value(cx, item.clone())?;
438            if !matched.accepted {
439                matched
440                    .diagnostics
441                    .insert(0, Diagnostic::error("shape-repeat: item rejected"));
442                return Ok(matched);
443            }
444            out.captures.extend(matched.captures);
445            out.score += matched.score;
446        }
447        self.finish_value(cx, out, items.len())
448    }
449
450    fn finish_value(&self, cx: &mut Cx, mut out: ShapeMatch, count: usize) -> Result<ShapeMatch> {
451        if count < self.min {
452            return Ok(ShapeMatch::reject("shape-repeat: too few items"));
453        }
454        if matches!(self.max, Some(max) if count > max) {
455            return Ok(ShapeMatch::reject("shape-repeat: too many items"));
456        }
457        out.captures
458            .bind_value(capture_symbol("repeat-count"), number_value(cx, count)?);
459        Ok(out)
460    }
461
462    fn finish_expr(&self, mut out: ShapeMatch, count: usize) -> Result<ShapeMatch> {
463        if count < self.min {
464            return Ok(ShapeMatch::reject("shape-repeat: too few items"));
465        }
466        if matches!(self.max, Some(max) if count > max) {
467            return Ok(ShapeMatch::reject("shape-repeat: too many items"));
468        }
469        out.captures
470            .bind_expr(capture_symbol("repeat-count"), number_expr(count));
471        Ok(out)
472    }
473}
474
475fn extra_policy_at_least_as_strict(
476    cx: &mut Cx,
477    child: &TableExtraPolicy,
478    parent: &TableExtraPolicy,
479) -> Result<bool> {
480    Ok(match (child, parent) {
481        (_, TableExtraPolicy::Allow) => true,
482        (TableExtraPolicy::Reject, TableExtraPolicy::Reject | TableExtraPolicy::Shape(_)) => true,
483        (TableExtraPolicy::Shape(child), TableExtraPolicy::Shape(parent)) => {
484            shape_is_subshape_of(cx, child.as_ref(), parent.as_ref())?
485        }
486        (TableExtraPolicy::Allow, TableExtraPolicy::Reject | TableExtraPolicy::Shape(_)) => false,
487        (TableExtraPolicy::Shape(_), TableExtraPolicy::Reject) => false,
488    })
489}
490
491fn max_at_most(child: Option<usize>, parent: Option<usize>) -> bool {
492    match (child, parent) {
493        (_, None) => true,
494        (Some(child), Some(parent)) => child <= parent,
495        (None, Some(_)) => false,
496    }
497}