egglog_ast/
generic_ast_helpers.rs

1use std::fmt::{Display, Formatter};
2use std::hash::Hash;
3
4use ordered_float::OrderedFloat;
5
6use super::util::ListDisplay;
7use crate::generic_ast::*;
8use crate::span::Span;
9
10// Macro to implement From conversions for Literal types
11macro_rules! impl_from {
12    ($ctor:ident($t:ty)) => {
13        impl From<Literal> for $t {
14            fn from(literal: Literal) -> Self {
15                match literal {
16                    Literal::$ctor(t) => t,
17                    #[allow(unreachable_patterns)]
18                    _ => panic!("Expected {}, got {literal}", stringify!($ctor)),
19                }
20            }
21        }
22
23        impl From<$t> for Literal {
24            fn from(t: $t) -> Self {
25                Literal::$ctor(t)
26            }
27        }
28    };
29}
30
31impl<Head: Display, Leaf: Display> Display for GenericRule<Head, Leaf>
32where
33    Head: Clone + Display,
34    Leaf: Clone + PartialEq + Eq + Display + Hash,
35{
36    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
37        let indent = " ".repeat(7);
38        write!(f, "(rule (")?;
39        for (i, fact) in self.body.iter().enumerate() {
40            if i > 0 {
41                write!(f, "{}", indent)?;
42            }
43
44            if i != self.body.len() - 1 {
45                writeln!(f, "{}", fact)?;
46            } else {
47                write!(f, "{}", fact)?;
48            }
49        }
50        write!(f, ")\n      (")?;
51        for (i, action) in self.head.0.iter().enumerate() {
52            if i > 0 {
53                write!(f, "{}", indent)?;
54            }
55            if i != self.head.0.len() - 1 {
56                writeln!(f, "{}", action)?;
57            } else {
58                write!(f, "{}", action)?;
59            }
60        }
61        let ruleset = if !self.ruleset.is_empty() {
62            format!(":ruleset {}", self.ruleset)
63        } else {
64            "".into()
65        };
66        let name = if !self.name.is_empty() {
67            format!(":name \"{}\"", self.name)
68        } else {
69            "".into()
70        };
71        write!(f, ")\n{} {} {})", indent, ruleset, name)
72    }
73}
74
75// Use the macro for Int, Float, and String conversions
76impl_from!(Int(i64));
77impl_from!(Float(OrderedFloat<f64>));
78impl_from!(String(String));
79
80impl<Head: Display, Leaf: Display> Display for GenericFact<Head, Leaf> {
81    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
82        match self {
83            GenericFact::Eq(_, e1, e2) => write!(f, "(= {e1} {e2})"),
84            GenericFact::Fact(expr) => write!(f, "{expr}"),
85        }
86    }
87}
88
89// Implement Display for GenericAction
90impl<Head: Display, Leaf: Display> Display for GenericAction<Head, Leaf>
91where
92    Head: Clone + Display,
93    Leaf: Clone + PartialEq + Eq + Display + Hash,
94{
95    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
96        match self {
97            GenericAction::Let(_, lhs, rhs) => write!(f, "(let {} {})", lhs, rhs),
98            GenericAction::Set(_, lhs, args, rhs) => write!(
99                f,
100                "(set ({} {}) {})",
101                lhs,
102                args.iter()
103                    .map(|a| format!("{}", a))
104                    .collect::<Vec<_>>()
105                    .join(" "),
106                rhs
107            ),
108            GenericAction::Union(_, lhs, rhs) => write!(f, "(union {} {})", lhs, rhs),
109            GenericAction::Change(_, change, lhs, args) => {
110                let change_str = match change {
111                    Change::Delete => "delete",
112                    Change::Subsume => "subsume",
113                };
114                write!(
115                    f,
116                    "({} ({} {}))",
117                    change_str,
118                    lhs,
119                    args.iter()
120                        .map(|a| format!("{}", a))
121                        .collect::<Vec<_>>()
122                        .join(" ")
123                )
124            }
125            GenericAction::Panic(_, msg) => write!(f, "(panic \"{}\")", msg),
126            GenericAction::Expr(_, e) => write!(f, "{}", e),
127        }
128    }
129}
130
131impl<Head, Leaf> Display for GenericExpr<Head, Leaf>
132where
133    Head: Display,
134    Leaf: Display,
135{
136    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
137        match self {
138            GenericExpr::Lit(_ann, lit) => write!(f, "{lit}"),
139            GenericExpr::Var(_ann, var) => write!(f, "{var}"),
140            GenericExpr::Call(_ann, op, children) => {
141                write!(f, "({} {})", op, ListDisplay(children, " "))
142            }
143        }
144    }
145}
146
147impl<Head, Leaf> Default for GenericActions<Head, Leaf>
148where
149    Head: Clone + Display,
150    Leaf: Clone + PartialEq + Eq + Display + Hash,
151{
152    fn default() -> Self {
153        Self(vec![])
154    }
155}
156
157impl<Head, Leaf> GenericRule<Head, Leaf>
158where
159    Head: Clone + Display,
160    Leaf: Clone + PartialEq + Eq + Display + Hash,
161{
162    pub fn visit_exprs(
163        self,
164        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
165    ) -> Self {
166        Self {
167            span: self.span,
168            head: self.head.visit_exprs(f),
169            body: self
170                .body
171                .into_iter()
172                .map(|bexpr| bexpr.visit_exprs(f))
173                .collect(),
174            name: self.name.clone(),
175            ruleset: self.ruleset.clone(),
176        }
177    }
178}
179
180impl<Head, Leaf> GenericActions<Head, Leaf>
181where
182    Head: Clone + Display,
183    Leaf: Clone + PartialEq + Eq + Display + Hash,
184{
185    pub fn len(&self) -> usize {
186        self.0.len()
187    }
188
189    pub fn is_empty(&self) -> bool {
190        self.0.is_empty()
191    }
192
193    pub fn iter(&self) -> impl Iterator<Item = &GenericAction<Head, Leaf>> {
194        self.0.iter()
195    }
196
197    pub fn visit_exprs(
198        self,
199        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
200    ) -> Self {
201        Self(self.0.into_iter().map(|a| a.visit_exprs(f)).collect())
202    }
203
204    pub fn new(actions: Vec<GenericAction<Head, Leaf>>) -> Self {
205        Self(actions)
206    }
207
208    pub fn singleton(action: GenericAction<Head, Leaf>) -> Self {
209        Self(vec![action])
210    }
211}
212
213impl<Head, Leaf> GenericAction<Head, Leaf>
214where
215    Head: Clone + Display,
216    Leaf: Clone + Eq + Display + Hash,
217{
218    // Applys `f` to all expressions in the action.
219    pub fn map_exprs(
220        &self,
221        f: &mut impl FnMut(&GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
222    ) -> Self {
223        match self {
224            GenericAction::Let(span, lhs, rhs) => {
225                GenericAction::Let(span.clone(), lhs.clone(), f(rhs))
226            }
227            GenericAction::Set(span, lhs, args, rhs) => {
228                let right = f(rhs);
229                GenericAction::Set(
230                    span.clone(),
231                    lhs.clone(),
232                    args.iter().map(f).collect(),
233                    right,
234                )
235            }
236            GenericAction::Change(span, change, lhs, args) => GenericAction::Change(
237                span.clone(),
238                *change,
239                lhs.clone(),
240                args.iter().map(f).collect(),
241            ),
242            GenericAction::Union(span, lhs, rhs) => {
243                GenericAction::Union(span.clone(), f(lhs), f(rhs))
244            }
245            GenericAction::Panic(span, msg) => GenericAction::Panic(span.clone(), msg.clone()),
246            GenericAction::Expr(span, e) => GenericAction::Expr(span.clone(), f(e)),
247        }
248    }
249
250    /// Applys `f` to all sub-expressions (including `self`)
251    /// bottom-up, collecting the results.
252    pub fn visit_exprs(
253        self,
254        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
255    ) -> Self {
256        match self {
257            GenericAction::Let(span, lhs, rhs) => {
258                GenericAction::Let(span, lhs.clone(), rhs.visit_exprs(f))
259            }
260            // TODO should we refactor `Set` so that we can map over Expr::Call(lhs, args)?
261            // This seems more natural to oflatt
262            // Currently, visit_exprs does not apply f to the first argument of Set.
263            GenericAction::Set(span, lhs, args, rhs) => {
264                let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
265                GenericAction::Set(span, lhs.clone(), args, rhs.visit_exprs(f))
266            }
267            GenericAction::Change(span, change, lhs, args) => {
268                let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
269                GenericAction::Change(span, change, lhs.clone(), args)
270            }
271            GenericAction::Union(span, lhs, rhs) => {
272                GenericAction::Union(span, lhs.visit_exprs(f), rhs.visit_exprs(f))
273            }
274            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
275            GenericAction::Expr(span, e) => GenericAction::Expr(span, e.visit_exprs(f)),
276        }
277    }
278
279    pub fn subst(&self, subst: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head, Leaf>) -> Self {
280        self.map_exprs(&mut |e| e.subst_leaf(subst))
281    }
282
283    pub fn map_def_use(self, fvar: &mut impl FnMut(Leaf, bool) -> Leaf) -> Self {
284        macro_rules! fvar_expr {
285            () => {
286                |span, s: _| GenericExpr::Var(span.clone(), fvar(s.clone(), false))
287            };
288        }
289        match self {
290            GenericAction::Let(span, lhs, rhs) => {
291                let lhs = fvar(lhs, true);
292                let rhs = rhs.subst_leaf(&mut fvar_expr!());
293                GenericAction::Let(span, lhs, rhs)
294            }
295            GenericAction::Set(span, lhs, args, rhs) => {
296                let args = args
297                    .into_iter()
298                    .map(|e| e.subst_leaf(&mut fvar_expr!()))
299                    .collect();
300                let rhs = rhs.subst_leaf(&mut fvar_expr!());
301                GenericAction::Set(span, lhs.clone(), args, rhs)
302            }
303            GenericAction::Change(span, change, lhs, args) => {
304                let args = args
305                    .into_iter()
306                    .map(|e| e.subst_leaf(&mut fvar_expr!()))
307                    .collect();
308                GenericAction::Change(span, change, lhs.clone(), args)
309            }
310            GenericAction::Union(span, lhs, rhs) => {
311                let lhs = lhs.subst_leaf(&mut fvar_expr!());
312                let rhs = rhs.subst_leaf(&mut fvar_expr!());
313                GenericAction::Union(span, lhs, rhs)
314            }
315            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
316            GenericAction::Expr(span, e) => {
317                GenericAction::Expr(span, e.subst_leaf(&mut fvar_expr!()))
318            }
319        }
320    }
321}
322
323impl<Head, Leaf> GenericFact<Head, Leaf>
324where
325    Head: Clone + Display,
326    Leaf: Clone + PartialEq + Eq + Display + Hash,
327{
328    pub fn visit_exprs(
329        self,
330        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
331    ) -> GenericFact<Head, Leaf> {
332        match self {
333            GenericFact::Eq(span, e1, e2) => {
334                GenericFact::Eq(span, e1.visit_exprs(f), e2.visit_exprs(f))
335            }
336            GenericFact::Fact(expr) => GenericFact::Fact(expr.visit_exprs(f)),
337        }
338    }
339
340    pub fn map_exprs<Head2, Leaf2>(
341        &self,
342        f: &mut impl FnMut(&GenericExpr<Head, Leaf>) -> GenericExpr<Head2, Leaf2>,
343    ) -> GenericFact<Head2, Leaf2> {
344        match self {
345            GenericFact::Eq(span, e1, e2) => GenericFact::Eq(span.clone(), f(e1), f(e2)),
346            GenericFact::Fact(expr) => GenericFact::Fact(f(expr)),
347        }
348    }
349
350    pub fn subst<Leaf2, Head2>(
351        &self,
352        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head2, Leaf2>,
353        subst_head: &mut impl FnMut(&Head) -> Head2,
354    ) -> GenericFact<Head2, Leaf2> {
355        self.map_exprs(&mut |e| e.subst(subst_leaf, subst_head))
356    }
357}
358
359impl<Head, Leaf> GenericFact<Head, Leaf>
360where
361    Leaf: Clone + PartialEq + Eq + Display + Hash,
362    Head: Clone + Display,
363{
364    pub fn make_unresolved(self) -> GenericFact<String, String> {
365        self.subst(
366            &mut |span, v| GenericExpr::Var(span.clone(), v.to_string()),
367            &mut |h| h.to_string(),
368        )
369    }
370}
371
372impl<Head: Clone + Display, Leaf: Hash + Clone + Display + Eq> GenericExpr<Head, Leaf> {
373    pub fn span(&self) -> Span {
374        match self {
375            GenericExpr::Lit(span, _) => span.clone(),
376            GenericExpr::Var(span, _) => span.clone(),
377            GenericExpr::Call(span, _, _) => span.clone(),
378        }
379    }
380
381    pub fn is_var(&self) -> bool {
382        matches!(self, GenericExpr::Var(_, _))
383    }
384
385    pub fn get_var(&self) -> Option<Leaf> {
386        match self {
387            GenericExpr::Var(_ann, v) => Some(v.clone()),
388            _ => None,
389        }
390    }
391
392    fn children(&self) -> &[Self] {
393        match self {
394            GenericExpr::Var(_, _) | GenericExpr::Lit(_, _) => &[],
395            GenericExpr::Call(_, _, children) => children,
396        }
397    }
398
399    pub fn ast_size(&self) -> usize {
400        let mut size = 0;
401        self.walk(&mut |_e| size += 1, &mut |_| {});
402        size
403    }
404
405    pub fn walk(&self, pre: &mut impl FnMut(&Self), post: &mut impl FnMut(&Self)) {
406        pre(self);
407        self.children()
408            .iter()
409            .for_each(|child| child.walk(pre, post));
410        post(self);
411    }
412
413    pub fn fold<Out>(&self, f: &mut impl FnMut(&Self, Vec<Out>) -> Out) -> Out {
414        let ts = self.children().iter().map(|child| child.fold(f)).collect();
415        f(self, ts)
416    }
417
418    /// Applys `f` to all sub-expressions (including `self`)
419    /// bottom-up, collecting the results.
420    pub fn visit_exprs(self, f: &mut impl FnMut(Self) -> Self) -> Self {
421        match self {
422            GenericExpr::Lit(..) => f(self),
423            GenericExpr::Var(..) => f(self),
424            GenericExpr::Call(span, op, children) => {
425                let children = children.into_iter().map(|c| c.visit_exprs(f)).collect();
426                f(GenericExpr::Call(span, op.clone(), children))
427            }
428        }
429    }
430
431    /// `subst` replaces occurrences of variables and head symbols in the expression.
432    pub fn subst<Head2, Leaf2>(
433        &self,
434        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head2, Leaf2>,
435        subst_head: &mut impl FnMut(&Head) -> Head2,
436    ) -> GenericExpr<Head2, Leaf2> {
437        match self {
438            GenericExpr::Lit(span, lit) => GenericExpr::Lit(span.clone(), lit.clone()),
439            GenericExpr::Var(span, v) => subst_leaf(span, v),
440            GenericExpr::Call(span, op, children) => {
441                let children = children
442                    .iter()
443                    .map(|c| c.subst(subst_leaf, subst_head))
444                    .collect();
445                GenericExpr::Call(span.clone(), subst_head(op), children)
446            }
447        }
448    }
449
450    pub fn subst_leaf<Leaf2>(
451        &self,
452        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head, Leaf2>,
453    ) -> GenericExpr<Head, Leaf2> {
454        self.subst(subst_leaf, &mut |x| x.clone())
455    }
456
457    pub fn vars(&self) -> impl Iterator<Item = Leaf> + '_ {
458        let iterator: Box<dyn Iterator<Item = Leaf>> = match self {
459            GenericExpr::Lit(_ann, _l) => Box::new(std::iter::empty()),
460            GenericExpr::Var(_ann, v) => Box::new(std::iter::once(v.clone())),
461            GenericExpr::Call(_ann, _head, exprs) => Box::new(exprs.iter().flat_map(|e| e.vars())),
462        };
463        iterator
464    }
465}
466
467impl Display for Literal {
468    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        match &self {
470            Literal::Int(i) => Display::fmt(i, f),
471            Literal::Float(n) => {
472                // need to display with decimal if there is none
473                let str = n.to_string();
474                if let Ok(_num) = str.parse::<i64>() {
475                    write!(f, "{}.0", str)
476                } else {
477                    write!(f, "{}", str)
478                }
479            }
480            Literal::Bool(b) => Display::fmt(b, f),
481            Literal::String(s) => write!(f, "\"{}\"", s),
482            Literal::Unit => write!(f, "()"),
483        }
484    }
485}