use crate::term::{Term, TermId, TermKind};
use std::collections::HashMap;
pub struct Arena {
terms: Vec<Term>,
cache: HashMap<u64, Vec<TermId>>,
stats: ArenaStats,
}
#[derive(Debug, Default, Clone)]
pub struct ArenaStats {
pub allocated: usize,
pub cache_hits: usize,
pub cache_misses: usize,
}
impl Arena {
pub fn new() -> Self {
Self {
terms: Vec::new(),
cache: HashMap::new(),
stats: ArenaStats::default(),
}
}
pub fn intern(&mut self, kind: TermKind) -> TermId {
let term = Term::new(kind);
let hash = term.hash();
if let Some(candidates) = self.cache.get(&hash) {
for &id in candidates {
if let Some(existing) = self.terms.get(id.raw() as usize) {
if existing.kind == term.kind {
self.stats.cache_hits += 1;
return id;
}
}
}
}
self.stats.cache_misses += 1;
self.stats.allocated += 1;
let id = TermId::new(self.terms.len() as u32);
self.terms.push(term);
self.cache.entry(hash).or_insert_with(Vec::new).push(id);
id
}
pub fn get(&self, id: TermId) -> Option<&Term> {
self.terms.get(id.raw() as usize)
}
pub fn get_term(&self, id: TermId) -> Option<&Term> {
self.get(id)
}
pub fn kind(&self, id: TermId) -> Option<&TermKind> {
self.get(id).map(|t| &t.kind)
}
pub fn terms(&self) -> usize {
self.terms.len()
}
pub fn stats(&self) -> &ArenaStats {
&self.stats
}
pub fn cache_hit_rate(&self) -> f64 {
if self.stats.cache_hits + self.stats.cache_misses == 0 {
return 0.0;
}
self.stats.cache_hits as f64
/ (self.stats.cache_hits + self.stats.cache_misses) as f64
}
pub fn clear_stats(&mut self) {
self.stats = ArenaStats::default();
}
pub fn mk_sort(&mut self, level: crate::level::LevelId) -> TermId {
self.intern(TermKind::Sort(level))
}
pub fn mk_const(
&mut self,
name: crate::symbol::SymbolId,
levels: Vec<crate::level::LevelId>,
) -> TermId {
self.intern(TermKind::Const(name, levels))
}
pub fn mk_var(&mut self, index: u32) -> TermId {
self.intern(TermKind::Var(index))
}
pub fn mk_app(&mut self, func: TermId, arg: TermId) -> TermId {
self.intern(TermKind::App(func, arg))
}
pub fn mk_lam(&mut self, binder: crate::term::Binder, body: TermId) -> TermId {
self.intern(TermKind::Lam(binder, body))
}
pub fn mk_pi(&mut self, binder: crate::term::Binder, body: TermId) -> TermId {
self.intern(TermKind::Pi(binder, body))
}
pub fn mk_let(
&mut self,
binder: crate::term::Binder,
value: TermId,
body: TermId,
) -> TermId {
self.intern(TermKind::Let(binder, value, body))
}
pub fn mk_mvar(&mut self, id: crate::term::MetaVarId) -> TermId {
self.intern(TermKind::MVar(id))
}
pub fn mk_nat(&mut self, n: u64) -> TermId {
self.intern(TermKind::Lit(crate::term::Literal::Nat(n)))
}
pub fn mk_app_spine(&mut self, func: TermId, args: &[TermId]) -> TermId {
args.iter().fold(func, |acc, &arg| self.mk_app(acc, arg))
}
pub fn mk_level_zero(&mut self) -> crate::level::LevelId {
crate::level::LevelId::new(0)
}
}
impl Default for Arena {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::level::LevelId;
#[test]
fn test_hash_consing() {
let mut arena = Arena::new();
let var0_1 = arena.mk_var(0);
let var0_2 = arena.mk_var(0);
assert_eq!(var0_1, var0_2);
assert_eq!(arena.terms(), 1);
assert!(arena.stats().cache_hits > 0);
}
#[test]
fn test_different_terms() {
let mut arena = Arena::new();
let var0 = arena.mk_var(0);
let var1 = arena.mk_var(1);
assert_ne!(var0, var1);
assert_eq!(arena.terms(), 2);
}
#[test]
fn test_app_spine() {
let mut arena = Arena::new();
let f = arena.mk_var(0);
let x = arena.mk_var(1);
let y = arena.mk_var(2);
let app = arena.mk_app_spine(f, &[x, y]);
if let Some(TermKind::App(left, _)) = arena.kind(app) {
if let Some(TermKind::App(_, _)) = arena.kind(*left) {
} else {
panic!("Expected nested application");
}
} else {
panic!("Expected application");
}
}
#[test]
fn test_cache_efficiency() {
let mut arena = Arena::new();
for _ in 0..100 {
arena.mk_var(0);
}
assert!(arena.cache_hit_rate() > 0.95);
assert_eq!(arena.terms(), 1);
}
}