use std::hash::Hasher;
use std::marker::PhantomData;
use std::ops::AddAssign;
use crate::{constraint::grounded_check, *};
use egglog_ast::generic_ast::{Change, GenericAction, GenericActions, GenericExpr};
use egglog_ast::span::Span;
use egglog_ast::util::ListDisplay;
use typechecking::{FuncType, PrimitiveWithId, TypeError};
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum HeadOrEq<Head> {
Head(Head),
Eq,
}
pub(crate) type StringOrEq = HeadOrEq<String>;
impl From<String> for StringOrEq {
fn from(value: String) -> Self {
StringOrEq::Head(value)
}
}
impl<Head> HeadOrEq<Head> {
pub fn is_eq(&self) -> bool {
matches!(self, HeadOrEq::Eq)
}
}
#[derive(Debug, Clone)]
pub struct SpecializedPrimitive {
primitive: PrimitiveWithId,
input: Vec<ArcSort>,
output: ArcSort,
}
impl SpecializedPrimitive {
pub fn name(&self) -> &str {
self.primitive.0.name()
}
pub fn output(&self) -> &ArcSort {
&self.output
}
pub fn input(&self) -> &[ArcSort] {
&self.input
}
pub(crate) fn external_id(&self) -> ExternalFunctionId {
self.primitive.1
}
}
impl PartialEq for SpecializedPrimitive {
fn eq(&self, other: &Self) -> bool {
self.primitive.1 == other.primitive.1
}
}
impl Eq for SpecializedPrimitive {}
impl Hash for SpecializedPrimitive {
fn hash<H: Hasher>(&self, state: &mut H) {
self.primitive.1.hash(state);
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ResolvedCall {
Func(FuncType),
Primitive(SpecializedPrimitive),
}
impl ResolvedCall {
pub fn name(&self) -> &str {
match self {
ResolvedCall::Func(func) => &func.name,
ResolvedCall::Primitive(prim) => prim.name(),
}
}
pub fn output(&self) -> &ArcSort {
match self {
ResolvedCall::Func(func) => &func.output,
ResolvedCall::Primitive(prim) => prim.output(),
}
}
pub(crate) fn view_types(&self) -> Vec<ArcSort> {
match self {
ResolvedCall::Func(func) => {
let mut types = func.input.clone();
types.push(func.output.clone());
types
}
ResolvedCall::Primitive(prim) => prim.input().to_vec(),
}
}
pub fn from_resolution_func_types(
head: &str,
types: &[ArcSort],
typeinfo: &TypeInfo,
) -> Option<ResolvedCall> {
if let Some(ty) = typeinfo.get_func_type(head) {
let expected = ty.input.iter().map(|s| s.name());
let actual = types.iter().map(|s| s.name());
if expected.eq(actual) {
return Some(ResolvedCall::Func(ty.clone()));
}
}
None
}
pub fn from_resolution(head: &str, types: &[ArcSort], typeinfo: &TypeInfo) -> ResolvedCall {
let mut resolved_call = Vec::with_capacity(1);
if let Some(ty) = typeinfo.get_func_type(head) {
let expected = ty.input.iter().chain(once(&ty.output)).map(|s| s.name());
let actual = types.iter().map(|s| s.name());
if expected.eq(actual) {
resolved_call.push(ResolvedCall::Func(ty.clone()));
}
}
if let Some(primitives) = typeinfo.get_prims(head) {
for primitive in primitives {
if primitive.accept(types, typeinfo) {
let (out, inp) = types.split_last().unwrap();
resolved_call.push(ResolvedCall::Primitive(SpecializedPrimitive {
primitive: primitive.clone(),
input: inp.to_vec(),
output: out.clone(),
}));
}
}
}
assert!(
resolved_call.len() == 1,
"Ambiguous resolution for {:?}",
head,
);
resolved_call.pop().unwrap()
}
}
impl Display for ResolvedCall {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ResolvedCall::Func(func) => write!(f, "{}", func.name),
ResolvedCall::Primitive(prim) => write!(f, "{}", prim.name()),
}
}
}
pub trait IsFunc {
fn is_constructor(&self, type_info: &TypeInfo) -> bool;
}
impl IsFunc for ResolvedCall {
fn is_constructor(&self, type_info: &TypeInfo) -> bool {
match self {
ResolvedCall::Func(func) => type_info.is_constructor(&func.name),
ResolvedCall::Primitive(_) => false,
}
}
}
impl IsFunc for String {
fn is_constructor(&self, type_info: &TypeInfo) -> bool {
type_info.is_constructor(self)
}
}
#[derive(Debug, Clone)]
pub enum GenericAtomTerm<Leaf> {
Var(Span, Leaf),
Literal(Span, Literal),
Global(Span, Leaf),
}
impl<Leaf> PartialEq for GenericAtomTerm<Leaf>
where
Leaf: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(GenericAtomTerm::Var(_, v1), GenericAtomTerm::Var(_, v2)) => v1 == v2,
(GenericAtomTerm::Literal(_, l1), GenericAtomTerm::Literal(_, l2)) => l1 == l2,
(GenericAtomTerm::Global(_, g1), GenericAtomTerm::Global(_, g2)) => g1 == g2,
_ => false,
}
}
}
impl<Leaf> Eq for GenericAtomTerm<Leaf> where Leaf: Eq {}
impl<Leaf> Hash for GenericAtomTerm<Leaf>
where
Leaf: Hash,
{
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
GenericAtomTerm::Var(_, v) => v.hash(state),
GenericAtomTerm::Literal(_, l) => l.hash(state),
GenericAtomTerm::Global(_, g) => g.hash(state),
}
}
}
pub type AtomTerm = GenericAtomTerm<String>;
pub type ResolvedAtomTerm = GenericAtomTerm<ResolvedVar>;
impl<Leaf> GenericAtomTerm<Leaf> {
pub fn span(&self) -> &Span {
match self {
GenericAtomTerm::Var(span, _) => span,
GenericAtomTerm::Literal(span, _) => span,
GenericAtomTerm::Global(span, _) => span,
}
}
}
impl<Leaf: Clone> GenericAtomTerm<Leaf> {
pub fn to_expr<Head>(&self) -> GenericExpr<Head, Leaf> {
match self {
GenericAtomTerm::Var(span, v) => GenericExpr::Var(span.clone(), v.clone()),
GenericAtomTerm::Literal(span, l) => GenericExpr::Lit(span.clone(), l.clone()),
GenericAtomTerm::Global(span, v) => GenericExpr::Var(span.clone(), v.clone()),
}
}
}
impl ResolvedAtomTerm {
pub fn output(&self) -> ArcSort {
match self {
ResolvedAtomTerm::Var(_, v) => v.sort.clone(),
ResolvedAtomTerm::Literal(_, l) => literal_sort(l),
ResolvedAtomTerm::Global(_, v) => v.sort.clone(),
}
}
}
impl std::fmt::Display for AtomTerm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AtomTerm::Var(_, v) => write!(f, "{v}"),
AtomTerm::Literal(_, lit) => write!(f, "{lit}"),
AtomTerm::Global(_, g) => write!(f, "{g}"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GenericAtom<Head, Leaf> {
pub span: Span,
pub head: Head,
pub args: Vec<GenericAtomTerm<Leaf>>,
}
pub type Atom<T> = GenericAtom<T, String>;
impl<T: std::fmt::Display> std::fmt::Display for Atom<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({} {}) ", self.head, ListDisplay(&self.args, " "))
}
}
impl<Head, Leaf> GenericAtom<Head, Leaf>
where
Leaf: Clone + Eq + Hash,
Head: Clone,
{
pub fn vars(&self) -> impl Iterator<Item = Leaf> + '_ {
self.args.iter().filter_map(|t| match t {
GenericAtomTerm::Var(_, v) => Some(v.clone()),
GenericAtomTerm::Literal(..) => None,
GenericAtomTerm::Global(..) => None,
})
}
fn subst(&mut self, subst: &HashMap<Leaf, GenericAtomTerm<Leaf>>) {
for arg in self.args.iter_mut() {
match arg {
GenericAtomTerm::Var(_, v) => {
if let Some(at) = subst.get(v) {
*arg = at.clone();
}
}
GenericAtomTerm::Literal(..) => (),
GenericAtomTerm::Global(..) => (),
}
}
}
}
impl Atom<String> {
pub(crate) fn to_expr(&self) -> Expr {
let n = self.args.len();
Expr::Call(
self.span.clone(),
self.head.clone(),
self.args[0..n - 1]
.iter()
.map(|arg| arg.to_expr())
.collect(),
)
}
}
#[derive(Debug, Clone)]
pub struct Query<Head, Leaf> {
pub atoms: Vec<GenericAtom<Head, Leaf>>,
}
impl<Head, Leaf> Default for Query<Head, Leaf> {
fn default() -> Self {
Self {
atoms: Default::default(),
}
}
}
impl Query<StringOrEq, String> {
pub fn get_constraints(
&self,
type_info: &TypeInfo,
) -> Result<Vec<Box<dyn Constraint<AtomTerm, ArcSort>>>, TypeError> {
let mut constraints = vec![];
for atom in self.atoms.iter() {
constraints.extend(atom.get_constraints(type_info)?.into_iter());
}
Ok(constraints)
}
pub(crate) fn atom_terms(&self) -> HashSet<AtomTerm> {
self.atoms
.iter()
.flat_map(|atom| atom.args.iter().cloned())
.collect()
}
}
impl<Head, Leaf> Query<Head, Leaf>
where
Leaf: Eq + Clone + Hash,
Head: Clone,
{
pub(crate) fn get_vars(&self) -> IndexSet<Leaf> {
self.atoms
.iter()
.flat_map(|atom| atom.vars())
.collect::<IndexSet<_>>()
}
}
impl<Head, Leaf> AddAssign for Query<Head, Leaf> {
fn add_assign(&mut self, rhs: Self) {
self.atoms.extend(rhs.atoms);
}
}
impl std::fmt::Display for Query<String, String> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for atom in &self.atoms {
writeln!(f, "{atom}")?;
}
Ok(())
}
}
impl std::fmt::Display for Query<ResolvedCall, String> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for atom in self.funcs() {
writeln!(f, "{atom}")?;
}
let filters: Vec<_> = self.filters().collect();
if !filters.is_empty() {
writeln!(f, "where ")?;
for filter in &filters {
writeln!(
f,
"({} {})",
filter.head.primitive.0.name(),
ListDisplay(&filter.args, " ")
)?;
}
}
Ok(())
}
}
impl<Leaf: Clone> Query<ResolvedCall, Leaf> {
pub fn filters(&self) -> impl Iterator<Item = GenericAtom<SpecializedPrimitive, Leaf>> + '_ {
self.atoms.iter().filter_map(|atom| match &atom.head {
ResolvedCall::Func(_) => None,
ResolvedCall::Primitive(head) => Some(GenericAtom {
span: atom.span.clone(),
head: head.clone(),
args: atom.args.clone(),
}),
})
}
pub fn funcs(&self) -> impl Iterator<Item = GenericAtom<String, Leaf>> + '_ {
self.atoms.iter().filter_map(|atom| match &atom.head {
ResolvedCall::Func(head) => Some(GenericAtom {
span: atom.span.clone(),
head: head.name.clone(),
args: atom.args.clone(),
}),
ResolvedCall::Primitive(_) => None,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum GenericCoreAction<Head, Leaf> {
Let(Span, Leaf, Head, Vec<GenericAtomTerm<Leaf>>),
LetAtomTerm(Span, Leaf, GenericAtomTerm<Leaf>),
Set(
Span,
Head,
Vec<GenericAtomTerm<Leaf>>,
GenericAtomTerm<Leaf>,
),
Change(Span, Change, Head, Vec<GenericAtomTerm<Leaf>>),
Union(Span, GenericAtomTerm<Leaf>, GenericAtomTerm<Leaf>),
Panic(Span, String),
}
pub type CoreAction = GenericCoreAction<String, String>;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct GenericCoreActions<Head, Leaf>(pub(crate) Vec<GenericCoreAction<Head, Leaf>>);
pub(crate) type ResolvedCoreActions = GenericCoreActions<ResolvedCall, ResolvedVar>;
impl<Head, Leaf> Default for GenericCoreActions<Head, Leaf> {
fn default() -> Self {
Self(vec![])
}
}
impl<Head, Leaf> GenericCoreActions<Head, Leaf>
where
Leaf: Clone + Eq + Hash,
{
pub(crate) fn subst(&mut self, subst: &HashMap<Leaf, GenericAtomTerm<Leaf>>) {
let actions = subst.iter().map(|(symbol, atom_term)| {
GenericCoreAction::LetAtomTerm(
atom_term.span().clone(),
symbol.clone(),
atom_term.clone(),
)
});
let existing_actions = std::mem::take(&mut self.0);
self.0 = actions.chain(existing_actions).collect();
}
fn new(actions: Vec<GenericCoreAction<Head, Leaf>>) -> GenericCoreActions<Head, Leaf> {
Self(actions)
}
pub(crate) fn get_free_vars(&self) -> HashSet<Leaf> {
let at_free_var = |at: &GenericAtomTerm<Leaf>| match at {
GenericAtomTerm::Var(_, v) => Some(v.clone()),
GenericAtomTerm::Literal(..) => None,
GenericAtomTerm::Global(..) => None,
};
let add_from_atom = |free_vars: &mut HashSet<Leaf>, at: &GenericAtomTerm<Leaf>| {
if let Some(v) = at_free_var(at) {
free_vars.insert(v);
}
};
let add_from_atoms = |free_vars: &mut HashSet<Leaf>, ats: &[GenericAtomTerm<Leaf>]| {
ats.iter().flat_map(&at_free_var).for_each(|v| {
free_vars.insert(v);
});
};
let mut free_vars = HashSet::default();
for action in self.0.iter().rev() {
match action {
GenericCoreAction::Let(_span, v, _, ats) => {
add_from_atoms(&mut free_vars, ats);
free_vars.remove(v);
}
GenericCoreAction::LetAtomTerm(_span, v, at) => {
add_from_atom(&mut free_vars, at);
free_vars.remove(v);
}
GenericCoreAction::Set(_span, _, ats, at) => {
add_from_atoms(&mut free_vars, ats);
add_from_atom(&mut free_vars, at);
}
GenericCoreAction::Change(_span, _change, _, ats) => {
add_from_atoms(&mut free_vars, ats);
}
GenericCoreAction::Union(_span, at, at1) => {
add_from_atom(&mut free_vars, at);
add_from_atom(&mut free_vars, at1);
}
GenericCoreAction::Panic(_span, _) => {}
}
}
free_vars
}
}
pub(crate) struct CoreActionContext<'a, Head, Leaf, FG> {
pub typeinfo: &'a TypeInfo,
pub binding: &'a mut IndexSet<Leaf>,
pub fresh_gen: &'a mut FG,
pub union_to_set_optimization: bool,
_marker: PhantomData<fn() -> Head>,
}
impl<'a, Head, Leaf, FG> CoreActionContext<'a, Head, Leaf, FG> {
pub fn new(
typeinfo: &'a TypeInfo,
binding: &'a mut IndexSet<Leaf>,
fresh_gen: &'a mut FG,
union_to_set_optimization: bool,
) -> Self {
Self {
typeinfo,
binding,
fresh_gen,
union_to_set_optimization,
_marker: PhantomData,
}
}
}
pub(crate) trait GenericActionsExt<Head, Leaf> {
#[allow(clippy::type_complexity)]
fn to_core_actions<FG>(
&self,
ctx: &mut CoreActionContext<'_, Head, Leaf, FG>,
) -> Result<(GenericCoreActions<Head, Leaf>, MappedActions<Head, Leaf>), TypeError>
where
Head: Clone + Display + IsFunc,
Leaf: Clone + PartialEq + Eq + Display + Hash,
FG: FreshGen<Head, Leaf>;
}
impl<Head, Leaf> GenericActionsExt<Head, Leaf> for GenericActions<Head, Leaf>
where
Head: Clone + Display + IsFunc,
Leaf: Clone + PartialEq + Eq + Display + Hash,
{
#[allow(clippy::type_complexity)]
fn to_core_actions<FG>(
&self,
ctx: &mut CoreActionContext<'_, Head, Leaf, FG>,
) -> Result<(GenericCoreActions<Head, Leaf>, MappedActions<Head, Leaf>), TypeError>
where
Head: Clone + Display + IsFunc,
Leaf: Clone + PartialEq + Eq + Display + Hash,
FG: FreshGen<Head, Leaf>,
{
let mut norm_actions = vec![];
let mut mapped_actions: MappedActions<Head, Leaf> = GenericActions(vec![]);
let typeinfo = ctx.typeinfo;
let union_to_set_optimization = ctx.union_to_set_optimization;
for action in self.0.iter() {
match action {
GenericAction::Let(span, var, expr) => {
if ctx.binding.contains(var) {
return Err(TypeError::AlreadyDefined(var.to_string(), span.clone()));
}
let mapped_expr = expr.to_core_actions(ctx, &mut norm_actions)?;
norm_actions.push(GenericCoreAction::LetAtomTerm(
span.clone(),
var.clone(),
mapped_expr.get_corresponding_var_or_lit(typeinfo),
));
mapped_actions.0.push(GenericAction::Let(
span.clone(),
var.clone(),
mapped_expr,
));
ctx.binding.insert(var.clone());
}
GenericAction::Set(span, head, args, expr) => {
let mut mapped_args = vec![];
for arg in args {
let mapped_arg = arg.to_core_actions(ctx, &mut norm_actions)?;
mapped_args.push(mapped_arg);
}
let mapped_expr = expr.to_core_actions(ctx, &mut norm_actions)?;
norm_actions.push(GenericCoreAction::Set(
span.clone(),
head.clone(),
mapped_args
.iter()
.map(|e| e.get_corresponding_var_or_lit(typeinfo))
.collect(),
mapped_expr.get_corresponding_var_or_lit(typeinfo),
));
let v = ctx.fresh_gen.fresh(head);
mapped_actions.0.push(GenericAction::Set(
span.clone(),
CorrespondingVar::new(head.clone(), v),
mapped_args,
mapped_expr,
));
}
GenericAction::Change(span, change, head, args) => {
let mut mapped_args = vec![];
for arg in args {
let mapped_arg = arg.to_core_actions(ctx, &mut norm_actions)?;
mapped_args.push(mapped_arg);
}
norm_actions.push(GenericCoreAction::Change(
span.clone(),
*change,
head.clone(),
mapped_args
.iter()
.map(|e| e.get_corresponding_var_or_lit(typeinfo))
.collect(),
));
let v = ctx.fresh_gen.fresh(head);
mapped_actions.0.push(GenericAction::Change(
span.clone(),
*change,
CorrespondingVar::new(head.clone(), v),
mapped_args,
));
}
GenericAction::Union(span, e1, e2) => {
match (e1, e2) {
(var @ GenericExpr::Var(..), GenericExpr::Call(_, f, args))
| (GenericExpr::Call(_, f, args), var @ GenericExpr::Var(..))
if f.is_constructor(typeinfo) && union_to_set_optimization =>
{
let head = f;
let expr = var;
let mut mapped_args = vec![];
for arg in args {
let mapped_arg = arg.to_core_actions(ctx, &mut norm_actions)?;
mapped_args.push(mapped_arg);
}
let mapped_expr = expr.to_core_actions(ctx, &mut norm_actions)?;
norm_actions.push(GenericCoreAction::Set(
span.clone(),
head.clone(),
mapped_args
.iter()
.map(|e| e.get_corresponding_var_or_lit(typeinfo))
.collect(),
mapped_expr.get_corresponding_var_or_lit(typeinfo),
));
let v = ctx.fresh_gen.fresh(head);
mapped_actions.0.push(GenericAction::Set(
span.clone(),
CorrespondingVar::new(head.clone(), v),
mapped_args,
mapped_expr,
));
}
_ => {
let mapped_e1 = e1.to_core_actions(ctx, &mut norm_actions)?;
let mapped_e2 = e2.to_core_actions(ctx, &mut norm_actions)?;
norm_actions.push(GenericCoreAction::Union(
span.clone(),
mapped_e1.get_corresponding_var_or_lit(typeinfo),
mapped_e2.get_corresponding_var_or_lit(typeinfo),
));
mapped_actions.0.push(GenericAction::Union(
span.clone(),
mapped_e1,
mapped_e2,
));
}
};
}
GenericAction::Panic(span, string) => {
norm_actions.push(GenericCoreAction::Panic(span.clone(), string.clone()));
mapped_actions
.0
.push(GenericAction::Panic(span.clone(), string.clone()));
}
GenericAction::Expr(span, expr) => {
let mapped_expr = expr.to_core_actions(ctx, &mut norm_actions)?;
mapped_actions
.0
.push(GenericAction::Expr(span.clone(), mapped_expr));
}
}
}
Ok((GenericCoreActions::new(norm_actions), mapped_actions))
}
}
pub(crate) trait GenericExprExt<Head, Leaf>
where
Head: Clone + Display,
Leaf: Clone + PartialEq + Eq + Display + Hash,
{
fn to_query(
&self,
typeinfo: &TypeInfo,
fresh_gen: &mut impl FreshGen<Head, Leaf>,
) -> (
Vec<GenericAtom<HeadOrEq<Head>, Leaf>>,
MappedExpr<Head, Leaf>,
);
fn to_core_actions<FG: FreshGen<Head, Leaf>>(
&self,
ctx: &mut CoreActionContext<'_, Head, Leaf, FG>,
out_actions: &mut Vec<GenericCoreAction<Head, Leaf>>,
) -> Result<MappedExpr<Head, Leaf>, TypeError>;
}
impl<Head, Leaf> GenericExprExt<Head, Leaf> for GenericExpr<Head, Leaf>
where
Head: Clone + Display,
Leaf: Clone + PartialEq + Eq + Display + Hash,
{
fn to_query(
&self,
typeinfo: &TypeInfo,
fresh_gen: &mut impl FreshGen<Head, Leaf>,
) -> (
Vec<GenericAtom<HeadOrEq<Head>, Leaf>>,
MappedExpr<Head, Leaf>,
)
where
Head: Clone + Display,
Leaf: Clone + PartialEq + Eq + Display + Hash,
{
match self {
GenericExpr::Lit(span, lit) => (vec![], GenericExpr::Lit(span.clone(), lit.clone())),
GenericExpr::Var(span, v) => (vec![], GenericExpr::Var(span.clone(), v.clone())),
GenericExpr::Call(span, f, children) => {
let fresh = fresh_gen.fresh(f);
let mut new_children = vec![];
let mut atoms = vec![];
let mut child_exprs = vec![];
for child in children {
let (child_atoms, child_expr) = child.to_query(typeinfo, fresh_gen);
let child_atomterm = child_expr.get_corresponding_var_or_lit(typeinfo);
new_children.push(child_atomterm);
atoms.extend(child_atoms);
child_exprs.push(child_expr);
}
let args = {
new_children.push(GenericAtomTerm::Var(span.clone(), fresh.clone()));
new_children
};
atoms.push(GenericAtom {
span: span.clone(),
head: HeadOrEq::Head(f.clone()),
args,
});
(
atoms,
GenericExpr::Call(
span.clone(),
CorrespondingVar::new(f.clone(), fresh),
child_exprs,
),
)
}
}
}
fn to_core_actions<FG: FreshGen<Head, Leaf>>(
&self,
ctx: &mut CoreActionContext<'_, Head, Leaf, FG>,
out_actions: &mut Vec<GenericCoreAction<Head, Leaf>>,
) -> Result<MappedExpr<Head, Leaf>, TypeError> {
let typeinfo = ctx.typeinfo;
match self {
GenericExpr::Lit(span, lit) => Ok(GenericExpr::Lit(span.clone(), lit.clone())),
GenericExpr::Var(span, v) => {
let sym = v.to_string();
if ctx.binding.contains(v) || typeinfo.is_global(&sym) {
Ok(GenericExpr::Var(span.clone(), v.clone()))
} else {
Err(TypeError::Unbound(sym, span.clone()))
}
}
GenericExpr::Call(span, f, args) => {
let mut norm_args = vec![];
let mut mapped_args = vec![];
for arg in args {
let mapped_arg = arg.to_core_actions(ctx, out_actions)?;
norm_args.push(mapped_arg.get_corresponding_var_or_lit(typeinfo));
mapped_args.push(mapped_arg);
}
let var = ctx.fresh_gen.fresh(f);
ctx.binding.insert(var.clone());
out_actions.push(GenericCoreAction::Let(
span.clone(),
var.clone(),
f.clone(),
norm_args,
));
Ok(GenericExpr::Call(
span.clone(),
CorrespondingVar::new(f.clone(), var),
mapped_args,
))
}
}
}
}
#[derive(Debug, Clone)]
pub struct GenericCoreRule<HeadQ, HeadA, Leaf> {
pub span: Span,
pub body: Query<HeadQ, Leaf>,
pub head: GenericCoreActions<HeadA, Leaf>,
}
pub(crate) type CoreRule = GenericCoreRule<StringOrEq, String, String>;
pub(crate) type ResolvedCoreRule = GenericCoreRule<ResolvedCall, ResolvedCall, ResolvedVar>;
impl<Head1, Head2, Leaf> GenericCoreRule<Head1, Head2, Leaf>
where
Head1: Clone,
Head2: Clone,
Leaf: Clone + Eq + Hash,
{
pub fn subst(&mut self, subst: &HashMap<Leaf, GenericAtomTerm<Leaf>>) {
for atom in &mut self.body.atoms {
atom.subst(subst);
}
self.head.subst(subst);
}
}
impl<Head, Leaf> GenericCoreRule<HeadOrEq<Head>, Head, Leaf>
where
Leaf: Eq + Clone + Hash + Debug,
Head: Clone,
{
pub(crate) fn canonicalize(
self,
value_eq: impl Fn(&GenericAtomTerm<Leaf>, &GenericAtomTerm<Leaf>) -> Head,
) -> GenericCoreRule<Head, Head, Leaf> {
let mut result_rule = self;
loop {
let mut to_subst = None;
for atom in result_rule.body.atoms.iter() {
if atom.head.is_eq() && atom.args[0] != atom.args[1] {
match &atom.args[..] {
[GenericAtomTerm::Var(_, x), y] | [y, GenericAtomTerm::Var(_, x)] => {
to_subst = Some((x, y));
break;
}
_ => (),
}
}
}
if let Some((x, y)) = to_subst {
let subst = HashMap::from_iter([(x.clone(), y.clone())]);
result_rule.subst(&subst);
} else {
break;
}
}
let atoms = result_rule
.body
.atoms
.into_iter()
.filter_map(|atom| match atom.head {
HeadOrEq::Eq => {
assert_eq!(atom.args.len(), 2);
match (&atom.args[0], &atom.args[1]) {
(GenericAtomTerm::Var(_, v1), GenericAtomTerm::Var(_, v2)) => {
assert_eq!(v1, v2);
None
}
(GenericAtomTerm::Var(..), _) | (_, GenericAtomTerm::Var(..)) => {
panic!("equalities between variable and non-variable arguments should have been canonicalized")
}
(at1, at2) => {
if at1 == at2 {
None
} else {
Some(GenericAtom {
span: atom.span.clone(),
head: value_eq(&atom.args[0], &atom.args[1]),
args: vec![
atom.args[0].clone(),
atom.args[1].clone(),
GenericAtomTerm::Literal(atom.span.clone(), Literal::Unit),
],
})
}
},
}
}
HeadOrEq::Head(symbol) => Some(GenericAtom {
span: atom.span.clone(),
head: symbol,
args: atom.args,
}),
})
.collect();
GenericCoreRule {
span: result_rule.span,
body: Query { atoms },
head: result_rule.head,
}
}
}
fn equiv_groups_to_eq_constraints<Head, Leaf>(
groups: &HashMap<(Head, Vec<GenericAtomTerm<Leaf>>), Vec<GenericAtomTerm<Leaf>>>,
span: &Span,
) -> Vec<GenericAtom<HeadOrEq<Head>, Leaf>>
where
Leaf: Eq + Clone + Hash + Debug,
Head: Clone,
{
let mut eq_constraints = vec![];
for group in groups.values() {
let first = &group[0];
for other in &group[1..] {
if first == other {
continue;
}
eq_constraints.push(GenericAtom {
span: span.clone(),
head: HeadOrEq::Eq,
args: vec![first.clone(), other.clone()],
});
}
}
eq_constraints
}
impl<Head, Leaf> GenericCoreRule<Head, Head, Leaf>
where
Leaf: Eq + Clone + Hash + Debug,
Head: Clone + Eq + Hash,
{
pub(crate) fn remove_dup_vars(
mut self,
value_eq: impl Fn(&GenericAtomTerm<Leaf>, &GenericAtomTerm<Leaf>) -> Head,
) -> Self {
let mut groups: HashMap<(Head, Vec<GenericAtomTerm<Leaf>>), Vec<GenericAtomTerm<Leaf>>> =
HashMap::default();
self.body.atoms.retain(|atom| {
let (out, inp) = atom.args.split_last().unwrap();
let key = (atom.head.clone(), inp.to_owned());
let group = groups.entry(key).or_default();
group.push(out.clone());
group.len() == 1
});
let new_atoms = equiv_groups_to_eq_constraints(&groups, &self.span);
if new_atoms.is_empty() {
self
} else {
let atoms: Vec<GenericAtom<HeadOrEq<Head>, Leaf>> = new_atoms
.into_iter()
.chain(self.body.atoms.into_iter().map(|atom| GenericAtom {
span: atom.span,
head: HeadOrEq::Head(atom.head),
args: atom.args,
}))
.collect();
GenericCoreRule {
span: self.span,
body: Query { atoms },
head: self.head,
}
.canonicalize(&value_eq)
.remove_dup_vars(value_eq)
}
}
}
pub(crate) trait GenericRuleExt<Head, Leaf> {
fn to_core_rule(
&self,
typeinfo: &TypeInfo,
fresh_gen: &mut impl FreshGen<Head, Leaf>,
union_to_set_optimization: bool,
) -> Result<GenericCoreRule<HeadOrEq<Head>, Head, Leaf>, TypeError>
where
Head: Clone + Display + IsFunc,
Leaf: Clone + PartialEq + Eq + Display + Hash + Debug;
}
impl<Head, Leaf> GenericRuleExt<Head, Leaf> for GenericRule<Head, Leaf>
where
Head: Clone + Display + IsFunc,
Leaf: Clone + PartialEq + Eq + Display + Hash + Debug,
{
fn to_core_rule(
&self,
typeinfo: &TypeInfo,
fresh_gen: &mut impl FreshGen<Head, Leaf>,
union_to_set_optimization: bool,
) -> Result<GenericCoreRule<HeadOrEq<Head>, Head, Leaf>, TypeError>
where
Head: Clone + Display + IsFunc,
Leaf: Clone + PartialEq + Eq + Display + Hash + Debug,
{
let (body, _correspondence) = Facts(self.body.clone()).to_query(typeinfo, fresh_gen);
let mut binding = body.get_vars();
let mut ctx =
CoreActionContext::new(typeinfo, &mut binding, fresh_gen, union_to_set_optimization);
let (head, _correspondence) = self.head.to_core_actions(&mut ctx)?;
Ok(GenericCoreRule {
span: self.span.clone(),
body,
head,
})
}
}
pub(crate) trait ResolvedRuleExt {
fn to_canonicalized_core_rule(
&self,
typeinfo: &TypeInfo,
fresh_gen: &mut SymbolGen,
union_to_set_optimization: bool,
) -> Result<ResolvedCoreRule, TypeError>;
}
impl ResolvedRuleExt for ResolvedRule {
fn to_canonicalized_core_rule(
&self,
typeinfo: &TypeInfo,
fresh_gen: &mut SymbolGen,
union_to_set_optimization: bool,
) -> Result<ResolvedCoreRule, TypeError> {
let value_eq = &typeinfo.get_prims("value-eq").unwrap()[0];
let value_eq = |at1: &ResolvedAtomTerm, at2: &ResolvedAtomTerm| {
ResolvedCall::Primitive(SpecializedPrimitive {
primitive: value_eq.clone(),
input: vec![at1.output(), at2.output()],
output: UnitSort.to_arcsort(),
})
};
let rule = self.to_core_rule(typeinfo, fresh_gen, union_to_set_optimization)?;
grounded_check(&rule)?;
let rule = rule.canonicalize(&value_eq);
let rule = rule.remove_dup_vars(value_eq);
Ok(rule)
}
}
#[cfg(test)]
mod tests {
use super::*;
type TestCoreRule = GenericCoreRule<String, String, String>;
fn make_var(name: &str) -> GenericAtomTerm<String> {
GenericAtomTerm::Var(span!(), name.to_string())
}
fn make_atom(head: &str, args: Vec<&str>) -> GenericAtom<String, String> {
GenericAtom {
span: span!(),
head: head.to_string(),
args: args.into_iter().map(make_var).collect(),
}
}
fn value_eq_string(_at1: &GenericAtomTerm<String>, _at2: &GenericAtomTerm<String>) -> String {
"value-eq".to_string()
}
#[test]
fn test_remove_dup_vars_basic() {
let rule = TestCoreRule {
span: span!(),
body: Query {
atoms: vec![
make_atom("R", vec!["x", "y", "z1"]),
make_atom("R", vec!["x", "y", "z2"]),
make_atom("R", vec!["x", "z3"]),
make_atom("R", vec!["a", "b", "z4"]),
make_atom("R", vec!["c", "d", "z5"]),
],
},
head: GenericCoreActions::default(),
};
let result = rule.remove_dup_vars(value_eq_string);
assert_eq!(result.body.atoms.len(), 4);
assert_eq!(result.body.atoms[0].head, "R");
assert_eq!(result.body.atoms[0].args[0], make_var("x"));
assert_eq!(result.body.atoms[0].args[1], make_var("y"));
assert_eq!(result.body.atoms[1].args.len(), 2);
}
#[test]
fn test_remove_dup_vars_fixpoint() {
let rule = TestCoreRule {
span: span!(),
body: Query {
atoms: vec![
make_atom("R", vec!["x", "y", "z1"]),
make_atom("R", vec!["x", "y", "z2"]),
make_atom("R", vec!["x", "y", "z3"]),
make_atom("S", vec!["z1", "z2", "z3"]),
make_atom("S", vec!["z2", "z1", "z4"]),
],
},
head: GenericCoreActions::default(),
};
let result = rule.remove_dup_vars(value_eq_string);
assert_eq!(result.body.atoms.len(), 2);
assert_eq!(result.body.atoms[0].head, "R");
assert_eq!(result.body.atoms[0].args[0], make_var("x"));
assert_eq!(result.body.atoms[0].args[1], make_var("y"));
assert_eq!(result.body.atoms[1].args[0], result.body.atoms[1].args[1]);
let rule = TestCoreRule {
span: span!(),
body: Query {
atoms: vec![
make_atom("R", vec!["x", "y", "z1"]),
make_atom("R", vec!["x", "y", "z2"]),
make_atom("R", vec!["x", "y", "z3"]),
make_atom("R", vec!["z1", "z2", "z1"]),
make_atom("R", vec!["z2", "z1", "x"]),
make_atom("R", vec!["z2", "z2", "y"]),
],
},
head: GenericCoreActions::default(),
};
let result = rule.remove_dup_vars(value_eq_string);
assert_eq!(result.body.atoms.len(), 1);
assert_eq!(result.body.atoms[0].head, "R");
assert_eq!(result.body.atoms[0].args[0], result.body.atoms[0].args[1]);
assert_eq!(result.body.atoms[0].args[0], result.body.atoms[0].args[2]);
}
#[test]
fn test_remove_dup_vars_with_actions_using_removed_var() {
let rule = TestCoreRule {
span: span!(),
body: Query {
atoms: vec![
make_atom("R", vec!["x", "y", "z1"]),
make_atom("R", vec!["x", "y", "z2"]),
],
},
head: GenericCoreActions(vec![GenericCoreAction::Union(
span!(),
make_var("z2"),
make_var("z1"),
)]),
};
let result = rule.remove_dup_vars(value_eq_string);
assert_eq!(result.body.atoms.len(), 1);
assert!(matches!(
&result.head.0.as_slice(),
&[
GenericCoreAction::LetAtomTerm(_, _, _),
GenericCoreAction::Union(_, _, _)
]
));
}
}