egglog/ast/
mod.rs

1pub mod check_shadowing;
2pub mod desugar;
3mod expr;
4mod parse;
5pub mod remove_globals;
6
7use crate::core::{GenericAtom, GenericAtomTerm, HeadOrEq, Query, ResolvedCall};
8use crate::*;
9pub use expr::*;
10pub use parse::*;
11pub use symbol_table::GlobalSymbol as Symbol;
12
13#[derive(Clone, Debug)]
14/// The egglog internal representation of already compiled rules
15pub(crate) enum Ruleset {
16    /// Represents a ruleset with a set of rules.
17    /// Use an [`IndexMap`] to ensure egglog is deterministic.
18    /// Rules added to the [`IndexMap`] first apply their
19    /// actions first.
20    Rules(Symbol, IndexMap<Symbol, CompiledRule>),
21    /// A combined ruleset may contain other rulesets.
22    Combined(Symbol, Vec<Symbol>),
23}
24
25pub type NCommand = GenericNCommand<Symbol, Symbol>;
26/// [`ResolvedNCommand`] is another specialization of [`GenericNCommand`], which
27/// adds the type information to heads and leaves of commands.
28/// [`TypeInfo::typecheck_command`] turns an [`NCommand`] into a [`ResolvedNCommand`].
29pub(crate) type ResolvedNCommand = GenericNCommand<ResolvedCall, ResolvedVar>;
30
31/// A [`NCommand`] is a desugared [`Command`], where syntactic sugars
32/// like [`Command::Datatype`] and [`Command::Rewrite`]
33/// are eliminated.
34/// Most of the heavy lifting in egglog is done over [`NCommand`]s.
35///
36/// [`GenericNCommand`] is a generalization of [`NCommand`], like how [`GenericCommand`]
37/// is a generalization of [`Command`], allowing annotations over `Head` and `Leaf`.
38///
39/// TODO: The name "NCommand" used to denote normalized command, but this
40/// meaning is obsolete. A future PR should rename this type to something
41/// like "DCommand".
42#[derive(Debug, Clone, Eq, PartialEq, Hash)]
43pub enum GenericNCommand<Head, Leaf>
44where
45    Head: Clone + Display,
46    Leaf: Clone + PartialEq + Eq + Display + Hash,
47{
48    SetOption {
49        name: Symbol,
50        value: GenericExpr<Head, Leaf>,
51    },
52    Sort(
53        Span,
54        Symbol,
55        Option<(Symbol, Vec<GenericExpr<Symbol, Symbol>>)>,
56    ),
57    Function(GenericFunctionDecl<Head, Leaf>),
58    AddRuleset(Span, Symbol),
59    UnstableCombinedRuleset(Span, Symbol, Vec<Symbol>),
60    NormRule {
61        name: Symbol,
62        ruleset: Symbol,
63        rule: GenericRule<Head, Leaf>,
64    },
65    CoreAction(GenericAction<Head, Leaf>),
66    RunSchedule(GenericSchedule<Head, Leaf>),
67    PrintOverallStatistics,
68    Check(Span, Vec<GenericFact<Head, Leaf>>),
69    PrintTable(Span, Symbol, usize),
70    PrintSize(Span, Option<Symbol>),
71    Output {
72        span: Span,
73        file: String,
74        exprs: Vec<GenericExpr<Head, Leaf>>,
75    },
76    Push(usize),
77    Pop(Span, usize),
78    Fail(Span, Box<GenericNCommand<Head, Leaf>>),
79    Input {
80        span: Span,
81        name: Symbol,
82        file: String,
83    },
84}
85
86impl<Head, Leaf> GenericNCommand<Head, Leaf>
87where
88    Head: Clone + Display,
89    Leaf: Clone + PartialEq + Eq + Display + Hash,
90{
91    pub fn to_command(&self) -> GenericCommand<Head, Leaf> {
92        match self {
93            GenericNCommand::SetOption { name, value } => GenericCommand::SetOption {
94                name: *name,
95                value: value.clone(),
96            },
97            GenericNCommand::Sort(span, name, params) => {
98                GenericCommand::Sort(span.clone(), *name, params.clone())
99            }
100            // This is awkward for the three subtypes change
101            GenericNCommand::Function(f) => match f.subtype {
102                FunctionSubtype::Constructor => GenericCommand::Constructor {
103                    span: f.span.clone(),
104                    name: f.name,
105                    schema: f.schema.clone(),
106                    cost: f.cost,
107                    unextractable: f.unextractable,
108                },
109                FunctionSubtype::Relation => GenericCommand::Relation {
110                    span: f.span.clone(),
111                    name: f.name,
112                    inputs: f.schema.input.clone(),
113                },
114                FunctionSubtype::Custom => GenericCommand::Function {
115                    span: f.span.clone(),
116                    schema: f.schema.clone(),
117                    name: f.name,
118                    merge: f.merge.clone(),
119                },
120            },
121            GenericNCommand::AddRuleset(span, name) => {
122                GenericCommand::AddRuleset(span.clone(), *name)
123            }
124            GenericNCommand::UnstableCombinedRuleset(span, name, others) => {
125                GenericCommand::UnstableCombinedRuleset(span.clone(), *name, others.clone())
126            }
127            GenericNCommand::NormRule {
128                name,
129                ruleset,
130                rule,
131            } => GenericCommand::Rule {
132                name: *name,
133                ruleset: *ruleset,
134                rule: rule.clone(),
135            },
136            GenericNCommand::RunSchedule(schedule) => GenericCommand::RunSchedule(schedule.clone()),
137            GenericNCommand::PrintOverallStatistics => GenericCommand::PrintOverallStatistics,
138            GenericNCommand::CoreAction(action) => GenericCommand::Action(action.clone()),
139            GenericNCommand::Check(span, facts) => {
140                GenericCommand::Check(span.clone(), facts.clone())
141            }
142            GenericNCommand::PrintTable(span, name, n) => {
143                GenericCommand::PrintFunction(span.clone(), *name, *n)
144            }
145            GenericNCommand::PrintSize(span, name) => {
146                GenericCommand::PrintSize(span.clone(), *name)
147            }
148            GenericNCommand::Output { span, file, exprs } => GenericCommand::Output {
149                span: span.clone(),
150                file: file.to_string(),
151                exprs: exprs.clone(),
152            },
153            GenericNCommand::Push(n) => GenericCommand::Push(*n),
154            GenericNCommand::Pop(span, n) => GenericCommand::Pop(span.clone(), *n),
155            GenericNCommand::Fail(span, cmd) => {
156                GenericCommand::Fail(span.clone(), Box::new(cmd.to_command()))
157            }
158            GenericNCommand::Input { span, name, file } => GenericCommand::Input {
159                span: span.clone(),
160                name: *name,
161                file: file.clone(),
162            },
163        }
164    }
165
166    pub fn visit_exprs(
167        self,
168        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
169    ) -> Self {
170        match self {
171            GenericNCommand::SetOption { name, value } => GenericNCommand::SetOption {
172                name,
173                value: f(value.clone()),
174            },
175            GenericNCommand::Sort(span, name, params) => GenericNCommand::Sort(span, name, params),
176            GenericNCommand::Function(func) => GenericNCommand::Function(func.visit_exprs(f)),
177            GenericNCommand::AddRuleset(span, name) => GenericNCommand::AddRuleset(span, name),
178            GenericNCommand::UnstableCombinedRuleset(span, name, rulesets) => {
179                GenericNCommand::UnstableCombinedRuleset(span, name, rulesets)
180            }
181            GenericNCommand::NormRule {
182                name,
183                ruleset,
184                rule,
185            } => GenericNCommand::NormRule {
186                name,
187                ruleset,
188                rule: rule.visit_exprs(f),
189            },
190            GenericNCommand::RunSchedule(schedule) => {
191                GenericNCommand::RunSchedule(schedule.visit_exprs(f))
192            }
193            GenericNCommand::PrintOverallStatistics => GenericNCommand::PrintOverallStatistics,
194            GenericNCommand::CoreAction(action) => {
195                GenericNCommand::CoreAction(action.visit_exprs(f))
196            }
197            GenericNCommand::Check(span, facts) => GenericNCommand::Check(
198                span,
199                facts.into_iter().map(|fact| fact.visit_exprs(f)).collect(),
200            ),
201            GenericNCommand::PrintTable(span, name, n) => {
202                GenericNCommand::PrintTable(span, name, n)
203            }
204            GenericNCommand::PrintSize(span, name) => GenericNCommand::PrintSize(span, name),
205            GenericNCommand::Output { span, file, exprs } => GenericNCommand::Output {
206                span,
207                file,
208                exprs: exprs.into_iter().map(f).collect(),
209            },
210            GenericNCommand::Push(n) => GenericNCommand::Push(n),
211            GenericNCommand::Pop(span, n) => GenericNCommand::Pop(span, n),
212            GenericNCommand::Fail(span, cmd) => {
213                GenericNCommand::Fail(span, Box::new(cmd.visit_exprs(f)))
214            }
215            GenericNCommand::Input { span, name, file } => {
216                GenericNCommand::Input { span, name, file }
217            }
218        }
219    }
220}
221
222pub type Schedule = GenericSchedule<Symbol, Symbol>;
223pub(crate) type ResolvedSchedule = GenericSchedule<ResolvedCall, ResolvedVar>;
224
225#[derive(Debug, Clone, PartialEq, Eq, Hash)]
226pub enum GenericSchedule<Head, Leaf> {
227    Saturate(Span, Box<GenericSchedule<Head, Leaf>>),
228    Repeat(Span, usize, Box<GenericSchedule<Head, Leaf>>),
229    Run(Span, GenericRunConfig<Head, Leaf>),
230    Sequence(Span, Vec<GenericSchedule<Head, Leaf>>),
231}
232
233impl<Head, Leaf> GenericSchedule<Head, Leaf>
234where
235    Head: Clone + Display,
236    Leaf: Clone + PartialEq + Eq + Display + Hash,
237{
238    fn visit_exprs(
239        self,
240        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
241    ) -> Self {
242        match self {
243            GenericSchedule::Saturate(span, sched) => {
244                GenericSchedule::Saturate(span, Box::new(sched.visit_exprs(f)))
245            }
246            GenericSchedule::Repeat(span, size, sched) => {
247                GenericSchedule::Repeat(span, size, Box::new(sched.visit_exprs(f)))
248            }
249            GenericSchedule::Run(span, config) => GenericSchedule::Run(span, config.visit_exprs(f)),
250            GenericSchedule::Sequence(span, scheds) => GenericSchedule::Sequence(
251                span,
252                scheds.into_iter().map(|s| s.visit_exprs(f)).collect(),
253            ),
254        }
255    }
256}
257
258impl<Head: Display, Leaf: Display> Display for GenericSchedule<Head, Leaf> {
259    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
260        match self {
261            GenericSchedule::Saturate(_ann, sched) => write!(f, "(saturate {sched})"),
262            GenericSchedule::Repeat(_ann, size, sched) => write!(f, "(repeat {size} {sched})"),
263            GenericSchedule::Run(_ann, config) => write!(f, "{config}"),
264            GenericSchedule::Sequence(_ann, scheds) => {
265                write!(f, "(seq {})", ListDisplay(scheds, " "))
266            }
267        }
268    }
269}
270
271pub type Command = GenericCommand<Symbol, Symbol>;
272
273pub type Subsume = bool;
274
275#[derive(Debug, Clone, PartialEq, Eq)]
276pub enum Subdatatypes {
277    Variants(Vec<Variant>),
278    NewSort(Symbol, Vec<Expr>),
279}
280
281/// A [`Command`] is the top-level construct in egglog.
282/// It includes defining rules, declaring functions,
283/// adding to tables, and running rules (via a [`Schedule`]).
284#[derive(Debug, Clone)]
285pub enum GenericCommand<Head, Leaf>
286where
287    Head: Clone + Display,
288    Leaf: Clone + PartialEq + Eq + Display + Hash,
289{
290    /// Egglog supports several *experimental* options
291    /// that can be set using the `set-option` command.
292    ///
293    /// Options supported include:
294    /// - "interactive_mode" (default: false): when enabled, egglog prints "(done)" after each command, allowing an external
295    /// tool to know when each command has finished running.
296    SetOption {
297        name: Symbol,
298        value: GenericExpr<Head, Leaf>,
299    },
300
301    /// Create a new user-defined sort, which can then
302    /// be used in new [`Command::Function`] declarations.
303    /// The [`Command::Datatype`] command desugars directly to this command, with one [`Command::Function`]
304    /// per constructor.
305    /// The main use of this command (as opposed to using [`Command::Datatype`]) is for forward-declaring a sort for mutually-recursive datatypes.
306    ///
307    /// It can also be used to create
308    /// a container sort.
309    /// For example, here's how to make a sort for vectors
310    /// of some user-defined sort `Math`:
311    /// ```text
312    /// (sort MathVec (Vec Math))
313    /// ```
314    ///
315    /// Now `MathVec` can be used as an input or output sort.
316    Sort(Span, Symbol, Option<(Symbol, Vec<Expr>)>),
317
318    /// Egglog supports three types of functions
319    ///
320    /// A constructor models an egg-style user-defined datatype
321    /// It can only be defined through the `datatype`/`datatype*` command
322    /// or the `constructor` command
323    ///
324    /// A relation models a datalog-style mathematical relation
325    /// It can only be defined through the `relation` command
326    ///
327    /// A custom function is a dictionary
328    /// It can only be defined through the `function` command
329
330    /// The `datatype` command declares a user-defined datatype.
331    /// Datatypes can be unioned with [`Action::Union`] either
332    /// at the top level or in the actions of a rule.
333    /// This makes them equal in the implicit, global equality relation.
334
335    /// Example:
336    /// ```text
337    /// (datatype Math
338    ///   (Num i64)
339    ///   (Var String)
340    ///   (Add Math Math)
341    ///   (Mul Math Math))
342    /// ```
343
344    /// defines a simple `Math` datatype with variants for numbers, named variables, addition and multiplication.
345    ///
346    /// Datatypes desugar directly to a [`Command::Sort`] and a [`Command::Constructor`] for each constructor.
347    /// The code above becomes:
348    /// ```text
349    /// (sort Math)
350    /// (constructor Num (i64) Math)
351    /// (constructor Var (String) Math)
352    /// (constructor Add (Math Math) Math)
353    /// (constructor Mul (Math Math) Math)
354
355    /// Datatypes are also known as algebraic data types, tagged unions and sum types.
356    Datatype {
357        span: Span,
358        name: Symbol,
359        variants: Vec<Variant>,
360    },
361    Datatypes {
362        span: Span,
363        datatypes: Vec<(Span, Symbol, Subdatatypes)>,
364    },
365
366    /// The `constructor` command defines a new constructor for a user-defined datatype
367    /// Example:
368    /// ```text
369    /// (constructor Add (i64 i64) Math)
370    /// ```
371    ///
372    Constructor {
373        span: Span,
374        name: Symbol,
375        schema: Schema,
376        cost: Option<usize>,
377        unextractable: bool,
378    },
379
380    /// The `relation` command declares a named relation
381    /// Example:
382    /// ```text
383    /// (relation path (i64 i64))
384    /// (relation edge (i64 i64))
385    /// ```
386    Relation {
387        span: Span,
388        name: Symbol,
389        inputs: Vec<Symbol>,
390    },
391
392    /// The `function` command declare an egglog custom function, which is a database table with a
393    /// a functional dependency (also called a primary key) on its inputs to one output.
394    ///
395    /// ```text
396    /// (function <name:Ident> <schema:Schema> <cost:Cost>
397    ///        (:on_merge <List<Action>>)?
398    ///        (:merge <Expr>)?)
399    ///```
400    /// A function can have a `cost` for extraction.
401    ///
402    /// Finally, it can have a `merge` and `on_merge`, which are triggered when
403    /// the function dependency is violated.
404    /// In this case, the merge expression determines which of the two outputs
405    /// for the same input is used.
406    /// The `on_merge` actions are run after the merge expression is evaluated.
407    ///
408    /// Note that the `:merge` expression must be monotonic
409    /// for the behavior of the egglog program to be consistent and defined.
410    /// In other words, the merge function must define a lattice on the output of the function.
411    /// If values are merged in different orders, they should still result in the same output.
412    /// If the merge expression is not monotonic, the behavior can vary as
413    /// actions may be applied more than once with different results.
414    ///
415    /// ```text
416    /// (function LowerBound (Math) i64 :merge (max old new))
417    /// ```
418    ///
419    /// Specifically, a custom function can also have an EqSort output type:
420    ///
421    /// ```text
422    /// (function Add (i64 i64) Math)
423    /// ```
424    ///
425    /// All functions can be `set`
426    /// with [`Action::Set`].
427    ///
428    /// Output of a function, if being the EqSort type, can be unioned with [`Action::Union`]
429    /// with another datatype of the same `sort`.
430    ///
431    Function {
432        span: Span,
433        name: Symbol,
434        schema: Schema,
435        merge: Option<GenericExpr<Head, Leaf>>,
436    },
437
438    /// Using the `ruleset` command, defines a new
439    /// ruleset that can be added to in [`Command::Rule`]s.
440    /// Rulesets are used to group rules together
441    /// so that they can be run together in a [`Schedule`].
442    ///
443    /// Example:
444    /// Ruleset allows users to define a ruleset- a set of rules
445
446    /// ```text
447    /// (ruleset myrules)
448    /// (rule ((edge x y))
449    ///       ((path x y))
450    ///       :ruleset myrules)
451    /// (run myrules 2)
452    /// ```
453    AddRuleset(Span, Symbol),
454    /// Using the `combined-ruleset` command, construct another ruleset
455    /// which runs all the rules in the given rulesets.
456    /// This is useful for running multiple rulesets together.
457    /// The combined ruleset also inherits any rules added to the individual rulesets
458    /// after the combined ruleset is declared.
459    ///
460    /// Example:
461    /// ```text
462    /// (ruleset myrules1)
463    /// (rule ((edge x y))
464    ///       ((path x y))
465    ///      :ruleset myrules1)
466    /// (ruleset myrules2)
467    /// (rule ((path x y) (edge y z))
468    ///       ((path x z))
469    ///       :ruleset myrules2)
470    /// (combined-ruleset myrules-combined myrules1 myrules2)
471    UnstableCombinedRuleset(Span, Symbol, Vec<Symbol>),
472    /// ```text
473    /// (rule <body:List<Fact>> <head:List<Action>>)
474    /// ```
475
476    /// defines an egglog rule.
477    /// The rule matches a list of facts with respect to
478    /// the global database, and runs the list of actions
479    /// for each match.
480    /// The matches are done *modulo equality*, meaning
481    /// equal datatypes in the database are considered
482    /// equal.
483
484    /// Example:
485    /// ```text
486    /// (rule ((edge x y))
487    ///       ((path x y)))
488
489    /// (rule ((path x y) (edge y z))
490    ///       ((path x z)))
491    /// ```
492    Rule {
493        name: Symbol,
494        ruleset: Symbol,
495        rule: GenericRule<Head, Leaf>,
496    },
497    /// `rewrite` is syntactic sugar for a specific form of `rule`
498    /// which simply unions the left and right hand sides.
499    ///
500    /// Example:
501    /// ```text
502    /// (rewrite (Add a b)
503    ///          (Add b a))
504    /// ```
505    ///
506    /// Desugars to:
507    /// ```text
508    /// (rule ((= lhs (Add a b)))
509    ///       ((union lhs (Add b a))))
510    /// ```
511    ///
512    /// Additionally, additional facts can be specified
513    /// using a `:when` clause.
514    /// For example, the same rule can be run only
515    /// when `a` is zero:
516    ///
517    /// ```text
518    /// (rewrite (Add a b)
519    ///          (Add b a)
520    ///          :when ((= a (Num 0)))
521    /// ```
522    ///
523    /// Add the `:subsume` flag to cause the left hand side to be subsumed after matching, which means it can
524    /// no longer be matched in a rule, but can still be checked against (See [`Change`] for more details.)
525    ///
526    /// ```text
527    /// (rewrite (Mul a 2) (bitshift-left a 1) :subsume)
528    /// ```
529    ///
530    /// Desugars to:
531    /// ```text
532    /// (rule ((= lhs (Mul a 2)))
533    ///       ((union lhs (bitshift-left a 1))
534    ///        (subsume (Mul a 2))))
535    /// ```
536    Rewrite(Symbol, GenericRewrite<Head, Leaf>, Subsume),
537    /// Similar to [`Command::Rewrite`], but
538    /// generates two rules, one for each direction.
539    ///
540    /// Example:
541    /// ```text
542    /// (bi-rewrite (Mul (Var x) (Num 0))
543    ///             (Var x))
544    /// ```
545    ///
546    /// Becomes:
547    /// ```text
548    /// (rule ((= lhs (Mul (Var x) (Num 0))))
549    ///       ((union lhs (Var x))))
550    /// (rule ((= lhs (Var x)))
551    ///       ((union lhs (Mul (Var x) (Num 0)))))
552    /// ```
553    BiRewrite(Symbol, GenericRewrite<Head, Leaf>),
554    /// Perform an [`Action`] on the global database
555    /// (see documentation for [`Action`] for more details).
556    /// Example:
557    /// ```text
558    /// (let xplusone (Add (Var "x") (Num 1)))
559    /// ```
560    Action(GenericAction<Head, Leaf>),
561    /// Runs a [`Schedule`], which specifies
562    /// rulesets and the number of times to run them.
563    ///
564    /// Example:
565    /// ```text
566    /// (run-schedule
567    ///     (saturate my-ruleset-1)
568    ///     (run my-ruleset-2 4))
569    /// ```
570    ///
571    /// Runs `my-ruleset-1` until saturation,
572    /// then runs `my-ruleset-2` four times.
573    ///
574    /// See [`Schedule`] for more details.
575    RunSchedule(GenericSchedule<Head, Leaf>),
576    /// Print runtime statistics about rules
577    /// and rulesets so far.
578    PrintOverallStatistics,
579    // TODO provide simplify docs
580    Simplify {
581        span: Span,
582        expr: GenericExpr<Head, Leaf>,
583        schedule: GenericSchedule<Head, Leaf>,
584    },
585    /// The `query-extract` command runs a query,
586    /// extracting the result for each match that it finds.
587    /// For a simpler extraction command, use [`Action::Extract`] instead.
588    ///
589    /// Example:
590    /// ```text
591    /// (query-extract (Add a b))
592    /// ```
593    ///
594    /// Extracts every `Add` term in the database, once
595    /// for each class of equivalent `a` and `b`.
596    ///
597    /// The resulting datatype is chosen from the egraph
598    /// as the smallest term by size (taking into account
599    /// the `:cost` annotations for each constructor).
600    /// This cost does *not* take into account common sub-expressions.
601    /// For example, the following term has cost 5:
602    /// ```text
603    /// (Add
604    ///     (Num 1)
605    ///     (Num 1))
606    /// ```
607    ///
608    /// Under the hood, this command is implemented with the [`EGraph::extract`]
609    /// function.
610    QueryExtract {
611        span: Span,
612        variants: usize,
613        expr: GenericExpr<Head, Leaf>,
614    },
615    /// The `check` command checks that the given facts
616    /// match at least once in the current database.
617    /// The list of facts is matched in the same way a [`Command::Rule`] is matched.
618    ///
619    /// Example:
620
621    /// ```text
622    /// (check (= (+ 1 2) 3))
623    /// (check (<= 0 3) (>= 3 0))
624    /// (fail (check (= 1 2)))
625    /// ```
626
627    /// prints
628
629    /// ```text
630    /// [INFO ] Checked.
631    /// [INFO ] Checked.
632    /// [ERROR] Check failed
633    /// [INFO ] Command failed as expected.
634    /// ```
635    Check(Span, Vec<GenericFact<Head, Leaf>>),
636    /// Print out rows a given function, extracting each of the elements of the function.
637    /// Example:
638    /// ```text
639    /// (print-function Add 20)
640    /// ```
641    /// prints the first 20 rows of the `Add` function.
642    ///
643    PrintFunction(Span, Symbol, usize),
644    /// Print out the number of rows in a function or all functions.
645    PrintSize(Span, Option<Symbol>),
646    /// Input a CSV file directly into a function.
647    Input {
648        span: Span,
649        name: Symbol,
650        file: String,
651    },
652    /// Extract and output a set of expressions to a file.
653    Output {
654        span: Span,
655        file: String,
656        exprs: Vec<GenericExpr<Head, Leaf>>,
657    },
658    /// `push` the current egraph `n` times so that it is saved.
659    /// Later, the current database and rules can be restored using `pop`.
660    Push(usize),
661    /// `pop` the current egraph, restoring the previous one.
662    /// The argument specifies how many egraphs to pop.
663    Pop(Span, usize),
664    /// Assert that a command fails with an error.
665    Fail(Span, Box<GenericCommand<Head, Leaf>>),
666    /// Include another egglog file directly as text and run it.
667    Include(Span, String),
668}
669
670impl<Head, Leaf> Display for GenericCommand<Head, Leaf>
671where
672    Head: Clone + Display,
673    Leaf: Clone + PartialEq + Eq + Display + Hash,
674{
675    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
676        match self {
677            GenericCommand::SetOption { name, value } => write!(f, "(set-option {name} {value})"),
678            GenericCommand::Rewrite(name, rewrite, subsume) => {
679                rewrite.fmt_with_ruleset(f, *name, false, *subsume)
680            }
681            GenericCommand::BiRewrite(name, rewrite) => {
682                rewrite.fmt_with_ruleset(f, *name, true, false)
683            }
684            GenericCommand::Datatype {
685                span: _,
686                name,
687                variants,
688            } => write!(f, "(datatype {name} {})", ListDisplay(variants, " ")),
689            GenericCommand::Action(a) => write!(f, "{a}"),
690            GenericCommand::Sort(_span, name, None) => write!(f, "(sort {name})"),
691            GenericCommand::Sort(_span, name, Some((name2, args))) => {
692                write!(f, "(sort {name} ({name2} {}))", ListDisplay(args, " "))
693            }
694            GenericCommand::Function {
695                span: _,
696                name,
697                schema,
698                merge,
699            } => {
700                write!(f, "(function {name} {schema}")?;
701                if let Some(merge) = &merge {
702                    write!(f, " :merge {merge}")?;
703                } else {
704                    write!(f, " :no-merge")?;
705                }
706                write!(f, ")")
707            }
708            GenericCommand::Constructor {
709                span: _,
710                name,
711                schema,
712                cost,
713                unextractable,
714            } => {
715                write!(f, "(constructor {name} {schema}")?;
716                if let Some(cost) = cost {
717                    write!(f, " :cost {cost}")?;
718                }
719                if *unextractable {
720                    write!(f, " :unextractable")?;
721                }
722                write!(f, ")")
723            }
724            GenericCommand::Relation {
725                span: _,
726                name,
727                inputs,
728            } => {
729                write!(f, "(relation {name} ({}))", ListDisplay(inputs, " "))
730            }
731            GenericCommand::AddRuleset(_span, name) => write!(f, "(ruleset {name})"),
732            GenericCommand::UnstableCombinedRuleset(_span, name, others) => {
733                write!(
734                    f,
735                    "(unstable-combined-ruleset {name} {})",
736                    ListDisplay(others, " ")
737                )
738            }
739            GenericCommand::Rule {
740                ruleset,
741                name,
742                rule,
743            } => rule.fmt_with_ruleset(f, *ruleset, *name),
744            GenericCommand::RunSchedule(sched) => write!(f, "(run-schedule {sched})"),
745            GenericCommand::PrintOverallStatistics => write!(f, "(print-stats)"),
746            GenericCommand::QueryExtract {
747                span: _,
748                variants,
749                expr,
750            } => {
751                write!(f, "(query-extract :variants {variants} {expr})")
752            }
753            GenericCommand::Check(_ann, facts) => {
754                write!(f, "(check {})", ListDisplay(facts, "\n"))
755            }
756            GenericCommand::Push(n) => write!(f, "(push {n})"),
757            GenericCommand::Pop(_span, n) => write!(f, "(pop {n})"),
758            GenericCommand::PrintFunction(_span, name, n) => {
759                write!(f, "(print-function {name} {n})")
760            }
761            GenericCommand::PrintSize(_span, name) => {
762                write!(f, "(print-size {})", ListDisplay(name, " "))
763            }
764            GenericCommand::Input {
765                span: _,
766                name,
767                file,
768            } => write!(f, "(input {name} {file:?})"),
769            GenericCommand::Output {
770                span: _,
771                file,
772                exprs,
773            } => write!(f, "(output {file:?} {})", ListDisplay(exprs, " ")),
774            GenericCommand::Fail(_span, cmd) => write!(f, "(fail {cmd})"),
775            GenericCommand::Include(_span, file) => write!(f, "(include {file:?})"),
776            GenericCommand::Simplify {
777                span: _,
778                expr,
779                schedule,
780            } => write!(f, "(simplify {schedule} {expr})"),
781            GenericCommand::Datatypes { span: _, datatypes } => {
782                let datatypes: Vec<_> = datatypes
783                    .iter()
784                    .map(|(_, name, variants)| match variants {
785                        Subdatatypes::Variants(variants) => {
786                            format!("({name} {})", ListDisplay(variants, " "))
787                        }
788                        Subdatatypes::NewSort(head, args) => {
789                            format!("(sort {name} ({head} {}))", ListDisplay(args, " "))
790                        }
791                    })
792                    .collect();
793                write!(f, "(datatype* {})", ListDisplay(datatypes, " "))
794            }
795        }
796    }
797}
798
799#[derive(Clone, Debug, PartialEq, Eq, Hash)]
800pub struct IdentSort {
801    pub ident: Symbol,
802    pub sort: Symbol,
803}
804
805impl Display for IdentSort {
806    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
807        write!(f, "({} {})", self.ident, self.sort)
808    }
809}
810
811pub type RunConfig = GenericRunConfig<Symbol, Symbol>;
812pub(crate) type ResolvedRunConfig = GenericRunConfig<ResolvedCall, ResolvedVar>;
813
814#[derive(Clone, Debug, PartialEq, Eq, Hash)]
815pub struct GenericRunConfig<Head, Leaf> {
816    pub ruleset: Symbol,
817    pub until: Option<Vec<GenericFact<Head, Leaf>>>,
818}
819
820impl<Head, Leaf> GenericRunConfig<Head, Leaf>
821where
822    Head: Clone + Display,
823    Leaf: Clone + PartialEq + Eq + Display + Hash,
824{
825    pub fn visit_exprs(
826        self,
827        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
828    ) -> Self {
829        Self {
830            ruleset: self.ruleset,
831            until: self
832                .until
833                .map(|until| until.into_iter().map(|fact| fact.visit_exprs(f)).collect()),
834        }
835    }
836}
837
838impl<Head: Display, Leaf: Display> Display for GenericRunConfig<Head, Leaf>
839where
840    Head: Display,
841    Leaf: Display,
842{
843    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
844        write!(f, "(run")?;
845        if self.ruleset != "".into() {
846            write!(f, " {}", self.ruleset)?;
847        }
848        if let Some(until) = &self.until {
849            write!(f, " :until {}", ListDisplay(until, " "))?;
850        }
851        write!(f, ")")
852    }
853}
854
855pub type FunctionDecl = GenericFunctionDecl<Symbol, Symbol>;
856pub(crate) type ResolvedFunctionDecl = GenericFunctionDecl<ResolvedCall, ResolvedVar>;
857
858#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
859pub enum FunctionSubtype {
860    Constructor,
861    Relation,
862    Custom,
863}
864
865impl Display for FunctionSubtype {
866    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
867        match self {
868            FunctionSubtype::Constructor => write!(f, "Constructor"),
869            FunctionSubtype::Relation => write!(f, "Relation"),
870            FunctionSubtype::Custom => write!(f, "CustomFunction"),
871        }
872    }
873}
874
875/// Represents the declaration of a function
876/// directly parsed from source syntax.
877#[derive(Clone, Debug, PartialEq, Eq, Hash)]
878pub struct GenericFunctionDecl<Head, Leaf>
879where
880    Head: Clone + Display,
881    Leaf: Clone + PartialEq + Eq + Display + Hash,
882{
883    pub name: Symbol,
884    pub subtype: FunctionSubtype,
885    pub schema: Schema,
886    pub merge: Option<GenericExpr<Head, Leaf>>,
887    pub cost: Option<usize>,
888    pub unextractable: bool,
889    /// Globals are desugared to functions, with this flag set to true.
890    /// This is used by visualization to handle globals differently.
891    pub ignore_viz: bool,
892    pub span: Span,
893}
894
895#[derive(Clone, Debug, PartialEq, Eq, Hash)]
896pub struct Variant {
897    pub span: Span,
898    pub name: Symbol,
899    pub types: Vec<Symbol>,
900    pub cost: Option<usize>,
901}
902
903impl Display for Variant {
904    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
905        write!(f, "({}", self.name)?;
906        if !self.types.is_empty() {
907            write!(f, " {}", ListDisplay(&self.types, " "))?;
908        }
909        if let Some(cost) = self.cost {
910            write!(f, " :cost {cost}")?;
911        }
912        write!(f, ")")
913    }
914}
915
916#[derive(Clone, Debug, PartialEq, Eq, Hash)]
917pub struct Schema {
918    pub input: Vec<Symbol>,
919    pub output: Symbol,
920}
921
922impl Display for Schema {
923    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
924        write!(f, "({}) {}", ListDisplay(&self.input, " "), self.output)
925    }
926}
927
928impl Schema {
929    pub fn new(input: Vec<Symbol>, output: Symbol) -> Self {
930        Self { input, output }
931    }
932}
933
934impl FunctionDecl {
935    pub fn function(
936        span: Span,
937        name: Symbol,
938        schema: Schema,
939        merge: Option<GenericExpr<Symbol, Symbol>>,
940    ) -> Self {
941        Self {
942            name,
943            subtype: FunctionSubtype::Custom,
944            schema,
945            merge,
946            cost: None,
947            unextractable: true,
948            ignore_viz: false,
949            span,
950        }
951    }
952
953    pub fn constructor(
954        span: Span,
955        name: Symbol,
956        schema: Schema,
957        cost: Option<usize>,
958        unextractable: bool,
959    ) -> Self {
960        Self {
961            name,
962            subtype: FunctionSubtype::Constructor,
963            schema,
964            merge: None,
965            cost,
966            unextractable,
967            ignore_viz: false,
968            span,
969        }
970    }
971
972    pub fn relation(span: Span, name: Symbol, input: Vec<Symbol>) -> Self {
973        Self {
974            name,
975            subtype: FunctionSubtype::Relation,
976            schema: Schema {
977                input,
978                output: Symbol::from("Unit"),
979            },
980            merge: None,
981            cost: None,
982            unextractable: true,
983            ignore_viz: false,
984            span,
985        }
986    }
987}
988
989impl<Head, Leaf> GenericFunctionDecl<Head, Leaf>
990where
991    Head: Clone + Display,
992    Leaf: Clone + PartialEq + Eq + Display + Hash,
993{
994    pub fn visit_exprs(
995        self,
996        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
997    ) -> GenericFunctionDecl<Head, Leaf> {
998        GenericFunctionDecl {
999            name: self.name,
1000            subtype: self.subtype,
1001            schema: self.schema,
1002            merge: self.merge.map(|expr| expr.visit_exprs(f)),
1003            cost: self.cost,
1004            unextractable: self.unextractable,
1005            ignore_viz: self.ignore_viz,
1006            span: self.span,
1007        }
1008    }
1009}
1010
1011pub type Fact = GenericFact<Symbol, Symbol>;
1012pub(crate) type ResolvedFact = GenericFact<ResolvedCall, ResolvedVar>;
1013pub(crate) type MappedFact<Head, Leaf> = GenericFact<CorrespondingVar<Head, Leaf>, Leaf>;
1014
1015/// Facts are the left-hand side of a [`Command::Rule`].
1016/// They represent a part of a database query.
1017/// Facts can be expressions or equality constraints between expressions.
1018///
1019/// Note that primitives such as  `!=` are partial.
1020/// When two things are equal, it returns nothing and the query does not match.
1021/// For example, the following egglog code runs:
1022/// ```text
1023/// (fail (check (!= 1 1)))
1024/// ```
1025#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1026pub enum GenericFact<Head, Leaf> {
1027    Eq(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
1028    Fact(GenericExpr<Head, Leaf>),
1029}
1030
1031pub struct Facts<Head, Leaf>(pub Vec<GenericFact<Head, Leaf>>);
1032
1033impl<Head, Leaf> Facts<Head, Leaf>
1034where
1035    Head: Clone + Display,
1036    Leaf: Clone + PartialEq + Eq + Display + Hash,
1037{
1038    /// Flattens a list of facts into a Query.
1039    /// For typechecking, we need the correspondence between the original ast
1040    /// and the flattened one, so that we can annotate the original with types.
1041    /// That's why this function produces a corresponding list of facts, annotated with
1042    /// the variable names in the flattened Query.
1043    /// (Typechecking preserves the original AST this way,
1044    /// and allows terms and proof instrumentation to do the same).
1045    pub(crate) fn to_query(
1046        &self,
1047        typeinfo: &TypeInfo,
1048        fresh_gen: &mut impl FreshGen<Head, Leaf>,
1049    ) -> (Query<HeadOrEq<Head>, Leaf>, Vec<MappedFact<Head, Leaf>>)
1050    where
1051        Leaf: SymbolLike,
1052    {
1053        let mut atoms = vec![];
1054        let mut new_body = vec![];
1055
1056        for fact in self.0.iter() {
1057            match fact {
1058                GenericFact::Eq(span, e1, e2) => {
1059                    let mut to_equate = vec![];
1060                    let mut process = |expr: &GenericExpr<Head, Leaf>| {
1061                        let (child_atoms, expr) = expr.to_query(typeinfo, fresh_gen);
1062                        atoms.extend(child_atoms);
1063                        to_equate.push(expr.get_corresponding_var_or_lit(typeinfo));
1064                        expr
1065                    };
1066                    let e1 = process(e1);
1067                    let e2 = process(e2);
1068                    atoms.push(GenericAtom {
1069                        span: span.clone(),
1070                        head: HeadOrEq::Eq,
1071                        args: to_equate,
1072                    });
1073                    new_body.push(GenericFact::Eq(span.clone(), e1, e2));
1074                }
1075                GenericFact::Fact(expr) => {
1076                    let (child_atoms, expr) = expr.to_query(typeinfo, fresh_gen);
1077                    atoms.extend(child_atoms);
1078                    new_body.push(GenericFact::Fact(expr));
1079                }
1080            }
1081        }
1082        (Query { atoms }, new_body)
1083    }
1084}
1085
1086impl<Head: Display, Leaf: Display> Display for GenericFact<Head, Leaf> {
1087    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1088        match self {
1089            GenericFact::Eq(_, e1, e2) => write!(f, "(= {e1} {e2})"),
1090            GenericFact::Fact(expr) => write!(f, "{expr}"),
1091        }
1092    }
1093}
1094
1095impl<Head, Leaf> GenericFact<Head, Leaf>
1096where
1097    Head: Clone + Display,
1098    Leaf: Clone + PartialEq + Eq + Display + Hash,
1099{
1100    pub(crate) fn visit_exprs(
1101        self,
1102        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
1103    ) -> GenericFact<Head, Leaf> {
1104        match self {
1105            GenericFact::Eq(span, e1, e2) => {
1106                GenericFact::Eq(span, e1.visit_exprs(f), e2.visit_exprs(f))
1107            }
1108            GenericFact::Fact(expr) => GenericFact::Fact(expr.visit_exprs(f)),
1109        }
1110    }
1111
1112    pub(crate) fn map_exprs<Head2, Leaf2>(
1113        &self,
1114        f: &mut impl FnMut(&GenericExpr<Head, Leaf>) -> GenericExpr<Head2, Leaf2>,
1115    ) -> GenericFact<Head2, Leaf2> {
1116        match self {
1117            GenericFact::Eq(span, e1, e2) => GenericFact::Eq(span.clone(), f(e1), f(e2)),
1118            GenericFact::Fact(expr) => GenericFact::Fact(f(expr)),
1119        }
1120    }
1121
1122    pub(crate) fn subst<Leaf2, Head2>(
1123        &self,
1124        subst_leaf: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head2, Leaf2>,
1125        subst_head: &mut impl FnMut(&Head) -> Head2,
1126    ) -> GenericFact<Head2, Leaf2> {
1127        self.map_exprs(&mut |e| e.subst(subst_leaf, subst_head))
1128    }
1129}
1130
1131impl<Head, Leaf> GenericFact<Head, Leaf>
1132where
1133    Leaf: Clone + PartialEq + Eq + Display + Hash,
1134    Head: Clone + Display,
1135{
1136    pub(crate) fn make_unresolved(self) -> GenericFact<Symbol, Symbol>
1137    where
1138        Leaf: SymbolLike,
1139        Head: SymbolLike,
1140    {
1141        self.subst(
1142            &mut |span, v| GenericExpr::Var(span.clone(), v.to_symbol()),
1143            &mut |h| h.to_symbol(),
1144        )
1145    }
1146}
1147
1148#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1149pub struct CorrespondingVar<Head, Leaf>
1150where
1151    Head: Clone + Display,
1152    Leaf: Clone + PartialEq + Eq + Display + Hash,
1153{
1154    pub head: Head,
1155    pub to: Leaf,
1156}
1157
1158impl<Head, Leaf> CorrespondingVar<Head, Leaf>
1159where
1160    Head: Clone + Display,
1161    Leaf: Clone + PartialEq + Eq + Display + Hash,
1162{
1163    pub fn new(head: Head, leaf: Leaf) -> Self {
1164        Self { head, to: leaf }
1165    }
1166}
1167
1168impl<Head, Leaf> Display for CorrespondingVar<Head, Leaf>
1169where
1170    Head: Clone + Display,
1171    Leaf: Clone + PartialEq + Eq + Display + Hash,
1172{
1173    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1174        write!(f, "{} -> {}", self.head, self.to)
1175    }
1176}
1177
1178/// Change a function entry.
1179#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
1180pub enum Change {
1181    /// `delete` this entry from a function.
1182    /// Be wary! Only delete entries that are guaranteed to be not useful.
1183    Delete,
1184    /// `subsume` this entry so that it cannot be queried or extracted, but still can be checked.
1185    /// Note that this is currently forbidden for functions with custom merges.
1186    Subsume,
1187}
1188
1189pub type Action = GenericAction<Symbol, Symbol>;
1190pub(crate) type MappedAction = GenericAction<CorrespondingVar<Symbol, Symbol>, Symbol>;
1191pub(crate) type ResolvedAction = GenericAction<ResolvedCall, ResolvedVar>;
1192
1193#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1194pub enum GenericAction<Head, Leaf>
1195where
1196    Head: Clone + Display,
1197    Leaf: Clone + PartialEq + Eq + Display + Hash,
1198{
1199    /// Bind a variable to a particular datatype or primitive.
1200    /// At the top level (in a [`Command::Action`]), this defines a global variable.
1201    /// In a [`Command::Rule`], this defines a local variable in the actions.
1202    Let(Span, Leaf, GenericExpr<Head, Leaf>),
1203    /// `set` a function to a particular result.
1204    /// `set` should not be used on datatypes-
1205    /// instead, use `union`.
1206    Set(
1207        Span,
1208        Head,
1209        Vec<GenericExpr<Head, Leaf>>,
1210        GenericExpr<Head, Leaf>,
1211    ),
1212    /// Delete or subsume (mark as hidden from future rewrites and unextractable) an entry from a function.
1213    Change(Span, Change, Head, Vec<GenericExpr<Head, Leaf>>),
1214    /// `union` two datatypes, making them equal
1215    /// in the implicit, global equality relation
1216    /// of egglog.
1217    /// All rules match modulo this equality relation.
1218    ///
1219    /// Example:
1220    /// ```text
1221    /// (datatype Math (Num i64))
1222    /// (union (Num 1) (Num 2)); Define that Num 1 and Num 2 are equivalent
1223    /// (extract (Num 1)); Extracts Num 1
1224    /// (extract (Num 2)); Extracts Num 1
1225    /// ```
1226    Union(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
1227    /// `extract` a datatype from the egraph, choosing
1228    /// the smallest representative.
1229    /// By default, each constructor costs 1 to extract
1230    /// (common subexpressions are not shared in the cost
1231    /// model).
1232    /// The second argument is the number of variants to
1233    /// extract, picking different terms in the
1234    /// same equivalence class.
1235    Extract(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
1236    Panic(Span, String),
1237    Expr(Span, GenericExpr<Head, Leaf>),
1238    // If(Expr, Action, Action),
1239}
1240
1241#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1242
1243pub struct GenericActions<Head: Clone + Display, Leaf: Clone + PartialEq + Eq + Display + Hash>(
1244    pub Vec<GenericAction<Head, Leaf>>,
1245);
1246pub type Actions = GenericActions<Symbol, Symbol>;
1247pub(crate) type ResolvedActions = GenericActions<ResolvedCall, ResolvedVar>;
1248pub(crate) type MappedActions<Head, Leaf> = GenericActions<CorrespondingVar<Head, Leaf>, Leaf>;
1249
1250impl<Head, Leaf> Default for GenericActions<Head, Leaf>
1251where
1252    Head: Clone + Display,
1253    Leaf: Clone + PartialEq + Eq + Display + Hash,
1254{
1255    fn default() -> Self {
1256        Self(vec![])
1257    }
1258}
1259
1260impl<Head, Leaf> GenericActions<Head, Leaf>
1261where
1262    Head: Clone + Display,
1263    Leaf: Clone + PartialEq + Eq + Display + Hash,
1264{
1265    pub(crate) fn len(&self) -> usize {
1266        self.0.len()
1267    }
1268
1269    pub(crate) fn iter(&self) -> impl Iterator<Item = &GenericAction<Head, Leaf>> {
1270        self.0.iter()
1271    }
1272
1273    pub(crate) fn visit_exprs(
1274        self,
1275        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
1276    ) -> Self {
1277        Self(self.0.into_iter().map(|a| a.visit_exprs(f)).collect())
1278    }
1279}
1280
1281impl<Head, Leaf> Display for GenericAction<Head, Leaf>
1282where
1283    Head: Clone + Display,
1284    Leaf: Clone + PartialEq + Eq + Display + Hash,
1285{
1286    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1287        match self {
1288            GenericAction::Let(_ann, lhs, rhs) => write!(f, "(let {lhs} {rhs})"),
1289            GenericAction::Set(_ann, lhs, args, rhs) => {
1290                write!(f, "(set ({lhs} {}) {rhs})", ListDisplay(args, " "))
1291            }
1292            GenericAction::Union(_ann, lhs, rhs) => write!(f, "(union {lhs} {rhs})"),
1293            GenericAction::Change(_ann, change, lhs, args) => {
1294                let change = match change {
1295                    Change::Delete => "delete",
1296                    Change::Subsume => "subsume",
1297                };
1298                write!(f, "({change} ({lhs} {}))", ListDisplay(args, " "))
1299            }
1300            GenericAction::Extract(_ann, expr, variants) => {
1301                write!(f, "(extract {expr} {variants})")
1302            }
1303            GenericAction::Panic(_ann, msg) => write!(f, "(panic {msg:?})"),
1304            GenericAction::Expr(_ann, e) => write!(f, "{e}"),
1305        }
1306    }
1307}
1308
1309impl<Head, Leaf> GenericAction<Head, Leaf>
1310where
1311    Head: Clone + Display,
1312    Leaf: Clone + Eq + Display + Hash,
1313{
1314    // Applys `f` to all expressions in the action.
1315    pub fn map_exprs(
1316        &self,
1317        f: &mut impl FnMut(&GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
1318    ) -> Self {
1319        match self {
1320            GenericAction::Let(span, lhs, rhs) => {
1321                GenericAction::Let(span.clone(), lhs.clone(), f(rhs))
1322            }
1323            GenericAction::Set(span, lhs, args, rhs) => {
1324                let right = f(rhs);
1325                GenericAction::Set(
1326                    span.clone(),
1327                    lhs.clone(),
1328                    args.iter().map(f).collect(),
1329                    right,
1330                )
1331            }
1332            GenericAction::Change(span, change, lhs, args) => GenericAction::Change(
1333                span.clone(),
1334                *change,
1335                lhs.clone(),
1336                args.iter().map(f).collect(),
1337            ),
1338            GenericAction::Union(span, lhs, rhs) => {
1339                GenericAction::Union(span.clone(), f(lhs), f(rhs))
1340            }
1341            GenericAction::Extract(span, expr, variants) => {
1342                GenericAction::Extract(span.clone(), f(expr), f(variants))
1343            }
1344            GenericAction::Panic(span, msg) => GenericAction::Panic(span.clone(), msg.clone()),
1345            GenericAction::Expr(span, e) => GenericAction::Expr(span.clone(), f(e)),
1346        }
1347    }
1348
1349    /// Applys `f` to all sub-expressions (including `self`)
1350    /// bottom-up, collecting the results.
1351    pub fn visit_exprs(
1352        self,
1353        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
1354    ) -> Self {
1355        match self {
1356            GenericAction::Let(span, lhs, rhs) => {
1357                GenericAction::Let(span, lhs.clone(), rhs.visit_exprs(f))
1358            }
1359            // TODO should we refactor `Set` so that we can map over Expr::Call(lhs, args)?
1360            // This seems more natural to oflatt
1361            // Currently, visit_exprs does not apply f to the first argument of Set.
1362            GenericAction::Set(span, lhs, args, rhs) => {
1363                let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
1364                GenericAction::Set(span, lhs.clone(), args, rhs.visit_exprs(f))
1365            }
1366            GenericAction::Change(span, change, lhs, args) => {
1367                let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
1368                GenericAction::Change(span, change, lhs.clone(), args)
1369            }
1370            GenericAction::Union(span, lhs, rhs) => {
1371                GenericAction::Union(span, lhs.visit_exprs(f), rhs.visit_exprs(f))
1372            }
1373            GenericAction::Extract(span, expr, variants) => {
1374                GenericAction::Extract(span, expr.visit_exprs(f), variants.visit_exprs(f))
1375            }
1376            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
1377            GenericAction::Expr(span, e) => GenericAction::Expr(span, e.visit_exprs(f)),
1378        }
1379    }
1380
1381    pub fn subst(&self, subst: &mut impl FnMut(&Span, &Leaf) -> GenericExpr<Head, Leaf>) -> Self {
1382        self.map_exprs(&mut |e| e.subst_leaf(subst))
1383    }
1384
1385    pub fn map_def_use(self, fvar: &mut impl FnMut(Leaf, bool) -> Leaf) -> Self {
1386        macro_rules! fvar_expr {
1387            () => {
1388                |span, s: _| GenericExpr::Var(span.clone(), fvar(s.clone(), false))
1389            };
1390        }
1391        match self {
1392            GenericAction::Let(span, lhs, rhs) => {
1393                let lhs = fvar(lhs, true);
1394                let rhs = rhs.subst_leaf(&mut fvar_expr!());
1395                GenericAction::Let(span, lhs, rhs)
1396            }
1397            GenericAction::Set(span, lhs, args, rhs) => {
1398                let args = args
1399                    .into_iter()
1400                    .map(|e| e.subst_leaf(&mut fvar_expr!()))
1401                    .collect();
1402                let rhs = rhs.subst_leaf(&mut fvar_expr!());
1403                GenericAction::Set(span, lhs.clone(), args, rhs)
1404            }
1405            GenericAction::Change(span, change, lhs, args) => {
1406                let args = args
1407                    .into_iter()
1408                    .map(|e| e.subst_leaf(&mut fvar_expr!()))
1409                    .collect();
1410                GenericAction::Change(span, change, lhs.clone(), args)
1411            }
1412            GenericAction::Union(span, lhs, rhs) => {
1413                let lhs = lhs.subst_leaf(&mut fvar_expr!());
1414                let rhs = rhs.subst_leaf(&mut fvar_expr!());
1415                GenericAction::Union(span, lhs, rhs)
1416            }
1417            GenericAction::Extract(span, expr, variants) => {
1418                let expr = expr.subst_leaf(&mut fvar_expr!());
1419                let variants = variants.subst_leaf(&mut fvar_expr!());
1420                GenericAction::Extract(span, expr, variants)
1421            }
1422            GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
1423            GenericAction::Expr(span, e) => {
1424                GenericAction::Expr(span, e.subst_leaf(&mut fvar_expr!()))
1425            }
1426        }
1427    }
1428}
1429
1430#[derive(Clone, Debug)]
1431pub(crate) struct CompiledRule {
1432    pub(crate) query: CompiledQuery,
1433    pub(crate) program: Program,
1434}
1435
1436pub type Rule = GenericRule<Symbol, Symbol>;
1437pub(crate) type ResolvedRule = GenericRule<ResolvedCall, ResolvedVar>;
1438
1439#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1440pub struct GenericRule<Head, Leaf>
1441where
1442    Head: Clone + Display,
1443    Leaf: Clone + PartialEq + Eq + Display + Hash,
1444{
1445    pub span: Span,
1446    pub head: GenericActions<Head, Leaf>,
1447    pub body: Vec<GenericFact<Head, Leaf>>,
1448}
1449
1450impl<Head, Leaf> GenericRule<Head, Leaf>
1451where
1452    Head: Clone + Display,
1453    Leaf: Clone + PartialEq + Eq + Display + Hash,
1454{
1455    pub(crate) fn visit_exprs(
1456        self,
1457        f: &mut impl FnMut(GenericExpr<Head, Leaf>) -> GenericExpr<Head, Leaf>,
1458    ) -> Self {
1459        Self {
1460            span: self.span,
1461            head: self.head.visit_exprs(f),
1462            body: self
1463                .body
1464                .into_iter()
1465                .map(|bexpr| bexpr.visit_exprs(f))
1466                .collect(),
1467        }
1468    }
1469}
1470
1471impl<Head, Leaf> GenericRule<Head, Leaf>
1472where
1473    Head: Clone + Display,
1474    Leaf: Clone + PartialEq + Eq + Display + Hash,
1475{
1476    pub(crate) fn fmt_with_ruleset(
1477        &self,
1478        f: &mut Formatter,
1479        ruleset: Symbol,
1480        name: Symbol,
1481    ) -> std::fmt::Result {
1482        let indent = " ".repeat(7);
1483        write!(f, "(rule (")?;
1484        for (i, fact) in self.body.iter().enumerate() {
1485            if i > 0 {
1486                write!(f, "{}", indent)?;
1487            }
1488
1489            if i != self.body.len() - 1 {
1490                writeln!(f, "{}", fact)?;
1491            } else {
1492                write!(f, "{}", fact)?;
1493            }
1494        }
1495        write!(f, ")\n      (")?;
1496        for (i, action) in self.head.0.iter().enumerate() {
1497            if i > 0 {
1498                write!(f, "{}", indent)?;
1499            }
1500            if i != self.head.0.len() - 1 {
1501                writeln!(f, "{}", action)?;
1502            } else {
1503                write!(f, "{}", action)?;
1504            }
1505        }
1506        let ruleset = if ruleset != "".into() {
1507            format!(":ruleset {}", ruleset)
1508        } else {
1509            "".into()
1510        };
1511        let name = if name != "".into() {
1512            format!(":name \"{}\"", name)
1513        } else {
1514            "".into()
1515        };
1516        write!(f, ")\n{} {} {})", indent, ruleset, name)
1517    }
1518}
1519
1520pub type Rewrite = GenericRewrite<Symbol, Symbol>;
1521
1522#[derive(Clone, Debug)]
1523pub struct GenericRewrite<Head, Leaf> {
1524    pub span: Span,
1525    pub lhs: GenericExpr<Head, Leaf>,
1526    pub rhs: GenericExpr<Head, Leaf>,
1527    pub conditions: Vec<GenericFact<Head, Leaf>>,
1528}
1529
1530impl<Head: Display, Leaf: Display> GenericRewrite<Head, Leaf> {
1531    /// Converts the rewrite into an s-expression.
1532    pub fn fmt_with_ruleset(
1533        &self,
1534        f: &mut Formatter,
1535        ruleset: Symbol,
1536        is_bidirectional: bool,
1537        subsume: bool,
1538    ) -> std::fmt::Result {
1539        let direction = if is_bidirectional {
1540            "birewrite"
1541        } else {
1542            "rewrite"
1543        };
1544        write!(f, "({direction} {} {}", self.lhs, self.rhs)?;
1545        if subsume {
1546            write!(f, " :subsume")?;
1547        }
1548        if !self.conditions.is_empty() {
1549            write!(f, " :when ({})", ListDisplay(&self.conditions, " "))?;
1550        }
1551        if ruleset != "".into() {
1552            write!(f, " :ruleset {ruleset}")?;
1553        }
1554        write!(f, ")")
1555    }
1556}
1557
1558impl<Head, Leaf: Clone> MappedExpr<Head, Leaf>
1559where
1560    Head: Clone + Display,
1561    Leaf: Clone + PartialEq + Eq + Display + Hash,
1562{
1563    pub(crate) fn get_corresponding_var_or_lit(&self, typeinfo: &TypeInfo) -> GenericAtomTerm<Leaf>
1564    where
1565        Leaf: SymbolLike,
1566    {
1567        // Note: need typeinfo to resolve whether a symbol is a global or not
1568        // This is error-prone and the complexities can be avoided by treating globals
1569        // as nullary functions.
1570        match self {
1571            GenericExpr::Var(span, v) => {
1572                if typeinfo.is_global(v.to_symbol()) {
1573                    GenericAtomTerm::Global(span.clone(), v.clone())
1574                } else {
1575                    GenericAtomTerm::Var(span.clone(), v.clone())
1576                }
1577            }
1578            GenericExpr::Lit(span, lit) => GenericAtomTerm::Literal(span.clone(), lit.clone()),
1579            GenericExpr::Call(span, head, _) => GenericAtomTerm::Var(span.clone(), head.to.clone()),
1580        }
1581    }
1582}
1583
1584impl<Head, Leaf> GenericActions<Head, Leaf>
1585where
1586    Head: Clone + Display,
1587    Leaf: Clone + PartialEq + Eq + Display + Hash,
1588{
1589    pub fn new(actions: Vec<GenericAction<Head, Leaf>>) -> Self {
1590        Self(actions)
1591    }
1592
1593    pub fn singleton(action: GenericAction<Head, Leaf>) -> Self {
1594        Self(vec![action])
1595    }
1596}