use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum Term {
Var(Arc<str>),
App {
op: Arc<str>,
args: Vec<Self>,
},
}
impl Term {
#[must_use]
pub fn var(name: impl Into<Arc<str>>) -> Self {
Self::Var(name.into())
}
#[must_use]
pub fn app(op: impl Into<Arc<str>>, args: Vec<Self>) -> Self {
Self::App {
op: op.into(),
args,
}
}
#[must_use]
pub fn constant(op: impl Into<Arc<str>>) -> Self {
Self::App {
op: op.into(),
args: Vec::new(),
}
}
#[must_use]
pub fn substitute(&self, subst: &rustc_hash::FxHashMap<Arc<str>, Self>) -> Self {
match self {
Self::Var(name) => subst.get(name).cloned().unwrap_or_else(|| self.clone()),
Self::App { op, args } => Self::App {
op: Arc::clone(op),
args: args.iter().map(|a| a.substitute(subst)).collect(),
},
}
}
#[must_use]
pub fn free_vars(&self) -> rustc_hash::FxHashSet<Arc<str>> {
let mut vars = rustc_hash::FxHashSet::default();
self.collect_vars(&mut vars);
vars
}
fn collect_vars(&self, vars: &mut rustc_hash::FxHashSet<Arc<str>>) {
match self {
Self::Var(name) => {
vars.insert(Arc::clone(name));
}
Self::App { args, .. } => {
for arg in args {
arg.collect_vars(vars);
}
}
}
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
match self {
Self::Var(_) => self.clone(),
Self::App { op, args } => Self::App {
op: op_map.get(op).cloned().unwrap_or_else(|| Arc::clone(op)),
args: args.iter().map(|a| a.rename_ops(op_map)).collect(),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct Equation {
pub name: Arc<str>,
pub lhs: Term,
pub rhs: Term,
}
impl Equation {
#[must_use]
pub fn new(name: impl Into<Arc<str>>, lhs: Term, rhs: Term) -> Self {
Self {
name: name.into(),
lhs,
rhs,
}
}
#[must_use]
pub fn rename_ops(&self, op_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
Self {
name: Arc::clone(&self.name),
lhs: self.lhs.rename_ops(op_map),
rhs: self.rhs.rename_ops(op_map),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn term_substitution() {
let term = Term::app("add", vec![Term::var("x"), Term::constant("zero")]);
let mut subst = rustc_hash::FxHashMap::default();
subst.insert(Arc::from("x"), Term::var("y"));
let result = term.substitute(&subst);
assert_eq!(
result,
Term::app("add", vec![Term::var("y"), Term::constant("zero")])
);
}
#[test]
fn free_variables() {
let term = Term::app("f", vec![Term::var("x"), Term::var("y")]);
let vars = term.free_vars();
assert!(vars.contains("x"));
assert!(vars.contains("y"));
assert_eq!(vars.len(), 2);
}
}