1use std::{collections::HashSet, fmt::Display};
14
15use crate::{
16 Error,
17 parse::RawExpression,
18 symbolic::{Constant, ExpressionTree, ExpressionWrapper, Symbol},
19};
20
21impl TryFrom<RawExpression> for Closed {
22 type Error = Error;
23
24 fn try_from(value: RawExpression) -> Result<Self, Self::Error> {
25 let tree = value.inner();
26 closed_under(&AvailableBinding::Root, tree).map_err(Error::UnboundSymbols)
27 }
28}
29
30impl std::str::FromStr for Closed {
31 type Err = Error;
32
33 fn from_str(s: &str) -> Result<Self, Self::Err> {
34 let raw: RawExpression = s.parse()?;
35 raw.try_into()
36 }
37}
38
39#[derive(Debug, PartialEq, Eq, Clone, Hash)]
43pub struct Closed(ExpressionTree<Closed>);
44
45impl ExpressionWrapper for Closed {
46 fn inner(&self) -> &ExpressionTree<Self> {
47 &self.0
48 }
49}
50
51impl Display for Closed {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 self.0.fmt(f)
54 }
55}
56
57impl Closed {
58 pub(crate) fn substitute(&self, sym: &Symbol, value: isize) -> Closed {
64 let expr = if value < 0 {
65 ExpressionTree::Negated(Box::new(Closed(ExpressionTree::Modifier(Constant(
66 value.unsigned_abs(),
67 )))))
68 } else {
69 ExpressionTree::Modifier(Constant(value.unsigned_abs()))
70 };
71 self.substitute_inner(sym, &Closed(expr))
72 }
73
74 fn substitute_inner(&self, sym: &Symbol, expr: &Closed) -> Closed {
75 match self.inner() {
76 ExpressionTree::Symbol(symbol) if symbol == sym => expr.clone(),
77 ExpressionTree::Modifier(_) | ExpressionTree::Die(_) | ExpressionTree::Symbol(_) => {
78 self.clone()
79 }
80 ExpressionTree::Negated(e) => Closed(ExpressionTree::Negated(Box::new(
81 e.substitute_inner(sym, expr),
82 ))),
83 ExpressionTree::Repeated {
84 count,
85 value,
86 ranker,
87 } => {
88 let count = Box::new(count.substitute_inner(sym, expr));
89 let value = Box::new(value.substitute_inner(sym, expr));
90 Closed(ExpressionTree::Repeated {
91 count,
92 value,
93 ranker: *ranker,
94 })
95 }
96 ExpressionTree::Product(a, b) => {
97 let a = Box::new(a.substitute_inner(sym, expr));
98 let b = Box::new(b.substitute_inner(sym, expr));
99 Closed(ExpressionTree::Product(a, b))
100 }
101 ExpressionTree::Floor(a, b) => {
102 let a = Box::new(a.substitute_inner(sym, expr));
103 let b = Box::new(b.substitute_inner(sym, expr));
104 Closed(ExpressionTree::Floor(a, b))
105 }
106 ExpressionTree::Comparison { a, b, op } => {
107 let a = Box::new(a.substitute_inner(sym, expr));
108 let b = Box::new(b.substitute_inner(sym, expr));
109 Closed(ExpressionTree::Comparison { a, b, op: *op })
110 }
111 ExpressionTree::Sum(items) => Closed(ExpressionTree::Sum(
112 items
113 .iter()
114 .map(|v| v.substitute_inner(sym, expr))
115 .collect(),
116 )),
117 ExpressionTree::Binding {
118 symbol,
119 value,
120 tail,
121 } => {
122 let value = Box::new(value.substitute_inner(sym, expr));
123 let tail = Box::new(tail.substitute_inner(sym, expr));
124 Closed(ExpressionTree::Binding {
125 symbol: symbol.clone(),
126 value,
127 tail,
128 })
129 }
130 }
131 }
132}
133
134type ClosureResult = Result<Closed, HashSet<Symbol>>;
135
136fn combine_close_results(
137 a: ClosureResult,
138 b: ClosureResult,
139) -> Result<(Closed, Closed), HashSet<Symbol>> {
140 match (a, b) {
141 (Ok(a), Ok(b)) => Ok((a, b)),
142 (Err(a), Err(b)) => Err(a.into_iter().chain(b).collect()),
143 (Err(a), _) => Err(a),
144 (_, Err(b)) => Err(b),
145 }
146}
147
148fn closed_under(
153 bindings: &AvailableBinding<Closed>,
154 tree: &ExpressionTree<RawExpression>,
155) -> ClosureResult {
156 match tree {
157 ExpressionTree::Modifier(a) => Ok(Closed(ExpressionTree::Modifier(*a))),
158 ExpressionTree::Die(a) => Ok(Closed(ExpressionTree::Die(*a))),
159 ExpressionTree::Symbol(symbol) => {
160 if bindings.search(symbol).is_some() {
161 Ok(Closed(ExpressionTree::Symbol(symbol.to_owned())))
162 } else {
163 Err([symbol.clone()].into_iter().collect::<HashSet<_>>())
164 }
165 }
166 ExpressionTree::Negated(n) => Ok(Closed(ExpressionTree::Negated(Box::new(closed_under(
167 bindings,
168 n.inner(),
169 )?)))),
170 ExpressionTree::Repeated {
171 count,
172 value,
173 ranker,
174 } => {
175 let (count, value) = combine_close_results(
176 closed_under(bindings, count.inner()),
177 closed_under(bindings, value.inner()),
178 )?;
179 let count = Box::new(count);
180 let value = Box::new(value);
181 Ok(Closed(ExpressionTree::Repeated {
182 count,
183 value,
184 ranker: *ranker,
185 }))
186 }
187 ExpressionTree::Product(a, b) => {
188 let (a, b) = combine_close_results(
189 closed_under(bindings, a.inner()),
190 closed_under(bindings, b.inner()),
191 )?;
192 Ok(Closed(ExpressionTree::Product(Box::new(a), Box::new(b))))
193 }
194 ExpressionTree::Sum(items) => {
195 let mut unbound: HashSet<Symbol> = Default::default();
196 let items: Vec<Closed> = items
197 .iter()
198 .filter_map(|item| match closed_under(bindings, item.inner()) {
199 Ok(v) => Some(v),
200 Err(e) => {
201 for e in e {
202 unbound.insert(e);
203 }
204 None
205 }
206 })
207 .collect();
208 if unbound.is_empty() {
209 Ok(Closed(ExpressionTree::Sum(items)))
210 } else {
211 Err(unbound)
212 }
213 }
214 ExpressionTree::Floor(a, b) => {
215 let (a, b) = combine_close_results(
216 closed_under(bindings, a.inner()),
217 closed_under(bindings, b.inner()),
218 )?;
219 Ok(Closed(ExpressionTree::Floor(Box::new(a), Box::new(b))))
220 }
221 ExpressionTree::Comparison { a, b, op } => {
222 let (a, b) = combine_close_results(
223 closed_under(bindings, a.inner()),
224 closed_under(bindings, b.inner()),
225 )?;
226 Ok(Closed(ExpressionTree::Comparison {
227 a: Box::new(a),
228 b: Box::new(b),
229 op: *op,
230 }))
231 }
232
233 ExpressionTree::Binding {
234 symbol,
235 value,
236 tail,
237 } => {
238 let value = closed_under(bindings, value.inner())?;
239 let tail = closed_under(
240 &AvailableBinding::Chain {
241 defined: symbol,
242 definition: &value,
243 prev: bindings,
244 },
245 tail.inner(),
246 )?;
247
248 Ok(Closed(ExpressionTree::Binding {
249 symbol: symbol.clone(),
250 value: Box::new(value),
251 tail: Box::new(tail),
252 }))
253 }
254 }
255}
256
257#[derive(Copy, Clone)]
259enum AvailableBinding<'a, T: ExpressionWrapper> {
260 Root,
261 Chain {
262 defined: &'a Symbol,
263 definition: &'a T,
264 prev: &'a AvailableBinding<'a, T>,
265 },
266}
267
268impl<T: ExpressionWrapper> AvailableBinding<'_, T> {
269 fn search(&self, needle: &Symbol) -> Option<&T> {
271 let mut current: &AvailableBinding<T> = self;
272 while let AvailableBinding::Chain {
273 defined,
274 prev,
275 definition,
276 } = current
277 {
278 if *defined == needle {
279 return Some(definition);
280 } else {
281 current = *prev;
282 }
283 }
284 None
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use proptest::prelude::*;
291 use proptest::strategy::Union;
292
293 use super::*;
294 use crate::parse::RawExpression;
295 use crate::properties;
296 use crate::symbolic::{Constant, Die};
297
298 #[test]
299 fn open_symbols() {
300 const CASES: &[(&str, &[&str])] = &[
301 ("ATK", &["ATK"]),
302 ("2(ATK+CHA)", &["ATK", "CHA"]),
303 ("[AC: 10] [ATK: 1d20] (ATK + CHA) > AC", &["CHA"]),
304 ];
305 for (expr, symbols) in CASES {
306 let raw: RawExpression = expr.parse().unwrap();
307 let symbols: HashSet<Symbol> = symbols.iter().map(|v| v.parse().unwrap()).collect();
308 let unclosed: Result<Closed, _> = raw.try_into();
309 let Err(Error::UnboundSymbols(unbound)) = unclosed else {
310 panic!("got closed expression")
311 };
312 assert_eq!(symbols, unbound, "case: {expr}");
313 }
314 }
315
316 #[test]
317 fn closed_symbols() {
318 const CASES: &[&str] = &["[AC: 10] 2([ATK: 1d20] (ATK + 3) > AC)"];
319 for expr in CASES {
320 let raw: RawExpression = expr.parse().unwrap();
321 let closed: Closed = raw.clone().try_into().unwrap();
322 assert_eq!(closed.to_string(), raw.to_string());
323 }
324 }
325
326 fn search_for<'a, T, F>(
328 tree: &'a ExpressionTree<T>,
329 predicate: &mut F,
330 ) -> Option<&'a ExpressionTree<T>>
331 where
332 F: FnMut(&ExpressionTree<T>) -> bool,
333 T: ExpressionWrapper,
334 {
335 if predicate(tree) {
336 return Some(tree);
337 }
338 match tree {
339 ExpressionTree::Negated(e) => search_for(e.inner(), predicate),
340 ExpressionTree::Repeated {
341 count,
342 value,
343 ranker: _,
344 } => search_for(count.inner(), predicate).or(search_for(value.inner(), predicate)),
345 ExpressionTree::Product(a, b) => {
346 search_for(a.inner(), predicate).or(search_for(b.inner(), predicate))
347 }
348 ExpressionTree::Floor(a, b) => {
349 search_for(a.inner(), predicate).or(search_for(b.inner(), predicate))
350 }
351 ExpressionTree::Comparison { a, b, op: _ } => {
352 search_for(a.inner(), predicate).or(search_for(b.inner(), predicate))
353 }
354 ExpressionTree::Sum(items) => {
355 for item in items {
356 if let Some(v) = search_for(item.inner(), predicate) {
357 return Some(v);
358 }
359 }
360 None
361 }
362 ExpressionTree::Binding {
363 symbol,
364 value,
365 tail,
366 } => search_for(value.inner(), predicate).or(search_for(tail.inner(), predicate)),
367
368 _ => None,
369 }
370 }
371
372 fn expression_closed_under(
387 symbols: HashSet<Symbol>,
388 ) -> impl Strategy<Value = (RawExpression, HashSet<Symbol>)> {
389 let symbols_final = symbols.clone();
390
391 let static_leaf = Union::new([
392 any::<Die>().prop_map(ExpressionTree::Die).boxed(),
393 any::<Constant>().prop_map(ExpressionTree::Modifier).boxed(),
394 ]);
395
396 let leaf = if symbols.is_empty() {
399 static_leaf.boxed()
400 } else {
401 (0..symbols.len())
402 .prop_map(move |v| {
403 let s = symbols.iter().nth(v).unwrap();
404 ExpressionTree::Symbol(s.clone())
405 })
406 .boxed()
407 };
408
409 let leaf = leaf.prop_map(RawExpression::from);
410 let closure = leaf.prop_recursive(2, 2, 2, |strat| {
411 prop_oneof![
412 properties::negated(&strat),
413 properties::repeated(&strat),
414 properties::product(&strat),
415 properties::floor(&strat),
416 properties::sum(&strat),
417 properties::comparison(&strat),
418 ]
419 .prop_map(RawExpression::from)
420 });
421 closure.prop_map(move |v| (v, symbols_final.clone()))
422 }
423
424 proptest! {
425 #[test]
426 fn identify_open_symbols(
427 (_symbols, (exp, _)) in
428 proptest::collection::hash_set(properties::symbol(), 1..4)
429 .prop_flat_map(|symbols| (Just(symbols.clone()), expression_closed_under(symbols)))
430 ) {
431 let result : Result<Closed, _> = exp.clone().try_into();
432
433 if let Err(Error::UnboundSymbols(got)) = result {
434 for symbol in got {
435 assert!(search_for(exp.inner(), &mut |s| matches!(s, ExpressionTree::Symbol(sym) if sym == &symbol)).is_some());
437 }
438 }
439 }
440 }
443
444 fn closed_expression() -> impl Strategy<Value = RawExpression> {
446 let leaf = expression_closed_under(HashSet::new());
447 let syms = leaf.prop_recursive(2, 2, 2, |strat| {
448 (properties::symbol(), strat.clone()).prop_flat_map(
449 |(symbol, (definition, mut symbols))| {
450 symbols.insert(symbol.clone());
457 expression_closed_under(symbols).prop_map(move |(tail, new_symbols)| {
458 (
459 RawExpression::from(ExpressionTree::Binding {
460 symbol: symbol.clone(),
461 value: Box::new(definition.clone()),
462 tail: Box::new(tail),
463 }),
464 new_symbols,
465 )
466 })
467 },
468 )
469 });
470 syms.prop_map(|(tree, _syms)| tree)
471 }
472
473 fn unbound_tree<'a, W>(
475 symbol: &Symbol,
476 tree: &'a ExpressionTree<W>,
477 ) -> Option<&'a ExpressionTree<W>>
478 where
479 W: ExpressionWrapper,
480 {
481 match tree {
482 ExpressionTree::Binding {
483 symbol: sym,
484 value,
485 tail,
486 } => {
487 let value = unbound_tree(symbol, value.inner());
489 if sym == symbol {
490 value
492 } else {
493 value.or_else(|| unbound_tree(symbol, tail.inner()))
495 }
496 }
497 ExpressionTree::Modifier(_) => None,
498 ExpressionTree::Die(_) => None,
499 ExpressionTree::Symbol(sym) if sym == symbol => Some(tree),
500 ExpressionTree::Symbol(_) => None,
501 ExpressionTree::Negated(e) => unbound_tree(symbol, e.inner()),
502 ExpressionTree::Repeated {
503 count,
504 value,
505 ranker: _,
506 } => {
507 unbound_tree(symbol, count.inner()).or_else(|| unbound_tree(symbol, value.inner()))
508 }
509 ExpressionTree::Product(a, b) => {
510 unbound_tree(symbol, a.inner()).or_else(|| unbound_tree(symbol, b.inner()))
511 }
512 ExpressionTree::Floor(a, b) => {
513 unbound_tree(symbol, a.inner()).or_else(|| unbound_tree(symbol, b.inner()))
514 }
515 ExpressionTree::Comparison { a, b, op: _ } => {
516 unbound_tree(symbol, a.inner()).or_else(|| unbound_tree(symbol, b.inner()))
517 }
518 ExpressionTree::Sum(items) => items
519 .iter()
520 .filter_map(|v| unbound_tree(symbol, v.inner()))
521 .next(),
522 }
523 }
524
525 proptest! {
526 #[test]
529 fn generate_valid_bindings(exp in closed_expression()) {
530 let exp = exp.simplify();
531 let result : Result<Closed, _> = exp.clone().try_into();
532 if let Err(Error::UnboundSymbols(got)) = result {
533 for symbol in got {
534 assert!(search_for(exp.inner(), &mut |s| matches!(s, ExpressionTree::Symbol(sym) if sym == &symbol)).is_some());
536
537 assert!(unbound_tree(&symbol, exp.inner()).is_some());
541 }
542 }
543
544
545 }
546 }
547}