lean_agentic/
context.rs

1//! Type context for local variables
2//!
3//! Manages the local context (Γ) in typing judgments Γ ⊢ t : T
4
5use crate::symbol::SymbolId;
6use crate::term::TermId;
7
8/// Entry in the local context
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct ContextEntry {
11    /// Name of the variable (for pretty-printing)
12    pub name: SymbolId,
13
14    /// Type of the variable
15    pub ty: TermId,
16
17    /// Optional value (for let bindings)
18    pub value: Option<TermId>,
19}
20
21impl ContextEntry {
22    /// Create a new context entry
23    pub fn new(name: SymbolId, ty: TermId) -> Self {
24        Self {
25            name,
26            ty,
27            value: None,
28        }
29    }
30
31    /// Create a context entry with a value
32    pub fn with_value(name: SymbolId, ty: TermId, value: TermId) -> Self {
33        Self {
34            name,
35            ty,
36            value: Some(value),
37        }
38    }
39}
40
41/// Local typing context
42///
43/// Uses de Bruijn indices: variable #0 is the most recently bound,
44/// #1 is the one before that, etc.
45#[derive(Debug, Clone, Default)]
46pub struct Context {
47    /// Stack of local bindings (most recent at the end)
48    entries: Vec<ContextEntry>,
49}
50
51impl Context {
52    /// Create a new empty context
53    pub fn new() -> Self {
54        Self {
55            entries: Vec::new(),
56        }
57    }
58
59    /// Push a new binding onto the context
60    pub fn push(&mut self, entry: ContextEntry) {
61        self.entries.push(entry);
62    }
63
64    /// Push a simple variable binding
65    pub fn push_var(&mut self, name: SymbolId, ty: TermId) {
66        self.push(ContextEntry::new(name, ty));
67    }
68
69    /// Pop the most recent binding
70    pub fn pop(&mut self) -> Option<ContextEntry> {
71        self.entries.pop()
72    }
73
74    /// Get the number of entries in the context
75    pub fn len(&self) -> usize {
76        self.entries.len()
77    }
78
79    /// Check if the context is empty
80    pub fn is_empty(&self) -> bool {
81        self.entries.is_empty()
82    }
83
84    /// Look up a variable by de Bruijn index
85    ///
86    /// Index 0 refers to the most recently bound variable
87    pub fn lookup(&self, index: u32) -> Option<&ContextEntry> {
88        let pos = self.entries.len().checked_sub(index as usize + 1)?;
89        self.entries.get(pos)
90    }
91
92    /// Get the type of a variable by de Bruijn index
93    pub fn type_of(&self, index: u32) -> Option<TermId> {
94        self.lookup(index).map(|e| e.ty)
95    }
96
97    /// Get the value of a variable (if it's a let binding)
98    pub fn value_of(&self, index: u32) -> Option<TermId> {
99        self.lookup(index).and_then(|e| e.value)
100    }
101
102    /// Extend the context with multiple entries
103    pub fn extend(&mut self, entries: impl IntoIterator<Item = ContextEntry>) {
104        self.entries.extend(entries);
105    }
106
107    /// Create a new context by extending this one
108    pub fn with_entries(&self, entries: impl IntoIterator<Item = ContextEntry>) -> Self {
109        let mut new_ctx = self.clone();
110        new_ctx.extend(entries);
111        new_ctx
112    }
113
114    /// Get all entries (for iteration)
115    pub fn entries(&self) -> &[ContextEntry] {
116        &self.entries
117    }
118
119    /// Clear the context
120    pub fn clear(&mut self) {
121        self.entries.clear();
122    }
123
124    /// Truncate the context to a specific length
125    pub fn truncate(&mut self, len: usize) {
126        self.entries.truncate(len);
127    }
128
129    /// Save the current context length (for later restoration)
130    pub fn mark(&self) -> usize {
131        self.len()
132    }
133
134    /// Restore context to a previous mark
135    pub fn restore(&mut self, mark: usize) {
136        self.truncate(mark);
137    }
138}
139
140/// RAII guard for context management
141///
142/// Automatically pops entries when dropped
143pub struct ContextGuard<'a> {
144    context: &'a mut Context,
145    mark: usize,
146}
147
148impl<'a> ContextGuard<'a> {
149    /// Create a new context guard
150    pub fn new(context: &'a mut Context) -> Self {
151        let mark = context.mark();
152        Self { context, mark }
153    }
154
155    /// Push an entry within this guard
156    pub fn push(&mut self, entry: ContextEntry) {
157        self.context.push(entry);
158    }
159
160    /// Get a reference to the context
161    pub fn context(&self) -> &Context {
162        self.context
163    }
164}
165
166impl<'a> Drop for ContextGuard<'a> {
167    fn drop(&mut self) {
168        self.context.restore(self.mark);
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_context_basic() {
178        let mut ctx = Context::new();
179
180        let name = SymbolId::new(0);
181        let ty = TermId::new(0);
182
183        ctx.push_var(name, ty);
184        assert_eq!(ctx.len(), 1);
185
186        let entry = ctx.lookup(0).unwrap();
187        assert_eq!(entry.name, name);
188        assert_eq!(entry.ty, ty);
189    }
190
191    #[test]
192    fn test_debruijn_indices() {
193        let mut ctx = Context::new();
194
195        let x_ty = TermId::new(0);
196        let y_ty = TermId::new(1);
197        let z_ty = TermId::new(2);
198
199        ctx.push_var(SymbolId::new(0), x_ty); // #2 after all pushes
200        ctx.push_var(SymbolId::new(1), y_ty); // #1
201        ctx.push_var(SymbolId::new(2), z_ty); // #0 (most recent)
202
203        assert_eq!(ctx.type_of(0), Some(z_ty)); // Most recent
204        assert_eq!(ctx.type_of(1), Some(y_ty));
205        assert_eq!(ctx.type_of(2), Some(x_ty)); // Oldest
206        assert_eq!(ctx.type_of(3), None); // Out of bounds
207    }
208
209    #[test]
210    fn test_context_guard() {
211        let mut ctx = Context::new();
212
213        ctx.push_var(SymbolId::new(0), TermId::new(0));
214        assert_eq!(ctx.len(), 1);
215
216        {
217            let mut guard = ContextGuard::new(&mut ctx);
218            guard.push(ContextEntry::new(SymbolId::new(1), TermId::new(1)));
219            assert_eq!(guard.context().len(), 2);
220        } // Guard dropped, context restored
221
222        assert_eq!(ctx.len(), 1); // Back to original size
223    }
224
225    #[test]
226    fn test_let_binding() {
227        let mut ctx = Context::new();
228
229        let name = SymbolId::new(0);
230        let ty = TermId::new(0);
231        let val = TermId::new(1);
232
233        ctx.push(ContextEntry::with_value(name, ty, val));
234
235        assert_eq!(ctx.type_of(0), Some(ty));
236        assert_eq!(ctx.value_of(0), Some(val));
237    }
238}