1use crate::arena::Arena;
7use crate::context::Context;
8use crate::environment::Environment;
9use crate::term::{TermId, TermKind};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::sync::RwLock;
13
14const DEFAULT_FUEL: u32 = 10000;
16
17type WhnfCache = Arc<RwLock<HashMap<(TermId, usize), TermId>>>;
19
20pub struct Converter {
22 fuel: u32,
24
25 cache: WhnfCache,
27
28 stats: ConversionStats,
30}
31
32#[derive(Debug, Default, Clone)]
34pub struct ConversionStats {
35 pub checks: usize,
37
38 pub successes: usize,
40
41 pub reductions: usize,
43
44 pub cache_hits: usize,
46}
47
48impl Converter {
49 pub fn new() -> Self {
51 Self {
52 fuel: DEFAULT_FUEL,
53 cache: Arc::new(RwLock::new(HashMap::new())),
54 stats: ConversionStats::default(),
55 }
56 }
57
58 pub fn with_fuel(fuel: u32) -> Self {
60 Self {
61 fuel,
62 cache: Arc::new(RwLock::new(HashMap::new())),
63 stats: ConversionStats::default(),
64 }
65 }
66
67 pub fn is_def_eq(
69 &mut self,
70 arena: &mut Arena,
71 env: &Environment,
72 ctx: &Context,
73 t1: TermId,
74 t2: TermId,
75 ) -> crate::Result<bool> {
76 self.stats.checks += 1;
77
78 if t1 == t2 {
80 self.stats.successes += 1;
81 return Ok(true);
82 }
83
84 let whnf1 = self.whnf(arena, env, ctx, t1)?;
86 let whnf2 = self.whnf(arena, env, ctx, t2)?;
87
88 if whnf1 == whnf2 {
89 self.stats.successes += 1;
90 return Ok(true);
91 }
92
93 let result = self.is_def_eq_whnf(arena, env, ctx, whnf1, whnf2)?;
95 if result {
96 self.stats.successes += 1;
97 }
98
99 Ok(result)
100 }
101
102 pub fn whnf(
104 &mut self,
105 arena: &mut Arena,
106 env: &Environment,
107 ctx: &Context,
108 term: TermId,
109 ) -> crate::Result<TermId> {
110 if self.fuel == 0 {
111 return Err(crate::Error::Internal(
112 "Out of fuel during normalization".to_string(),
113 ));
114 }
115
116 let cache_key = (term, ctx.len());
118 {
119 let cache = self.cache.read().unwrap();
120 if let Some(&cached) = cache.get(&cache_key) {
121 self.stats.cache_hits += 1;
122 return Ok(cached);
123 }
124 }
125
126 self.fuel -= 1;
127 self.stats.reductions += 1;
128
129 let kind = arena.kind(term).ok_or_else(|| {
130 crate::Error::Internal(format!("Invalid term ID: {:?}", term))
131 })?.clone();
132
133 let result = match kind {
134 TermKind::Var(idx) => {
136 if let Some(value) = ctx.value_of(idx) {
137 self.whnf(arena, env, ctx, value)?
138 } else {
139 term
140 }
141 }
142
143 TermKind::Const(name, _levels) => {
145 if let Some(decl) = env.get_decl(name) {
146 if decl.is_reducible() {
147 if let Some(body) = decl.value {
148 self.whnf(arena, env, ctx, body)?
151 } else {
152 term
153 }
154 } else {
155 term
156 }
157 } else {
158 term
159 }
160 }
161
162 TermKind::App(func, arg) => {
164 let func_whnf = self.whnf(arena, env, ctx, func)?;
165
166 if let Some(TermKind::Lam(_binder, body)) = arena.kind(func_whnf).cloned() {
167 let subst = self.substitute(arena, body, 0, arg)?;
169 self.whnf(arena, env, ctx, subst)?
170 } else {
171 if func_whnf != func {
173 let new_app = arena.mk_app(func_whnf, arg);
174 self.whnf(arena, env, ctx, new_app)?
175 } else {
176 term
177 }
178 }
179 }
180
181 TermKind::Let(_binder, value, body) => {
183 let subst = self.substitute(arena, body, 0, value)?;
185 self.whnf(arena, env, ctx, subst)?
186 }
187
188 TermKind::Sort(_) | TermKind::Pi(_, _) | TermKind::Lam(_, _) => term,
190
191 TermKind::MVar(_) | TermKind::Lit(_) => term,
193 };
194
195 {
197 let mut cache = self.cache.write().unwrap();
198 cache.insert(cache_key, result);
199 }
200
201 Ok(result)
202 }
203
204 fn is_def_eq_whnf(
206 &mut self,
207 arena: &mut Arena,
208 env: &Environment,
209 ctx: &Context,
210 t1: TermId,
211 t2: TermId,
212 ) -> crate::Result<bool> {
213 if t1 == t2 {
214 return Ok(true);
215 }
216
217 let kind1 = arena.kind(t1).ok_or_else(|| {
218 crate::Error::Internal(format!("Invalid term ID: {:?}", t1))
219 })?.clone();
220
221 let kind2 = arena.kind(t2).ok_or_else(|| {
222 crate::Error::Internal(format!("Invalid term ID: {:?}", t2))
223 })?.clone();
224
225 match (kind1, kind2) {
226 (TermKind::Sort(l1), TermKind::Sort(l2)) => Ok(l1 == l2),
228
229 (TermKind::Var(i1), TermKind::Var(i2)) => Ok(i1 == i2),
231
232 (TermKind::Const(n1, lvls1), TermKind::Const(n2, lvls2)) => {
234 Ok(n1 == n2 && lvls1 == lvls2)
235 }
236
237 (TermKind::App(f1, a1), TermKind::App(f2, a2)) => {
239 let funcs_eq = self.is_def_eq(arena, env, ctx, f1, f2)?;
240 let args_eq = self.is_def_eq(arena, env, ctx, a1, a2)?;
241 Ok(funcs_eq && args_eq)
242 }
243
244 (TermKind::Lam(b1, body1), TermKind::Lam(b2, body2)) => {
246 let types_eq = self.is_def_eq(arena, env, ctx, b1.ty, b2.ty)?;
248 if !types_eq {
249 return Ok(false);
250 }
251
252 let mut new_ctx = ctx.clone();
254 new_ctx.push_var(b1.name, b1.ty);
255 self.is_def_eq(arena, env, &new_ctx, body1, body2)
256 }
257
258 (TermKind::Pi(b1, body1), TermKind::Pi(b2, body2)) => {
260 let types_eq = self.is_def_eq(arena, env, ctx, b1.ty, b2.ty)?;
262 if !types_eq {
263 return Ok(false);
264 }
265
266 let mut new_ctx = ctx.clone();
268 new_ctx.push_var(b1.name, b1.ty);
269 self.is_def_eq(arena, env, &new_ctx, body1, body2)
270 }
271
272 (TermKind::Lit(l1), TermKind::Lit(l2)) => Ok(l1 == l2),
274
275 _ => Ok(false),
277 }
278 }
279
280 pub fn substitute(
283 &mut self,
284 arena: &mut Arena,
285 term: TermId,
286 idx: u32,
287 replacement: TermId,
288 ) -> crate::Result<TermId> {
289 let kind = arena.kind(term).ok_or_else(|| {
290 crate::Error::Internal(format!("Invalid term ID: {:?}", term))
291 })?.clone();
292
293 let result = match kind {
294 TermKind::Var(i) => {
295 if i == idx {
296 replacement
297 } else {
298 term
299 }
300 }
301
302 TermKind::App(func, arg) => {
303 let new_func = self.substitute(arena, func, idx, replacement)?;
304 let new_arg = self.substitute(arena, arg, idx, replacement)?;
305 if new_func == func && new_arg == arg {
306 term
307 } else {
308 arena.mk_app(new_func, new_arg)
309 }
310 }
311
312 TermKind::Lam(binder, body) => {
313 let old_ty = binder.ty;
314 let new_ty = self.substitute(arena, binder.ty, idx, replacement)?;
315 let new_body = self.substitute(arena, body, idx + 1, replacement)?;
316 if new_ty == old_ty && new_body == body {
317 term
318 } else {
319 let new_binder = crate::term::Binder { ty: new_ty, ..binder };
320 arena.mk_lam(new_binder, new_body)
321 }
322 }
323
324 TermKind::Pi(binder, body) => {
325 let old_ty = binder.ty;
326 let new_ty = self.substitute(arena, binder.ty, idx, replacement)?;
327 let new_body = self.substitute(arena, body, idx + 1, replacement)?;
328 if new_ty == old_ty && new_body == body {
329 term
330 } else {
331 let new_binder = crate::term::Binder { ty: new_ty, ..binder };
332 arena.mk_pi(new_binder, new_body)
333 }
334 }
335
336 TermKind::Let(binder, value, body) => {
337 let old_ty = binder.ty;
338 let new_ty = self.substitute(arena, binder.ty, idx, replacement)?;
339 let new_val = self.substitute(arena, value, idx, replacement)?;
340 let new_body = self.substitute(arena, body, idx + 1, replacement)?;
341 if new_ty == old_ty && new_val == value && new_body == body {
342 term
343 } else {
344 let new_binder = crate::term::Binder { ty: new_ty, ..binder };
345 arena.mk_let(new_binder, new_val, new_body)
346 }
347 }
348
349 TermKind::Sort(_) | TermKind::Const(_, _) | TermKind::Lit(_) | TermKind::MVar(_) => term,
351 };
352
353 Ok(result)
354 }
355
356 pub fn stats(&self) -> &ConversionStats {
358 &self.stats
359 }
360
361 pub fn clear_cache(&self) {
363 let mut cache = self.cache.write().unwrap();
364 cache.clear();
365 }
366
367 pub fn reset_fuel(&mut self) {
369 self.fuel = DEFAULT_FUEL;
370 }
371}
372
373impl Default for Converter {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::symbol::SymbolId;
383 use crate::term::Binder;
384
385 #[test]
386 fn test_simple_conversion() {
387 let mut arena = Arena::new();
388 let env = Environment::new();
389 let ctx = Context::new();
390 let mut conv = Converter::new();
391
392 let var0 = arena.mk_var(0);
393 let var0_2 = arena.mk_var(0);
394
395 assert!(conv.is_def_eq(&mut arena, &env, &ctx, var0, var0_2).unwrap());
396 }
397
398 #[test]
399 fn test_beta_reduction() {
400 let mut arena = Arena::new();
401 let env = Environment::new();
402 let ctx = Context::new();
403 let mut conv = Converter::new();
404
405 let x = arena.mk_var(0);
407 let binder = Binder::new(SymbolId::new(0), TermId::new(0));
408 let lam = arena.mk_lam(binder, x);
409 let y = arena.mk_var(1);
410 let app = arena.mk_app(lam, y);
411
412 let result = conv.whnf(&mut arena, &env, &ctx, app).unwrap();
413
414 assert_ne!(result, app); }
418
419 #[test]
420 fn test_fuel_exhaustion() {
421 let mut arena = Arena::new();
422 let env = Environment::new();
423 let ctx = Context::new();
424 let mut conv = Converter::with_fuel(1);
425
426 let var = arena.mk_var(0);
427
428 assert!(conv.whnf(&mut arena, &env, &ctx, var).is_ok());
430 }
431}