Skip to main content

refrain_egraph/
lib.rs

1//! refrain-egraph: equality-saturation normalization for Refrain ASTs.
2//!
3//! Uses the `egg` crate to define a `RefrainLang` term sort, applies a small
4//! set of Refrain-specific rewrite rules, runs the e-graph to fixpoint
5//! (bounded by a node limit and iteration cap), and extracts the
6//! lowest-cost representative under the `AstSize` cost model.
7
8use egg::{
9    define_language, rewrite, AstSize, EGraph, Extractor, Id, RecExpr, Rewrite, Runner, Symbol,
10};
11
12use refrain_core::{Op, Pattern, Refrain, RefrainError, Result};
13
14define_language! {
15    /// Term sort for Refrain ASTs inside egg's e-graph.
16    pub enum RefrainLang {
17        "note"     = Note([Id; 2]),     // [pitch_sym, dur_sym]
18        "loop"     = Loop([Id; 2]),     // [count_num, body]
19        "dy/dx"    = Diff([Id; 2]),     // [x_sym, t_sym]
20        "quotient" = Quotient(Box<[Id]>),
21        "seq"      = Seq(Box<[Id]>),
22        Num(u32),
23        Sym(Symbol),
24    }
25}
26
27/// The standard rewrite rule set for Refrain normalization.
28fn rules() -> Vec<Rewrite<RefrainLang, ()>> {
29    vec![
30        rewrite!("loop-1-identity"; "(loop 1 ?x)" => "?x"),
31        rewrite!("seq-singleton-identity"; "(seq ?x)" => "?x"),
32    ]
33}
34
35pub struct Egraph {
36    rules: Vec<Rewrite<RefrainLang, ()>>,
37    node_limit: usize,
38    iter_limit: usize,
39}
40
41impl Egraph {
42    pub fn new() -> Self {
43        Self {
44            rules: rules(),
45            node_limit: 10_000,
46            iter_limit: 32,
47        }
48    }
49
50    pub fn with_limits(node_limit: usize, iter_limit: usize) -> Self {
51        Self {
52            rules: rules(),
53            node_limit,
54            iter_limit,
55        }
56    }
57
58    pub fn normalize(&self, r: &Refrain) -> Result<Refrain> {
59        let mut out = Refrain::new(r.name.clone());
60        out.territorialize = match &r.territorialize {
61            Some(p) => Some(self.normalize_pattern(p)?),
62            None => None,
63        };
64        out.deterritorialize = match &r.deterritorialize {
65            Some(p) => Some(self.normalize_pattern(p)?),
66            None => None,
67        };
68        out.reterritorialize = match &r.reterritorialize {
69            Some(p) => Some(self.normalize_pattern(p)?),
70            None => None,
71        };
72        Ok(out)
73    }
74
75    pub fn normalize_pattern(&self, p: &Pattern) -> Result<Pattern> {
76        let mut expr = RecExpr::default();
77        let _ = pattern_to_expr(p, &mut expr);
78        let runner: Runner<RefrainLang, ()> = Runner::default()
79            .with_node_limit(self.node_limit)
80            .with_iter_limit(self.iter_limit)
81            .with_expr(&expr);
82        let runner = runner.run(&self.rules);
83        let extractor = Extractor::new(&runner.egraph, AstSize);
84        let root_id = runner.roots[0];
85        let (_cost, best) = extractor.find_best(root_id);
86        expr_to_pattern(&best)
87    }
88}
89
90impl Default for Egraph {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96fn pattern_to_expr(p: &Pattern, b: &mut RecExpr<RefrainLang>) -> Id {
97    match p {
98        Pattern::Op(Op::Note { pitch, dur }) => {
99            let ps = b.add(RefrainLang::Sym(Symbol::from(pitch.as_str())));
100            let ds = b.add(RefrainLang::Sym(Symbol::from(dur.as_str())));
101            b.add(RefrainLang::Note([ps, ds]))
102        }
103        Pattern::Op(Op::Loop { count, body }) => {
104            let n = b.add(RefrainLang::Num(*count));
105            let body_id = pattern_to_expr(body, b);
106            b.add(RefrainLang::Loop([n, body_id]))
107        }
108        Pattern::Op(Op::Diff { x, t }) => {
109            let xs = b.add(RefrainLang::Sym(Symbol::from(x.as_str())));
110            let ts = b.add(RefrainLang::Sym(Symbol::from(t.as_str())));
111            b.add(RefrainLang::Diff([xs, ts]))
112        }
113        Pattern::Op(Op::Quotient { rels }) => {
114            let ids: Vec<Id> = rels
115                .iter()
116                .map(|s| b.add(RefrainLang::Sym(Symbol::from(s.as_str()))))
117                .collect();
118            b.add(RefrainLang::Quotient(ids.into_boxed_slice()))
119        }
120        Pattern::Op(Op::Sym(s)) => b.add(RefrainLang::Sym(Symbol::from(s.as_str()))),
121        Pattern::Op(Op::Call { head, args }) => {
122            let h = b.add(RefrainLang::Sym(Symbol::from(head.as_str())));
123            let mut ids = vec![h];
124            for a in args {
125                ids.push(pattern_to_expr(a, b));
126            }
127            b.add(RefrainLang::Seq(ids.into_boxed_slice()))
128        }
129        Pattern::Seq(items) => {
130            let ids: Vec<Id> = items.iter().map(|p| pattern_to_expr(p, b)).collect();
131            b.add(RefrainLang::Seq(ids.into_boxed_slice()))
132        }
133    }
134}
135
136fn expr_to_pattern(expr: &RecExpr<RefrainLang>) -> Result<Pattern> {
137    let nodes = expr.as_ref();
138    let root = Id::from(nodes.len() - 1);
139    node_to_pattern(nodes, root)
140}
141
142fn node_to_pattern(nodes: &[RefrainLang], id: Id) -> Result<Pattern> {
143    let n = &nodes[usize::from(id)];
144    match n {
145        RefrainLang::Note([p, d]) => {
146            let pitch = sym_at(nodes, *p)?.to_string();
147            let dur = sym_at(nodes, *d)?.to_string();
148            Ok(Pattern::Op(Op::Note { pitch, dur }))
149        }
150        RefrainLang::Loop([c, body]) => {
151            let count = num_at(nodes, *c)?;
152            let body_pat = node_to_pattern(nodes, *body)?;
153            Ok(Pattern::Op(Op::Loop {
154                count,
155                body: Box::new(body_pat),
156            }))
157        }
158        RefrainLang::Diff([x, t]) => {
159            let xs = sym_at(nodes, *x)?.to_string();
160            let ts = sym_at(nodes, *t)?.to_string();
161            Ok(Pattern::Op(Op::Diff { x: xs, t: ts }))
162        }
163        RefrainLang::Quotient(ids) => {
164            let mut rels = Vec::with_capacity(ids.len());
165            for i in ids.iter() {
166                rels.push(sym_at(nodes, *i)?.to_string());
167            }
168            Ok(Pattern::Op(Op::Quotient { rels }))
169        }
170        RefrainLang::Seq(ids) => {
171            let mut items = Vec::with_capacity(ids.len());
172            for i in ids.iter() {
173                items.push(node_to_pattern(nodes, *i)?);
174            }
175            if items.len() == 1 {
176                Ok(items.into_iter().next().unwrap())
177            } else {
178                Ok(Pattern::Seq(items))
179            }
180        }
181        RefrainLang::Sym(s) => Ok(Pattern::Op(Op::Sym(s.as_str().to_string()))),
182        RefrainLang::Num(n) => Ok(Pattern::Op(Op::Sym(n.to_string()))),
183    }
184}
185
186fn sym_at(nodes: &[RefrainLang], id: Id) -> Result<&str> {
187    match &nodes[usize::from(id)] {
188        RefrainLang::Sym(s) => Ok(s.as_str()),
189        other => Err(RefrainError::Rewrite(format!(
190            "expected symbol, got {:?}",
191            other
192        ))),
193    }
194}
195
196fn num_at(nodes: &[RefrainLang], id: Id) -> Result<u32> {
197    match &nodes[usize::from(id)] {
198        RefrainLang::Num(n) => Ok(*n),
199        other => Err(RefrainError::Rewrite(format!(
200            "expected number, got {:?}",
201            other
202        ))),
203    }
204}
205
206/// Build a raw `EGraph` (without rules) for advanced use; the standard
207/// pipeline goes through `Egraph::normalize`.
208pub fn empty_egraph() -> EGraph<RefrainLang, ()> {
209    EGraph::default()
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use refrain_core::parse;
216
217    #[test]
218    fn loop_one_collapses_to_body() {
219        let r = parse("(refrain a (territorialize (loop 1 (note C4 q))))").unwrap();
220        let n = Egraph::default().normalize(&r).unwrap();
221        match n.territorialize.as_ref().unwrap() {
222            Pattern::Op(Op::Note { pitch, dur }) => {
223                assert_eq!(pitch, "C4");
224                assert_eq!(dur, "q");
225            }
226            other => panic!("expected Note, got {:?}", other),
227        }
228    }
229
230    #[test]
231    fn loop_two_stays() {
232        let r = parse("(refrain a (territorialize (loop 2 (note C4 q))))").unwrap();
233        let n = Egraph::default().normalize(&r).unwrap();
234        match n.territorialize.as_ref().unwrap() {
235            Pattern::Op(Op::Loop { count, .. }) => assert_eq!(*count, 2),
236            other => panic!("expected Loop, got {:?}", other),
237        }
238    }
239
240    #[test]
241    fn note_stays_unchanged() {
242        let r = parse("(refrain a (territorialize (note G4 e)))").unwrap();
243        let n = Egraph::default().normalize(&r).unwrap();
244        assert_eq!(n, r);
245    }
246
247    #[test]
248    fn diff_stays_unchanged() {
249        let r = parse("(refrain a (deterritorialize (dy/dx intensity time)))").unwrap();
250        let n = Egraph::default().normalize(&r).unwrap();
251        assert_eq!(n, r);
252    }
253
254    #[test]
255    fn quotient_stays_unchanged() {
256        let r = parse("(refrain a (reterritorialize (quotient ~a ~b)))").unwrap();
257        let n = Egraph::default().normalize(&r).unwrap();
258        assert_eq!(n, r);
259    }
260
261    #[test]
262    fn empty_refrain_normalizes_to_itself() {
263        let r = parse("(refrain empty)").unwrap();
264        let n = Egraph::default().normalize(&r).unwrap();
265        assert_eq!(n, r);
266    }
267
268    #[test]
269    fn normalize_is_idempotent() {
270        let r = parse("(refrain a (territorialize (loop 1 (loop 1 (note C4 q)))))").unwrap();
271        let n1 = Egraph::default().normalize(&r).unwrap();
272        let n2 = Egraph::default().normalize(&n1).unwrap();
273        assert_eq!(n1, n2);
274    }
275}