lean_agentic/
unification.rs

1//! Unification and constraint solving for metavariables
2//!
3//! Implements first-order unification for dependent type theory
4//! with occurs check and constraint propagation.
5
6use crate::arena::Arena;
7use crate::context::Context;
8use crate::environment::Environment;
9use crate::term::{MetaVarId, TermId, TermKind};
10use std::collections::HashMap;
11use std::collections::VecDeque;
12
13/// Assignment of metavariables to terms
14#[derive(Debug, Clone)]
15pub struct Substitution {
16    assignments: HashMap<MetaVarId, TermId>,
17}
18
19impl Substitution {
20    /// Create a new empty substitution
21    pub fn new() -> Self {
22        Self {
23            assignments: HashMap::new(),
24        }
25    }
26
27    /// Assign a metavariable to a term
28    pub fn assign(&mut self, mvar: MetaVarId, term: TermId) {
29        self.assignments.insert(mvar, term);
30    }
31
32    /// Look up the assignment for a metavariable
33    pub fn lookup(&self, mvar: MetaVarId) -> Option<TermId> {
34        self.assignments.get(&mvar).copied()
35    }
36
37    /// Check if a metavariable is assigned
38    pub fn is_assigned(&self, mvar: MetaVarId) -> bool {
39        self.assignments.contains_key(&mvar)
40    }
41
42    /// Get all assignments
43    pub fn assignments(&self) -> &HashMap<MetaVarId, TermId> {
44        &self.assignments
45    }
46}
47
48impl Default for Substitution {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54/// A unification constraint
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum Constraint {
57    /// Two terms must be equal
58    Unify(TermId, TermId),
59
60    /// A term must be a sort
61    IsSort(TermId),
62
63    /// A metavariable must have a specific type
64    HasType(MetaVarId, TermId),
65}
66
67/// Unification engine with constraint solving
68pub struct Unifier {
69    /// Current substitution
70    subst: Substitution,
71
72    /// Pending constraints
73    constraints: VecDeque<Constraint>,
74
75    /// Metavariable type information
76    mvar_types: HashMap<MetaVarId, TermId>,
77}
78
79impl Unifier {
80    /// Create a new unifier
81    pub fn new() -> Self {
82        Self {
83            subst: Substitution::new(),
84            constraints: VecDeque::new(),
85            mvar_types: HashMap::new(),
86        }
87    }
88
89    /// Add a constraint to the queue
90    pub fn add_constraint(&mut self, constraint: Constraint) {
91        self.constraints.push_back(constraint);
92    }
93
94    /// Unify two terms
95    pub fn unify(&mut self, t1: TermId, t2: TermId) {
96        self.add_constraint(Constraint::Unify(t1, t2));
97    }
98
99    /// Declare a metavariable with its type
100    pub fn declare_mvar(&mut self, mvar: MetaVarId, ty: TermId) {
101        self.mvar_types.insert(mvar, ty);
102    }
103
104    /// Solve all pending constraints
105    pub fn solve(
106        &mut self,
107        arena: &mut Arena,
108        env: &Environment,
109        ctx: &Context,
110    ) -> crate::Result<()> {
111        while let Some(constraint) = self.constraints.pop_front() {
112            match constraint {
113                Constraint::Unify(t1, t2) => {
114                    self.solve_unify(arena, env, ctx, t1, t2)?;
115                }
116                Constraint::IsSort(term) => {
117                    // Check if term is or unifies to a sort
118                    let term = self.apply_subst(arena, term)?;
119                    if let Some(TermKind::Sort(_)) = arena.kind(term) {
120                        // OK
121                    } else if let Some(TermKind::MVar(_mvar)) = arena.kind(term) {
122                        // Defer: we need more information
123                        self.add_constraint(Constraint::IsSort(term));
124                    } else {
125                        return Err(crate::Error::UnificationError(
126                            "Expected sort".to_string(),
127                        ));
128                    }
129                }
130                Constraint::HasType(mvar, ty) => {
131                    // Record the type constraint
132                    self.mvar_types.insert(mvar, ty);
133                }
134            }
135        }
136
137        Ok(())
138    }
139
140    /// Solve a unification constraint
141    fn solve_unify(
142        &mut self,
143        arena: &mut Arena,
144        _env: &Environment,
145        _ctx: &Context,
146        t1: TermId,
147        t2: TermId,
148    ) -> crate::Result<()> {
149        // Fast path: already equal
150        if t1 == t2 {
151            return Ok(());
152        }
153
154        // Apply current substitution
155        let t1 = self.apply_subst(arena, t1)?;
156        let t2 = self.apply_subst(arena, t2)?;
157
158        if t1 == t2 {
159            return Ok(());
160        }
161
162        let kind1 = arena.kind(t1).ok_or_else(|| {
163            crate::Error::Internal(format!("Invalid term ID: {:?}", t1))
164        })?.clone();
165
166        let kind2 = arena.kind(t2).ok_or_else(|| {
167            crate::Error::Internal(format!("Invalid term ID: {:?}", t2))
168        })?.clone();
169
170        match (kind1, kind2) {
171            // ?m = t  or  t = ?m
172            (TermKind::MVar(m), _) => {
173                if !self.subst.is_assigned(m) {
174                    if self.occurs_check(m, t2, arena)? {
175                        return Err(crate::Error::UnificationError(
176                            "Occurs check failed".to_string(),
177                        ));
178                    }
179                    self.subst.assign(m, t2);
180                    Ok(())
181                } else {
182                    let assigned = self.subst.lookup(m).unwrap();
183                    self.solve_unify(arena, _env, _ctx, assigned, t2)
184                }
185            }
186
187            (_, TermKind::MVar(m)) => {
188                if !self.subst.is_assigned(m) {
189                    if self.occurs_check(m, t1, arena)? {
190                        return Err(crate::Error::UnificationError(
191                            "Occurs check failed".to_string(),
192                        ));
193                    }
194                    self.subst.assign(m, t1);
195                    Ok(())
196                } else {
197                    let assigned = self.subst.lookup(m).unwrap();
198                    self.solve_unify(arena, _env, _ctx, t1, assigned)
199                }
200            }
201
202            // Structural unification
203            (TermKind::App(f1, a1), TermKind::App(f2, a2)) => {
204                self.solve_unify(arena, _env, _ctx, f1, f2)?;
205                self.solve_unify(arena, _env, _ctx, a1, a2)?;
206                Ok(())
207            }
208
209            (TermKind::Lam(b1, body1), TermKind::Lam(b2, body2)) => {
210                self.solve_unify(arena, _env, _ctx, b1.ty, b2.ty)?;
211                self.solve_unify(arena, _env, _ctx, body1, body2)?;
212                Ok(())
213            }
214
215            (TermKind::Pi(b1, body1), TermKind::Pi(b2, body2)) => {
216                self.solve_unify(arena, _env, _ctx, b1.ty, b2.ty)?;
217                self.solve_unify(arena, _env, _ctx, body1, body2)?;
218                Ok(())
219            }
220
221            (TermKind::Sort(l1), TermKind::Sort(l2)) if l1 == l2 => Ok(()),
222
223            (TermKind::Var(i1), TermKind::Var(i2)) if i1 == i2 => Ok(()),
224
225            (TermKind::Const(n1, lvls1), TermKind::Const(n2, lvls2))
226                if n1 == n2 && lvls1 == lvls2 =>
227            {
228                Ok(())
229            }
230
231            // Can't unify
232            _ => Err(crate::Error::UnificationError(format!(
233                "Cannot unify {:?} with {:?}",
234                t1, t2
235            ))),
236        }
237    }
238
239    /// Check if a metavariable occurs in a term (occurs check)
240    fn occurs_check(
241        &self,
242        mvar: MetaVarId,
243        term: TermId,
244        arena: &Arena,
245    ) -> crate::Result<bool> {
246        let kind = arena.kind(term).ok_or_else(|| {
247            crate::Error::Internal(format!("Invalid term ID: {:?}", term))
248        })?;
249
250        match kind {
251            TermKind::MVar(m) if *m == mvar => Ok(true),
252
253            TermKind::MVar(m) => {
254                if let Some(assigned) = self.subst.lookup(*m) {
255                    self.occurs_check(mvar, assigned, arena)
256                } else {
257                    Ok(false)
258                }
259            }
260
261            TermKind::App(f, a) => {
262                let in_func = self.occurs_check(mvar, *f, arena)?;
263                let in_arg = self.occurs_check(mvar, *a, arena)?;
264                Ok(in_func || in_arg)
265            }
266
267            TermKind::Lam(b, body) | TermKind::Pi(b, body) => {
268                let in_ty = self.occurs_check(mvar, b.ty, arena)?;
269                let in_body = self.occurs_check(mvar, *body, arena)?;
270                Ok(in_ty || in_body)
271            }
272
273            TermKind::Let(b, val, body) => {
274                let in_ty = self.occurs_check(mvar, b.ty, arena)?;
275                let in_val = self.occurs_check(mvar, *val, arena)?;
276                let in_body = self.occurs_check(mvar, *body, arena)?;
277                Ok(in_ty || in_val || in_body)
278            }
279
280            TermKind::Sort(_) | TermKind::Const(_, _) | TermKind::Var(_) | TermKind::Lit(_) => {
281                Ok(false)
282            }
283        }
284    }
285
286    /// Apply the current substitution to a term
287    fn apply_subst(&self, arena: &Arena, term: TermId) -> crate::Result<TermId> {
288        let kind = arena.kind(term).ok_or_else(|| {
289            crate::Error::Internal(format!("Invalid term ID: {:?}", term))
290        })?;
291
292        match kind {
293            TermKind::MVar(m) => {
294                if let Some(assigned) = self.subst.lookup(*m) {
295                    // Recursively apply substitution
296                    self.apply_subst(arena, assigned)
297                } else {
298                    Ok(term)
299                }
300            }
301            _ => Ok(term),
302        }
303    }
304
305    /// Get the current substitution
306    pub fn substitution(&self) -> &Substitution {
307        &self.subst
308    }
309
310    /// Check if all constraints are solved
311    pub fn is_solved(&self) -> bool {
312        self.constraints.is_empty()
313    }
314
315    /// Get the number of pending constraints
316    pub fn num_constraints(&self) -> usize {
317        self.constraints.len()
318    }
319}
320
321impl Default for Unifier {
322    fn default() -> Self {
323        Self::new()
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn test_basic_unification() {
333        let mut arena = Arena::new();
334        let env = Environment::new();
335        let ctx = Context::new();
336        let mut unifier = Unifier::new();
337
338        let var0 = arena.mk_var(0);
339        let var1 = arena.mk_var(1);
340
341        // ?0 = var0
342        let mvar0 = arena.mk_mvar(MetaVarId::new(0));
343        unifier.unify(mvar0, var0);
344
345        unifier.solve(&mut arena, &env, &ctx).unwrap();
346
347        assert!(unifier.is_solved());
348        assert!(unifier.substitution().is_assigned(MetaVarId::new(0)));
349    }
350
351    #[test]
352    fn test_occurs_check() {
353        let mut arena = Arena::new();
354        let env = Environment::new();
355        let ctx = Context::new();
356        let mut unifier = Unifier::new();
357
358        // ?0 = App(?0, x) -- should fail occurs check
359        let mvar0_id = MetaVarId::new(0);
360        let mvar0 = arena.mk_mvar(mvar0_id);
361        let x = arena.mk_var(0);
362        let app = arena.mk_app(mvar0, x);
363
364        unifier.unify(mvar0, app);
365
366        let result = unifier.solve(&mut arena, &env, &ctx);
367        assert!(result.is_err());
368    }
369
370    #[test]
371    fn test_structural_unification() {
372        let mut arena = Arena::new();
373        let env = Environment::new();
374        let ctx = Context::new();
375        let mut unifier = Unifier::new();
376
377        // App(?0, x) = App(y, x)  =>  ?0 = y
378        let mvar0 = arena.mk_mvar(MetaVarId::new(0));
379        let x = arena.mk_var(0);
380        let y = arena.mk_var(1);
381
382        let app1 = arena.mk_app(mvar0, x);
383        let app2 = arena.mk_app(y, x);
384
385        unifier.unify(app1, app2);
386
387        unifier.solve(&mut arena, &env, &ctx).unwrap();
388
389        let assignment = unifier.substitution().lookup(MetaVarId::new(0)).unwrap();
390        assert_eq!(assignment, y);
391    }
392}