Skip to main content

minicas_core/ast/
ac_collect.rs

1use crate::ast::{BinaryOp, NodeInner, Path};
2
3/// Describes issues with collecting terms.
4#[derive(Debug, Clone, PartialEq, Eq, Hash)]
5pub enum AcError {
6    BadVariant,
7    NotAssociativeCommutative,
8    Internal,
9}
10
11fn ac_collect_recursive(
12    n: &NodeInner,
13    op: BinaryOp,
14    path: Path,
15    terms: &mut Vec<Path>,
16) -> Result<(), ()> {
17    match n {
18        NodeInner::Binary(b) => {
19            if b.op() == op {
20                let (mut lp, mut rp) = (path.clone(), path.clone());
21                lp.push(0);
22                ac_collect_recursive(b.lhs(), op, lp, terms)?;
23                rp.push(1);
24                ac_collect_recursive(b.rhs(), op, rp, terms)?;
25            } else {
26                let self_path = path.clone();
27                terms.push(self_path);
28            }
29        }
30
31        NodeInner::Const(_) | NodeInner::Var(_) | NodeInner::Unary(_) | NodeInner::Piecewise(_) => {
32            let self_path = path.clone();
33            terms.push(self_path);
34        }
35    }
36
37    Ok(())
38}
39
40/// Collects paths to all terms below this node which share the same
41/// associative + commutative operator.
42pub fn ac_collect(n: &NodeInner, terms: &mut Vec<Path>) -> Result<(), AcError> {
43    if let Some(b) = n.as_binary() {
44        let op = b.op();
45        if !op.associative() || !op.commutative() {
46            return Err(AcError::NotAssociativeCommutative);
47        }
48
49        ac_collect_recursive(n, op, Path::default(), terms).map_err(|_| AcError::Internal)
50    } else {
51        Err(AcError::BadVariant)
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58    use crate::ast::Node;
59
60    #[test]
61    fn simple() {
62        let n = Node::try_from("5 + 3").unwrap();
63        let mut output = Vec::new();
64        assert_eq!(ac_collect(&n, &mut output), Ok(()));
65        assert_eq!(
66            output
67                .into_iter()
68                .map(|p| p.into())
69                .collect::<Vec<Vec<usize>>>(),
70            vec![vec![0], vec![1]]
71        );
72    }
73
74    #[test]
75    fn tri() {
76        let n = Node::try_from("5 + 3 + 2x").unwrap();
77        let mut output = Vec::new();
78        assert_eq!(ac_collect(&n, &mut output), Ok(()));
79        assert_eq!(
80            output
81                .into_iter()
82                .map(|p| p.into())
83                .collect::<Vec<Vec<usize>>>(),
84            vec![vec![0, 0], vec![0, 1], vec![1]]
85        );
86    }
87
88    #[test]
89    fn mul() {
90        let n = Node::try_from("(2 + 1) * (5 * 3)").unwrap();
91        let mut output = Vec::new();
92        assert_eq!(ac_collect(&n, &mut output), Ok(()));
93        assert_eq!(
94            output
95                .into_iter()
96                .map(|p| format!("{}", n.get(&mut p.iter()).unwrap()))
97                .collect::<Vec<_>>(),
98            vec!["2 + 1", "5", "3",],
99        );
100    }
101}