use crate::level::LevelId;
use crate::symbol::SymbolId;
use std::fmt;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct TermId(u32);
impl TermId {
pub(crate) fn new(id: u32) -> Self {
Self(id)
}
pub fn raw(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Binder {
pub name: SymbolId,
pub ty: TermId,
pub implicit: bool,
pub info: BinderInfo,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum BinderInfo {
Default,
Implicit,
StrictImplicit,
InstImplicit,
}
impl Binder {
pub fn new(name: SymbolId, ty: TermId) -> Self {
Self {
name,
ty,
implicit: false,
info: BinderInfo::Default,
}
}
pub fn implicit(name: SymbolId, ty: TermId) -> Self {
Self {
name,
ty,
implicit: true,
info: BinderInfo::Implicit,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TermKind {
Sort(LevelId),
Const(SymbolId, Vec<LevelId>),
Var(u32),
App(TermId, TermId),
Lam(Binder, TermId),
Pi(Binder, TermId),
Let(Binder, TermId, TermId),
MVar(MetaVarId),
Lit(Literal),
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct MetaVarId(u32);
impl MetaVarId {
pub fn new(id: u32) -> Self {
Self(id)
}
pub fn raw(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Literal {
Nat(u64),
String(String),
}
#[derive(Debug, Clone)]
pub struct Term {
pub kind: TermKind,
hash: u64,
}
impl Term {
pub fn new(kind: TermKind) -> Self {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
kind.hash(&mut hasher);
let hash = hasher.finish();
Self { kind, hash }
}
pub fn hash(&self) -> u64 {
self.hash
}
pub fn is_sort(&self) -> bool {
matches!(self.kind, TermKind::Sort(_))
}
pub fn is_var(&self) -> bool {
matches!(self.kind, TermKind::Var(_))
}
pub fn is_lam(&self) -> bool {
matches!(self.kind, TermKind::Lam(_, _))
}
pub fn is_pi(&self) -> bool {
matches!(self.kind, TermKind::Pi(_, _))
}
pub fn is_app(&self) -> bool {
matches!(self.kind, TermKind::App(_, _))
}
}
impl PartialEq for Term {
fn eq(&self, other: &Self) -> bool {
self.hash == other.hash && self.kind == other.kind
}
}
impl Eq for Term {}
impl std::hash::Hash for Term {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.hash.hash(state);
}
}
impl fmt::Display for TermKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TermKind::Sort(l) => write!(f, "Sort({})", l.raw()),
TermKind::Const(name, levels) => {
write!(f, "Const({}, [", name.raw())?;
for (i, l) in levels.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", l.raw())?;
}
write!(f, "])")
}
TermKind::Var(idx) => write!(f, "#{}", idx),
TermKind::App(func, arg) => write!(f, "({} {})", func.raw(), arg.raw()),
TermKind::Lam(binder, body) => {
write!(f, "(λ {} : {} . {})", binder.name.raw(), binder.ty.raw(), body.raw())
}
TermKind::Pi(binder, body) => {
write!(f, "(Π {} : {} . {})", binder.name.raw(), binder.ty.raw(), body.raw())
}
TermKind::Let(binder, val, body) => write!(
f,
"(let {} : {} := {} in {})",
binder.name.raw(),
binder.ty.raw(),
val.raw(),
body.raw()
),
TermKind::MVar(id) => write!(f, "?{}", id.raw()),
TermKind::Lit(lit) => match lit {
Literal::Nat(n) => write!(f, "{}", n),
Literal::String(s) => write!(f, "\"{}\"", s),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_term_creation() {
let term1 = Term::new(TermKind::Var(0));
let term2 = Term::new(TermKind::Var(0));
assert_eq!(term1.hash(), term2.hash());
}
#[test]
fn test_binder_info() {
let binder = Binder::new(SymbolId::new(0), TermId::new(0));
assert!(!binder.implicit);
let implicit = Binder::implicit(SymbolId::new(0), TermId::new(0));
assert!(implicit.implicit);
}
}