1use egg::{
9 define_language, rewrite, AstSize, EGraph, Extractor, Id, RecExpr, Rewrite, Runner, Symbol,
10};
11
12use refrain_core::{Op, Pattern, Refrain, RefrainError, Result};
13
14define_language! {
15 pub enum RefrainLang {
17 "note" = Note([Id; 2]), "loop" = Loop([Id; 2]), "dy/dx" = Diff([Id; 2]), "quotient" = Quotient(Box<[Id]>),
21 "seq" = Seq(Box<[Id]>),
22 Num(u32),
23 Sym(Symbol),
24 }
25}
26
27fn rules() -> Vec<Rewrite<RefrainLang, ()>> {
29 vec![
30 rewrite!("loop-1-identity"; "(loop 1 ?x)" => "?x"),
31 rewrite!("seq-singleton-identity"; "(seq ?x)" => "?x"),
32 ]
33}
34
35pub struct Egraph {
36 rules: Vec<Rewrite<RefrainLang, ()>>,
37 node_limit: usize,
38 iter_limit: usize,
39}
40
41impl Egraph {
42 pub fn new() -> Self {
43 Self {
44 rules: rules(),
45 node_limit: 10_000,
46 iter_limit: 32,
47 }
48 }
49
50 pub fn with_limits(node_limit: usize, iter_limit: usize) -> Self {
51 Self {
52 rules: rules(),
53 node_limit,
54 iter_limit,
55 }
56 }
57
58 pub fn normalize(&self, r: &Refrain) -> Result<Refrain> {
59 let mut out = Refrain::new(r.name.clone());
60 out.territorialize = match &r.territorialize {
61 Some(p) => Some(self.normalize_pattern(p)?),
62 None => None,
63 };
64 out.deterritorialize = match &r.deterritorialize {
65 Some(p) => Some(self.normalize_pattern(p)?),
66 None => None,
67 };
68 out.reterritorialize = match &r.reterritorialize {
69 Some(p) => Some(self.normalize_pattern(p)?),
70 None => None,
71 };
72 Ok(out)
73 }
74
75 pub fn normalize_pattern(&self, p: &Pattern) -> Result<Pattern> {
76 let mut expr = RecExpr::default();
77 let _ = pattern_to_expr(p, &mut expr);
78 let runner: Runner<RefrainLang, ()> = Runner::default()
79 .with_node_limit(self.node_limit)
80 .with_iter_limit(self.iter_limit)
81 .with_expr(&expr);
82 let runner = runner.run(&self.rules);
83 let extractor = Extractor::new(&runner.egraph, AstSize);
84 let root_id = runner.roots[0];
85 let (_cost, best) = extractor.find_best(root_id);
86 expr_to_pattern(&best)
87 }
88}
89
90impl Default for Egraph {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96fn pattern_to_expr(p: &Pattern, b: &mut RecExpr<RefrainLang>) -> Id {
97 match p {
98 Pattern::Op(Op::Note { pitch, dur }) => {
99 let ps = b.add(RefrainLang::Sym(Symbol::from(pitch.as_str())));
100 let ds = b.add(RefrainLang::Sym(Symbol::from(dur.as_str())));
101 b.add(RefrainLang::Note([ps, ds]))
102 }
103 Pattern::Op(Op::Loop { count, body }) => {
104 let n = b.add(RefrainLang::Num(*count));
105 let body_id = pattern_to_expr(body, b);
106 b.add(RefrainLang::Loop([n, body_id]))
107 }
108 Pattern::Op(Op::Diff { x, t }) => {
109 let xs = b.add(RefrainLang::Sym(Symbol::from(x.as_str())));
110 let ts = b.add(RefrainLang::Sym(Symbol::from(t.as_str())));
111 b.add(RefrainLang::Diff([xs, ts]))
112 }
113 Pattern::Op(Op::Quotient { rels }) => {
114 let ids: Vec<Id> = rels
115 .iter()
116 .map(|s| b.add(RefrainLang::Sym(Symbol::from(s.as_str()))))
117 .collect();
118 b.add(RefrainLang::Quotient(ids.into_boxed_slice()))
119 }
120 Pattern::Op(Op::Sym(s)) => b.add(RefrainLang::Sym(Symbol::from(s.as_str()))),
121 Pattern::Op(Op::Call { head, args }) => {
122 let h = b.add(RefrainLang::Sym(Symbol::from(head.as_str())));
123 let mut ids = vec![h];
124 for a in args {
125 ids.push(pattern_to_expr(a, b));
126 }
127 b.add(RefrainLang::Seq(ids.into_boxed_slice()))
128 }
129 Pattern::Seq(items) => {
130 let ids: Vec<Id> = items.iter().map(|p| pattern_to_expr(p, b)).collect();
131 b.add(RefrainLang::Seq(ids.into_boxed_slice()))
132 }
133 }
134}
135
136fn expr_to_pattern(expr: &RecExpr<RefrainLang>) -> Result<Pattern> {
137 let nodes = expr.as_ref();
138 let root = Id::from(nodes.len() - 1);
139 node_to_pattern(nodes, root)
140}
141
142fn node_to_pattern(nodes: &[RefrainLang], id: Id) -> Result<Pattern> {
143 let n = &nodes[usize::from(id)];
144 match n {
145 RefrainLang::Note([p, d]) => {
146 let pitch = sym_at(nodes, *p)?.to_string();
147 let dur = sym_at(nodes, *d)?.to_string();
148 Ok(Pattern::Op(Op::Note { pitch, dur }))
149 }
150 RefrainLang::Loop([c, body]) => {
151 let count = num_at(nodes, *c)?;
152 let body_pat = node_to_pattern(nodes, *body)?;
153 Ok(Pattern::Op(Op::Loop {
154 count,
155 body: Box::new(body_pat),
156 }))
157 }
158 RefrainLang::Diff([x, t]) => {
159 let xs = sym_at(nodes, *x)?.to_string();
160 let ts = sym_at(nodes, *t)?.to_string();
161 Ok(Pattern::Op(Op::Diff { x: xs, t: ts }))
162 }
163 RefrainLang::Quotient(ids) => {
164 let mut rels = Vec::with_capacity(ids.len());
165 for i in ids.iter() {
166 rels.push(sym_at(nodes, *i)?.to_string());
167 }
168 Ok(Pattern::Op(Op::Quotient { rels }))
169 }
170 RefrainLang::Seq(ids) => {
171 let mut items = Vec::with_capacity(ids.len());
172 for i in ids.iter() {
173 items.push(node_to_pattern(nodes, *i)?);
174 }
175 if items.len() == 1 {
176 Ok(items.into_iter().next().unwrap())
177 } else {
178 Ok(Pattern::Seq(items))
179 }
180 }
181 RefrainLang::Sym(s) => Ok(Pattern::Op(Op::Sym(s.as_str().to_string()))),
182 RefrainLang::Num(n) => Ok(Pattern::Op(Op::Sym(n.to_string()))),
183 }
184}
185
186fn sym_at(nodes: &[RefrainLang], id: Id) -> Result<&str> {
187 match &nodes[usize::from(id)] {
188 RefrainLang::Sym(s) => Ok(s.as_str()),
189 other => Err(RefrainError::Rewrite(format!(
190 "expected symbol, got {:?}",
191 other
192 ))),
193 }
194}
195
196fn num_at(nodes: &[RefrainLang], id: Id) -> Result<u32> {
197 match &nodes[usize::from(id)] {
198 RefrainLang::Num(n) => Ok(*n),
199 other => Err(RefrainError::Rewrite(format!(
200 "expected number, got {:?}",
201 other
202 ))),
203 }
204}
205
206pub fn empty_egraph() -> EGraph<RefrainLang, ()> {
209 EGraph::default()
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use refrain_core::parse;
216
217 #[test]
218 fn loop_one_collapses_to_body() {
219 let r = parse("(refrain a (territorialize (loop 1 (note C4 q))))").unwrap();
220 let n = Egraph::default().normalize(&r).unwrap();
221 match n.territorialize.as_ref().unwrap() {
222 Pattern::Op(Op::Note { pitch, dur }) => {
223 assert_eq!(pitch, "C4");
224 assert_eq!(dur, "q");
225 }
226 other => panic!("expected Note, got {:?}", other),
227 }
228 }
229
230 #[test]
231 fn loop_two_stays() {
232 let r = parse("(refrain a (territorialize (loop 2 (note C4 q))))").unwrap();
233 let n = Egraph::default().normalize(&r).unwrap();
234 match n.territorialize.as_ref().unwrap() {
235 Pattern::Op(Op::Loop { count, .. }) => assert_eq!(*count, 2),
236 other => panic!("expected Loop, got {:?}", other),
237 }
238 }
239
240 #[test]
241 fn note_stays_unchanged() {
242 let r = parse("(refrain a (territorialize (note G4 e)))").unwrap();
243 let n = Egraph::default().normalize(&r).unwrap();
244 assert_eq!(n, r);
245 }
246
247 #[test]
248 fn diff_stays_unchanged() {
249 let r = parse("(refrain a (deterritorialize (dy/dx intensity time)))").unwrap();
250 let n = Egraph::default().normalize(&r).unwrap();
251 assert_eq!(n, r);
252 }
253
254 #[test]
255 fn quotient_stays_unchanged() {
256 let r = parse("(refrain a (reterritorialize (quotient ~a ~b)))").unwrap();
257 let n = Egraph::default().normalize(&r).unwrap();
258 assert_eq!(n, r);
259 }
260
261 #[test]
262 fn empty_refrain_normalizes_to_itself() {
263 let r = parse("(refrain empty)").unwrap();
264 let n = Egraph::default().normalize(&r).unwrap();
265 assert_eq!(n, r);
266 }
267
268 #[test]
269 fn normalize_is_idempotent() {
270 let r = parse("(refrain a (territorialize (loop 1 (loop 1 (note C4 q)))))").unwrap();
271 let n1 = Egraph::default().normalize(&r).unwrap();
272 let n2 = Egraph::default().normalize(&n1).unwrap();
273 assert_eq!(n1, n2);
274 }
275}