lean_agentic/
conversion.rs

1//! Definitional equality and weak head normal form evaluation
2//!
3//! Implements conversion checking through normalization with
4//! beta, delta, zeta, and iota reductions.
5
6use crate::arena::Arena;
7use crate::context::Context;
8use crate::environment::Environment;
9use crate::term::{TermId, TermKind};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::sync::RwLock;
13
14/// Fuel for preventing non-termination in reduction
15const DEFAULT_FUEL: u32 = 10000;
16
17/// Cache for memoizing WHNF computations
18type WhnfCache = Arc<RwLock<HashMap<(TermId, usize), TermId>>>;
19
20/// Conversion checker with WHNF evaluation
21pub struct Converter {
22    /// Fuel remaining to prevent infinite loops
23    fuel: u32,
24
25    /// Cache for WHNF results
26    cache: WhnfCache,
27
28    /// Statistics
29    stats: ConversionStats,
30}
31
32/// Statistics for conversion checking
33#[derive(Debug, Default, Clone)]
34pub struct ConversionStats {
35    /// Number of conversions checked
36    pub checks: usize,
37
38    /// Number of successful conversions
39    pub successes: usize,
40
41    /// Number of WHNF reductions
42    pub reductions: usize,
43
44    /// Cache hits
45    pub cache_hits: usize,
46}
47
48impl Converter {
49    /// Create a new converter with default fuel
50    pub fn new() -> Self {
51        Self {
52            fuel: DEFAULT_FUEL,
53            cache: Arc::new(RwLock::new(HashMap::new())),
54            stats: ConversionStats::default(),
55        }
56    }
57
58    /// Create a converter with custom fuel
59    pub fn with_fuel(fuel: u32) -> Self {
60        Self {
61            fuel,
62            cache: Arc::new(RwLock::new(HashMap::new())),
63            stats: ConversionStats::default(),
64        }
65    }
66
67    /// Check if two terms are definitionally equal
68    pub fn is_def_eq(
69        &mut self,
70        arena: &mut Arena,
71        env: &Environment,
72        ctx: &Context,
73        t1: TermId,
74        t2: TermId,
75    ) -> crate::Result<bool> {
76        self.stats.checks += 1;
77
78        // Fast path: pointer equality
79        if t1 == t2 {
80            self.stats.successes += 1;
81            return Ok(true);
82        }
83
84        // Reduce both to WHNF and compare
85        let whnf1 = self.whnf(arena, env, ctx, t1)?;
86        let whnf2 = self.whnf(arena, env, ctx, t2)?;
87
88        if whnf1 == whnf2 {
89            self.stats.successes += 1;
90            return Ok(true);
91        }
92
93        // Structural comparison
94        let result = self.is_def_eq_whnf(arena, env, ctx, whnf1, whnf2)?;
95        if result {
96            self.stats.successes += 1;
97        }
98
99        Ok(result)
100    }
101
102    /// Reduce a term to weak head normal form
103    pub fn whnf(
104        &mut self,
105        arena: &mut Arena,
106        env: &Environment,
107        ctx: &Context,
108        term: TermId,
109    ) -> crate::Result<TermId> {
110        if self.fuel == 0 {
111            return Err(crate::Error::Internal(
112                "Out of fuel during normalization".to_string(),
113            ));
114        }
115
116        // Check cache
117        let cache_key = (term, ctx.len());
118        {
119            let cache = self.cache.read().unwrap();
120            if let Some(&cached) = cache.get(&cache_key) {
121                self.stats.cache_hits += 1;
122                return Ok(cached);
123            }
124        }
125
126        self.fuel -= 1;
127        self.stats.reductions += 1;
128
129        let kind = arena.kind(term).ok_or_else(|| {
130            crate::Error::Internal(format!("Invalid term ID: {:?}", term))
131        })?.clone();
132
133        let result = match kind {
134            // Variables: look up in context for let-bound values
135            TermKind::Var(idx) => {
136                if let Some(value) = ctx.value_of(idx) {
137                    self.whnf(arena, env, ctx, value)?
138                } else {
139                    term
140                }
141            }
142
143            // Constants: unfold if reducible
144            TermKind::Const(name, _levels) => {
145                if let Some(decl) = env.get_decl(name) {
146                    if decl.is_reducible() {
147                        if let Some(body) = decl.value {
148                            // Instantiate universe parameters if needed
149                            // For now, just reduce the body
150                            self.whnf(arena, env, ctx, body)?
151                        } else {
152                            term
153                        }
154                    } else {
155                        term
156                    }
157                } else {
158                    term
159                }
160            }
161
162            // Application: try beta reduction
163            TermKind::App(func, arg) => {
164                let func_whnf = self.whnf(arena, env, ctx, func)?;
165
166                if let Some(TermKind::Lam(_binder, body)) = arena.kind(func_whnf).cloned() {
167                    // Beta reduction: (λx.body) arg ~> body[x := arg]
168                    let subst = self.substitute(arena, body, 0, arg)?;
169                    self.whnf(arena, env, ctx, subst)?
170                } else {
171                    // Can't reduce further
172                    if func_whnf != func {
173                        let new_app = arena.mk_app(func_whnf, arg);
174                        self.whnf(arena, env, ctx, new_app)?
175                    } else {
176                        term
177                    }
178                }
179            }
180
181            // Let expression: zeta reduction
182            TermKind::Let(_binder, value, body) => {
183                // Substitute value into body
184                let subst = self.substitute(arena, body, 0, value)?;
185                self.whnf(arena, env, ctx, subst)?
186            }
187
188            // Already in WHNF
189            TermKind::Sort(_) | TermKind::Pi(_, _) | TermKind::Lam(_, _) => term,
190
191            // Metavariables and literals are values
192            TermKind::MVar(_) | TermKind::Lit(_) => term,
193        };
194
195        // Cache the result
196        {
197            let mut cache = self.cache.write().unwrap();
198            cache.insert(cache_key, result);
199        }
200
201        Ok(result)
202    }
203
204    /// Compare two terms in WHNF
205    fn is_def_eq_whnf(
206        &mut self,
207        arena: &mut Arena,
208        env: &Environment,
209        ctx: &Context,
210        t1: TermId,
211        t2: TermId,
212    ) -> crate::Result<bool> {
213        if t1 == t2 {
214            return Ok(true);
215        }
216
217        let kind1 = arena.kind(t1).ok_or_else(|| {
218            crate::Error::Internal(format!("Invalid term ID: {:?}", t1))
219        })?.clone();
220
221        let kind2 = arena.kind(t2).ok_or_else(|| {
222            crate::Error::Internal(format!("Invalid term ID: {:?}", t2))
223        })?.clone();
224
225        match (kind1, kind2) {
226            // Sorts
227            (TermKind::Sort(l1), TermKind::Sort(l2)) => Ok(l1 == l2),
228
229            // Variables
230            (TermKind::Var(i1), TermKind::Var(i2)) => Ok(i1 == i2),
231
232            // Constants
233            (TermKind::Const(n1, lvls1), TermKind::Const(n2, lvls2)) => {
234                Ok(n1 == n2 && lvls1 == lvls2)
235            }
236
237            // Applications
238            (TermKind::App(f1, a1), TermKind::App(f2, a2)) => {
239                let funcs_eq = self.is_def_eq(arena, env, ctx, f1, f2)?;
240                let args_eq = self.is_def_eq(arena, env, ctx, a1, a2)?;
241                Ok(funcs_eq && args_eq)
242            }
243
244            // Lambda
245            (TermKind::Lam(b1, body1), TermKind::Lam(b2, body2)) => {
246                // Check binder types
247                let types_eq = self.is_def_eq(arena, env, ctx, b1.ty, b2.ty)?;
248                if !types_eq {
249                    return Ok(false);
250                }
251
252                // Check bodies under extended context
253                let mut new_ctx = ctx.clone();
254                new_ctx.push_var(b1.name, b1.ty);
255                self.is_def_eq(arena, env, &new_ctx, body1, body2)
256            }
257
258            // Pi types
259            (TermKind::Pi(b1, body1), TermKind::Pi(b2, body2)) => {
260                // Check binder types
261                let types_eq = self.is_def_eq(arena, env, ctx, b1.ty, b2.ty)?;
262                if !types_eq {
263                    return Ok(false);
264                }
265
266                // Check bodies under extended context
267                let mut new_ctx = ctx.clone();
268                new_ctx.push_var(b1.name, b1.ty);
269                self.is_def_eq(arena, env, &new_ctx, body1, body2)
270            }
271
272            // Literals
273            (TermKind::Lit(l1), TermKind::Lit(l2)) => Ok(l1 == l2),
274
275            // Different constructors
276            _ => Ok(false),
277        }
278    }
279
280    /// Substitute a term in another term
281    /// subst(term, idx, replacement) replaces variable #idx with replacement
282    pub fn substitute(
283        &mut self,
284        arena: &mut Arena,
285        term: TermId,
286        idx: u32,
287        replacement: TermId,
288    ) -> crate::Result<TermId> {
289        let kind = arena.kind(term).ok_or_else(|| {
290            crate::Error::Internal(format!("Invalid term ID: {:?}", term))
291        })?.clone();
292
293        let result = match kind {
294            TermKind::Var(i) => {
295                if i == idx {
296                    replacement
297                } else {
298                    term
299                }
300            }
301
302            TermKind::App(func, arg) => {
303                let new_func = self.substitute(arena, func, idx, replacement)?;
304                let new_arg = self.substitute(arena, arg, idx, replacement)?;
305                if new_func == func && new_arg == arg {
306                    term
307                } else {
308                    arena.mk_app(new_func, new_arg)
309                }
310            }
311
312            TermKind::Lam(binder, body) => {
313                let old_ty = binder.ty;
314                let new_ty = self.substitute(arena, binder.ty, idx, replacement)?;
315                let new_body = self.substitute(arena, body, idx + 1, replacement)?;
316                if new_ty == old_ty && new_body == body {
317                    term
318                } else {
319                    let new_binder = crate::term::Binder { ty: new_ty, ..binder };
320                    arena.mk_lam(new_binder, new_body)
321                }
322            }
323
324            TermKind::Pi(binder, body) => {
325                let old_ty = binder.ty;
326                let new_ty = self.substitute(arena, binder.ty, idx, replacement)?;
327                let new_body = self.substitute(arena, body, idx + 1, replacement)?;
328                if new_ty == old_ty && new_body == body {
329                    term
330                } else {
331                    let new_binder = crate::term::Binder { ty: new_ty, ..binder };
332                    arena.mk_pi(new_binder, new_body)
333                }
334            }
335
336            TermKind::Let(binder, value, body) => {
337                let old_ty = binder.ty;
338                let new_ty = self.substitute(arena, binder.ty, idx, replacement)?;
339                let new_val = self.substitute(arena, value, idx, replacement)?;
340                let new_body = self.substitute(arena, body, idx + 1, replacement)?;
341                if new_ty == old_ty && new_val == value && new_body == body {
342                    term
343                } else {
344                    let new_binder = crate::term::Binder { ty: new_ty, ..binder };
345                    arena.mk_let(new_binder, new_val, new_body)
346                }
347            }
348
349            // No free variables in these
350            TermKind::Sort(_) | TermKind::Const(_, _) | TermKind::Lit(_) | TermKind::MVar(_) => term,
351        };
352
353        Ok(result)
354    }
355
356    /// Get conversion statistics
357    pub fn stats(&self) -> &ConversionStats {
358        &self.stats
359    }
360
361    /// Clear the WHNF cache
362    pub fn clear_cache(&self) {
363        let mut cache = self.cache.write().unwrap();
364        cache.clear();
365    }
366
367    /// Reset fuel to default
368    pub fn reset_fuel(&mut self) {
369        self.fuel = DEFAULT_FUEL;
370    }
371}
372
373impl Default for Converter {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::symbol::SymbolId;
383    use crate::term::Binder;
384
385    #[test]
386    fn test_simple_conversion() {
387        let mut arena = Arena::new();
388        let env = Environment::new();
389        let ctx = Context::new();
390        let mut conv = Converter::new();
391
392        let var0 = arena.mk_var(0);
393        let var0_2 = arena.mk_var(0);
394
395        assert!(conv.is_def_eq(&mut arena, &env, &ctx, var0, var0_2).unwrap());
396    }
397
398    #[test]
399    fn test_beta_reduction() {
400        let mut arena = Arena::new();
401        let env = Environment::new();
402        let ctx = Context::new();
403        let mut conv = Converter::new();
404
405        // (λx. x) y should reduce to y
406        let x = arena.mk_var(0);
407        let binder = Binder::new(SymbolId::new(0), TermId::new(0));
408        let lam = arena.mk_lam(binder, x);
409        let y = arena.mk_var(1);
410        let app = arena.mk_app(lam, y);
411
412        let result = conv.whnf(&mut arena, &env, &ctx, app).unwrap();
413
414        // After beta reduction, should get y (but with adjusted indices)
415        // This is a simplified test
416        assert_ne!(result, app); // Should have reduced
417    }
418
419    #[test]
420    fn test_fuel_exhaustion() {
421        let mut arena = Arena::new();
422        let env = Environment::new();
423        let ctx = Context::new();
424        let mut conv = Converter::with_fuel(1);
425
426        let var = arena.mk_var(0);
427
428        // This should work with minimal fuel
429        assert!(conv.whnf(&mut arena, &env, &ctx, var).is_ok());
430    }
431}