nice_dice/
analysis.rs

1//! Well-formedness checking (semantic analysis) of dice expressions.
2//!
3//! We need to make sure:
4//! - All symbols are defined when used.
5//! - Symbols do not conflict.
6//! - Keep-highest / keep-lowest are always valid
7//! - Division never divides by zero
8//!
9//!
10//! TODO: Conditionals are nonlinear, so we can't generate min/max.
11//! We have to just handle the whole path.
12
13use std::{collections::HashSet, fmt::Display};
14
15use crate::{
16    Error,
17    parse::RawExpression,
18    symbolic::{Constant, ExpressionTree, ExpressionWrapper, Symbol},
19};
20
21impl TryFrom<RawExpression> for Closed {
22    type Error = Error;
23
24    fn try_from(value: RawExpression) -> Result<Self, Self::Error> {
25        let tree = value.inner();
26        closed_under(&AvailableBinding::Root, tree).map_err(Error::UnboundSymbols)
27    }
28}
29
30impl std::str::FromStr for Closed {
31    type Err = Error;
32
33    fn from_str(s: &str) -> Result<Self, Self::Err> {
34        let raw: RawExpression = s.parse()?;
35        raw.try_into()
36    }
37}
38
39/// An expression which is closed: no unbound symbols from its environment.
40//
41// Note that this really only applies at the top level: the sub-tree can't safely be extracted.
42#[derive(Debug, PartialEq, Eq, Clone, Hash)]
43pub struct Closed(ExpressionTree<Closed>);
44
45impl ExpressionWrapper for Closed {
46    fn inner(&self) -> &ExpressionTree<Self> {
47        &self.0
48    }
49}
50
51impl Display for Closed {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        self.0.fmt(f)
54    }
55}
56
57impl Closed {
58    /// Return a copy of the expression tree where the given expression is replaced with the
59    /// provided constant.
60    ///
61    /// Note that any Closed expression can be given a substitution with a fixed value and remained
62    /// Closed, as the substitution results in strictly fewer unbound variables.
63    pub(crate) fn substitute(&self, sym: &Symbol, value: isize) -> Closed {
64        let expr = if value < 0 {
65            ExpressionTree::Negated(Box::new(Closed(ExpressionTree::Modifier(Constant(
66                value.unsigned_abs(),
67            )))))
68        } else {
69            ExpressionTree::Modifier(Constant(value.unsigned_abs()))
70        };
71        self.substitute_inner(sym, &Closed(expr))
72    }
73
74    fn substitute_inner(&self, sym: &Symbol, expr: &Closed) -> Closed {
75        match self.inner() {
76            ExpressionTree::Symbol(symbol) if symbol == sym => expr.clone(),
77            ExpressionTree::Modifier(_) | ExpressionTree::Die(_) | ExpressionTree::Symbol(_) => {
78                self.clone()
79            }
80            ExpressionTree::Negated(e) => Closed(ExpressionTree::Negated(Box::new(
81                e.substitute_inner(sym, expr),
82            ))),
83            ExpressionTree::Repeated {
84                count,
85                value,
86                ranker,
87            } => {
88                let count = Box::new(count.substitute_inner(sym, expr));
89                let value = Box::new(value.substitute_inner(sym, expr));
90                Closed(ExpressionTree::Repeated {
91                    count,
92                    value,
93                    ranker: *ranker,
94                })
95            }
96            ExpressionTree::Product(a, b) => {
97                let a = Box::new(a.substitute_inner(sym, expr));
98                let b = Box::new(b.substitute_inner(sym, expr));
99                Closed(ExpressionTree::Product(a, b))
100            }
101            ExpressionTree::Floor(a, b) => {
102                let a = Box::new(a.substitute_inner(sym, expr));
103                let b = Box::new(b.substitute_inner(sym, expr));
104                Closed(ExpressionTree::Floor(a, b))
105            }
106            ExpressionTree::Comparison { a, b, op } => {
107                let a = Box::new(a.substitute_inner(sym, expr));
108                let b = Box::new(b.substitute_inner(sym, expr));
109                Closed(ExpressionTree::Comparison { a, b, op: *op })
110            }
111            ExpressionTree::Sum(items) => Closed(ExpressionTree::Sum(
112                items
113                    .iter()
114                    .map(|v| v.substitute_inner(sym, expr))
115                    .collect(),
116            )),
117            ExpressionTree::Binding {
118                symbol,
119                value,
120                tail,
121            } => {
122                let value = Box::new(value.substitute_inner(sym, expr));
123                let tail = Box::new(tail.substitute_inner(sym, expr));
124                Closed(ExpressionTree::Binding {
125                    symbol: symbol.clone(),
126                    value,
127                    tail,
128                })
129            }
130        }
131    }
132}
133
134type ClosureResult = Result<Closed, HashSet<Symbol>>;
135
136fn combine_close_results(
137    a: ClosureResult,
138    b: ClosureResult,
139) -> Result<(Closed, Closed), HashSet<Symbol>> {
140    match (a, b) {
141        (Ok(a), Ok(b)) => Ok((a, b)),
142        (Err(a), Err(b)) => Err(a.into_iter().chain(b).collect()),
143        (Err(a), _) => Err(a),
144        (_, Err(b)) => Err(b),
145    }
146}
147
148/// Evaluate the expression under the provided bindings.
149///
150/// If the expression is closed under those bindings, return Ok();
151/// otherwise, return the unbound symbol(s).
152fn closed_under(
153    bindings: &AvailableBinding<Closed>,
154    tree: &ExpressionTree<RawExpression>,
155) -> ClosureResult {
156    match tree {
157        ExpressionTree::Modifier(a) => Ok(Closed(ExpressionTree::Modifier(*a))),
158        ExpressionTree::Die(a) => Ok(Closed(ExpressionTree::Die(*a))),
159        ExpressionTree::Symbol(symbol) => {
160            if bindings.search(symbol).is_some() {
161                Ok(Closed(ExpressionTree::Symbol(symbol.to_owned())))
162            } else {
163                Err([symbol.clone()].into_iter().collect::<HashSet<_>>())
164            }
165        }
166        ExpressionTree::Negated(n) => Ok(Closed(ExpressionTree::Negated(Box::new(closed_under(
167            bindings,
168            n.inner(),
169        )?)))),
170        ExpressionTree::Repeated {
171            count,
172            value,
173            ranker,
174        } => {
175            let (count, value) = combine_close_results(
176                closed_under(bindings, count.inner()),
177                closed_under(bindings, value.inner()),
178            )?;
179            let count = Box::new(count);
180            let value = Box::new(value);
181            Ok(Closed(ExpressionTree::Repeated {
182                count,
183                value,
184                ranker: *ranker,
185            }))
186        }
187        ExpressionTree::Product(a, b) => {
188            let (a, b) = combine_close_results(
189                closed_under(bindings, a.inner()),
190                closed_under(bindings, b.inner()),
191            )?;
192            Ok(Closed(ExpressionTree::Product(Box::new(a), Box::new(b))))
193        }
194        ExpressionTree::Sum(items) => {
195            let mut unbound: HashSet<Symbol> = Default::default();
196            let items: Vec<Closed> = items
197                .iter()
198                .filter_map(|item| match closed_under(bindings, item.inner()) {
199                    Ok(v) => Some(v),
200                    Err(e) => {
201                        for e in e {
202                            unbound.insert(e);
203                        }
204                        None
205                    }
206                })
207                .collect();
208            if unbound.is_empty() {
209                Ok(Closed(ExpressionTree::Sum(items)))
210            } else {
211                Err(unbound)
212            }
213        }
214        ExpressionTree::Floor(a, b) => {
215            let (a, b) = combine_close_results(
216                closed_under(bindings, a.inner()),
217                closed_under(bindings, b.inner()),
218            )?;
219            Ok(Closed(ExpressionTree::Floor(Box::new(a), Box::new(b))))
220        }
221        ExpressionTree::Comparison { a, b, op } => {
222            let (a, b) = combine_close_results(
223                closed_under(bindings, a.inner()),
224                closed_under(bindings, b.inner()),
225            )?;
226            Ok(Closed(ExpressionTree::Comparison {
227                a: Box::new(a),
228                b: Box::new(b),
229                op: *op,
230            }))
231        }
232
233        ExpressionTree::Binding {
234            symbol,
235            value,
236            tail,
237        } => {
238            let value = closed_under(bindings, value.inner())?;
239            let tail = closed_under(
240                &AvailableBinding::Chain {
241                    defined: symbol,
242                    definition: &value,
243                    prev: bindings,
244                },
245                tail.inner(),
246            )?;
247
248            Ok(Closed(ExpressionTree::Binding {
249                symbol: symbol.clone(),
250                value: Box::new(value),
251                tail: Box::new(tail),
252            }))
253        }
254    }
255}
256
257/// Linked list (on the stack) of current symbol bindings.
258#[derive(Copy, Clone)]
259enum AvailableBinding<'a, T: ExpressionWrapper> {
260    Root,
261    Chain {
262        defined: &'a Symbol,
263        definition: &'a T,
264        prev: &'a AvailableBinding<'a, T>,
265    },
266}
267
268impl<T: ExpressionWrapper> AvailableBinding<'_, T> {
269    /// Search the stack of bindings for the provided symbol.
270    fn search(&self, needle: &Symbol) -> Option<&T> {
271        let mut current: &AvailableBinding<T> = self;
272        while let AvailableBinding::Chain {
273            defined,
274            prev,
275            definition,
276        } = current
277        {
278            if *defined == needle {
279                return Some(definition);
280            } else {
281                current = *prev;
282            }
283        }
284        None
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use proptest::prelude::*;
291    use proptest::strategy::Union;
292
293    use super::*;
294    use crate::parse::RawExpression;
295    use crate::properties;
296    use crate::symbolic::{Constant, Die};
297
298    #[test]
299    fn open_symbols() {
300        const CASES: &[(&str, &[&str])] = &[
301            ("ATK", &["ATK"]),
302            ("2(ATK+CHA)", &["ATK", "CHA"]),
303            ("[AC: 10] [ATK: 1d20] (ATK + CHA) > AC", &["CHA"]),
304        ];
305        for (expr, symbols) in CASES {
306            let raw: RawExpression = expr.parse().unwrap();
307            let symbols: HashSet<Symbol> = symbols.iter().map(|v| v.parse().unwrap()).collect();
308            let unclosed: Result<Closed, _> = raw.try_into();
309            let Err(Error::UnboundSymbols(unbound)) = unclosed else {
310                panic!("got closed expression")
311            };
312            assert_eq!(symbols, unbound, "case: {expr}");
313        }
314    }
315
316    #[test]
317    fn closed_symbols() {
318        const CASES: &[&str] = &["[AC: 10] 2([ATK: 1d20] (ATK + 3) > AC)"];
319        for expr in CASES {
320            let raw: RawExpression = expr.parse().unwrap();
321            let closed: Closed = raw.clone().try_into().unwrap();
322            assert_eq!(closed.to_string(), raw.to_string());
323        }
324    }
325
326    /// Search for an expression that matches the predicate
327    fn search_for<'a, T, F>(
328        tree: &'a ExpressionTree<T>,
329        predicate: &mut F,
330    ) -> Option<&'a ExpressionTree<T>>
331    where
332        F: FnMut(&ExpressionTree<T>) -> bool,
333        T: ExpressionWrapper,
334    {
335        if predicate(tree) {
336            return Some(tree);
337        }
338        match tree {
339            ExpressionTree::Negated(e) => search_for(e.inner(), predicate),
340            ExpressionTree::Repeated {
341                count,
342                value,
343                ranker: _,
344            } => search_for(count.inner(), predicate).or(search_for(value.inner(), predicate)),
345            ExpressionTree::Product(a, b) => {
346                search_for(a.inner(), predicate).or(search_for(b.inner(), predicate))
347            }
348            ExpressionTree::Floor(a, b) => {
349                search_for(a.inner(), predicate).or(search_for(b.inner(), predicate))
350            }
351            ExpressionTree::Comparison { a, b, op: _ } => {
352                search_for(a.inner(), predicate).or(search_for(b.inner(), predicate))
353            }
354            ExpressionTree::Sum(items) => {
355                for item in items {
356                    if let Some(v) = search_for(item.inner(), predicate) {
357                        return Some(v);
358                    }
359                }
360                None
361            }
362            ExpressionTree::Binding {
363                symbol,
364                value,
365                tail,
366            } => search_for(value.inner(), predicate).or(search_for(tail.inner(), predicate)),
367
368            _ => None,
369        }
370    }
371
372    // TODO: These don't work correctly;
373    // And they don't shrink well, which hurts too.
374    //
375    // I think the "recurse in both directions" is a problem-
376    // that one of our "recursive strategies" is generating something of arbitrary depth.
377    // I think that's a problem for the srhinkage.
378    //
379    // Not sure what the right way to put these together is.
380    // Maybe:
381    //
382
383    /// Inner generator: produces a strategy for expressions closed under the given set of symbols.
384    /// Preserves the set of available symbols.
385    /// Does not introduce any bindings.
386    fn expression_closed_under(
387        symbols: HashSet<Symbol>,
388    ) -> impl Strategy<Value = (RawExpression, HashSet<Symbol>)> {
389        let symbols_final = symbols.clone();
390
391        let static_leaf = Union::new([
392            any::<Die>().prop_map(ExpressionTree::Die).boxed(),
393            any::<Constant>().prop_map(ExpressionTree::Modifier).boxed(),
394        ]);
395
396        // If any symbols are available, only use those symbols.
397        // This guarantees that symbols show up when in use.
398        let leaf = if symbols.is_empty() {
399            static_leaf.boxed()
400        } else {
401            (0..symbols.len())
402                .prop_map(move |v| {
403                    let s = symbols.iter().nth(v).unwrap();
404                    ExpressionTree::Symbol(s.clone())
405                })
406                .boxed()
407        };
408
409        let leaf = leaf.prop_map(RawExpression::from);
410        let closure = leaf.prop_recursive(2, 2, 2, |strat| {
411            prop_oneof![
412                properties::negated(&strat),
413                properties::repeated(&strat),
414                properties::product(&strat),
415                properties::floor(&strat),
416                properties::sum(&strat),
417                properties::comparison(&strat),
418            ]
419            .prop_map(RawExpression::from)
420        });
421        closure.prop_map(move |v| (v, symbols_final.clone()))
422    }
423
424    proptest! {
425        #[test]
426        fn identify_open_symbols(
427            (_symbols, (exp, _)) in
428            proptest::collection::hash_set(properties::symbol(), 1..4)
429            .prop_flat_map(|symbols| (Just(symbols.clone()), expression_closed_under(symbols)))
430        ) {
431            let result : Result<Closed, _> = exp.clone().try_into();
432
433            if let Err(Error::UnboundSymbols(got)) = result {
434                for symbol in got {
435                    // The symbol is used:
436                    assert!(search_for(exp.inner(), &mut |s| matches!(s, ExpressionTree::Symbol(sym) if sym == &symbol)).is_some());
437                }
438            }
439        }
440        // The generator doesn't introduce any bindings, so we don't need to test the binding
441        // hierarchy here.
442    }
443
444    /// Generate an Expression with valid bindings.
445    fn closed_expression() -> impl Strategy<Value = RawExpression> {
446        let leaf = expression_closed_under(HashSet::new());
447        let syms = leaf.prop_recursive(2, 2, 2, |strat| {
448            (properties::symbol(), strat.clone()).prop_flat_map(
449                |(symbol, (definition, mut symbols))| {
450                    // In the symbol-recursive case, we:
451                    // - select a symbol
452                    // - generate a binding from our _existing_ strategy... which takes into account
453                    //   symbols already defined
454                    // - create a _new_ strategy including our symbol the others, to generate the tail
455                    //   with
456                    symbols.insert(symbol.clone());
457                    expression_closed_under(symbols).prop_map(move |(tail, new_symbols)| {
458                        (
459                            RawExpression::from(ExpressionTree::Binding {
460                                symbol: symbol.clone(),
461                                value: Box::new(definition.clone()),
462                                tail: Box::new(tail),
463                            }),
464                            new_symbols,
465                        )
466                    })
467                },
468            )
469        });
470        syms.prop_map(|(tree, _syms)| tree)
471    }
472
473    /// Matches a tree where the symbol is unbound.
474    fn unbound_tree<'a, W>(
475        symbol: &Symbol,
476        tree: &'a ExpressionTree<W>,
477    ) -> Option<&'a ExpressionTree<W>>
478    where
479        W: ExpressionWrapper,
480    {
481        match tree {
482            ExpressionTree::Binding {
483                symbol: sym,
484                value,
485                tail,
486            } => {
487                // Symbol is unbound in the "value" statement
488                let value = unbound_tree(symbol, value.inner());
489                if sym == symbol {
490                    // Symbol is bound in the tail, we don't need to inspect it.
491                    value
492                } else {
493                    // Symbol is also unbound in the tail, look there.
494                    value.or_else(|| unbound_tree(symbol, tail.inner()))
495                }
496            }
497            ExpressionTree::Modifier(_) => None,
498            ExpressionTree::Die(_) => None,
499            ExpressionTree::Symbol(sym) if sym == symbol => Some(tree),
500            ExpressionTree::Symbol(_) => None,
501            ExpressionTree::Negated(e) => unbound_tree(symbol, e.inner()),
502            ExpressionTree::Repeated {
503                count,
504                value,
505                ranker: _,
506            } => {
507                unbound_tree(symbol, count.inner()).or_else(|| unbound_tree(symbol, value.inner()))
508            }
509            ExpressionTree::Product(a, b) => {
510                unbound_tree(symbol, a.inner()).or_else(|| unbound_tree(symbol, b.inner()))
511            }
512            ExpressionTree::Floor(a, b) => {
513                unbound_tree(symbol, a.inner()).or_else(|| unbound_tree(symbol, b.inner()))
514            }
515            ExpressionTree::Comparison { a, b, op: _ } => {
516                unbound_tree(symbol, a.inner()).or_else(|| unbound_tree(symbol, b.inner()))
517            }
518            ExpressionTree::Sum(items) => items
519                .iter()
520                .filter_map(|v| unbound_tree(symbol, v.inner()))
521                .next(),
522        }
523    }
524
525    proptest! {
526        // TODO: This tests that all _detected_ unbound variables are in fact unbound.
527        // It doesn't check that all unbound variables are detected.
528        #[test]
529        fn generate_valid_bindings(exp in closed_expression()) {
530            let exp = exp.simplify();
531            let result : Result<Closed, _> = exp.clone().try_into();
532            if let Err(Error::UnboundSymbols(got)) = result {
533                for symbol in got {
534                    // The symbol is used:
535                    assert!(search_for(exp.inner(), &mut |s| matches!(s, ExpressionTree::Symbol(sym) if sym == &symbol)).is_some());
536
537                    // And there is some sub tree where:
538                    // - There is no binding for the symbol, and
539                    // - The symbol is used
540                    assert!(unbound_tree(&symbol, exp.inner()).is_some());
541                }
542            }
543
544
545        }
546    }
547}