use crate::symbol::SymbolId;
use crate::term::TermId;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ContextEntry {
pub name: SymbolId,
pub ty: TermId,
pub value: Option<TermId>,
}
impl ContextEntry {
pub fn new(name: SymbolId, ty: TermId) -> Self {
Self {
name,
ty,
value: None,
}
}
pub fn with_value(name: SymbolId, ty: TermId, value: TermId) -> Self {
Self {
name,
ty,
value: Some(value),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Context {
entries: Vec<ContextEntry>,
}
impl Context {
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn push(&mut self, entry: ContextEntry) {
self.entries.push(entry);
}
pub fn push_var(&mut self, name: SymbolId, ty: TermId) {
self.push(ContextEntry::new(name, ty));
}
pub fn pop(&mut self) -> Option<ContextEntry> {
self.entries.pop()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn lookup(&self, index: u32) -> Option<&ContextEntry> {
let pos = self.entries.len().checked_sub(index as usize + 1)?;
self.entries.get(pos)
}
pub fn type_of(&self, index: u32) -> Option<TermId> {
self.lookup(index).map(|e| e.ty)
}
pub fn value_of(&self, index: u32) -> Option<TermId> {
self.lookup(index).and_then(|e| e.value)
}
pub fn extend(&mut self, entries: impl IntoIterator<Item = ContextEntry>) {
self.entries.extend(entries);
}
pub fn with_entries(&self, entries: impl IntoIterator<Item = ContextEntry>) -> Self {
let mut new_ctx = self.clone();
new_ctx.extend(entries);
new_ctx
}
pub fn entries(&self) -> &[ContextEntry] {
&self.entries
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn truncate(&mut self, len: usize) {
self.entries.truncate(len);
}
pub fn mark(&self) -> usize {
self.len()
}
pub fn restore(&mut self, mark: usize) {
self.truncate(mark);
}
}
pub struct ContextGuard<'a> {
context: &'a mut Context,
mark: usize,
}
impl<'a> ContextGuard<'a> {
pub fn new(context: &'a mut Context) -> Self {
let mark = context.mark();
Self { context, mark }
}
pub fn push(&mut self, entry: ContextEntry) {
self.context.push(entry);
}
pub fn context(&self) -> &Context {
self.context
}
}
impl<'a> Drop for ContextGuard<'a> {
fn drop(&mut self) {
self.context.restore(self.mark);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_basic() {
let mut ctx = Context::new();
let name = SymbolId::new(0);
let ty = TermId::new(0);
ctx.push_var(name, ty);
assert_eq!(ctx.len(), 1);
let entry = ctx.lookup(0).unwrap();
assert_eq!(entry.name, name);
assert_eq!(entry.ty, ty);
}
#[test]
fn test_debruijn_indices() {
let mut ctx = Context::new();
let x_ty = TermId::new(0);
let y_ty = TermId::new(1);
let z_ty = TermId::new(2);
ctx.push_var(SymbolId::new(0), x_ty); ctx.push_var(SymbolId::new(1), y_ty); ctx.push_var(SymbolId::new(2), z_ty);
assert_eq!(ctx.type_of(0), Some(z_ty)); assert_eq!(ctx.type_of(1), Some(y_ty));
assert_eq!(ctx.type_of(2), Some(x_ty)); assert_eq!(ctx.type_of(3), None); }
#[test]
fn test_context_guard() {
let mut ctx = Context::new();
ctx.push_var(SymbolId::new(0), TermId::new(0));
assert_eq!(ctx.len(), 1);
{
let mut guard = ContextGuard::new(&mut ctx);
guard.push(ContextEntry::new(SymbolId::new(1), TermId::new(1)));
assert_eq!(guard.context().len(), 2);
}
assert_eq!(ctx.len(), 1); }
#[test]
fn test_let_binding() {
let mut ctx = Context::new();
let name = SymbolId::new(0);
let ty = TermId::new(0);
let val = TermId::new(1);
ctx.push(ContextEntry::with_value(name, ty, val));
assert_eq!(ctx.type_of(0), Some(ty));
assert_eq!(ctx.value_of(0), Some(val));
}
}