minicas_core/ast/
mod.rs

1//! AST types for representing a math formula.
2use crate::ty::{Ty, TyValue};
3use crate::Path;
4use std::fmt;
5use std::iter::once;
6use std::ops::{Deref, DerefMut};
7
8mod node_const;
9pub use node_const::Const;
10mod node_binary;
11pub use node_binary::{Binary, BinaryOp, CmpOp};
12mod node_unary;
13pub use node_unary::{Unary, UnaryOp};
14mod node_variable;
15pub use node_variable::Var;
16mod node_piecewise;
17pub use node_piecewise::Piecewise;
18
19mod parse;
20
21mod ac_collect;
22pub use ac_collect::{ac_collect, AcError};
23mod fold;
24pub use fold::fold;
25mod typecheck;
26pub use typecheck::{typecheck, TypeError};
27mod rearrange;
28pub use rearrange::make_subject;
29
30/// Errors that can occur when evaluating an AST.
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
32pub enum EvalError {
33    DivByZero,
34    NonInteger,
35    UnexpectedType(Vec<Ty>),
36    UnknownIdent(String),
37    Multiple,
38    UnboundedInterval,
39    IndeterminatePredicate,
40}
41
42/// Context that needs to be provided for evaluation.
43pub trait EvalContext {
44    fn resolve_var(&self, id: &str) -> Option<&TyValue>;
45}
46
47impl EvalContext for () {
48    fn resolve_var(&self, _id: &str) -> Option<&TyValue> {
49        None
50    }
51}
52
53impl<S: AsRef<str>> EvalContext for Vec<(S, TyValue)> {
54    fn resolve_var(&self, id: &str) -> Option<&TyValue> {
55        for (ident, val) in self {
56            if ident.as_ref() == id {
57                return Some(val);
58            }
59        }
60        None
61    }
62}
63
64/// [EvalContext] variant for integer arithmetic.
65pub trait EvalContextInterval {
66    fn resolve_var(&self, id: &str) -> Option<(&TyValue, &TyValue)>;
67}
68
69impl EvalContextInterval for () {
70    fn resolve_var(&self, _id: &str) -> Option<(&TyValue, &TyValue)> {
71        None
72    }
73}
74
75impl<S: AsRef<str>> EvalContextInterval for Vec<(S, (TyValue, TyValue))> {
76    fn resolve_var(&self, id: &str) -> Option<(&TyValue, &TyValue)> {
77        for (ident, val) in self {
78            if ident.as_ref() == id {
79                return Some((&val.0, &val.1));
80            }
81        }
82        None
83    }
84}
85
86pub trait AstNode: Clone + Sized + std::fmt::Debug {
87    /// Returns the type of the value this node yields.
88    fn returns(&self) -> Option<Ty>;
89    /// Returns the types of the operands of this node.
90    fn descendant_types(&self) -> impl Iterator<Item = Option<Ty>>;
91
92    /// Attempts to evaluate the AST to a single finite value.
93    fn finite_eval<C: EvalContext>(&self, ctx: &C) -> Result<TyValue, EvalError>;
94    /// Evaluates all possible values of the AST.
95    fn eval<C: EvalContext>(
96        &self,
97        ctx: &C,
98    ) -> Result<Box<dyn Iterator<Item = Result<TyValue, EvalError>> + '_>, EvalError>;
99    /// Interval variant of [AstNode::eval].
100    fn eval_interval<C: EvalContextInterval>(
101        &self,
102        ctx: &C,
103    ) -> Result<Box<dyn Iterator<Item = Result<(TyValue, TyValue), EvalError>> + '_>, EvalError>;
104
105    /// Recursively executes the given function on every node in the AST.
106    /// The walk will end early if the given function returns false and
107    /// the invocation was not depth first.
108    fn walk(&self, depth_first: bool, cb: &mut impl FnMut(&NodeInner) -> bool);
109    /// Recursively executes the given function on every node in the AST.
110    /// The walk will end early if the given function returns false and
111    /// the invocation was not depth first.
112    fn walk_mut(&mut self, depth_first: bool, cb: &mut impl FnMut(&mut NodeInner) -> bool);
113
114    /// Returns the concrete variant this AST node represents.
115    fn as_inner(&self) -> &NodeInner;
116    /// Iterates through the child nodes of this node.
117    fn iter_children(&self) -> impl Iterator<Item = &NodeInner>;
118
119    /// Returns the nested child described by the given sequence.
120    ///
121    /// Each entry describes where to branch when recursively fetching
122    /// the referenced node. A value of 0 means the left-hand side or unary
123    /// operand, and a value of 1 means the right-hand side operand.
124    fn get<I: Iterator<Item = usize>>(&self, i: I) -> Option<&NodeInner>;
125    /// Returns the mutable [NodeInner] described by the given path sequence.
126    ///
127    /// Each entry describes where to branch when recursively fetching
128    /// the referenced node. A value of 0 means the left-hand side or unary
129    /// operand, and a value of 1 means the right-hand side operand.
130    fn get_mut<I: Iterator<Item = usize>>(&mut self, i: I) -> Option<&mut NodeInner>;
131
132    /// Returns whether the operator is left-associative, and its precedence.
133    ///
134    /// None is returned if precedence is unambiguous.
135    fn parsing_precedence(&self) -> Option<(bool, usize)>;
136}
137
138/// High-level type representing any node.
139///
140/// Wraps a [NodeInner].
141#[derive(Debug, Clone, PartialEq, Eq, Hash)]
142pub struct Node {
143    n: NodeInner,
144}
145
146impl Node {
147    pub fn new(n: NodeInner) -> Self {
148        Self { n }
149    }
150}
151
152impl AstNode for Node {
153    fn returns(&self) -> Option<Ty> {
154        self.n.returns()
155    }
156    fn descendant_types(&self) -> impl Iterator<Item = Option<Ty>> {
157        self.n.descendant_types()
158    }
159    fn finite_eval<C: EvalContext>(&self, ctx: &C) -> Result<TyValue, EvalError> {
160        self.n.finite_eval(ctx)
161    }
162    fn eval<C: EvalContext>(
163        &self,
164        ctx: &C,
165    ) -> Result<Box<dyn Iterator<Item = Result<TyValue, EvalError>> + '_>, EvalError> {
166        self.n.eval(ctx)
167    }
168    fn eval_interval<C: EvalContextInterval>(
169        &self,
170        ctx: &C,
171    ) -> Result<Box<dyn Iterator<Item = Result<(TyValue, TyValue), EvalError>> + '_>, EvalError>
172    {
173        self.n.eval_interval(ctx)
174    }
175    fn walk(&self, depth_first: bool, cb: &mut impl FnMut(&NodeInner) -> bool) {
176        self.n.walk(depth_first, cb)
177    }
178    fn walk_mut(&mut self, depth_first: bool, cb: &mut impl FnMut(&mut NodeInner) -> bool) {
179        self.n.walk_mut(depth_first, cb)
180    }
181    fn as_inner(&self) -> &NodeInner {
182        self.n.as_inner()
183    }
184    fn iter_children(&self) -> impl Iterator<Item = &NodeInner> {
185        self.n.iter_children()
186    }
187    fn get<I: Iterator<Item = usize>>(&self, mut i: I) -> Option<&NodeInner> {
188        self.n.get(&mut i)
189    }
190    fn get_mut<I: Iterator<Item = usize>>(&mut self, mut i: I) -> Option<&mut NodeInner> {
191        self.n.get_mut(&mut i)
192    }
193    fn parsing_precedence(&self) -> Option<(bool, usize)> {
194        self.n.parsing_precedence()
195    }
196}
197
198impl Deref for Node {
199    type Target = NodeInner;
200
201    fn deref(&self) -> &NodeInner {
202        &self.n
203    }
204}
205
206impl fmt::Display for Node {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        fmt::Display::fmt(&self.n, f)
209    }
210}
211
212impl From<NodeInner> for Node {
213    fn from(n: NodeInner) -> Self {
214        Self { n }
215    }
216}
217
218impl<'a> TryFrom<parse::ParseNode<'a>> for Node {
219    type Error = String;
220
221    fn try_from(n: parse::ParseNode<'a>) -> Result<Self, Self::Error> {
222        use parse::ParseNode;
223        match n {
224            ParseNode::Bool(b) => Ok(NodeInner::Const(b.into()).into()),
225            ParseNode::Int(i) => Ok(NodeInner::Const(i.into()).into()),
226            ParseNode::Float(f) => {
227                Ok(NodeInner::Const((num::rational::Ratio::from_float(f).unwrap()).into()).into())
228            }
229
230            ParseNode::Ident(i) => Ok(NodeInner::Var(Var::new_untyped(i)).into()),
231            // TODO: Should probably somehow represent that this is a coefficient, and hence
232            // can only be a numeric quantity (vs say a boolean)
233            ParseNode::IdentWithCoefficient(co_eff, i) => Ok(NodeInner::Binary(Binary::mul(
234                NodeInner::Const(co_eff.into()),
235                NodeInner::Var(Var::new_untyped(i)),
236            ))
237            .into()),
238
239            ParseNode::Abs(operand) => {
240                let i = Node::try_from(*operand)?;
241                Ok(NodeInner::Unary(Unary::abs(i)).into())
242            }
243            ParseNode::Unary { op, operand } => {
244                let i = Node::try_from(*operand)?;
245                match op {
246                    "-" => Ok(NodeInner::Unary(Unary::negate(i)).into()),
247                    _ => Err(format!("unknown unary op {}", op)),
248                }
249            }
250
251            ParseNode::Root(operand, base) => {
252                let o = Node::try_from(*operand)?;
253                let b = Node::try_from(*base)?;
254                Ok(NodeInner::Binary(Binary::root(o, b)).into())
255            }
256            ParseNode::Pow(l, r) => {
257                let l = Node::try_from(*l)?;
258                let r = Node::try_from(*r)?;
259                Ok(NodeInner::Binary(Binary::pow(l, r)).into())
260            }
261            ParseNode::Min(l, r) => {
262                let l = Node::try_from(*l)?;
263                let r = Node::try_from(*r)?;
264                Ok(NodeInner::Binary(Binary::min(l, r)).into())
265            }
266            ParseNode::Max(l, r) => {
267                let l = Node::try_from(*l)?;
268                let r = Node::try_from(*r)?;
269                Ok(NodeInner::Binary(Binary::max(l, r)).into())
270            }
271            ParseNode::Binary { op, lhs, rhs } => {
272                let (l, r) = (Node::try_from(*lhs)?, Node::try_from(*rhs)?);
273                match op {
274                    "-" => Ok(NodeInner::Binary(Binary::sub(l, r)).into()),
275                    "+" => Ok(NodeInner::Binary(Binary::add(l, r)).into()),
276                    "±" => Ok(NodeInner::Binary(Binary::plus_or_minus(l, r)).into()),
277                    "*" => Ok(NodeInner::Binary(Binary::mul(l, r)).into()),
278                    "/" => Ok(NodeInner::Binary(Binary::div(l, r)).into()),
279                    "==" => Ok(NodeInner::Binary(Binary::equals(l, r)).into()),
280                    "<" => Ok(NodeInner::Binary(Binary::lt(l, r)).into()),
281                    "<=" => Ok(NodeInner::Binary(Binary::lte(l, r)).into()),
282                    ">" => Ok(NodeInner::Binary(Binary::gt(l, r)).into()),
283                    ">=" => Ok(NodeInner::Binary(Binary::gte(l, r)).into()),
284                    _ => Err(format!("unknown binary op {}", op)),
285                }
286            }
287            ParseNode::Piecewise { arms, otherwise } => {
288                let otherwise = Node::try_from(*otherwise)?;
289                Ok(NodeInner::from(Piecewise::new(
290                    arms.into_iter()
291                        .map(|(e, c)| Ok((Node::try_from(*e)?.into(), Node::try_from(*c)?.into())))
292                        .collect::<Result<Vec<_>, String>>()?,
293                    otherwise.into(),
294                ))
295                .into())
296            }
297        }
298    }
299}
300
301impl<'a> TryFrom<&'a str> for Node {
302    type Error = String;
303
304    fn try_from(s: &'a str) -> Result<Self, Self::Error> {
305        match parse::parse(s) {
306            Ok((_, pn)) => Node::try_from(pn),
307            Err(e) => Err(format!("parse err: {}", e)),
308        }
309    }
310}
311
312/// A [Node] on the heap.
313#[derive(Debug, Clone, PartialEq, Eq, Hash)]
314pub struct HN(Box<Node>);
315
316impl HN {
317    pub fn new(n: Node) -> HN {
318        HN(Box::new(n))
319    }
320
321    pub fn make(n: NodeInner) -> HN {
322        Self::new(Node { n })
323    }
324
325    /// Move out of the heap.
326    /// Intended for chaining transformations not covered by `map`.
327    pub fn and_then<F>(self, f: F) -> Node
328    where
329        F: FnOnce(Node) -> Node,
330    {
331        f(*self.0)
332    }
333
334    /// Produce a new `HN` from `self` without reallocating.
335    pub fn map<F>(mut self, f: F) -> Self
336    where
337        F: FnOnce(Node) -> Node,
338    {
339        let x = f(*self.0);
340        *self.0 = x;
341
342        self
343    }
344
345    /// Replaces the inner node.
346    pub fn replace_with(&mut self, n: Node) {
347        *self.0 = n;
348    }
349    pub fn swap(&mut self, n: Node) -> Node {
350        std::mem::replace(&mut self.0, n)
351    }
352}
353
354impl AstNode for HN {
355    fn returns(&self) -> Option<Ty> {
356        self.0.returns()
357    }
358    fn descendant_types(&self) -> impl Iterator<Item = Option<Ty>> {
359        self.0.descendant_types()
360    }
361    fn finite_eval<C: EvalContext>(&self, ctx: &C) -> Result<TyValue, EvalError> {
362        self.0.finite_eval(ctx)
363    }
364    fn eval<C: EvalContext>(
365        &self,
366        ctx: &C,
367    ) -> Result<Box<dyn Iterator<Item = Result<TyValue, EvalError>> + '_>, EvalError> {
368        self.0.eval(ctx)
369    }
370    fn eval_interval<C: EvalContextInterval>(
371        &self,
372        ctx: &C,
373    ) -> Result<Box<dyn Iterator<Item = Result<(TyValue, TyValue), EvalError>> + '_>, EvalError>
374    {
375        self.0.eval_interval(ctx)
376    }
377    fn walk(&self, depth_first: bool, cb: &mut impl FnMut(&NodeInner) -> bool) {
378        self.0.walk(depth_first, cb)
379    }
380    fn walk_mut(&mut self, depth_first: bool, cb: &mut impl FnMut(&mut NodeInner) -> bool) {
381        self.0.walk_mut(depth_first, cb)
382    }
383    fn as_inner(&self) -> &NodeInner {
384        self.0.as_inner()
385    }
386    fn iter_children(&self) -> impl Iterator<Item = &NodeInner> {
387        self.0.iter_children()
388    }
389    fn get<I: Iterator<Item = usize>>(&self, i: I) -> Option<&NodeInner> {
390        self.0.get(i)
391    }
392    fn get_mut<I: Iterator<Item = usize>>(&mut self, i: I) -> Option<&mut NodeInner> {
393        self.0.get_mut(i)
394    }
395    fn parsing_precedence(&self) -> Option<(bool, usize)> {
396        self.0.parsing_precedence()
397    }
398}
399
400impl Deref for HN {
401    type Target = Node;
402
403    fn deref(&self) -> &Node {
404        &self.0
405    }
406}
407
408impl DerefMut for HN {
409    fn deref_mut(&mut self) -> &mut Node {
410        &mut self.0
411    }
412}
413
414impl From<Node> for HN {
415    fn from(n: Node) -> Self {
416        Self(Box::new(n))
417    }
418}
419
420impl From<NodeInner> for HN {
421    fn from(n: NodeInner) -> Self {
422        Self(Box::new(Node { n }))
423    }
424}
425
426impl From<TyValue> for HN {
427    fn from(v: TyValue) -> Self {
428        let n = Const::new(v).into();
429        Self(Box::new(Node { n }))
430    }
431}
432
433impl fmt::Display for HN {
434    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435        fmt::Display::fmt(&self.0, f)
436    }
437}
438
439/// Concrete varieties of a node which together compose an AST.
440#[derive(Debug, Clone, PartialEq, Eq, Hash)]
441pub enum NodeInner {
442    /// Some constant value, like a coefficient or offset.
443    Const(Const),
444    /// An operation which takes a single operand.
445    Unary(Unary),
446    /// An operation which takes two operands.
447    Binary(Binary),
448    /// Some unknown value.
449    Var(Var),
450    /// A piecewise function.
451    Piecewise(Piecewise),
452}
453
454impl NodeInner {
455    /// Creates a new constant node with the given value.
456    pub fn new_const<V: Into<TyValue>>(v: V) -> Self {
457        Self::Const(Const::new(v.into()))
458    }
459    /// Creates a new variable node with the given identifier.
460    pub fn new_var<S: Into<String>>(ident: S) -> Self {
461        Self::Var(Var::new_untyped(ident))
462    }
463
464    /// Returns a ref to the inner [Const] if this node is that variant.
465    pub fn as_const(&self) -> Option<&Const> {
466        match self {
467            Self::Const(c) => Some(c),
468            _ => None,
469        }
470    }
471    /// Returns a ref to the inner [Unary] if this node is that variant.
472    pub fn as_unary(&self) -> Option<&Unary> {
473        match self {
474            Self::Unary(c) => Some(c),
475            _ => None,
476        }
477    }
478    /// Returns a ref to the inner [Binary] if this node is that variant.
479    pub fn as_binary(&self) -> Option<&Binary> {
480        match self {
481            Self::Binary(b) => Some(b),
482            _ => None,
483        }
484    }
485    /// Returns a ref to the inner [Var] if this node is that variant.
486    pub fn as_var(&self) -> Option<&Var> {
487        match self {
488            Self::Var(b) => Some(b),
489            _ => None,
490        }
491    }
492
493    /// Returns the nested child described by the given sequence.
494    ///
495    /// Each entry describes where to branch when recursively fetching
496    /// the referenced node. A value of 0 means the left-hand side or unary
497    /// operand, and a value of 1 means the right-hand side operand.
498    fn get<I: Iterator<Item = usize>>(&self, i: &mut I) -> Option<&NodeInner> {
499        match i.next() {
500            Some(idx) => match (self, idx) {
501                (Self::Unary(u), 0) => (*u.operand()).n.get(i),
502                (Self::Binary(b), 0) => (*b.lhs()).n.get(i),
503                (Self::Binary(b), 1) => (*b.rhs()).n.get(i),
504                (Self::Piecewise(p), idx) => {
505                    let num_branches = p.r#if.len();
506                    if idx == num_branches * 2 {
507                        p.r#else.n.get(i)
508                    } else if idx < num_branches * 2 {
509                        if idx % 2 == 1 {
510                            p.r#if[idx / 2].1.n.get(i)
511                        } else {
512                            p.r#if[idx / 2].0.n.get(i)
513                        }
514                    } else {
515                        None
516                    }
517                }
518                _ => None,
519            },
520            None => Some(self),
521        }
522    }
523
524    /// Returns the mutable [NodeInner] described by the given path sequence.
525    ///
526    /// Each entry describes where to branch when recursively fetching
527    /// the referenced node. A value of 0 means the left-hand side or unary
528    /// operand, and a value of 1 means the right-hand side operand.
529    fn get_mut<I: Iterator<Item = usize>>(&mut self, i: &mut I) -> Option<&mut NodeInner> {
530        match i.next() {
531            Some(idx) => match (self, idx) {
532                (Self::Unary(u), 0) => (*u.operand_mut()).n.get_mut(i),
533                (Self::Binary(b), 0) => (*b.lhs_mut()).n.get_mut(i),
534                (Self::Binary(b), 1) => (*b.rhs_mut()).n.get_mut(i),
535                (Self::Piecewise(p), idx) => {
536                    let num_branches = p.r#if.len();
537                    if idx == num_branches * 2 {
538                        p.r#else.n.get_mut(i)
539                    } else if idx < num_branches * 2 {
540                        if idx % 2 == 1 {
541                            p.r#if[idx / 2].1.n.get_mut(i)
542                        } else {
543                            p.r#if[idx / 2].0.n.get_mut(i)
544                        }
545                    } else {
546                        None
547                    }
548                }
549                _ => None,
550            },
551            None => Some(self),
552        }
553    }
554
555    /// Returns a neatly-formatted string representation of the AST.
556    pub fn pretty_str(&self, parent_precedence: Option<usize>) -> String {
557        match self {
558            Self::Const(c) => format!("{}", c),
559            Self::Unary(u) => u.pretty_str(parent_precedence),
560            Self::Binary(b) => b.pretty_str(parent_precedence),
561            Self::Var(v) => format!("{}", v),
562            Self::Piecewise(p) => format!("{}", p),
563        }
564    }
565}
566
567impl AstNode for NodeInner {
568    fn returns(&self) -> Option<Ty> {
569        match self {
570            Self::Const(c) => Some(c.returns()),
571            Self::Unary(u) => u.returns(),
572            Self::Binary(b) => b.returns(),
573            Self::Var(v) => v.returns(),
574            Self::Piecewise(p) => p.returns(),
575        }
576    }
577
578    fn descendant_types(&self) -> impl Iterator<Item = Option<Ty>> {
579        match self {
580            Self::Const(_) | Self::Var(_) | Self::Piecewise(_) => {
581                [None, None].into_iter().flatten()
582            }
583            Self::Unary(u) => [Some(u.operand().returns()), None].into_iter().flatten(),
584            Self::Binary(b) => [Some(b.lhs().returns()), Some(b.lhs().returns())]
585                .into_iter()
586                .flatten(),
587        }
588    }
589    fn finite_eval<C: EvalContext>(&self, ctx: &C) -> Result<TyValue, EvalError> {
590        match self {
591            Self::Const(c) => Ok(c.value().clone()),
592            Self::Unary(u) => u.finite_eval(ctx),
593            Self::Binary(b) => b.finite_eval(ctx),
594            Self::Piecewise(p) => p.finite_eval(ctx),
595            Self::Var(v) => match ctx.resolve_var(v.ident()) {
596                Some(v) => Ok(v.clone()),
597                None => Err(EvalError::UnknownIdent(v.ident().to_string())),
598            },
599        }
600    }
601    fn eval<C: EvalContext>(
602        &self,
603        ctx: &C,
604    ) -> Result<Box<dyn Iterator<Item = Result<TyValue, EvalError>> + '_>, EvalError> {
605        match self {
606            Self::Const(c) => Ok(Box::new(once(Ok(c.value().clone())))),
607            Self::Unary(u) => u.eval(ctx),
608            Self::Binary(b) => b.eval(ctx),
609            Self::Piecewise(p) => p.eval(ctx),
610            Self::Var(v) => match ctx.resolve_var(v.ident()) {
611                Some(v) => Ok(Box::new(once(Ok(v.clone())))),
612                None => Err(EvalError::UnknownIdent(v.ident().to_string())),
613            },
614        }
615    }
616    fn eval_interval<C: EvalContextInterval>(
617        &self,
618        ctx: &C,
619    ) -> Result<Box<dyn Iterator<Item = Result<(TyValue, TyValue), EvalError>> + '_>, EvalError>
620    {
621        match self {
622            Self::Const(c) => Ok(Box::new(once(Ok((c.value().clone(), c.value().clone()))))),
623            Self::Unary(u) => u.eval_interval(ctx),
624            Self::Binary(b) => b.eval_interval(ctx),
625            Self::Piecewise(p) => p.eval_interval(ctx),
626            Self::Var(v) => match ctx.resolve_var(v.ident()) {
627                Some((v_min, v_max)) => Ok(Box::new(once(Ok((v_min.clone(), v_max.clone()))))),
628                None => Err(EvalError::UnknownIdent(v.ident().to_string())),
629            },
630        }
631    }
632
633    fn walk(&self, depth_first: bool, cb: &mut impl FnMut(&NodeInner) -> bool) {
634        if !depth_first {
635            if !cb(self) {
636                return;
637            }
638        }
639
640        // recurse to sub-expressions
641        match self {
642            Self::Unary(u) => {
643                u.operand().walk(depth_first, cb);
644            }
645            Self::Binary(b) => {
646                b.lhs().walk(depth_first, cb);
647                b.rhs().walk(depth_first, cb);
648            }
649            Self::Piecewise(p) => {
650                for (e, p) in p.iter_branches() {
651                    e.walk(depth_first, cb);
652                    p.walk(depth_first, cb)
653                }
654                p.else_branch().walk(depth_first, cb);
655            }
656
657            // nothing contained to walk
658            Self::Const(_) | Self::Var(_) => {}
659        }
660
661        if depth_first {
662            if !cb(self) {
663                return;
664            }
665        }
666    }
667    fn walk_mut(&mut self, depth_first: bool, cb: &mut impl FnMut(&mut NodeInner) -> bool) {
668        if !depth_first {
669            if !cb(self) {
670                return;
671            }
672        }
673
674        // recurse to sub-expressions
675        match self {
676            Self::Unary(u) => {
677                u.operand_mut().walk_mut(depth_first, cb);
678            }
679            Self::Binary(b) => {
680                b.lhs_mut().walk_mut(depth_first, cb);
681                b.rhs_mut().walk_mut(depth_first, cb);
682            }
683            Self::Piecewise(p) => {
684                for (e, p) in p.iter_branches_mut() {
685                    e.walk_mut(depth_first, cb);
686                    p.walk_mut(depth_first, cb)
687                }
688                p.else_branch_mut().walk_mut(depth_first, cb);
689            }
690
691            // nothing contained to walk
692            Self::Const(_) | Self::Var(_) => {}
693        }
694
695        if depth_first {
696            if !cb(self) {
697                return;
698            }
699        }
700    }
701
702    fn as_inner(&self) -> &NodeInner {
703        self
704    }
705
706    fn iter_children(&self) -> impl Iterator<Item = &NodeInner> {
707        match self {
708            Self::Const(_) | Self::Var(_) | Self::Piecewise(_) => {
709                [None, None].into_iter().flatten()
710            }
711            Self::Unary(u) => [Some(&u.operand().0.n), None].into_iter().flatten(),
712            Self::Binary(b) => [Some(&b.lhs().0.n), Some(&b.rhs().0.n)]
713                .into_iter()
714                .flatten(),
715        }
716    }
717
718    fn get<I: Iterator<Item = usize>>(&self, mut i: I) -> Option<&NodeInner> {
719        NodeInner::get(self, &mut i)
720    }
721    fn get_mut<I: Iterator<Item = usize>>(&mut self, mut i: I) -> Option<&mut NodeInner> {
722        NodeInner::get_mut(self, &mut i)
723    }
724    fn parsing_precedence(&self) -> Option<(bool, usize)> {
725        match self {
726            Self::Const(_) | Self::Var(_) | Self::Piecewise(_) => None,
727            Self::Unary(u) => u.op.parsing_precedence(),
728            Self::Binary(b) => b.op.parsing_precedence(),
729        }
730    }
731}
732
733impl From<Const> for NodeInner {
734    fn from(n: Const) -> Self {
735        Self::Const(n)
736    }
737}
738impl From<Binary> for NodeInner {
739    fn from(n: Binary) -> Self {
740        Self::Binary(n)
741    }
742}
743impl From<Unary> for NodeInner {
744    fn from(n: Unary) -> Self {
745        Self::Unary(n)
746    }
747}
748impl From<Var> for NodeInner {
749    fn from(n: Var) -> Self {
750        Self::Var(n)
751    }
752}
753impl From<Piecewise> for NodeInner {
754    fn from(p: Piecewise) -> Self {
755        Self::Piecewise(p)
756    }
757}
758impl From<Node> for NodeInner {
759    fn from(n: Node) -> Self {
760        Self::from(n.n)
761    }
762}
763
764impl fmt::Display for NodeInner {
765    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
766        f.write_str(&self.pretty_str(None))
767    }
768}
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773
774    #[test]
775    fn parse_basic() {
776        assert_eq!(
777            Node::try_from("3 + 5"),
778            Ok(Node::new(
779                Binary::add::<TyValue, TyValue>(3.into(), 5.into()).into()
780            )),
781        );
782        assert_eq!(
783            Node::try_from("-5"),
784            Ok(Node::new(Unary::negate::<TyValue>(5.into()).into())),
785        );
786        assert_eq!(
787            Node::try_from("3--5"),
788            Ok(Node::new(
789                Binary::sub::<TyValue, HN>(
790                    3.into(),
791                    Node::new(Unary::negate::<TyValue>(5.into()).into()).into(),
792                )
793                .into()
794            )),
795        );
796        assert_eq!(
797            Node::try_from("3==5"),
798            Ok(Node::new(
799                Binary::equals::<TyValue, TyValue>(3.into(), 5.into()).into()
800            )),
801        );
802        assert_eq!(
803            Node::try_from("3 > 5"),
804            Ok(Node::new(
805                Binary::gt::<TyValue, TyValue>(3.into(), 5.into()).into()
806            )),
807        );
808        assert_eq!(
809            Node::try_from("5x"),
810            Ok(Node::new(
811                Binary::mul::<TyValue, HN>(
812                    5.into(),
813                    Node::new(Var::new_untyped("x").into()).into()
814                )
815                .into()
816            )),
817        );
818
819        assert_eq!(
820            Node::try_from("x ± 4 * y"),
821            Ok(Node::new(
822                Binary::plus_or_minus::<HN, HN>(
823                    Node::new(Var::new_untyped("x").into()).into(),
824                    Node::new(
825                        Binary::mul::<TyValue, HN>(
826                            4.into(),
827                            Node::new(Var::new_untyped("y").into()).into(),
828                        )
829                        .into()
830                    )
831                    .into(),
832                )
833                .into()
834            )),
835        );
836
837        assert_eq!(
838            Node::try_from("sqrt(4)"),
839            Ok(Node::new(
840                Binary::root::<TyValue, TyValue>(4.into(), 2.into(),).into()
841            )),
842        );
843        assert_eq!(
844            Node::try_from("root(8, 3)"),
845            Ok(Node::new(
846                Binary::root::<TyValue, TyValue>(8.into(), 3.into(),).into()
847            )),
848        );
849    }
850
851    #[test]
852    fn parse_piecewise() {
853        assert_eq!(
854            Node::try_from("{2x if x == 0; otherwise x}"),
855            Ok(Node::new(NodeInner::from(Piecewise::new(
856                vec![(
857                    Node::try_from("2x").unwrap().into(),
858                    Node::try_from("x == 0").unwrap().into(),
859                )],
860                Node::try_from("x").unwrap().into(),
861            )))),
862        );
863    }
864
865    #[test]
866    fn fmt_basic() {
867        assert_eq!(
868            "3 * (a + b)",
869            format!(
870                "{}",
871                Node::new(
872                    Binary::mul::<TyValue, HN>(
873                        3.into(),
874                        Node::new(
875                            Binary::add::<HN, HN>(
876                                Node::new(Var::new_untyped("a").into()).into(),
877                                Node::new(Var::new_untyped("b").into()).into(),
878                            )
879                            .into()
880                        )
881                        .into(),
882                    )
883                    .into()
884                )
885            )
886        );
887        assert_eq!(
888            "(a + b) * 3",
889            format!(
890                "{}",
891                Node::new(
892                    Binary::mul::<HN, TyValue>(
893                        Node::new(
894                            Binary::add::<HN, HN>(
895                                Node::new(Var::new_untyped("a").into()).into(),
896                                Node::new(Var::new_untyped("b").into()).into(),
897                            )
898                            .into()
899                        )
900                        .into(),
901                        3.into(),
902                    )
903                    .into()
904                )
905            )
906        );
907        assert_eq!(
908            "3 + a * b",
909            format!(
910                "{}",
911                Node::new(
912                    Binary::add::<TyValue, HN>(
913                        3.into(),
914                        Node::new(
915                            Binary::mul::<HN, HN>(
916                                Node::new(Var::new_untyped("a").into()).into(),
917                                Node::new(Var::new_untyped("b").into()).into(),
918                            )
919                            .into()
920                        )
921                        .into(),
922                    )
923                    .into()
924                )
925            )
926        );
927        assert_eq!(
928            "a * b + 3",
929            format!(
930                "{}",
931                Node::new(
932                    Binary::add::<HN, TyValue>(
933                        Node::new(
934                            Binary::mul::<HN, HN>(
935                                Node::new(Var::new_untyped("a").into()).into(),
936                                Node::new(Var::new_untyped("b").into()).into(),
937                            )
938                            .into()
939                        )
940                        .into(),
941                        3.into(),
942                    )
943                    .into()
944                )
945            )
946        );
947
948        assert_eq!(
949            "3 - x",
950            format!(
951                "{}",
952                Node::new(
953                    Binary::sub::<TyValue, HN>(
954                        3.into(),
955                        Node::new(Var::new_untyped("x").into()).into(),
956                    )
957                    .into()
958                )
959            )
960        );
961        assert_eq!(
962            "3 ± x",
963            format!(
964                "{}",
965                Node::new(
966                    Binary::plus_or_minus::<TyValue, HN>(
967                        3.into(),
968                        Node::new(Var::new_untyped("x").into()).into(),
969                    )
970                    .into()
971                )
972            )
973        );
974
975        assert_eq!(
976            "3 - (-5)",
977            format!(
978                "{}",
979                Node::new(
980                    Binary::sub::<TyValue, HN>(
981                        3.into(),
982                        Node::new(Unary::negate::<TyValue>(5.into()).into()).into(),
983                    )
984                    .into()
985                )
986            )
987        );
988
989        assert_eq!(
990            "3 - |5|",
991            format!(
992                "{}",
993                Node::new(
994                    Binary::sub::<TyValue, HN>(
995                        3.into(),
996                        Node::new(Unary::abs::<TyValue>(5.into()).into()).into(),
997                    )
998                    .into()
999                )
1000            )
1001        );
1002
1003        assert_eq!(
1004            "{2x if x == 0; otherwise x}",
1005            format!(
1006                "{}",
1007                Node::new(NodeInner::from(Piecewise::new(
1008                    vec![(
1009                        Node::try_from("2x").unwrap().into(),
1010                        Node::try_from("x == 0").unwrap().into(),
1011                    )],
1012                    Node::try_from("x").unwrap().into(),
1013                )))
1014            )
1015        );
1016    }
1017
1018    #[test]
1019    fn finite_eval_simple() {
1020        assert_eq!(
1021            Node::try_from("3.5 + 4.5").unwrap().finite_eval(&()),
1022            Ok(8.into()),
1023        );
1024        assert_eq!(
1025            Node::try_from("3 - 5").unwrap().finite_eval(&()),
1026            Ok((-2).into()),
1027        );
1028        assert_eq!(
1029            Node::try_from("9 - 3 * 2").unwrap().finite_eval(&()),
1030            Ok(3.into()),
1031        );
1032        assert_eq!(
1033            Node::try_from("root(8, 3) + sqrt(4)")
1034                .unwrap()
1035                .finite_eval(&()),
1036            Ok(4.into()),
1037        );
1038        assert_eq!(
1039            Node::try_from("min(2 + 1, max(4, 2))")
1040                .unwrap()
1041                .finite_eval(&()),
1042            Ok(3.into()),
1043        );
1044
1045        assert_eq!(
1046            Node::try_from("x").unwrap().finite_eval(&()),
1047            Err(EvalError::UnknownIdent("x".to_string()))
1048        );
1049        assert_eq!(
1050            Node::try_from("x")
1051                .unwrap()
1052                .finite_eval(&vec![("x", 69.into())]),
1053            Ok(69.into()),
1054        );
1055    }
1056
1057    #[test]
1058    fn finite_eval_piecewise() {
1059        assert_eq!(
1060            Node::try_from("{x if x > y; otherwise y}")
1061                .unwrap()
1062                .finite_eval(&vec![("x", 42.into()), ("y", 4.into())]),
1063            Ok(42.into()),
1064        );
1065        assert_eq!(
1066            Node::try_from("{x if x > y; otherwise y}")
1067                .unwrap()
1068                .finite_eval(&vec![("x", 1.into()), ("y", 4.into())]),
1069            Ok(4.into()),
1070        );
1071    }
1072
1073    #[test]
1074    fn eval_simple() {
1075        assert_eq!(
1076            Node::try_from("3.5 + 4.5")
1077                .unwrap()
1078                .eval(&())
1079                .unwrap()
1080                .collect::<Result<Vec<_>, _>>(),
1081            Ok(vec![8.into()]),
1082        );
1083        assert_eq!(
1084            Node::try_from("9 - 3 * 2")
1085                .unwrap()
1086                .eval(&())
1087                .unwrap()
1088                .collect::<Result<Vec<_>, _>>(),
1089            Ok(vec![3.into()]),
1090        );
1091
1092        assert_eq!(
1093            Node::try_from("5 ± 1")
1094                .unwrap()
1095                .eval(&())
1096                .unwrap()
1097                .collect::<Result<Vec<_>, _>>(),
1098            Ok(vec![6.into(), 4.into()]),
1099        );
1100        assert_eq!(
1101            Node::try_from("2 * (5 ± 1)")
1102                .unwrap()
1103                .eval(&())
1104                .unwrap()
1105                .collect::<Result<Vec<_>, _>>(),
1106            Ok(vec![12.into(), 8.into()]),
1107        );
1108
1109        assert_eq!(
1110            Node::try_from("-{0±x if x > y; otherwise y}")
1111                .unwrap()
1112                .eval(&vec![("x", 2.into()), ("y", 1.into())])
1113                .unwrap()
1114                .collect::<Result<Vec<_>, _>>(),
1115            Ok(vec![(-2).into(), 2.into()]),
1116        );
1117    }
1118
1119    #[test]
1120    fn interval_eval() {
1121        assert_eq!(
1122            Node::try_from("x - 2y")
1123                .unwrap()
1124                .eval_interval(&vec![
1125                    ("x", (1.into(), 2.into())),
1126                    ("y", (5.into(), 6.into()))
1127                ])
1128                .unwrap()
1129                .collect::<Result<Vec<_>, _>>(),
1130            Ok(vec![((-11).into(), (-8).into())]),
1131        );
1132
1133        // Smoke test over a large range
1134        for x in -8..=8 {
1135            for y in -8..=8 {
1136                assert!(
1137                    Node::try_from("sqrt(pow(-2 - x, 2) + pow(3 - y, 2)) + abs(x+y)")
1138                        .unwrap()
1139                        .eval_interval(&vec![
1140                            ("x", (x.into(), (x + 2).into())),
1141                            ("y", (y.into(), (y + 2).into())),
1142                        ])
1143                        .unwrap()
1144                        .collect::<Result<Vec<_>, _>>()
1145                        .is_ok(),
1146                );
1147            }
1148        }
1149    }
1150
1151    #[test]
1152    fn get() {
1153        assert_eq!(
1154            Node::try_from("3 + 2*5").unwrap().get(vec![0].into_iter()),
1155            Some(&NodeInner::new_const(3)),
1156        );
1157        assert_eq!(
1158            Node::try_from("3 + 2*5")
1159                .unwrap()
1160                .get(vec![1, 0].into_iter()),
1161            Some(&NodeInner::new_const(2)),
1162        );
1163        assert_eq!(
1164            Node::try_from("3 + 2*5").unwrap().get(vec![1].into_iter()),
1165            Some(Node::try_from("2 * 5").unwrap().as_inner()),
1166        );
1167        assert_eq!(
1168            Node::try_from("3 + 2*5").unwrap().get(vec![].into_iter()),
1169            Some(Node::try_from("3 + 2 * 5").unwrap().as_inner()),
1170        );
1171
1172        assert_eq!(
1173            Node::try_from("3 + 2*5").unwrap().get(vec![99].into_iter()),
1174            None,
1175        );
1176
1177        // Piecewise
1178        let p: Node = NodeInner::from(Piecewise::new(
1179            vec![(
1180                Node::try_from("x").unwrap().into(),
1181                Node::try_from("x == 0").unwrap().into(),
1182            )],
1183            Node::try_from("0").unwrap().into(),
1184        ))
1185        .into();
1186        assert_eq!(p.get(vec![99].into_iter()), None);
1187        assert_eq!(
1188            p.get(vec![0].into_iter()),
1189            Some(Node::try_from("x").unwrap().as_inner()),
1190        );
1191        assert_eq!(
1192            p.get(vec![1].into_iter()),
1193            Some(Node::try_from("x == 0").unwrap().as_inner()),
1194        );
1195        assert_eq!(
1196            p.get(vec![2].into_iter()),
1197            Some(Node::try_from("0").unwrap().as_inner()),
1198        );
1199    }
1200}