use crate::{Expr, Name};
use std::collections::HashMap;
use std::fmt;
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum HintPattern {
Var(Name),
Expr(Expr),
}
impl HintPattern {
pub fn var(name: impl Into<Name>) -> Self {
HintPattern::Var(name.into())
}
pub fn expr(e: Expr) -> Self {
HintPattern::Expr(e)
}
pub fn is_var(&self) -> bool {
matches!(self, HintPattern::Var(_))
}
}
impl fmt::Display for HintPattern {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HintPattern::Var(n) => write!(f, "?{}", n),
HintPattern::Expr(e) => write!(f, "{:?}", e),
}
}
}
#[derive(Clone, Debug)]
pub struct UnifHint {
pub name: Option<Name>,
pub lhs: Expr,
pub rhs: Expr,
pub hypotheses: Vec<(Name, (Expr, Expr))>,
pub priority: i32,
}
impl UnifHint {
pub fn new(lhs: Expr, rhs: Expr) -> Self {
Self {
name: None,
lhs,
rhs,
hypotheses: Vec::new(),
priority: 0,
}
}
pub fn with_hypotheses(lhs: Expr, rhs: Expr, hypotheses: Vec<(Name, (Expr, Expr))>) -> Self {
Self {
name: None,
lhs,
rhs,
hypotheses,
priority: 0,
}
}
pub fn named(mut self, name: impl Into<Name>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_priority(mut self, p: i32) -> Self {
self.priority = p;
self
}
pub fn hypothesis_count(&self) -> usize {
self.hypotheses.len()
}
pub fn is_unconditional(&self) -> bool {
self.hypotheses.is_empty()
}
}
impl fmt::Display for UnifHint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref n) = self.name {
write!(f, "[{}] ", n)?;
}
write!(f, "{:?} ≡ {:?}", self.lhs, self.rhs)?;
if !self.hypotheses.is_empty() {
write!(f, " where ")?;
for (i, (hname, (hl, hr))) in self.hypotheses.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}: {:?} ≡ {:?}", hname, hl, hr)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum HintMatchResult {
Matched(PatternSubst),
NoMatch,
}
impl HintMatchResult {
pub fn is_match(&self) -> bool {
matches!(self, HintMatchResult::Matched(_))
}
#[cfg(test)]
pub fn unwrap_subst(self) -> PatternSubst {
match self {
HintMatchResult::Matched(s) => s,
HintMatchResult::NoMatch => panic!("HintMatchResult::unwrap_subst called on NoMatch"),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct PatternSubst {
bindings: HashMap<String, Expr>,
}
impl PatternSubst {
pub fn new() -> Self {
Self {
bindings: HashMap::new(),
}
}
pub fn bind(&mut self, name: &Name, expr: Expr) -> bool {
let key = name.to_string();
match self.bindings.get(&key) {
Some(existing) if existing != &expr => false,
_ => {
self.bindings.insert(key, expr);
true
}
}
}
pub fn get(&self, name: &Name) -> Option<&Expr> {
self.bindings.get(&name.to_string())
}
pub fn apply(&self, expr: &Expr) -> Expr {
if self.bindings.is_empty() {
return expr.clone();
}
apply_subst_expr(self, expr)
}
pub fn len(&self) -> usize {
self.bindings.len()
}
pub fn is_empty(&self) -> bool {
self.bindings.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &Expr)> {
self.bindings.iter().map(|(k, v)| (k.as_str(), v))
}
}
fn apply_subst_expr(subst: &PatternSubst, expr: &Expr) -> Expr {
match expr {
Expr::FVar(id) => {
let key = format!("{}", id.0);
if let Some(replacement) = subst.bindings.get(&key) {
return replacement.clone();
}
expr.clone()
}
Expr::Const(n, levels) => {
let key = n.to_string();
if let Some(replacement) = subst.bindings.get(&key) {
return replacement.clone();
}
Expr::Const(n.clone(), levels.clone())
}
Expr::App(f, a) => Expr::App(
Box::new(apply_subst_expr(subst, f)),
Box::new(apply_subst_expr(subst, a)),
),
Expr::Lam(bi, n, ty, body) => Expr::Lam(
*bi,
n.clone(),
Box::new(apply_subst_expr(subst, ty)),
Box::new(apply_subst_expr(subst, body)),
),
Expr::Pi(bi, n, ty, body) => Expr::Pi(
*bi,
n.clone(),
Box::new(apply_subst_expr(subst, ty)),
Box::new(apply_subst_expr(subst, body)),
),
Expr::Let(n, ty, val, body) => Expr::Let(
n.clone(),
Box::new(apply_subst_expr(subst, ty)),
Box::new(apply_subst_expr(subst, val)),
Box::new(apply_subst_expr(subst, body)),
),
Expr::Proj(n, idx, inner) => {
Expr::Proj(n.clone(), *idx, Box::new(apply_subst_expr(subst, inner)))
}
_ => expr.clone(),
}
}
#[derive(Clone, Debug, Default)]
pub struct UnifHintDB {
hints: Vec<UnifHint>,
}
impl UnifHintDB {
pub fn new() -> Self {
Self { hints: Vec::new() }
}
pub fn add_hint(&mut self, hint: UnifHint) {
let pos = self
.hints
.iter()
.position(|h| h.priority < hint.priority)
.unwrap_or(self.hints.len());
self.hints.insert(pos, hint);
}
pub fn len(&self) -> usize {
self.hints.len()
}
pub fn is_empty(&self) -> bool {
self.hints.is_empty()
}
pub fn all_hints(&self) -> &[UnifHint] {
&self.hints
}
pub fn find_hints<'a>(
&'a self,
lhs_query: &Expr,
rhs_query: &Expr,
) -> Vec<(&'a UnifHint, PatternSubst, bool)> {
let mut results = Vec::new();
for hint in &self.hints {
let mut subst = PatternSubst::new();
if match_expr_pattern(&hint.lhs, lhs_query, &mut subst)
&& match_expr_pattern(&hint.rhs, rhs_query, &mut subst)
{
results.push((hint, subst, false));
continue;
}
let mut subst2 = PatternSubst::new();
if match_expr_pattern(&hint.rhs, lhs_query, &mut subst2)
&& match_expr_pattern(&hint.lhs, rhs_query, &mut subst2)
{
results.push((hint, subst2, true));
}
}
results
}
pub fn remove_named(&mut self, name: &Name) {
self.hints.retain(|h| h.name.as_ref() != Some(name));
}
pub fn clear(&mut self) {
self.hints.clear();
}
}
pub fn match_expr_pattern(pattern: &Expr, target: &Expr, subst: &mut PatternSubst) -> bool {
match pattern {
Expr::Const(n, _) if n.to_string().starts_with('?') => {
let var_key = n.to_string()[1..].to_string();
let var_name = Name::str(&var_key);
subst.bind(&var_name, target.clone())
}
Expr::Sort(lp) => matches!(target, Expr::Sort(lt) if lp == lt),
Expr::BVar(ip) => matches!(target, Expr::BVar(it) if ip == it),
Expr::FVar(fp) => matches!(target, Expr::FVar(ft) if fp == ft),
Expr::Const(np, lsp) => {
if let Expr::Const(nt, lst) = target {
np == nt
&& lsp.len() == lst.len()
&& lsp.iter().zip(lst.iter()).all(|(lp, lt)| lp == lt)
} else {
false
}
}
Expr::App(fp, ap) => {
if let Expr::App(ft, at_) = target {
match_expr_pattern(fp, ft, subst) && match_expr_pattern(ap, at_, subst)
} else {
false
}
}
Expr::Lam(_, _, typ, bodyp) => {
if let Expr::Lam(_, _, tyt, bodyt) = target {
match_expr_pattern(typ, tyt, subst) && match_expr_pattern(bodyp, bodyt, subst)
} else {
false
}
}
Expr::Pi(_, _, typ, bodyp) => {
if let Expr::Pi(_, _, tyt, bodyt) = target {
match_expr_pattern(typ, tyt, subst) && match_expr_pattern(bodyp, bodyt, subst)
} else {
false
}
}
Expr::Let(_, typ, valp, bodyp) => {
if let Expr::Let(_, tyt, valt, bodyt) = target {
match_expr_pattern(typ, tyt, subst)
&& match_expr_pattern(valp, valt, subst)
&& match_expr_pattern(bodyp, bodyt, subst)
} else {
false
}
}
Expr::Lit(lp) => matches!(target, Expr::Lit(lt) if lp == lt),
Expr::Proj(np, ip, ep) => {
if let Expr::Proj(nt, it, et) = target {
np == nt && ip == it && match_expr_pattern(ep, et, subst)
} else {
false
}
}
}
}