1use crate::ty::{Ty, TyValue};
3use crate::Path;
4use std::fmt;
5use std::iter::once;
6use std::ops::{Deref, DerefMut};
7
8mod node_const;
9pub use node_const::Const;
10mod node_binary;
11pub use node_binary::{Binary, BinaryOp, CmpOp};
12mod node_unary;
13pub use node_unary::{Unary, UnaryOp};
14mod node_variable;
15pub use node_variable::Var;
16mod node_piecewise;
17pub use node_piecewise::Piecewise;
18
19mod parse;
20
21mod ac_collect;
22pub use ac_collect::{ac_collect, AcError};
23mod fold;
24pub use fold::fold;
25mod typecheck;
26pub use typecheck::{typecheck, TypeError};
27mod rearrange;
28pub use rearrange::make_subject;
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
32pub enum EvalError {
33 DivByZero,
34 NonInteger,
35 UnexpectedType(Vec<Ty>),
36 UnknownIdent(String),
37 Multiple,
38 UnboundedInterval,
39 IndeterminatePredicate,
40}
41
42pub trait EvalContext {
44 fn resolve_var(&self, id: &str) -> Option<&TyValue>;
45}
46
47impl EvalContext for () {
48 fn resolve_var(&self, _id: &str) -> Option<&TyValue> {
49 None
50 }
51}
52
53impl<S: AsRef<str>> EvalContext for Vec<(S, TyValue)> {
54 fn resolve_var(&self, id: &str) -> Option<&TyValue> {
55 for (ident, val) in self {
56 if ident.as_ref() == id {
57 return Some(val);
58 }
59 }
60 None
61 }
62}
63
64pub trait EvalContextInterval {
66 fn resolve_var(&self, id: &str) -> Option<(&TyValue, &TyValue)>;
67}
68
69impl EvalContextInterval for () {
70 fn resolve_var(&self, _id: &str) -> Option<(&TyValue, &TyValue)> {
71 None
72 }
73}
74
75impl<S: AsRef<str>> EvalContextInterval for Vec<(S, (TyValue, TyValue))> {
76 fn resolve_var(&self, id: &str) -> Option<(&TyValue, &TyValue)> {
77 for (ident, val) in self {
78 if ident.as_ref() == id {
79 return Some((&val.0, &val.1));
80 }
81 }
82 None
83 }
84}
85
86pub trait AstNode: Clone + Sized + std::fmt::Debug {
87 fn returns(&self) -> Option<Ty>;
89 fn descendant_types(&self) -> impl Iterator<Item = Option<Ty>>;
91
92 fn finite_eval<C: EvalContext>(&self, ctx: &C) -> Result<TyValue, EvalError>;
94 fn eval<C: EvalContext>(
96 &self,
97 ctx: &C,
98 ) -> Result<Box<dyn Iterator<Item = Result<TyValue, EvalError>> + '_>, EvalError>;
99 fn eval_interval<C: EvalContextInterval>(
101 &self,
102 ctx: &C,
103 ) -> Result<Box<dyn Iterator<Item = Result<(TyValue, TyValue), EvalError>> + '_>, EvalError>;
104
105 fn walk(&self, depth_first: bool, cb: &mut impl FnMut(&NodeInner) -> bool);
109 fn walk_mut(&mut self, depth_first: bool, cb: &mut impl FnMut(&mut NodeInner) -> bool);
113
114 fn as_inner(&self) -> &NodeInner;
116 fn iter_children(&self) -> impl Iterator<Item = &NodeInner>;
118
119 fn get<I: Iterator<Item = usize>>(&self, i: I) -> Option<&NodeInner>;
125 fn get_mut<I: Iterator<Item = usize>>(&mut self, i: I) -> Option<&mut NodeInner>;
131
132 fn parsing_precedence(&self) -> Option<(bool, usize)>;
136}
137
138#[derive(Debug, Clone, PartialEq, Eq, Hash)]
142pub struct Node {
143 n: NodeInner,
144}
145
146impl Node {
147 pub fn new(n: NodeInner) -> Self {
148 Self { n }
149 }
150}
151
152impl AstNode for Node {
153 fn returns(&self) -> Option<Ty> {
154 self.n.returns()
155 }
156 fn descendant_types(&self) -> impl Iterator<Item = Option<Ty>> {
157 self.n.descendant_types()
158 }
159 fn finite_eval<C: EvalContext>(&self, ctx: &C) -> Result<TyValue, EvalError> {
160 self.n.finite_eval(ctx)
161 }
162 fn eval<C: EvalContext>(
163 &self,
164 ctx: &C,
165 ) -> Result<Box<dyn Iterator<Item = Result<TyValue, EvalError>> + '_>, EvalError> {
166 self.n.eval(ctx)
167 }
168 fn eval_interval<C: EvalContextInterval>(
169 &self,
170 ctx: &C,
171 ) -> Result<Box<dyn Iterator<Item = Result<(TyValue, TyValue), EvalError>> + '_>, EvalError>
172 {
173 self.n.eval_interval(ctx)
174 }
175 fn walk(&self, depth_first: bool, cb: &mut impl FnMut(&NodeInner) -> bool) {
176 self.n.walk(depth_first, cb)
177 }
178 fn walk_mut(&mut self, depth_first: bool, cb: &mut impl FnMut(&mut NodeInner) -> bool) {
179 self.n.walk_mut(depth_first, cb)
180 }
181 fn as_inner(&self) -> &NodeInner {
182 self.n.as_inner()
183 }
184 fn iter_children(&self) -> impl Iterator<Item = &NodeInner> {
185 self.n.iter_children()
186 }
187 fn get<I: Iterator<Item = usize>>(&self, mut i: I) -> Option<&NodeInner> {
188 self.n.get(&mut i)
189 }
190 fn get_mut<I: Iterator<Item = usize>>(&mut self, mut i: I) -> Option<&mut NodeInner> {
191 self.n.get_mut(&mut i)
192 }
193 fn parsing_precedence(&self) -> Option<(bool, usize)> {
194 self.n.parsing_precedence()
195 }
196}
197
198impl Deref for Node {
199 type Target = NodeInner;
200
201 fn deref(&self) -> &NodeInner {
202 &self.n
203 }
204}
205
206impl fmt::Display for Node {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 fmt::Display::fmt(&self.n, f)
209 }
210}
211
212impl From<NodeInner> for Node {
213 fn from(n: NodeInner) -> Self {
214 Self { n }
215 }
216}
217
218impl<'a> TryFrom<parse::ParseNode<'a>> for Node {
219 type Error = String;
220
221 fn try_from(n: parse::ParseNode<'a>) -> Result<Self, Self::Error> {
222 use parse::ParseNode;
223 match n {
224 ParseNode::Bool(b) => Ok(NodeInner::Const(b.into()).into()),
225 ParseNode::Int(i) => Ok(NodeInner::Const(i.into()).into()),
226 ParseNode::Float(f) => {
227 Ok(NodeInner::Const((num::rational::Ratio::from_float(f).unwrap()).into()).into())
228 }
229
230 ParseNode::Ident(i) => Ok(NodeInner::Var(Var::new_untyped(i)).into()),
231 ParseNode::IdentWithCoefficient(co_eff, i) => Ok(NodeInner::Binary(Binary::mul(
234 NodeInner::Const(co_eff.into()),
235 NodeInner::Var(Var::new_untyped(i)),
236 ))
237 .into()),
238
239 ParseNode::Abs(operand) => {
240 let i = Node::try_from(*operand)?;
241 Ok(NodeInner::Unary(Unary::abs(i)).into())
242 }
243 ParseNode::Unary { op, operand } => {
244 let i = Node::try_from(*operand)?;
245 match op {
246 "-" => Ok(NodeInner::Unary(Unary::negate(i)).into()),
247 _ => Err(format!("unknown unary op {}", op)),
248 }
249 }
250
251 ParseNode::Root(operand, base) => {
252 let o = Node::try_from(*operand)?;
253 let b = Node::try_from(*base)?;
254 Ok(NodeInner::Binary(Binary::root(o, b)).into())
255 }
256 ParseNode::Pow(l, r) => {
257 let l = Node::try_from(*l)?;
258 let r = Node::try_from(*r)?;
259 Ok(NodeInner::Binary(Binary::pow(l, r)).into())
260 }
261 ParseNode::Min(l, r) => {
262 let l = Node::try_from(*l)?;
263 let r = Node::try_from(*r)?;
264 Ok(NodeInner::Binary(Binary::min(l, r)).into())
265 }
266 ParseNode::Max(l, r) => {
267 let l = Node::try_from(*l)?;
268 let r = Node::try_from(*r)?;
269 Ok(NodeInner::Binary(Binary::max(l, r)).into())
270 }
271 ParseNode::Binary { op, lhs, rhs } => {
272 let (l, r) = (Node::try_from(*lhs)?, Node::try_from(*rhs)?);
273 match op {
274 "-" => Ok(NodeInner::Binary(Binary::sub(l, r)).into()),
275 "+" => Ok(NodeInner::Binary(Binary::add(l, r)).into()),
276 "±" => Ok(NodeInner::Binary(Binary::plus_or_minus(l, r)).into()),
277 "*" => Ok(NodeInner::Binary(Binary::mul(l, r)).into()),
278 "/" => Ok(NodeInner::Binary(Binary::div(l, r)).into()),
279 "==" => Ok(NodeInner::Binary(Binary::equals(l, r)).into()),
280 "<" => Ok(NodeInner::Binary(Binary::lt(l, r)).into()),
281 "<=" => Ok(NodeInner::Binary(Binary::lte(l, r)).into()),
282 ">" => Ok(NodeInner::Binary(Binary::gt(l, r)).into()),
283 ">=" => Ok(NodeInner::Binary(Binary::gte(l, r)).into()),
284 _ => Err(format!("unknown binary op {}", op)),
285 }
286 }
287 ParseNode::Piecewise { arms, otherwise } => {
288 let otherwise = Node::try_from(*otherwise)?;
289 Ok(NodeInner::from(Piecewise::new(
290 arms.into_iter()
291 .map(|(e, c)| Ok((Node::try_from(*e)?.into(), Node::try_from(*c)?.into())))
292 .collect::<Result<Vec<_>, String>>()?,
293 otherwise.into(),
294 ))
295 .into())
296 }
297 }
298 }
299}
300
301impl<'a> TryFrom<&'a str> for Node {
302 type Error = String;
303
304 fn try_from(s: &'a str) -> Result<Self, Self::Error> {
305 match parse::parse(s) {
306 Ok((_, pn)) => Node::try_from(pn),
307 Err(e) => Err(format!("parse err: {}", e)),
308 }
309 }
310}
311
312#[derive(Debug, Clone, PartialEq, Eq, Hash)]
314pub struct HN(Box<Node>);
315
316impl HN {
317 pub fn new(n: Node) -> HN {
318 HN(Box::new(n))
319 }
320
321 pub fn make(n: NodeInner) -> HN {
322 Self::new(Node { n })
323 }
324
325 pub fn and_then<F>(self, f: F) -> Node
328 where
329 F: FnOnce(Node) -> Node,
330 {
331 f(*self.0)
332 }
333
334 pub fn map<F>(mut self, f: F) -> Self
336 where
337 F: FnOnce(Node) -> Node,
338 {
339 let x = f(*self.0);
340 *self.0 = x;
341
342 self
343 }
344
345 pub fn replace_with(&mut self, n: Node) {
347 *self.0 = n;
348 }
349 pub fn swap(&mut self, n: Node) -> Node {
350 std::mem::replace(&mut self.0, n)
351 }
352}
353
354impl AstNode for HN {
355 fn returns(&self) -> Option<Ty> {
356 self.0.returns()
357 }
358 fn descendant_types(&self) -> impl Iterator<Item = Option<Ty>> {
359 self.0.descendant_types()
360 }
361 fn finite_eval<C: EvalContext>(&self, ctx: &C) -> Result<TyValue, EvalError> {
362 self.0.finite_eval(ctx)
363 }
364 fn eval<C: EvalContext>(
365 &self,
366 ctx: &C,
367 ) -> Result<Box<dyn Iterator<Item = Result<TyValue, EvalError>> + '_>, EvalError> {
368 self.0.eval(ctx)
369 }
370 fn eval_interval<C: EvalContextInterval>(
371 &self,
372 ctx: &C,
373 ) -> Result<Box<dyn Iterator<Item = Result<(TyValue, TyValue), EvalError>> + '_>, EvalError>
374 {
375 self.0.eval_interval(ctx)
376 }
377 fn walk(&self, depth_first: bool, cb: &mut impl FnMut(&NodeInner) -> bool) {
378 self.0.walk(depth_first, cb)
379 }
380 fn walk_mut(&mut self, depth_first: bool, cb: &mut impl FnMut(&mut NodeInner) -> bool) {
381 self.0.walk_mut(depth_first, cb)
382 }
383 fn as_inner(&self) -> &NodeInner {
384 self.0.as_inner()
385 }
386 fn iter_children(&self) -> impl Iterator<Item = &NodeInner> {
387 self.0.iter_children()
388 }
389 fn get<I: Iterator<Item = usize>>(&self, i: I) -> Option<&NodeInner> {
390 self.0.get(i)
391 }
392 fn get_mut<I: Iterator<Item = usize>>(&mut self, i: I) -> Option<&mut NodeInner> {
393 self.0.get_mut(i)
394 }
395 fn parsing_precedence(&self) -> Option<(bool, usize)> {
396 self.0.parsing_precedence()
397 }
398}
399
400impl Deref for HN {
401 type Target = Node;
402
403 fn deref(&self) -> &Node {
404 &self.0
405 }
406}
407
408impl DerefMut for HN {
409 fn deref_mut(&mut self) -> &mut Node {
410 &mut self.0
411 }
412}
413
414impl From<Node> for HN {
415 fn from(n: Node) -> Self {
416 Self(Box::new(n))
417 }
418}
419
420impl From<NodeInner> for HN {
421 fn from(n: NodeInner) -> Self {
422 Self(Box::new(Node { n }))
423 }
424}
425
426impl From<TyValue> for HN {
427 fn from(v: TyValue) -> Self {
428 let n = Const::new(v).into();
429 Self(Box::new(Node { n }))
430 }
431}
432
433impl fmt::Display for HN {
434 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435 fmt::Display::fmt(&self.0, f)
436 }
437}
438
439#[derive(Debug, Clone, PartialEq, Eq, Hash)]
441pub enum NodeInner {
442 Const(Const),
444 Unary(Unary),
446 Binary(Binary),
448 Var(Var),
450 Piecewise(Piecewise),
452}
453
454impl NodeInner {
455 pub fn new_const<V: Into<TyValue>>(v: V) -> Self {
457 Self::Const(Const::new(v.into()))
458 }
459 pub fn new_var<S: Into<String>>(ident: S) -> Self {
461 Self::Var(Var::new_untyped(ident))
462 }
463
464 pub fn as_const(&self) -> Option<&Const> {
466 match self {
467 Self::Const(c) => Some(c),
468 _ => None,
469 }
470 }
471 pub fn as_unary(&self) -> Option<&Unary> {
473 match self {
474 Self::Unary(c) => Some(c),
475 _ => None,
476 }
477 }
478 pub fn as_binary(&self) -> Option<&Binary> {
480 match self {
481 Self::Binary(b) => Some(b),
482 _ => None,
483 }
484 }
485 pub fn as_var(&self) -> Option<&Var> {
487 match self {
488 Self::Var(b) => Some(b),
489 _ => None,
490 }
491 }
492
493 fn get<I: Iterator<Item = usize>>(&self, i: &mut I) -> Option<&NodeInner> {
499 match i.next() {
500 Some(idx) => match (self, idx) {
501 (Self::Unary(u), 0) => (*u.operand()).n.get(i),
502 (Self::Binary(b), 0) => (*b.lhs()).n.get(i),
503 (Self::Binary(b), 1) => (*b.rhs()).n.get(i),
504 (Self::Piecewise(p), idx) => {
505 let num_branches = p.r#if.len();
506 if idx == num_branches * 2 {
507 p.r#else.n.get(i)
508 } else if idx < num_branches * 2 {
509 if idx % 2 == 1 {
510 p.r#if[idx / 2].1.n.get(i)
511 } else {
512 p.r#if[idx / 2].0.n.get(i)
513 }
514 } else {
515 None
516 }
517 }
518 _ => None,
519 },
520 None => Some(self),
521 }
522 }
523
524 fn get_mut<I: Iterator<Item = usize>>(&mut self, i: &mut I) -> Option<&mut NodeInner> {
530 match i.next() {
531 Some(idx) => match (self, idx) {
532 (Self::Unary(u), 0) => (*u.operand_mut()).n.get_mut(i),
533 (Self::Binary(b), 0) => (*b.lhs_mut()).n.get_mut(i),
534 (Self::Binary(b), 1) => (*b.rhs_mut()).n.get_mut(i),
535 (Self::Piecewise(p), idx) => {
536 let num_branches = p.r#if.len();
537 if idx == num_branches * 2 {
538 p.r#else.n.get_mut(i)
539 } else if idx < num_branches * 2 {
540 if idx % 2 == 1 {
541 p.r#if[idx / 2].1.n.get_mut(i)
542 } else {
543 p.r#if[idx / 2].0.n.get_mut(i)
544 }
545 } else {
546 None
547 }
548 }
549 _ => None,
550 },
551 None => Some(self),
552 }
553 }
554
555 pub fn pretty_str(&self, parent_precedence: Option<usize>) -> String {
557 match self {
558 Self::Const(c) => format!("{}", c),
559 Self::Unary(u) => u.pretty_str(parent_precedence),
560 Self::Binary(b) => b.pretty_str(parent_precedence),
561 Self::Var(v) => format!("{}", v),
562 Self::Piecewise(p) => format!("{}", p),
563 }
564 }
565}
566
567impl AstNode for NodeInner {
568 fn returns(&self) -> Option<Ty> {
569 match self {
570 Self::Const(c) => Some(c.returns()),
571 Self::Unary(u) => u.returns(),
572 Self::Binary(b) => b.returns(),
573 Self::Var(v) => v.returns(),
574 Self::Piecewise(p) => p.returns(),
575 }
576 }
577
578 fn descendant_types(&self) -> impl Iterator<Item = Option<Ty>> {
579 match self {
580 Self::Const(_) | Self::Var(_) | Self::Piecewise(_) => {
581 [None, None].into_iter().flatten()
582 }
583 Self::Unary(u) => [Some(u.operand().returns()), None].into_iter().flatten(),
584 Self::Binary(b) => [Some(b.lhs().returns()), Some(b.lhs().returns())]
585 .into_iter()
586 .flatten(),
587 }
588 }
589 fn finite_eval<C: EvalContext>(&self, ctx: &C) -> Result<TyValue, EvalError> {
590 match self {
591 Self::Const(c) => Ok(c.value().clone()),
592 Self::Unary(u) => u.finite_eval(ctx),
593 Self::Binary(b) => b.finite_eval(ctx),
594 Self::Piecewise(p) => p.finite_eval(ctx),
595 Self::Var(v) => match ctx.resolve_var(v.ident()) {
596 Some(v) => Ok(v.clone()),
597 None => Err(EvalError::UnknownIdent(v.ident().to_string())),
598 },
599 }
600 }
601 fn eval<C: EvalContext>(
602 &self,
603 ctx: &C,
604 ) -> Result<Box<dyn Iterator<Item = Result<TyValue, EvalError>> + '_>, EvalError> {
605 match self {
606 Self::Const(c) => Ok(Box::new(once(Ok(c.value().clone())))),
607 Self::Unary(u) => u.eval(ctx),
608 Self::Binary(b) => b.eval(ctx),
609 Self::Piecewise(p) => p.eval(ctx),
610 Self::Var(v) => match ctx.resolve_var(v.ident()) {
611 Some(v) => Ok(Box::new(once(Ok(v.clone())))),
612 None => Err(EvalError::UnknownIdent(v.ident().to_string())),
613 },
614 }
615 }
616 fn eval_interval<C: EvalContextInterval>(
617 &self,
618 ctx: &C,
619 ) -> Result<Box<dyn Iterator<Item = Result<(TyValue, TyValue), EvalError>> + '_>, EvalError>
620 {
621 match self {
622 Self::Const(c) => Ok(Box::new(once(Ok((c.value().clone(), c.value().clone()))))),
623 Self::Unary(u) => u.eval_interval(ctx),
624 Self::Binary(b) => b.eval_interval(ctx),
625 Self::Piecewise(p) => p.eval_interval(ctx),
626 Self::Var(v) => match ctx.resolve_var(v.ident()) {
627 Some((v_min, v_max)) => Ok(Box::new(once(Ok((v_min.clone(), v_max.clone()))))),
628 None => Err(EvalError::UnknownIdent(v.ident().to_string())),
629 },
630 }
631 }
632
633 fn walk(&self, depth_first: bool, cb: &mut impl FnMut(&NodeInner) -> bool) {
634 if !depth_first {
635 if !cb(self) {
636 return;
637 }
638 }
639
640 match self {
642 Self::Unary(u) => {
643 u.operand().walk(depth_first, cb);
644 }
645 Self::Binary(b) => {
646 b.lhs().walk(depth_first, cb);
647 b.rhs().walk(depth_first, cb);
648 }
649 Self::Piecewise(p) => {
650 for (e, p) in p.iter_branches() {
651 e.walk(depth_first, cb);
652 p.walk(depth_first, cb)
653 }
654 p.else_branch().walk(depth_first, cb);
655 }
656
657 Self::Const(_) | Self::Var(_) => {}
659 }
660
661 if depth_first {
662 if !cb(self) {
663 return;
664 }
665 }
666 }
667 fn walk_mut(&mut self, depth_first: bool, cb: &mut impl FnMut(&mut NodeInner) -> bool) {
668 if !depth_first {
669 if !cb(self) {
670 return;
671 }
672 }
673
674 match self {
676 Self::Unary(u) => {
677 u.operand_mut().walk_mut(depth_first, cb);
678 }
679 Self::Binary(b) => {
680 b.lhs_mut().walk_mut(depth_first, cb);
681 b.rhs_mut().walk_mut(depth_first, cb);
682 }
683 Self::Piecewise(p) => {
684 for (e, p) in p.iter_branches_mut() {
685 e.walk_mut(depth_first, cb);
686 p.walk_mut(depth_first, cb)
687 }
688 p.else_branch_mut().walk_mut(depth_first, cb);
689 }
690
691 Self::Const(_) | Self::Var(_) => {}
693 }
694
695 if depth_first {
696 if !cb(self) {
697 return;
698 }
699 }
700 }
701
702 fn as_inner(&self) -> &NodeInner {
703 self
704 }
705
706 fn iter_children(&self) -> impl Iterator<Item = &NodeInner> {
707 match self {
708 Self::Const(_) | Self::Var(_) | Self::Piecewise(_) => {
709 [None, None].into_iter().flatten()
710 }
711 Self::Unary(u) => [Some(&u.operand().0.n), None].into_iter().flatten(),
712 Self::Binary(b) => [Some(&b.lhs().0.n), Some(&b.rhs().0.n)]
713 .into_iter()
714 .flatten(),
715 }
716 }
717
718 fn get<I: Iterator<Item = usize>>(&self, mut i: I) -> Option<&NodeInner> {
719 NodeInner::get(self, &mut i)
720 }
721 fn get_mut<I: Iterator<Item = usize>>(&mut self, mut i: I) -> Option<&mut NodeInner> {
722 NodeInner::get_mut(self, &mut i)
723 }
724 fn parsing_precedence(&self) -> Option<(bool, usize)> {
725 match self {
726 Self::Const(_) | Self::Var(_) | Self::Piecewise(_) => None,
727 Self::Unary(u) => u.op.parsing_precedence(),
728 Self::Binary(b) => b.op.parsing_precedence(),
729 }
730 }
731}
732
733impl From<Const> for NodeInner {
734 fn from(n: Const) -> Self {
735 Self::Const(n)
736 }
737}
738impl From<Binary> for NodeInner {
739 fn from(n: Binary) -> Self {
740 Self::Binary(n)
741 }
742}
743impl From<Unary> for NodeInner {
744 fn from(n: Unary) -> Self {
745 Self::Unary(n)
746 }
747}
748impl From<Var> for NodeInner {
749 fn from(n: Var) -> Self {
750 Self::Var(n)
751 }
752}
753impl From<Piecewise> for NodeInner {
754 fn from(p: Piecewise) -> Self {
755 Self::Piecewise(p)
756 }
757}
758impl From<Node> for NodeInner {
759 fn from(n: Node) -> Self {
760 Self::from(n.n)
761 }
762}
763
764impl fmt::Display for NodeInner {
765 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
766 f.write_str(&self.pretty_str(None))
767 }
768}
769
770#[cfg(test)]
771mod tests {
772 use super::*;
773
774 #[test]
775 fn parse_basic() {
776 assert_eq!(
777 Node::try_from("3 + 5"),
778 Ok(Node::new(
779 Binary::add::<TyValue, TyValue>(3.into(), 5.into()).into()
780 )),
781 );
782 assert_eq!(
783 Node::try_from("-5"),
784 Ok(Node::new(Unary::negate::<TyValue>(5.into()).into())),
785 );
786 assert_eq!(
787 Node::try_from("3--5"),
788 Ok(Node::new(
789 Binary::sub::<TyValue, HN>(
790 3.into(),
791 Node::new(Unary::negate::<TyValue>(5.into()).into()).into(),
792 )
793 .into()
794 )),
795 );
796 assert_eq!(
797 Node::try_from("3==5"),
798 Ok(Node::new(
799 Binary::equals::<TyValue, TyValue>(3.into(), 5.into()).into()
800 )),
801 );
802 assert_eq!(
803 Node::try_from("3 > 5"),
804 Ok(Node::new(
805 Binary::gt::<TyValue, TyValue>(3.into(), 5.into()).into()
806 )),
807 );
808 assert_eq!(
809 Node::try_from("5x"),
810 Ok(Node::new(
811 Binary::mul::<TyValue, HN>(
812 5.into(),
813 Node::new(Var::new_untyped("x").into()).into()
814 )
815 .into()
816 )),
817 );
818
819 assert_eq!(
820 Node::try_from("x ± 4 * y"),
821 Ok(Node::new(
822 Binary::plus_or_minus::<HN, HN>(
823 Node::new(Var::new_untyped("x").into()).into(),
824 Node::new(
825 Binary::mul::<TyValue, HN>(
826 4.into(),
827 Node::new(Var::new_untyped("y").into()).into(),
828 )
829 .into()
830 )
831 .into(),
832 )
833 .into()
834 )),
835 );
836
837 assert_eq!(
838 Node::try_from("sqrt(4)"),
839 Ok(Node::new(
840 Binary::root::<TyValue, TyValue>(4.into(), 2.into(),).into()
841 )),
842 );
843 assert_eq!(
844 Node::try_from("root(8, 3)"),
845 Ok(Node::new(
846 Binary::root::<TyValue, TyValue>(8.into(), 3.into(),).into()
847 )),
848 );
849 }
850
851 #[test]
852 fn parse_piecewise() {
853 assert_eq!(
854 Node::try_from("{2x if x == 0; otherwise x}"),
855 Ok(Node::new(NodeInner::from(Piecewise::new(
856 vec![(
857 Node::try_from("2x").unwrap().into(),
858 Node::try_from("x == 0").unwrap().into(),
859 )],
860 Node::try_from("x").unwrap().into(),
861 )))),
862 );
863 }
864
865 #[test]
866 fn fmt_basic() {
867 assert_eq!(
868 "3 * (a + b)",
869 format!(
870 "{}",
871 Node::new(
872 Binary::mul::<TyValue, HN>(
873 3.into(),
874 Node::new(
875 Binary::add::<HN, HN>(
876 Node::new(Var::new_untyped("a").into()).into(),
877 Node::new(Var::new_untyped("b").into()).into(),
878 )
879 .into()
880 )
881 .into(),
882 )
883 .into()
884 )
885 )
886 );
887 assert_eq!(
888 "(a + b) * 3",
889 format!(
890 "{}",
891 Node::new(
892 Binary::mul::<HN, TyValue>(
893 Node::new(
894 Binary::add::<HN, HN>(
895 Node::new(Var::new_untyped("a").into()).into(),
896 Node::new(Var::new_untyped("b").into()).into(),
897 )
898 .into()
899 )
900 .into(),
901 3.into(),
902 )
903 .into()
904 )
905 )
906 );
907 assert_eq!(
908 "3 + a * b",
909 format!(
910 "{}",
911 Node::new(
912 Binary::add::<TyValue, HN>(
913 3.into(),
914 Node::new(
915 Binary::mul::<HN, HN>(
916 Node::new(Var::new_untyped("a").into()).into(),
917 Node::new(Var::new_untyped("b").into()).into(),
918 )
919 .into()
920 )
921 .into(),
922 )
923 .into()
924 )
925 )
926 );
927 assert_eq!(
928 "a * b + 3",
929 format!(
930 "{}",
931 Node::new(
932 Binary::add::<HN, TyValue>(
933 Node::new(
934 Binary::mul::<HN, HN>(
935 Node::new(Var::new_untyped("a").into()).into(),
936 Node::new(Var::new_untyped("b").into()).into(),
937 )
938 .into()
939 )
940 .into(),
941 3.into(),
942 )
943 .into()
944 )
945 )
946 );
947
948 assert_eq!(
949 "3 - x",
950 format!(
951 "{}",
952 Node::new(
953 Binary::sub::<TyValue, HN>(
954 3.into(),
955 Node::new(Var::new_untyped("x").into()).into(),
956 )
957 .into()
958 )
959 )
960 );
961 assert_eq!(
962 "3 ± x",
963 format!(
964 "{}",
965 Node::new(
966 Binary::plus_or_minus::<TyValue, HN>(
967 3.into(),
968 Node::new(Var::new_untyped("x").into()).into(),
969 )
970 .into()
971 )
972 )
973 );
974
975 assert_eq!(
976 "3 - (-5)",
977 format!(
978 "{}",
979 Node::new(
980 Binary::sub::<TyValue, HN>(
981 3.into(),
982 Node::new(Unary::negate::<TyValue>(5.into()).into()).into(),
983 )
984 .into()
985 )
986 )
987 );
988
989 assert_eq!(
990 "3 - |5|",
991 format!(
992 "{}",
993 Node::new(
994 Binary::sub::<TyValue, HN>(
995 3.into(),
996 Node::new(Unary::abs::<TyValue>(5.into()).into()).into(),
997 )
998 .into()
999 )
1000 )
1001 );
1002
1003 assert_eq!(
1004 "{2x if x == 0; otherwise x}",
1005 format!(
1006 "{}",
1007 Node::new(NodeInner::from(Piecewise::new(
1008 vec![(
1009 Node::try_from("2x").unwrap().into(),
1010 Node::try_from("x == 0").unwrap().into(),
1011 )],
1012 Node::try_from("x").unwrap().into(),
1013 )))
1014 )
1015 );
1016 }
1017
1018 #[test]
1019 fn finite_eval_simple() {
1020 assert_eq!(
1021 Node::try_from("3.5 + 4.5").unwrap().finite_eval(&()),
1022 Ok(8.into()),
1023 );
1024 assert_eq!(
1025 Node::try_from("3 - 5").unwrap().finite_eval(&()),
1026 Ok((-2).into()),
1027 );
1028 assert_eq!(
1029 Node::try_from("9 - 3 * 2").unwrap().finite_eval(&()),
1030 Ok(3.into()),
1031 );
1032 assert_eq!(
1033 Node::try_from("root(8, 3) + sqrt(4)")
1034 .unwrap()
1035 .finite_eval(&()),
1036 Ok(4.into()),
1037 );
1038 assert_eq!(
1039 Node::try_from("min(2 + 1, max(4, 2))")
1040 .unwrap()
1041 .finite_eval(&()),
1042 Ok(3.into()),
1043 );
1044
1045 assert_eq!(
1046 Node::try_from("x").unwrap().finite_eval(&()),
1047 Err(EvalError::UnknownIdent("x".to_string()))
1048 );
1049 assert_eq!(
1050 Node::try_from("x")
1051 .unwrap()
1052 .finite_eval(&vec![("x", 69.into())]),
1053 Ok(69.into()),
1054 );
1055 }
1056
1057 #[test]
1058 fn finite_eval_piecewise() {
1059 assert_eq!(
1060 Node::try_from("{x if x > y; otherwise y}")
1061 .unwrap()
1062 .finite_eval(&vec![("x", 42.into()), ("y", 4.into())]),
1063 Ok(42.into()),
1064 );
1065 assert_eq!(
1066 Node::try_from("{x if x > y; otherwise y}")
1067 .unwrap()
1068 .finite_eval(&vec![("x", 1.into()), ("y", 4.into())]),
1069 Ok(4.into()),
1070 );
1071 }
1072
1073 #[test]
1074 fn eval_simple() {
1075 assert_eq!(
1076 Node::try_from("3.5 + 4.5")
1077 .unwrap()
1078 .eval(&())
1079 .unwrap()
1080 .collect::<Result<Vec<_>, _>>(),
1081 Ok(vec![8.into()]),
1082 );
1083 assert_eq!(
1084 Node::try_from("9 - 3 * 2")
1085 .unwrap()
1086 .eval(&())
1087 .unwrap()
1088 .collect::<Result<Vec<_>, _>>(),
1089 Ok(vec![3.into()]),
1090 );
1091
1092 assert_eq!(
1093 Node::try_from("5 ± 1")
1094 .unwrap()
1095 .eval(&())
1096 .unwrap()
1097 .collect::<Result<Vec<_>, _>>(),
1098 Ok(vec![6.into(), 4.into()]),
1099 );
1100 assert_eq!(
1101 Node::try_from("2 * (5 ± 1)")
1102 .unwrap()
1103 .eval(&())
1104 .unwrap()
1105 .collect::<Result<Vec<_>, _>>(),
1106 Ok(vec![12.into(), 8.into()]),
1107 );
1108
1109 assert_eq!(
1110 Node::try_from("-{0±x if x > y; otherwise y}")
1111 .unwrap()
1112 .eval(&vec![("x", 2.into()), ("y", 1.into())])
1113 .unwrap()
1114 .collect::<Result<Vec<_>, _>>(),
1115 Ok(vec![(-2).into(), 2.into()]),
1116 );
1117 }
1118
1119 #[test]
1120 fn interval_eval() {
1121 assert_eq!(
1122 Node::try_from("x - 2y")
1123 .unwrap()
1124 .eval_interval(&vec![
1125 ("x", (1.into(), 2.into())),
1126 ("y", (5.into(), 6.into()))
1127 ])
1128 .unwrap()
1129 .collect::<Result<Vec<_>, _>>(),
1130 Ok(vec![((-11).into(), (-8).into())]),
1131 );
1132
1133 for x in -8..=8 {
1135 for y in -8..=8 {
1136 assert!(
1137 Node::try_from("sqrt(pow(-2 - x, 2) + pow(3 - y, 2)) + abs(x+y)")
1138 .unwrap()
1139 .eval_interval(&vec![
1140 ("x", (x.into(), (x + 2).into())),
1141 ("y", (y.into(), (y + 2).into())),
1142 ])
1143 .unwrap()
1144 .collect::<Result<Vec<_>, _>>()
1145 .is_ok(),
1146 );
1147 }
1148 }
1149 }
1150
1151 #[test]
1152 fn get() {
1153 assert_eq!(
1154 Node::try_from("3 + 2*5").unwrap().get(vec![0].into_iter()),
1155 Some(&NodeInner::new_const(3)),
1156 );
1157 assert_eq!(
1158 Node::try_from("3 + 2*5")
1159 .unwrap()
1160 .get(vec![1, 0].into_iter()),
1161 Some(&NodeInner::new_const(2)),
1162 );
1163 assert_eq!(
1164 Node::try_from("3 + 2*5").unwrap().get(vec![1].into_iter()),
1165 Some(Node::try_from("2 * 5").unwrap().as_inner()),
1166 );
1167 assert_eq!(
1168 Node::try_from("3 + 2*5").unwrap().get(vec![].into_iter()),
1169 Some(Node::try_from("3 + 2 * 5").unwrap().as_inner()),
1170 );
1171
1172 assert_eq!(
1173 Node::try_from("3 + 2*5").unwrap().get(vec![99].into_iter()),
1174 None,
1175 );
1176
1177 let p: Node = NodeInner::from(Piecewise::new(
1179 vec![(
1180 Node::try_from("x").unwrap().into(),
1181 Node::try_from("x == 0").unwrap().into(),
1182 )],
1183 Node::try_from("0").unwrap().into(),
1184 ))
1185 .into();
1186 assert_eq!(p.get(vec![99].into_iter()), None);
1187 assert_eq!(
1188 p.get(vec![0].into_iter()),
1189 Some(Node::try_from("x").unwrap().as_inner()),
1190 );
1191 assert_eq!(
1192 p.get(vec![1].into_iter()),
1193 Some(Node::try_from("x == 0").unwrap().as_inner()),
1194 );
1195 assert_eq!(
1196 p.get(vec![2].into_iter()),
1197 Some(Node::try_from("0").unwrap().as_inner()),
1198 );
1199 }
1200}