pub mod ast;
#[cfg(feature = "bin")]
mod cli;
mod command_macro;
pub mod constraint;
mod core;
pub mod extract;
pub mod prelude;
pub mod scheduler;
mod serialize;
pub mod sort;
mod term_encoding;
mod termdag;
mod typechecking;
pub mod util;
pub use command_macro::{CommandMacro, CommandMacroRegistry};
extern crate self as egglog;
pub use ast::{ResolvedExpr, ResolvedFact, ResolvedVar};
#[cfg(feature = "bin")]
pub use cli::*;
use constraint::{Constraint, Problem, SimpleTypeConstraint, TypeConstraint};
pub use core::{Atom, AtomTerm};
use core::{CoreActionContext, ResolvedAtomTerm};
pub use core::{ResolvedCall, SpecializedPrimitive};
pub use core_relations::{BaseValue, ContainerValue, ExecutionState, Value};
use core_relations::{ExternalFunctionId, make_external_func};
use csv::Writer;
pub use egglog_add_primitive::add_primitive;
use egglog_ast::generic_ast::{Change, GenericExpr, Literal};
use egglog_ast::span::Span;
use egglog_ast::util::ListDisplay;
pub use egglog_bridge::FunctionRow;
use egglog_bridge::{ColumnTy, QueryEntry};
use egglog_core_relations as core_relations;
use egglog_numeric_id as numeric_id;
use egglog_reports::{ReportLevel, RunReport};
use extract::{CostModel, DefaultCost, Extractor, TreeAdditiveCostModel};
use indexmap::map::Entry;
use log::{Level, log_enabled};
use numeric_id::DenseIdMap;
use prelude::*;
use scheduler::{SchedulerId, SchedulerRecord};
pub use serialize::{SerializeConfig, SerializeOutput, SerializedNode};
use sort::*;
use std::fmt::{Debug, Display, Formatter};
use std::fs::File;
use std::hash::Hash;
use std::io::{Read, Write as _};
use std::iter::once;
use std::ops::Deref;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
pub use termdag::{Term, TermDag, TermId};
use thiserror::Error;
pub use typechecking::TypeError;
pub use typechecking::TypeInfo;
use util::*;
use crate::ast::desugar::desugar_command;
use crate::ast::*;
use crate::core::{GenericActionsExt, ResolvedRuleExt};
pub use crate::term_encoding::file_supports_proofs;
use crate::term_encoding::{EncodingState, TermState, command_supports_proof_encoding};
pub const GLOBAL_NAME_PREFIX: &str = "$";
pub type ArcSort = Arc<dyn Sort>;
pub trait Primitive {
fn name(&self) -> &str;
fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint>;
fn apply(&self, exec_state: &mut ExecutionState, args: &[Value]) -> Option<Value>;
}
pub trait UserDefinedCommandOutput: Debug + std::fmt::Display + Send + Sync {}
impl<T> UserDefinedCommandOutput for T where T: Debug + std::fmt::Display + Send + Sync {}
#[derive(Clone, Debug)]
#[allow(clippy::large_enum_variant)]
pub enum CommandOutput {
PrintFunctionSize(usize),
PrintAllFunctionsSize(Vec<(String, usize)>),
ExtractBest(TermDag, DefaultCost, TermId),
ExtractVariants(TermDag, Vec<TermId>),
OverallStatistics(RunReport),
PrintFunction(Function, TermDag, Vec<(TermId, TermId)>, PrintFunctionMode),
RunSchedule(RunReport),
UserDefined(Arc<dyn UserDefinedCommandOutput>),
}
impl std::fmt::Display for CommandOutput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CommandOutput::PrintFunctionSize(size) => writeln!(f, "{}", size),
CommandOutput::PrintAllFunctionsSize(names_and_sizes) => {
for name in names_and_sizes {
writeln!(f, "{}: {}", name.0, name.1)?;
}
Ok(())
}
CommandOutput::ExtractBest(termdag, _cost, term) => {
writeln!(f, "{}", termdag.to_string(*term))
}
CommandOutput::ExtractVariants(termdag, terms) => {
writeln!(f, "(")?;
for expr in terms {
writeln!(f, " {}", termdag.to_string(*expr))?;
}
writeln!(f, ")")
}
CommandOutput::OverallStatistics(run_report) => {
write!(f, "Overall statistics:\n{}", run_report)
}
CommandOutput::PrintFunction(function, termdag, terms_and_outputs, mode) => {
let out_is_unit = function.schema.output.name() == UnitSort.name();
if *mode == PrintFunctionMode::CSV {
let mut wtr = Writer::from_writer(vec![]);
for (term, output) in terms_and_outputs {
match termdag.get(*term) {
Term::App(name, children) => {
let mut values = vec![name.clone()];
for child_id in children {
values.push(termdag.to_string(*child_id));
}
if !out_is_unit {
values.push(termdag.to_string(*output));
}
wtr.write_record(&values).map_err(|_| std::fmt::Error)?;
}
_ => panic!("Expect function_to_dag to return a list of apps."),
}
}
let csv_bytes = wtr.into_inner().map_err(|_| std::fmt::Error)?;
f.write_str(&String::from_utf8(csv_bytes).map_err(|_| std::fmt::Error)?)
} else {
writeln!(f, "(")?;
for (term, output) in terms_and_outputs.iter() {
write!(f, " {}", termdag.to_string(*term))?;
if !out_is_unit {
write!(f, " -> {}", termdag.to_string(*output))?;
}
writeln!(f)?;
}
writeln!(f, ")")
}
}
CommandOutput::RunSchedule(_report) => Ok(()),
CommandOutput::UserDefined(output) => {
write!(f, "{}", *output)
}
}
}
}
#[derive(Clone)]
pub struct EGraph {
backend: egglog_bridge::EGraph,
pub parser: Parser,
names: check_shadowing::Names,
pushed_egraph: Option<Box<Self>>,
functions: IndexMap<String, Function>,
rulesets: IndexMap<String, Ruleset>,
pub fact_directory: Option<PathBuf>,
pub seminaive: bool,
type_info: TypeInfo,
overall_run_report: RunReport,
schedulers: DenseIdMap<SchedulerId, SchedulerRecord>,
commands: IndexMap<String, Arc<dyn UserDefinedCommand>>,
strict_mode: bool,
warned_about_global_prefix: bool,
command_macros: CommandMacroRegistry,
proof_state: EncodingState,
}
pub trait UserDefinedCommand: Send + Sync {
fn update(&self, egraph: &mut EGraph, args: &[Expr]) -> Result<Option<CommandOutput>, Error>;
}
#[derive(Clone)]
pub struct Function {
decl: ResolvedFunctionDecl,
schema: ResolvedSchema,
can_subsume: bool,
backend_id: egglog_bridge::FunctionId,
}
impl Function {
pub fn name(&self) -> &str {
&self.decl.name
}
pub fn schema(&self) -> &ResolvedSchema {
&self.schema
}
pub fn can_subsume(&self) -> bool {
self.can_subsume
}
}
#[derive(Clone, Debug)]
pub struct ResolvedSchema {
pub input: Vec<ArcSort>,
pub output: ArcSort,
}
impl ResolvedSchema {
pub fn get_by_pos(&self, index: usize) -> Option<&ArcSort> {
if self.input.len() == index {
Some(&self.output)
} else {
self.input.get(index)
}
}
}
impl Debug for Function {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Function")
.field("decl", &self.decl)
.field("schema", &self.schema)
.finish()
}
}
impl Default for EGraph {
fn default() -> Self {
let mut eg = Self {
backend: Default::default(),
parser: Default::default(),
names: Default::default(),
pushed_egraph: Default::default(),
functions: Default::default(),
rulesets: Default::default(),
fact_directory: None,
seminaive: true,
overall_run_report: Default::default(),
type_info: Default::default(),
schedulers: Default::default(),
commands: Default::default(),
strict_mode: false,
warned_about_global_prefix: false,
command_macros: Default::default(),
proof_state: Default::default(),
};
add_base_sort(&mut eg, UnitSort, span!()).unwrap();
add_base_sort(&mut eg, StringSort, span!()).unwrap();
add_base_sort(&mut eg, BoolSort, span!()).unwrap();
add_base_sort(&mut eg, I64Sort, span!()).unwrap();
add_base_sort(&mut eg, F64Sort, span!()).unwrap();
add_base_sort(&mut eg, BigIntSort, span!()).unwrap();
add_base_sort(&mut eg, BigRatSort, span!()).unwrap();
eg.type_info.add_presort::<MapSort>(span!()).unwrap();
eg.type_info.add_presort::<SetSort>(span!()).unwrap();
eg.type_info.add_presort::<VecSort>(span!()).unwrap();
eg.type_info.add_presort::<FunctionSort>(span!()).unwrap();
eg.type_info.add_presort::<MultiSetSort>(span!()).unwrap();
add_primitive!(&mut eg, "!=" = |a: #, b: #| -?> () {
(a != b).then_some(())
});
add_primitive!(&mut eg, "value-eq" = |a: #, b: #| -?> () {
(a == b).then_some(())
});
add_primitive!(&mut eg, "ordering-min" = |a: #, b: #| -> # {
if a < b { a } else { b }
});
add_primitive!(&mut eg, "ordering-max" = |a: #, b: #| -> # {
if a > b { a } else { b }
});
eg.rulesets
.insert("".into(), Ruleset::Rules(Default::default()));
eg
}
}
#[derive(Debug, Error)]
#[error("Not found: {0}")]
pub struct NotFoundError(String);
impl EGraph {
pub fn new_with_term_encoding() -> Self {
let mut egraph = EGraph::default();
egraph.proof_state.original_typechecking = Some(Box::new(egraph.clone()));
egraph
}
pub fn with_term_encoding_enabled(mut self) -> Self {
self.proof_state.original_typechecking = Some(Box::new(self.clone()));
self
}
pub fn type_info(&mut self) -> &mut TypeInfo {
&mut self.type_info
}
pub fn command_macros(&self) -> &CommandMacroRegistry {
&self.command_macros
}
pub fn command_macros_mut(&mut self) -> &mut CommandMacroRegistry {
&mut self.command_macros
}
pub fn add_command(
&mut self,
name: String,
command: Arc<dyn UserDefinedCommand>,
) -> Result<(), Error> {
if self.commands.contains_key(&name)
|| self.functions.contains_key(&name)
|| self.type_info.get_prims(&name).is_some()
{
return Err(Error::CommandAlreadyExists(name, span!()));
}
self.commands.insert(name.clone(), command);
self.parser.add_user_defined(name)?;
Ok(())
}
pub fn set_strict_mode(&mut self, strict_mode: bool) {
self.strict_mode = strict_mode;
}
pub fn strict_mode(&self) -> bool {
self.strict_mode
}
fn ensure_global_name_prefix(&mut self, span: &Span, name: &str) -> Result<(), TypeError> {
if name.starts_with(GLOBAL_NAME_PREFIX) {
return Ok(());
}
if self.strict_mode {
Err(TypeError::GlobalMissingPrefix {
name: name.to_owned(),
span: span.clone(),
})
} else {
self.warn_missing_global_prefix(span, name)?;
Ok(())
}
}
fn warn_missing_global_prefix(
&mut self,
span: &Span,
canonical_name: &str,
) -> Result<(), TypeError> {
if self.strict_mode {
return Err(TypeError::GlobalMissingPrefix { name: format!("{}{}", GLOBAL_NAME_PREFIX, canonical_name), span: span.clone() } );
}
if self.warned_about_global_prefix {
return Ok(());
}
self.warned_about_global_prefix = true;
log::warn!(
"{}\nGlobal `{}` should start with `{}`. Enable `--strict-mode` to turn this warning into an error. Suppressing additional warnings of this type.",
span,
canonical_name,
GLOBAL_NAME_PREFIX
);
Ok(())
}
fn warn_prefixed_non_globals(
&mut self,
span: &Span,
canonical_name: &str,
) -> Result<(), TypeError> {
if self.strict_mode {
return Err(TypeError::NonGlobalPrefixed {
name: format!("{}{}", GLOBAL_NAME_PREFIX, canonical_name),
span: span.clone(),
});
}
if self.warned_about_global_prefix {
return Ok(());
}
self.warned_about_global_prefix = true;
log::warn!(
"{}\nNon-global `{}` should not start with `{}`. Enable `--strict-mode` to turn this warning into an error. Suppressing additional warnings of this type.",
span,
canonical_name,
GLOBAL_NAME_PREFIX
);
Ok(())
}
pub fn push(&mut self) {
let prev_prev: Option<Box<Self>> = self.pushed_egraph.take();
let mut prev = self.clone();
prev.pushed_egraph = prev_prev;
self.pushed_egraph = Some(Box::new(prev));
}
pub fn pop(&mut self) -> Result<(), Error> {
match self.pushed_egraph.take() {
Some(e) => {
let overall_run_report = self.overall_run_report.clone();
*self = *e;
self.overall_run_report = overall_run_report;
Ok(())
}
None => Err(Error::Pop(span!())),
}
}
fn translate_expr_to_mergefn(
&self,
expr: &ResolvedExpr,
) -> Result<egglog_bridge::MergeFn, Error> {
match expr {
GenericExpr::Lit(_, literal) => {
let val = literal_to_value(&self.backend, literal);
Ok(egglog_bridge::MergeFn::Const(val))
}
GenericExpr::Var(span, resolved_var) => match resolved_var.name.as_str() {
"old" => Ok(egglog_bridge::MergeFn::Old),
"new" => Ok(egglog_bridge::MergeFn::New),
_ => Err(TypeError::Unbound(resolved_var.name.clone(), span.clone()).into()),
},
GenericExpr::Call(_, ResolvedCall::Func(f), args) => {
let translated_args = args
.iter()
.map(|arg| self.translate_expr_to_mergefn(arg))
.collect::<Result<Vec<_>, _>>()?;
Ok(egglog_bridge::MergeFn::Function(
self.functions[&f.name].backend_id,
translated_args,
))
}
GenericExpr::Call(_, ResolvedCall::Primitive(p), args) => {
let translated_args = args
.iter()
.map(|arg| self.translate_expr_to_mergefn(arg))
.collect::<Result<Vec<_>, _>>()?;
Ok(egglog_bridge::MergeFn::Primitive(
p.external_id(),
translated_args,
))
}
}
}
fn declare_function(&mut self, decl: &ResolvedFunctionDecl) -> Result<(), Error> {
let get_sort = |name: &String| match self.type_info.get_sort_by_name(name) {
Some(sort) => Ok(sort.clone()),
None => Err(Error::TypeError(TypeError::UndefinedSort(
name.to_owned(),
decl.span.clone(),
))),
};
let input = decl
.schema
.input
.iter()
.map(get_sort)
.collect::<Result<Vec<_>, _>>()?;
let output = get_sort(&decl.schema.output)?;
let can_subsume = match decl.subtype {
FunctionSubtype::Constructor => true,
FunctionSubtype::Relation => true,
FunctionSubtype::Custom => false,
};
use egglog_bridge::{DefaultVal, MergeFn};
let backend_id = self.backend.add_table(egglog_bridge::FunctionConfig {
schema: input
.iter()
.chain([&output])
.map(|sort| sort.column_ty(&self.backend))
.collect(),
default: match decl.subtype {
FunctionSubtype::Constructor => DefaultVal::FreshId,
FunctionSubtype::Custom => DefaultVal::Fail,
FunctionSubtype::Relation => DefaultVal::Const(self.backend.base_values().get(())),
},
merge: match decl.subtype {
FunctionSubtype::Constructor => MergeFn::UnionId,
FunctionSubtype::Relation => MergeFn::AssertEq,
FunctionSubtype::Custom => match &decl.merge {
None => MergeFn::AssertEq,
Some(expr) => self.translate_expr_to_mergefn(expr)?,
},
},
name: decl.name.to_string(),
can_subsume,
});
let function = Function {
decl: decl.clone(),
schema: ResolvedSchema { input, output },
can_subsume,
backend_id,
};
let old = self.functions.insert(decl.name.clone(), function);
if old.is_some() {
panic!(
"Typechecking should have caught function already bound: {}",
decl.name
);
}
Ok(())
}
pub fn function_to_dag(
&self,
sym: &str,
n: usize,
include_output: bool,
) -> Result<(Vec<TermId>, Option<Vec<TermId>>, TermDag), Error> {
let func = self
.functions
.get(sym)
.ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
let mut rootsorts = func.schema.input.clone();
if include_output {
rootsorts.push(func.schema.output.clone());
}
let extractor = Extractor::compute_costs_from_rootsorts(
Some(rootsorts),
self,
TreeAdditiveCostModel::default(),
);
let mut termdag = TermDag::default();
let mut inputs: Vec<TermId> = Vec::new();
let mut output: Option<Vec<TermId>> = if include_output {
Some(Vec::new())
} else {
None
};
let extract_row = |row: egglog_bridge::FunctionRow| {
if inputs.len() < n {
let mut children: Vec<TermId> = Vec::new();
for (value, sort) in row.vals.iter().zip(&func.schema.input) {
let (_, term) = extractor
.extract_best_with_sort(self, &mut termdag, *value, sort.clone())
.unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
children.push(term);
}
inputs.push(termdag.app(sym.to_owned(), children));
if include_output {
let value = row.vals[func.schema.input.len()];
let sort = &func.schema.output;
let (_, term) = extractor
.extract_best_with_sort(self, &mut termdag, value, sort.clone())
.unwrap_or_else(|| (0, termdag.var("Unextractable".into())));
output.as_mut().unwrap().push(term);
}
true
} else {
false
}
};
self.backend.for_each_while(func.backend_id, extract_row);
Ok((inputs, output, termdag))
}
pub fn print_function(
&mut self,
sym: &str,
n: Option<usize>,
file: Option<File>,
mode: PrintFunctionMode,
) -> Result<Option<CommandOutput>, Error> {
let n = match n {
Some(n) => {
log::info!("Printing up to {n} tuples of function {sym} as {mode}");
n
}
None => {
log::info!("Printing all tuples of function {sym} as {mode}");
usize::MAX
}
};
let (terms, outputs, termdag) = self.function_to_dag(sym, n, true)?;
let f = self
.functions
.get(sym)
.unwrap();
let terms_and_outputs: Vec<_> = terms.into_iter().zip(outputs.unwrap()).collect();
let output = CommandOutput::PrintFunction(f.clone(), termdag, terms_and_outputs, mode);
match file {
Some(mut file) => {
log::info!("Writing output to file");
file.write_all(output.to_string().as_bytes())
.expect("Error writing to file");
Ok(None)
}
None => Ok(Some(output)),
}
}
pub fn print_size(&self, sym: Option<&str>) -> Result<CommandOutput, Error> {
if let Some(sym) = sym {
let f = self
.functions
.get(sym)
.ok_or(TypeError::UnboundFunction(sym.to_owned(), span!()))?;
let size = self.backend.table_size(f.backend_id);
log::info!("Function {} has size {}", sym, size);
Ok(CommandOutput::PrintFunctionSize(size))
} else {
let mut lens = self
.functions
.iter()
.map(|(sym, f)| (sym.clone(), self.backend.table_size(f.backend_id)))
.collect::<Vec<_>>();
lens.sort_by_key(|(name, _)| name.clone());
if log_enabled!(Level::Info) {
for (sym, len) in &lens {
log::info!("Function {} has size {}", sym, len);
}
}
Ok(CommandOutput::PrintAllFunctionsSize(lens))
}
}
fn run_schedule(&mut self, sched: &ResolvedSchedule) -> Result<RunReport, Error> {
match sched {
ResolvedSchedule::Run(span, config) => self.run_rules(span, config),
ResolvedSchedule::Repeat(_span, limit, sched) => {
let mut report = RunReport::default();
for _i in 0..*limit {
let rec = self.run_schedule(sched)?;
let updated = rec.updated;
report.union(rec);
if !updated {
break;
}
}
Ok(report)
}
ResolvedSchedule::Saturate(_span, sched) => {
let mut report = RunReport::default();
loop {
let rec = self.run_schedule(sched)?;
let updated = rec.updated;
report.union(rec);
if !updated {
break;
}
}
Ok(report)
}
ResolvedSchedule::Sequence(_span, scheds) => {
let mut report = RunReport::default();
for sched in scheds {
report.union(self.run_schedule(sched)?);
}
Ok(report)
}
}
}
pub fn extract_value(
&self,
sort: &ArcSort,
value: Value,
) -> Result<(TermDag, TermId, DefaultCost), Error> {
self.extract_value_with_cost_model(sort, value, TreeAdditiveCostModel::default())
}
pub fn extract_value_with_cost_model<CM: CostModel<DefaultCost> + 'static>(
&self,
sort: &ArcSort,
value: Value,
cost_model: CM,
) -> Result<(TermDag, TermId, DefaultCost), Error> {
let extractor =
Extractor::compute_costs_from_rootsorts(Some(vec![sort.clone()]), self, cost_model);
let mut termdag = TermDag::default();
let (cost, term) = extractor.extract_best(self, &mut termdag, value).unwrap();
Ok((termdag, term, cost))
}
pub fn extract_value_to_string(
&self,
sort: &ArcSort,
value: Value,
) -> Result<(String, DefaultCost), Error> {
let (termdag, term, cost) = self.extract_value(sort, value)?;
Ok((termdag.to_string(term), cost))
}
fn run_rules(&mut self, span: &Span, config: &ResolvedRunConfig) -> Result<RunReport, Error> {
log::debug!("Running ruleset: {}", config.ruleset);
let mut report: RunReport = Default::default();
let GenericRunConfig { ruleset, until } = config;
if let Some(facts) = until {
if self.check_facts(span, facts).is_ok() {
log::info!(
"Breaking early because of facts:\n {}!",
ListDisplay(facts, "\n")
);
return Ok(report);
}
}
let subreport = self.step_rules(ruleset)?;
report.union(subreport);
if log_enabled!(Level::Debug) {
log::debug!("database size: {}", self.num_tuples());
}
Ok(report)
}
pub fn step_rules(&mut self, ruleset: &str) -> Result<RunReport, Error> {
fn collect_rule_ids(
ruleset: &str,
rulesets: &IndexMap<String, Ruleset>,
ids: &mut Vec<egglog_bridge::RuleId>,
) {
match &rulesets[ruleset] {
Ruleset::Rules(rules) => {
for (_, id) in rules.values() {
ids.push(*id);
}
}
Ruleset::Combined(sub_rulesets) => {
for sub_ruleset in sub_rulesets {
collect_rule_ids(sub_ruleset, rulesets, ids);
}
}
}
}
let mut rule_ids = Vec::new();
collect_rule_ids(ruleset, &self.rulesets, &mut rule_ids);
let iteration_report = self
.backend
.run_rules(&rule_ids)
.map_err(|e| Error::BackendError(e.to_string()))?;
Ok(RunReport::singleton(ruleset, iteration_report))
}
fn add_rule(&mut self, rule: ast::ResolvedRule) -> Result<String, Error> {
let core_rule = rule.to_canonicalized_core_rule(
&self.type_info,
&mut self.parser.symbol_gen,
self.proof_state.original_typechecking.is_none(),
)?;
let (query, actions) = (&core_rule.body, &core_rule.head);
let rule_id = {
let mut translator = BackendRule::new(
self.backend.new_rule(&rule.name, self.seminaive),
&self.functions,
&self.type_info,
);
translator.query(query, false);
translator.actions(actions)?;
translator.build()
};
if let Some(rules) = self.rulesets.get_mut(&rule.ruleset) {
match rules {
Ruleset::Rules(rules) => {
match rules.entry(rule.name.clone()) {
indexmap::map::Entry::Occupied(_) => {
let name = rule.name;
panic!("Rule '{name}' was already present")
}
indexmap::map::Entry::Vacant(e) => e.insert((core_rule, rule_id)),
};
Ok(rule.name)
}
Ruleset::Combined(_) => Err(Error::CombinedRulesetError(rule.ruleset, rule.span)),
}
} else {
Err(Error::NoSuchRuleset(rule.ruleset, rule.span))
}
}
fn eval_actions(&mut self, actions: &ResolvedActions) -> Result<(), Error> {
let mut binding = IndexSet::default();
let mut ctx = CoreActionContext::new(
&self.type_info,
&mut binding,
&mut self.parser.symbol_gen,
self.proof_state.original_typechecking.is_none(),
);
let (actions, _) = actions.to_core_actions(&mut ctx)?;
let mut translator = BackendRule::new(
self.backend.new_rule("eval_actions", false),
&self.functions,
&self.type_info,
);
translator.actions(&actions)?;
let id = translator.build();
let result = self.backend.run_rules(&[id]);
self.backend.free_rule(id);
match result {
Ok(_) => Ok(()),
Err(e) => Err(Error::BackendError(e.to_string())),
}
}
pub fn eval_expr(&mut self, expr: &Expr) -> Result<(ArcSort, Value), Error> {
let span = expr.span();
let command = Command::Action(Action::Expr(span.clone(), expr.clone()));
let resolved_commands = self.resolve_command(command)?;
assert_eq!(resolved_commands.len(), 1);
let resolved_command = resolved_commands.into_iter().next().unwrap();
let resolved_expr = match resolved_command {
ResolvedNCommand::CoreAction(ResolvedAction::Expr(_, resolved_expr)) => resolved_expr,
_ => unreachable!(),
};
let sort = resolved_expr.output_type();
let value = self.eval_resolved_expr(span, &resolved_expr)?;
Ok((sort, value))
}
fn eval_resolved_expr(&mut self, span: Span, expr: &ResolvedExpr) -> Result<Value, Error> {
let unit_id = self.backend.base_values().get_ty::<()>();
let unit_val = self.backend.base_values().get(());
let result: egglog_bridge::SideChannel<Value> = Default::default();
let result_ref = result.clone();
let ext_id = self
.backend
.register_external_func(Box::new(make_external_func(move |_es, vals| {
debug_assert!(vals.len() == 1);
*result_ref.lock().unwrap() = Some(vals[0]);
Some(unit_val)
})));
let mut translator = BackendRule::new(
self.backend.new_rule("eval_resolved_expr", false),
&self.functions,
&self.type_info,
);
let result_var = ResolvedVar {
name: self.parser.symbol_gen.fresh("eval_resolved_expr"),
sort: expr.output_type(),
is_global_ref: false,
};
let actions = ResolvedActions::singleton(ResolvedAction::Let(
span.clone(),
result_var.clone(),
expr.clone(),
));
let mut binding = IndexSet::default();
let mut ctx = CoreActionContext::new(
&self.type_info,
&mut binding,
&mut self.parser.symbol_gen,
self.proof_state.original_typechecking.is_none(),
);
let actions = actions.to_core_actions(&mut ctx)?.0;
translator.actions(&actions)?;
let arg = translator.entry(&ResolvedAtomTerm::Var(span.clone(), result_var));
translator.rb.call_external_func(
ext_id,
&[arg],
egglog_bridge::ColumnTy::Base(unit_id),
|| "this function will never panic".to_string(),
);
let id = translator.build();
let rule_result = self.backend.run_rules(&[id]);
self.backend.free_rule(id);
self.backend.free_external_func(ext_id);
let _ = rule_result.map_err(|e| {
Error::BackendError(format!("Failed to evaluate expression '{}': {}", expr, e))
})?;
let result = result.lock().unwrap().unwrap();
Ok(result)
}
fn add_combined_ruleset(&mut self, name: String, rulesets: Vec<String>) {
match self.rulesets.entry(name.clone()) {
Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"),
Entry::Vacant(e) => e.insert(Ruleset::Combined(rulesets)),
};
}
fn add_ruleset(&mut self, name: String) {
match self.rulesets.entry(name.clone()) {
Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"),
Entry::Vacant(e) => e.insert(Ruleset::Rules(Default::default())),
};
}
fn check_facts(&mut self, span: &Span, facts: &[ResolvedFact]) -> Result<(), Error> {
let fresh_name = self.parser.symbol_gen.fresh("check_facts");
let fresh_ruleset = self.parser.symbol_gen.fresh("check_facts_ruleset");
let rule = ast::ResolvedRule {
span: span.clone(),
head: ResolvedActions::default(),
body: facts.to_vec(),
name: fresh_name.clone(),
ruleset: fresh_ruleset.clone(),
};
let core_rule = rule.to_canonicalized_core_rule(
&self.type_info,
&mut self.parser.symbol_gen,
self.proof_state.original_typechecking.is_none(),
)?;
let query = core_rule.body;
let ext_sc = egglog_bridge::SideChannel::default();
let ext_sc_ref = ext_sc.clone();
let ext_id = self
.backend
.register_external_func(Box::new(make_external_func(move |_, _| {
*ext_sc_ref.lock().unwrap() = Some(());
Some(Value::new_const(0))
})));
let mut translator = BackendRule::new(
self.backend.new_rule("check_facts", false),
&self.functions,
&self.type_info,
);
translator.query(&query, true);
translator
.rb
.call_external_func(ext_id, &[], egglog_bridge::ColumnTy::Id, || {
"this function will never panic".to_string()
});
let id = translator.build();
let _ = self.backend.run_rules(&[id]).unwrap();
self.backend.free_rule(id);
self.backend.free_external_func(ext_id);
let ext_sc_val = ext_sc.lock().unwrap().take();
let matched = matches!(ext_sc_val, Some(()));
if !matched {
Err(Error::CheckError(
facts.iter().map(|f| f.clone().make_unresolved()).collect(),
span.clone(),
))
} else {
Ok(())
}
}
fn run_command(&mut self, command: ResolvedNCommand) -> Result<Option<CommandOutput>, Error> {
match command {
ResolvedNCommand::Sort(_span, name, _presort_and_args) => {
log::info!("Declared sort {}.", name)
}
ResolvedNCommand::Function(fdecl) => {
self.declare_function(&fdecl)?;
log::info!("Declared {} {}.", fdecl.subtype, fdecl.name)
}
ResolvedNCommand::AddRuleset(_span, name) => {
self.add_ruleset(name.clone());
log::info!("Declared ruleset {name}.");
}
ResolvedNCommand::UnstableCombinedRuleset(_span, name, others) => {
self.add_combined_ruleset(name.clone(), others);
log::info!("Declared ruleset {name}.");
}
ResolvedNCommand::NormRule { rule } => {
let name = rule.name.clone();
self.add_rule(rule)?;
log::info!("Declared rule {name}.")
}
ResolvedNCommand::RunSchedule(sched) => {
let report = self.run_schedule(&sched)?;
log::info!("Ran schedule {}.", sched);
log::info!("Report: {}", report);
self.overall_run_report.union(report.clone());
return Ok(Some(CommandOutput::RunSchedule(report)));
}
ResolvedNCommand::PrintOverallStatistics(span, file) => match file {
None => {
log::info!("Printed overall statistics");
return Ok(Some(CommandOutput::OverallStatistics(
self.overall_run_report.clone(),
)));
}
Some(path) => {
let mut file = std::fs::File::create(&path)
.map_err(|e| Error::IoError(path.clone().into(), e, span.clone()))?;
log::info!("Printed overall statistics to json file {}", path);
serde_json::to_writer(&mut file, &self.overall_run_report)
.expect("error serializing to json");
}
},
ResolvedNCommand::Check(span, facts) => {
self.check_facts(&span, &facts)?;
log::info!("Checked fact {:?}.", facts);
}
ResolvedNCommand::CoreAction(action) => match &action {
ResolvedAction::Let(_, name, contents) => {
panic!("Globals should have been desugared away: {name} = {contents}")
}
_ => {
self.eval_actions(&ResolvedActions::new(vec![action.clone()]))?;
}
},
ResolvedNCommand::Extract(span, expr, variants) => {
let sort = expr.output_type();
let x = self.eval_resolved_expr(span.clone(), &expr)?;
let n = self.eval_resolved_expr(span, &variants)?;
let n: i64 = self.backend.base_values().unwrap(n);
let mut termdag = TermDag::default();
let extractor = Extractor::compute_costs_from_rootsorts(
Some(vec![sort]),
self,
TreeAdditiveCostModel::default(),
);
return if n == 0 {
if let Some((cost, term)) = extractor.extract_best(self, &mut termdag, x) {
if log_enabled!(Level::Info) {
log::info!("extracted with cost {cost}: {}", termdag.to_string(term));
}
Ok(Some(CommandOutput::ExtractBest(termdag, cost, term)))
} else {
Err(Error::ExtractError(
"Unable to find any valid extraction (likely due to subsume or delete)"
.to_string(),
))
}
} else {
if n < 0 {
panic!("Cannot extract negative number of variants");
}
let terms: Vec<TermId> = extractor
.extract_variants(self, &mut termdag, x, n as usize)
.iter()
.map(|e| e.1)
.collect();
if log_enabled!(Level::Info) {
let expr_str = expr.to_string();
log::info!("extracted {} variants for {expr_str}", terms.len());
}
Ok(Some(CommandOutput::ExtractVariants(termdag, terms)))
};
}
ResolvedNCommand::Push(n) => {
(0..n).for_each(|_| self.push());
log::info!("Pushed {n} levels.")
}
ResolvedNCommand::Pop(span, n) => {
for _ in 0..n {
self.pop().map_err(|err| {
if let Error::Pop(_) = err {
Error::Pop(span.clone())
} else {
err
}
})?;
}
log::info!("Popped {n} levels.")
}
ResolvedNCommand::PrintFunction(span, f, n, file, mode) => {
let file = file
.map(|file| {
std::fs::File::create(&file)
.map_err(|e| Error::IoError(file.into(), e, span.clone()))
})
.transpose()?;
return self.print_function(&f, n, file, mode).map_err(|e| match e {
Error::TypeError(TypeError::UnboundFunction(f, _)) => {
Error::TypeError(TypeError::UnboundFunction(f, span.clone()))
}
_ => e,
});
}
ResolvedNCommand::PrintSize(span, f) => {
let res = self.print_size(f.as_deref()).map_err(|e| match e {
Error::TypeError(TypeError::UnboundFunction(f, _)) => {
Error::TypeError(TypeError::UnboundFunction(f, span.clone()))
}
_ => e,
})?;
return Ok(Some(res));
}
ResolvedNCommand::Fail(span, c) => {
let result = self.run_command(*c);
if let Err(e) = result {
log::info!("Command failed as expected: {e}");
} else {
return Err(Error::ExpectFail(span));
}
}
ResolvedNCommand::Input {
span: _,
name,
file,
} => {
self.input_file(&name, file)?;
}
ResolvedNCommand::Output { span, file, exprs } => {
let mut filename = self.fact_directory.clone().unwrap_or_default();
filename.push(file.as_str());
let mut f = File::options()
.append(true)
.create(true)
.open(&filename)
.map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
let extractor = Extractor::compute_costs_from_rootsorts(
None,
self,
TreeAdditiveCostModel::default(),
);
let mut termdag: TermDag = Default::default();
use std::io::Write;
for expr in exprs {
let value = self.eval_resolved_expr(span.clone(), &expr)?;
let expr_type = expr.output_type();
let term = extractor
.extract_best_with_sort(self, &mut termdag, value, expr_type)
.unwrap()
.1;
writeln!(f, "{}", termdag.to_string(term))
.map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
}
log::info!("Output to '{filename:?}'.")
}
ResolvedNCommand::UserDefined(_span, name, exprs) => {
let command = self.commands.swap_remove(&name).unwrap_or_else(|| {
panic!("Unrecognized user-defined command: {}", name);
});
let res = command.update(self, &exprs);
self.commands.insert(name, command);
return res;
}
};
Ok(None)
}
fn input_file(&mut self, func_name: &str, file: String) -> Result<(), Error> {
let function_type = self
.type_info
.get_func_type(func_name)
.unwrap_or_else(|| panic!("Unrecognized function name {}", func_name));
let func = self.functions.get_mut(func_name).unwrap();
let mut filename = self.fact_directory.clone().unwrap_or_default();
filename.push(file.as_str());
for t in &func.schema.input {
match t.name() {
"i64" | "f64" | "String" => {}
s => panic!("Unsupported type {} for input", s),
}
}
if function_type.subtype != FunctionSubtype::Constructor {
match func.schema.output.name() {
"i64" | "String" | "Unit" => {}
s => panic!("Unsupported type {} for input", s),
}
}
log::info!("Opening file '{:?}'...", filename);
let mut f = File::open(filename).unwrap();
let mut contents = String::new();
f.read_to_string(&mut contents).unwrap();
let mut parsed_contents: Vec<Vec<Value>> = Vec::with_capacity(contents.lines().count());
let mut row_schema = func.schema.input.clone();
if function_type.subtype == FunctionSubtype::Custom {
row_schema.push(func.schema.output.clone());
}
log::debug!("{:?}", row_schema);
let unit_val = self.backend.base_values().get(());
for line in contents.lines() {
let mut it = line.split('\t').map(|s| s.trim());
let mut row: Vec<Value> = Vec::with_capacity(row_schema.len());
for sort in row_schema.iter() {
if let Some(raw) = it.next() {
let val = match sort.name() {
"i64" => {
if let Ok(i) = raw.parse::<i64>() {
self.backend.base_values().get(i)
} else {
return Err(Error::InputFileFormatError(file));
}
}
"f64" => {
if let Ok(f) = raw.parse::<f64>() {
self.backend
.base_values()
.get::<F>(core_relations::Boxed::new(f.into()))
} else {
return Err(Error::InputFileFormatError(file));
}
}
"String" => self.backend.base_values().get::<S>(raw.to_string().into()),
"Unit" => unit_val,
_ => panic!("Unreachable"),
};
row.push(val);
} else {
break;
}
}
if row.is_empty() {
continue;
}
if row.len() != row_schema.len() || it.next().is_some() {
return Err(Error::InputFileFormatError(file));
}
parsed_contents.push(row);
}
log::debug!("Successfully loaded file.");
let num_facts = parsed_contents.len();
let mut table_action = egglog_bridge::TableAction::new(&self.backend, func.backend_id);
if function_type.subtype != FunctionSubtype::Constructor {
self.backend.with_execution_state(|es| {
for row in parsed_contents.iter() {
table_action.insert(es, row.iter().copied());
}
Some(unit_val)
});
} else {
self.backend.with_execution_state(|es| {
for row in parsed_contents.iter() {
table_action.lookup(es, row);
}
Some(unit_val)
});
}
self.backend.flush_updates();
log::info!("Read {num_facts} facts into {func_name} from '{file}'.");
Ok(())
}
fn resolve_command(&mut self, command: Command) -> Result<Vec<ResolvedNCommand>, Error> {
let desugared = desugar_command(command, &mut self.parser)?;
if let Some(original_typechecking) = self.proof_state.original_typechecking.as_mut() {
let mut typechecked = original_typechecking.typecheck_program(&desugared)?;
typechecked =
proof_global_remover::remove_globals(typechecked, &mut self.parser.symbol_gen);
for command in &typechecked {
self.names.check_shadowing(command)?;
if !command_supports_proof_encoding(&command.to_command()) {
let command_text = format!("{}", command.to_command());
return Err(Error::UnsupportedProofCommand {
command: command_text,
});
}
}
let term_encoding_added = TermState::add_term_encoding(self, typechecked);
let mut new_typechecked = vec![];
for new_cmd in term_encoding_added {
let desugared = desugar_command(new_cmd, &mut self.parser)?;
let desugared_typechecked = self.typecheck_program(&desugared)?;
let desugared_typechecked = remove_globals::remove_globals(
desugared_typechecked,
&mut self.parser.symbol_gen,
);
new_typechecked.extend(desugared_typechecked);
}
Ok(new_typechecked)
} else {
let mut typechecked = self.typecheck_program(&desugared)?;
typechecked = remove_globals::remove_globals(typechecked, &mut self.parser.symbol_gen);
for command in &typechecked {
self.names.check_shadowing(command)?;
}
Ok(typechecked)
}
}
fn process_program_internal(
&mut self,
program: Vec<Command>,
run_commands: bool,
) -> Result<(Vec<CommandOutput>, Vec<ResolvedCommand>), Error> {
let mut outputs = Vec::new();
let mut desugared_commands = Vec::new();
for before_expanded_command in program {
let macro_expanded = self.command_macros.apply(
before_expanded_command,
&mut self.parser.symbol_gen,
&self.type_info,
)?;
for command in macro_expanded {
if let Command::Include(span, file) = &command {
let s = std::fs::read_to_string(file)
.unwrap_or_else(|_| panic!("{span} Failed to read file {file}"));
let included_program = self
.parser
.get_program_from_string(Some(file.clone()), &s)?;
let (included_outputs, included_desugared) =
self.process_program_internal(included_program, run_commands)?;
outputs.extend(included_outputs);
desugared_commands.extend(included_desugared);
} else {
for processed in self.resolve_command(command)? {
desugared_commands.push(processed.to_command());
if run_commands
|| matches!(
processed,
ResolvedNCommand::Push(_) | ResolvedNCommand::Pop(_, _)
)
{
let result = self.run_command(processed)?;
if let Some(output) = result {
outputs.push(output);
}
}
}
}
}
}
Ok((outputs, desugared_commands))
}
pub fn run_program(&mut self, program: Vec<Command>) -> Result<Vec<CommandOutput>, Error> {
let (outputs, _desugared_commands) = self.process_program_internal(program, true)?;
Ok(outputs)
}
pub fn desugar_program(
&mut self,
filename: Option<String>,
input: &str,
) -> Result<Vec<ResolvedCommand>, Error> {
let parsed = self.parser.get_program_from_string(filename, input)?;
let (_outputs, desugared_commands) = self.process_program_internal(parsed, false)?;
Ok(desugared_commands)
}
pub fn parse_and_run_program(
&mut self,
filename: Option<String>,
input: &str,
) -> Result<Vec<CommandOutput>, Error> {
let parsed = self.parser.get_program_from_string(filename, input)?;
self.run_program(parsed)
}
pub fn num_tuples(&self) -> usize {
self.functions
.values()
.map(|f| self.backend.table_size(f.backend_id))
.sum()
}
pub fn get_sort<S: Sort>(&self) -> Arc<S> {
self.type_info.get_sort()
}
pub fn get_sort_by<S: Sort>(&self, f: impl Fn(&Arc<S>) -> bool) -> Arc<S> {
self.type_info.get_sort_by(f)
}
pub fn get_sorts<S: Sort>(&self) -> Vec<Arc<S>> {
self.type_info.get_sorts()
}
pub fn get_sorts_by<S: Sort>(&self, f: impl Fn(&Arc<S>) -> bool) -> Vec<Arc<S>> {
self.type_info.get_sorts_by(f)
}
pub fn get_arcsort_by(&self, f: impl Fn(&ArcSort) -> bool) -> ArcSort {
self.type_info.get_arcsort_by(f)
}
pub fn get_arcsorts_by(&self, f: impl Fn(&ArcSort) -> bool) -> Vec<ArcSort> {
self.type_info.get_arcsorts_by(f)
}
pub fn get_sort_by_name(&self, sym: &str) -> Option<&ArcSort> {
self.type_info.get_sort_by_name(sym)
}
pub fn get_overall_run_report(&self) -> &RunReport {
&self.overall_run_report
}
pub fn value_to_base<T: BaseValue>(&self, x: Value) -> T {
self.backend.base_values().unwrap::<T>(x)
}
pub fn base_to_value<T: BaseValue>(&self, x: T) -> Value {
self.backend.base_values().get::<T>(x)
}
pub fn value_to_container<T: ContainerValue>(
&self,
x: Value,
) -> Option<impl Deref<Target = T>> {
self.backend.container_values().get_val::<T>(x)
}
pub fn container_to_value<T: ContainerValue>(&mut self, x: T) -> Value {
self.backend.with_execution_state(|state| {
self.backend.container_values().register_val::<T>(x, state)
})
}
pub fn get_size(&self, func: &str) -> usize {
let function_id = self.functions.get(func).unwrap().backend_id;
self.backend.table_size(function_id)
}
pub fn lookup_function(&self, name: &str, key: &[Value]) -> Option<Value> {
let func = self.functions.get(name).unwrap().backend_id;
self.backend.lookup_id(func, key)
}
pub fn get_function(&self, name: &str) -> Option<&Function> {
self.functions.get(name)
}
pub fn set_report_level(&mut self, level: ReportLevel) {
self.backend.set_report_level(level);
}
pub fn dump_debug_info(&self) {
self.backend.dump_debug_info();
}
pub fn get_canonical_value(&self, val: Value, sort: &ArcSort) -> Value {
self.backend
.get_canon_repr(val, sort.column_ty(&self.backend))
}
}
struct BackendRule<'a> {
rb: egglog_bridge::RuleBuilder<'a>,
entries: HashMap<core::ResolvedAtomTerm, QueryEntry>,
functions: &'a IndexMap<String, Function>,
type_info: &'a TypeInfo,
}
impl<'a> BackendRule<'a> {
fn new(
rb: egglog_bridge::RuleBuilder<'a>,
functions: &'a IndexMap<String, Function>,
type_info: &'a TypeInfo,
) -> BackendRule<'a> {
BackendRule {
rb,
functions,
type_info,
entries: Default::default(),
}
}
fn entry(&mut self, x: &core::ResolvedAtomTerm) -> QueryEntry {
self.entries
.entry(x.clone())
.or_insert_with(|| match x {
core::GenericAtomTerm::Var(_, v) => self
.rb
.new_var_named(v.sort.column_ty(self.rb.egraph()), &v.name),
core::GenericAtomTerm::Literal(_, l) => literal_to_entry(self.rb.egraph(), l),
core::GenericAtomTerm::Global(..) => {
panic!("Globals should have been desugared")
}
})
.clone()
}
fn func(&self, f: &typechecking::FuncType) -> egglog_bridge::FunctionId {
self.functions[&f.name].backend_id
}
fn prim(
&mut self,
prim: &core::SpecializedPrimitive,
args: &[core::ResolvedAtomTerm],
) -> (ExternalFunctionId, Vec<QueryEntry>, ColumnTy) {
let mut qe_args = self.args(args);
if prim.name() == "unstable-fn" {
let core::ResolvedAtomTerm::Literal(_, Literal::String(ref name)) = args[0] else {
panic!("expected string literal after `unstable-fn`")
};
let id = if let Some(f) = self.type_info.get_func_type(name) {
ResolvedFunctionId::Lookup(egglog_bridge::TableAction::new(
self.rb.egraph(),
self.func(f),
))
} else if let Some(possible) = self.type_info.get_prims(name) {
let mut ps: Vec<_> = possible.iter().collect();
ps.retain(|p| {
self.type_info
.get_sorts::<FunctionSort>()
.into_iter()
.any(|f| {
let types: Vec<_> = prim
.input()
.iter()
.skip(1)
.chain(f.inputs())
.chain([&f.output()])
.cloned()
.collect();
p.accept(&types, self.type_info)
})
});
assert!(ps.len() == 1, "options for {name}: {ps:?}");
ResolvedFunctionId::Prim(ps.into_iter().next().unwrap().1)
} else {
panic!("no callable for {name}");
};
let partial_arcsorts = prim.input().iter().skip(1).cloned().collect();
qe_args[0] = self.rb.egraph().base_value_constant(ResolvedFunction {
id,
partial_arcsorts,
name: name.clone(),
});
}
(
prim.external_id(),
qe_args,
prim.output().column_ty(self.rb.egraph()),
)
}
fn args<'b>(
&mut self,
args: impl IntoIterator<Item = &'b core::ResolvedAtomTerm>,
) -> Vec<QueryEntry> {
args.into_iter().map(|x| self.entry(x)).collect()
}
fn query(&mut self, query: &core::Query<ResolvedCall, ResolvedVar>, include_subsumed: bool) {
for atom in &query.atoms {
match &atom.head {
ResolvedCall::Func(f) => {
let f = self.func(f);
let args = self.args(&atom.args);
let is_subsumed = match include_subsumed {
true => None,
false => Some(false),
};
self.rb.query_table(f, &args, is_subsumed).unwrap();
}
ResolvedCall::Primitive(p) => {
let (p, args, ty) = self.prim(p, &atom.args);
self.rb.query_prim(p, &args, ty).unwrap()
}
}
}
}
fn actions(&mut self, actions: &core::ResolvedCoreActions) -> Result<(), Error> {
for action in &actions.0 {
match action {
core::GenericCoreAction::Let(span, v, f, args) => {
let v = core::GenericAtomTerm::Var(span.clone(), v.clone());
let y = match f {
ResolvedCall::Func(f) => {
let name = f.name.clone();
let f = self.func(f);
let args = self.args(args);
let span = span.clone();
self.rb.lookup(f, &args, move || {
format!("{span}: lookup of function {name} failed")
})
}
ResolvedCall::Primitive(p) => {
let name = p.name().to_owned();
let (p, args, ty) = self.prim(p, args);
let span = span.clone();
self.rb.call_external_func(p, &args, ty, move || {
format!("{span}: call of primitive {name} failed")
})
}
};
self.entries.insert(v, y.into());
}
core::GenericCoreAction::LetAtomTerm(span, v, x) => {
let v = core::GenericAtomTerm::Var(span.clone(), v.clone());
let x = self.entry(x);
self.entries.insert(v, x);
}
core::GenericCoreAction::Set(_, f, xs, y) => match f {
ResolvedCall::Primitive(..) => panic!("runtime primitive set!"),
ResolvedCall::Func(f) => {
let f = self.func(f);
let args = self.args(xs.iter().chain([y]));
self.rb.set(f, &args)
}
},
core::GenericCoreAction::Change(span, change, f, args) => match f {
ResolvedCall::Primitive(..) => panic!("runtime primitive change!"),
ResolvedCall::Func(f) => {
let name = f.name.clone();
let can_subsume = self.functions[&f.name].can_subsume;
let f = self.func(f);
let args = self.args(args);
match change {
Change::Delete => self.rb.remove(f, &args),
Change::Subsume if can_subsume => self.rb.subsume(f, &args),
Change::Subsume => {
return Err(Error::SubsumeMergeError(name, span.clone()));
}
}
}
},
core::GenericCoreAction::Union(_, x, y) => {
let x = self.entry(x);
let y = self.entry(y);
self.rb.union(x, y)
}
core::GenericCoreAction::Panic(_, message) => self.rb.panic(message.clone()),
}
}
Ok(())
}
fn build(self) -> egglog_bridge::RuleId {
self.rb.build()
}
}
fn literal_to_entry(egraph: &egglog_bridge::EGraph, l: &Literal) -> QueryEntry {
match l {
Literal::Int(x) => egraph.base_value_constant::<i64>(*x),
Literal::Float(x) => egraph.base_value_constant::<sort::F>(x.into()),
Literal::String(x) => egraph.base_value_constant::<sort::S>(sort::S::new(x.clone())),
Literal::Bool(x) => egraph.base_value_constant::<bool>(*x),
Literal::Unit => egraph.base_value_constant::<()>(()),
}
}
fn literal_to_value(egraph: &egglog_bridge::EGraph, l: &Literal) -> Value {
match l {
Literal::Int(x) => egraph.base_values().get::<i64>(*x),
Literal::Float(x) => egraph.base_values().get::<sort::F>(x.into()),
Literal::String(x) => egraph.base_values().get::<sort::S>(sort::S::new(x.clone())),
Literal::Bool(x) => egraph.base_values().get::<bool>(*x),
Literal::Unit => egraph.base_values().get::<()>(()),
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error(transparent)]
ParseError(#[from] ParseError),
#[error(transparent)]
NotFoundError(#[from] NotFoundError),
#[error(transparent)]
TypeError(#[from] TypeError),
#[error("Errors:\n{}", ListDisplay(.0, "\n"))]
TypeErrors(Vec<TypeError>),
#[error("{}\nCheck failed: \n{}", .1, ListDisplay(.0, "\n"))]
CheckError(Vec<Fact>, Span),
#[error("{1}\nNo such ruleset: {0}")]
NoSuchRuleset(String, Span),
#[error(
"{1}\nAttempted to add a rule to combined ruleset {0}. Combined rulesets may only depend on other rulesets."
)]
CombinedRulesetError(String, Span),
#[error("{0}")]
BackendError(String),
#[error("{0}\nTried to pop too much")]
Pop(Span),
#[error("{0}\nCommand should have failed.")]
ExpectFail(Span),
#[error("{2}\nIO error: {0}: {1}")]
IoError(PathBuf, std::io::Error, Span),
#[error("{1}\nCannot subsume function with merge: {0}")]
SubsumeMergeError(String, Span),
#[error("extraction failure: {:?}", .0)]
ExtractError(String),
#[error("{1}\n{2}\nShadowing is not allowed, but found {0}")]
Shadowing(String, Span, Span),
#[error("{1}\nCommand already exists: {0}")]
CommandAlreadyExists(String, Span),
#[error("Incorrect format in file '{0}'.")]
InputFileFormatError(String),
#[error(
"Command is not supported by the current proof term encoding implementation.\n\
This typically means the command uses constructs that cannot yet be represented as proof terms.\n\
Consider disabling proof term encoding for this run or rewriting the command to avoid unsupported features.\n\
Offending command: {command}"
)]
UnsupportedProofCommand { command: String },
}
#[cfg(test)]
mod tests {
use crate::constraint::SimpleTypeConstraint;
use crate::sort::*;
use crate::*;
#[derive(Clone)]
struct InnerProduct {
vec: ArcSort,
}
impl Primitive for InnerProduct {
fn name(&self) -> &str {
"inner-product"
}
fn get_type_constraints(&self, span: &Span) -> Box<dyn crate::constraint::TypeConstraint> {
SimpleTypeConstraint::new(
self.name(),
vec![self.vec.clone(), self.vec.clone(), I64Sort.to_arcsort()],
span.clone(),
)
.into_box()
}
fn apply(&self, exec_state: &mut ExecutionState<'_>, args: &[Value]) -> Option<Value> {
let mut sum = 0;
let vec1 = exec_state
.container_values()
.get_val::<VecContainer>(args[0])
.unwrap();
let vec2 = exec_state
.container_values()
.get_val::<VecContainer>(args[1])
.unwrap();
assert_eq!(vec1.data.len(), vec2.data.len());
for (a, b) in vec1.data.iter().zip(vec2.data.iter()) {
let a = exec_state.base_values().unwrap::<i64>(*a);
let b = exec_state.base_values().unwrap::<i64>(*b);
sum += a * b;
}
Some(exec_state.base_values().get::<i64>(sum))
}
}
#[test]
fn test_user_defined_primitive() {
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(None, "(sort IntVec (Vec i64))")
.unwrap();
let int_vec_sort = egraph.get_arcsort_by(|s| {
s.value_type() == Some(std::any::TypeId::of::<VecContainer>())
&& s.inner_sorts()[0].name() == I64Sort.name()
});
egraph.add_primitive(InnerProduct { vec: int_vec_sort });
egraph
.parse_and_run_program(
None,
"
(let a (vec-of 1 2 3 4 5 6))
(let b (vec-of 6 5 4 3 2 1))
(check (= (inner-product a b) 56))
",
)
.unwrap();
}
#[test]
fn test_egraph_send_sync() {
fn is_send<T: Send>(_t: &T) -> bool {
true
}
fn is_sync<T: Sync>(_t: &T) -> bool {
true
}
let egraph = EGraph::default();
assert!(is_send(&egraph) && is_sync(&egraph));
}
fn get_function(egraph: &EGraph, name: &str) -> Function {
egraph.functions.get(name).unwrap().clone()
}
fn get_value(egraph: &EGraph, name: &str) -> Value {
let mut out = None;
let id = get_function(egraph, name).backend_id;
egraph.backend.for_each(id, |row| out = Some(row.vals[0]));
out.unwrap()
}
#[test]
fn test_subsumed_unextractable_rebuild_arg() {
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(
None,
r#"
(datatype Math)
(constructor container (Math) Math)
(constructor exp () Math :cost 100)
(constructor cheap () Math)
(constructor cheap-1 () Math)
; we make the container cheap so that it will be extracted if possible, but then we mark it as subsumed
; so the (exp) expr should be extracted instead
(let res (container (cheap)))
(union res (exp))
(cheap)
(cheap-1)
(subsume (container (cheap)))
"#,
).unwrap();
let orig_cheap_value = get_value(&egraph, "cheap");
let orig_cheap_1_value = get_value(&egraph, "cheap-1");
assert_ne!(orig_cheap_value, orig_cheap_1_value);
egraph
.parse_and_run_program(
None,
r#"
(union (cheap-1) (cheap))
"#,
)
.unwrap();
let new_cheap_value = get_value(&egraph, "cheap");
let new_cheap_1_value = get_value(&egraph, "cheap-1");
assert_eq!(new_cheap_value, new_cheap_1_value);
assert!(new_cheap_value != orig_cheap_value || new_cheap_1_value != orig_cheap_1_value);
let outputs = egraph
.parse_and_run_program(
None,
r#"
(extract res)
"#,
)
.unwrap();
assert_eq!(outputs[0].to_string(), "(exp)\n");
}
#[test]
fn test_subsumed_unextractable_rebuild_self() {
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(
None,
r#"
(datatype Math)
(constructor container (Math) Math)
(constructor exp () Math :cost 100)
(constructor cheap () Math)
(exp)
(let x (cheap))
(subsume (cheap))
"#,
)
.unwrap();
let orig_cheap_value = get_value(&egraph, "cheap");
egraph
.parse_and_run_program(
None,
r#"
(union (exp) x)
"#,
)
.unwrap();
let new_cheap_value = get_value(&egraph, "cheap");
assert_ne!(new_cheap_value, orig_cheap_value);
let res = egraph
.parse_and_run_program(
None,
r#"
(extract x)
"#,
)
.unwrap();
assert_eq!(res[0].to_string(), "(exp)\n");
}
}