minicas_core/ast/
ac_collect.rs1use crate::ast::{BinaryOp, NodeInner, Path};
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash)]
5pub enum AcError {
6 BadVariant,
7 NotAssociativeCommutative,
8 Internal,
9}
10
11fn ac_collect_recursive(
12 n: &NodeInner,
13 op: BinaryOp,
14 path: Path,
15 terms: &mut Vec<Path>,
16) -> Result<(), ()> {
17 match n {
18 NodeInner::Binary(b) => {
19 if b.op() == op {
20 let (mut lp, mut rp) = (path.clone(), path.clone());
21 lp.push(0);
22 ac_collect_recursive(b.lhs(), op, lp, terms)?;
23 rp.push(1);
24 ac_collect_recursive(b.rhs(), op, rp, terms)?;
25 } else {
26 let self_path = path.clone();
27 terms.push(self_path);
28 }
29 }
30
31 NodeInner::Const(_) | NodeInner::Var(_) | NodeInner::Unary(_) | NodeInner::Piecewise(_) => {
32 let self_path = path.clone();
33 terms.push(self_path);
34 }
35 }
36
37 Ok(())
38}
39
40pub fn ac_collect(n: &NodeInner, terms: &mut Vec<Path>) -> Result<(), AcError> {
43 if let Some(b) = n.as_binary() {
44 let op = b.op();
45 if !op.associative() || !op.commutative() {
46 return Err(AcError::NotAssociativeCommutative);
47 }
48
49 ac_collect_recursive(n, op, Path::default(), terms).map_err(|_| AcError::Internal)
50 } else {
51 Err(AcError::BadVariant)
52 }
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58 use crate::ast::Node;
59
60 #[test]
61 fn simple() {
62 let n = Node::try_from("5 + 3").unwrap();
63 let mut output = Vec::new();
64 assert_eq!(ac_collect(&n, &mut output), Ok(()));
65 assert_eq!(
66 output
67 .into_iter()
68 .map(|p| p.into())
69 .collect::<Vec<Vec<usize>>>(),
70 vec![vec![0], vec![1]]
71 );
72 }
73
74 #[test]
75 fn tri() {
76 let n = Node::try_from("5 + 3 + 2x").unwrap();
77 let mut output = Vec::new();
78 assert_eq!(ac_collect(&n, &mut output), Ok(()));
79 assert_eq!(
80 output
81 .into_iter()
82 .map(|p| p.into())
83 .collect::<Vec<Vec<usize>>>(),
84 vec![vec![0, 0], vec![0, 1], vec![1]]
85 );
86 }
87
88 #[test]
89 fn mul() {
90 let n = Node::try_from("(2 + 1) * (5 * 3)").unwrap();
91 let mut output = Vec::new();
92 assert_eq!(ac_collect(&n, &mut output), Ok(()));
93 assert_eq!(
94 output
95 .into_iter()
96 .map(|p| format!("{}", n.get(&mut p.iter()).unwrap()))
97 .collect::<Vec<_>>(),
98 vec!["2 + 1", "5", "3",],
99 );
100 }
101}