use std::borrow::Cow;
use egg::*;
pub use egg_recursive_derive::*;
pub use functo_rs::data::Functor;
use functo_rs::data::{ArrayFunctor, Identity, UndetVec};
pub type Signature<L, T> = <L as Functor>::Container<T>;
pub type AsLanguage<L> = <L as Functor>::Container<Id>;
pub trait Recursive: Functor + Sized
where
AsLanguage<Self>: Language,
{
fn unwrap(self) -> Signature<Self, Self>;
fn wrap(inner: Signature<Self, Self>) -> Self;
fn sclone<T: Clone>(sig: &Signature<Self, T>) -> Signature<Self, T>;
fn sig_each_ref<T>(refs: &Signature<Self, T>) -> Signature<Self, &T>;
fn add_into_rec_expr(self, expr: &mut RecExpr<Signature<Self, Id>>) -> Id
where {
let graph = Self::fmap(|e| e.add_into_rec_expr(expr), self.unwrap());
expr.add(graph)
}
fn into_rec_expr(self) -> (RecExpr<Signature<Self, Id>>, Id) {
let mut expr = RecExpr::default();
let id = self.add_into_rec_expr(&mut expr);
(expr, id)
}
fn from_rec_expr(expr: &RecExpr<Signature<Self, Id>>, id: Id) -> Self {
Self::wrap(Self::fmap(
|e| Self::from_rec_expr(expr, e),
expr[id].clone(),
))
}
}
pub enum Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
{
PatVar(Var),
Wrap(Signature<L, Box<Pat<L>>>),
}
impl<L> Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
{
fn add_vars_to(&self, vars: &mut Vec<Var>) {
match self {
Self::PatVar(v) => {
if !vars.contains(v) {
vars.push(*v)
}
}
Self::Wrap(w) => {
L::fmap(|e| e.add_vars_to(vars), L::sig_each_ref(w));
}
}
}
}
impl<L, N> egg::Searcher<AsLanguage<L>, N> for Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
N: Analysis<AsLanguage<L>>,
{
fn search_eclass_with_limit(
&self,
egraph: &EGraph<AsLanguage<L>, N>,
eclass: Id,
limit: usize,
) -> Option<::egg::SearchMatches<'_, AsLanguage<L>>> {
use ::egg::*;
let pat: Pattern<AsLanguage<L>> = Pattern::from(self);
let SearchMatches {
eclass,
substs,
ast,
} = pat.search_eclass_with_limit(egraph, eclass, limit)?;
Some(SearchMatches {
eclass,
substs,
ast: ast.map(|cow| Cow::Owned(cow.into_owned())),
})
}
fn vars(&self) -> Vec<Var> {
let mut vars = Vec::new();
self.add_vars_to(&mut vars);
vars
}
}
impl<L, N> egg::Applier<AsLanguage<L>, N> for Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
N: Analysis<AsLanguage<L>>,
{
fn apply_one(
&self,
egraph: &mut EGraph<AsLanguage<L>, N>,
eclass: Id,
subst: &Subst,
searcher_ast: Option<&PatternAst<AsLanguage<L>>>,
rule_name: Symbol,
) -> Vec<Id> {
Pattern::from(self).apply_one(egraph, eclass, subst, searcher_ast, rule_name)
}
}
impl<L> Clone for Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
{
fn clone(&self) -> Self {
match self {
Self::PatVar(arg0) => Self::PatVar(*arg0),
Self::Wrap(arg0) => Self::Wrap(L::sclone(arg0)),
}
}
}
impl<L> Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
{
pub fn pat_var<'a, S: Into<Cow<'a, str>>>(v: S) -> Self {
Self::PatVar(format!("?{}", v.into()).parse().unwrap())
}
pub fn unwrap(self) -> ENodeOrVar<Signature<L, Self>> {
match self {
Pat::PatVar(v) => ENodeOrVar::Var(v),
Pat::Wrap(w) => ENodeOrVar::ENode(L::fmap(|e| *e, w)),
}
}
pub fn wrap(inner: Signature<L, Self>) -> Self {
Pat::Wrap(L::fmap(Box::new, inner))
}
}
impl<L> From<Var> for Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
{
fn from(value: Var) -> Self {
Self::PatVar(value)
}
}
impl<L> From<L> for Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
{
fn from(value: L) -> Self {
Self::Wrap(L::fmap(|e| Box::new(e.into()), value.unwrap()))
}
}
impl<L> Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
{
pub fn into_pattern_ast(self) -> PatternAst<AsLanguage<L>> {
let mut ast = PatternAst::default();
self.add_into_pattern_ast(&mut ast);
ast
}
pub fn to_pattern_ast(&self) -> PatternAst<AsLanguage<L>> {
let mut ast = PatternAst::default();
self.add_to_pattern_ast(&mut ast);
ast
}
pub fn from_pattern_ast(ast: &PatternAst<AsLanguage<L>>, id: Id) -> Self {
match &ast[id] {
ENodeOrVar::Var(v) => Pat::PatVar(*v),
ENodeOrVar::ENode(e) => Pat::Wrap(L::fmap(
|e| Box::new(Pat::<L>::from_pattern_ast(ast, e)),
e.clone(),
)),
}
}
pub fn add_into_pattern_ast(self, ast: &mut PatternAst<AsLanguage<L>>) -> Id {
match self {
Pat::PatVar(v) => ast.add(ENodeOrVar::Var(v)),
Pat::Wrap(w) => {
let graph = L::fmap(|e| e.add_into_pattern_ast(ast), w);
ast.add(ENodeOrVar::ENode(graph))
}
}
}
pub fn add_to_pattern_ast(&self, ast: &mut PatternAst<AsLanguage<L>>) -> Id {
match self {
Pat::PatVar(v) => ast.add(ENodeOrVar::Var(*v)),
Pat::Wrap(w) => {
let graph = L::fmap(|e| e.add_to_pattern_ast(ast), L::sig_each_ref(w));
ast.add(ENodeOrVar::ENode(graph))
}
}
}
}
impl<L> From<Pat<L>> for PatternAst<AsLanguage<L>>
where
L: Recursive,
AsLanguage<L>: Language,
{
fn from(value: Pat<L>) -> Self {
value.into_pattern_ast()
}
}
impl<L> From<&Pat<L>> for PatternAst<AsLanguage<L>>
where
L: Recursive,
AsLanguage<L>: Language,
{
fn from(value: &Pat<L>) -> Self {
value.to_pattern_ast()
}
}
impl<L> From<&Pat<L>> for Pattern<AsLanguage<L>>
where
L: Recursive,
AsLanguage<L>: Language,
{
fn from(value: &Pat<L>) -> Self {
Pattern::new(value.to_pattern_ast())
}
}
impl<L> From<PatternAst<AsLanguage<L>>> for Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
{
fn from(value: PatternAst<AsLanguage<L>>) -> Self {
let root = value.root();
Pat::from_pattern_ast(&value, root)
}
}
impl<L> From<Pat<L>> for Pattern<AsLanguage<L>>
where
L: Recursive,
AsLanguage<L>: Language,
{
fn from(value: Pat<L>) -> Self {
Pattern::new(value.into_pattern_ast())
}
}
impl<L> From<Pattern<AsLanguage<L>>> for Pat<L>
where
L: Recursive,
AsLanguage<L>: Language,
{
fn from(value: Pattern<AsLanguage<L>>) -> Self {
Pat::from(value.ast)
}
}
pub type LangChildren<T> = <<T as IntoLanguageChildren>::RawData as Functor>::Container<Id>;
pub type RawData<T, U> = <<T as IntoLanguageChildren>::RawData as Functor>::Container<U>;
pub type View<T, U> = <<T as IntoLanguageChildren>::View as Functor>::Container<U>;
pub trait IntoLanguageChildren: Sized
where
Self::View: Functor<Container<Self::Param> = Self>,
Self::RawData: Functor,
<Self::RawData as Functor>::Container<Id>: LanguageChildren,
{
type Param;
type View;
type RawData;
fn view<T>(children: RawData<Self, T>) -> View<Self, T>;
fn unview<T>(children: View<Self, T>) -> RawData<Self, T>;
fn map<U>(self, f: impl FnMut(Self::Param) -> U) -> View<Self, U> {
<Self::View as Functor>::fmap(f, self)
}
fn raw_as_refs<T>(refs: &RawData<Self, T>) -> RawData<Self, &T>;
}
impl<T> IntoLanguageChildren for Vec<T> {
type View = UndetVec;
type Param = T;
type RawData = UndetVec;
#[inline(always)]
fn view<U>(children: Vec<U>) -> Vec<U> {
children
}
#[inline(always)]
fn unview<U>(children: Vec<U>) -> Vec<U> {
children
}
#[inline(always)]
fn raw_as_refs<U>(refs: &Vec<U>) -> Vec<&U> {
refs.iter().collect()
}
}
impl<T, const N: usize> IntoLanguageChildren for [T; N] {
type View = ArrayFunctor<N>;
type Param = T;
type RawData = ArrayFunctor<N>;
#[inline(always)]
fn view<U>(children: [U; N]) -> [U; N] {
children
}
#[inline(always)]
fn unview<U>(children: [U; N]) -> [U; N] {
children
}
#[inline(always)]
fn raw_as_refs<U>(refs: &[U; N]) -> [&U; N] {
refs.each_ref()
}
}
impl IntoLanguageChildren for Id {
type View = Identity;
type Param = Id;
type RawData = Identity;
#[inline(always)]
fn view<T>(children: T) -> T {
children
}
#[inline(always)]
fn unview<T>(children: T) -> T {
children
}
#[inline(always)]
fn raw_as_refs<T>(refs: &T) -> &T {
refs
}
}
#[macro_export]
macro_rules! rewrite {
(
$name:expr;
$lhs:expr => $rhs:expr
) => {{
let searcher = ::egg::Pattern::from($lhs);
let applier = ::egg::Pattern::from($rhs);
::egg::Rewrite::new($name.to_string(), searcher, applier).unwrap()
}};
(
$name:expr;
$lhs:expr => $rhs:expr;
$(if $cond:expr)*
) => {{
let searcher = ::egg::Pattern::from($lhs);
let core_applier = ::egg::Pattern::from($rhs);
let applier = ::egg::__rewrite!(@applier core_applier; $($cond,)*);
::egg::Rewrite::new($name.to_string(), searcher, applier).unwrap()
}};
}
#[cfg(test)]
mod tests {
use std::fs;
use std::path::Path;
#[test]
fn test_invalid_language_derivations() {
let t = trybuild::TestCases::new();
let mut chs = fs::read_dir(Path::new("tests").join("invalid").join("language"))
.unwrap()
.flatten()
.filter(|p| p.path().extension() == Some("rs".as_ref()))
.collect::<Vec<_>>();
chs.sort_by_key(|p| p.file_name());
for entry in chs {
t.compile_fail(entry.path());
}
}
#[test]
fn test_success() {
let t = trybuild::TestCases::new();
let mut chs = fs::read_dir(Path::new("tests").join("success"))
.unwrap()
.flatten()
.filter(|p| p.path().extension() == Some("rs".as_ref()))
.collect::<Vec<_>>();
chs.sort_by_key(|p| p.file_name());
for entry in chs {
t.pass(entry.path());
}
}
}