1use crate::arena::Arena;
7use crate::context::Context;
8use crate::environment::Environment;
9use crate::term::{MetaVarId, TermId, TermKind};
10use std::collections::HashMap;
11use std::collections::VecDeque;
12
13#[derive(Debug, Clone)]
15pub struct Substitution {
16 assignments: HashMap<MetaVarId, TermId>,
17}
18
19impl Substitution {
20 pub fn new() -> Self {
22 Self {
23 assignments: HashMap::new(),
24 }
25 }
26
27 pub fn assign(&mut self, mvar: MetaVarId, term: TermId) {
29 self.assignments.insert(mvar, term);
30 }
31
32 pub fn lookup(&self, mvar: MetaVarId) -> Option<TermId> {
34 self.assignments.get(&mvar).copied()
35 }
36
37 pub fn is_assigned(&self, mvar: MetaVarId) -> bool {
39 self.assignments.contains_key(&mvar)
40 }
41
42 pub fn assignments(&self) -> &HashMap<MetaVarId, TermId> {
44 &self.assignments
45 }
46}
47
48impl Default for Substitution {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum Constraint {
57 Unify(TermId, TermId),
59
60 IsSort(TermId),
62
63 HasType(MetaVarId, TermId),
65}
66
67pub struct Unifier {
69 subst: Substitution,
71
72 constraints: VecDeque<Constraint>,
74
75 mvar_types: HashMap<MetaVarId, TermId>,
77}
78
79impl Unifier {
80 pub fn new() -> Self {
82 Self {
83 subst: Substitution::new(),
84 constraints: VecDeque::new(),
85 mvar_types: HashMap::new(),
86 }
87 }
88
89 pub fn add_constraint(&mut self, constraint: Constraint) {
91 self.constraints.push_back(constraint);
92 }
93
94 pub fn unify(&mut self, t1: TermId, t2: TermId) {
96 self.add_constraint(Constraint::Unify(t1, t2));
97 }
98
99 pub fn declare_mvar(&mut self, mvar: MetaVarId, ty: TermId) {
101 self.mvar_types.insert(mvar, ty);
102 }
103
104 pub fn solve(
106 &mut self,
107 arena: &mut Arena,
108 env: &Environment,
109 ctx: &Context,
110 ) -> crate::Result<()> {
111 while let Some(constraint) = self.constraints.pop_front() {
112 match constraint {
113 Constraint::Unify(t1, t2) => {
114 self.solve_unify(arena, env, ctx, t1, t2)?;
115 }
116 Constraint::IsSort(term) => {
117 let term = self.apply_subst(arena, term)?;
119 if let Some(TermKind::Sort(_)) = arena.kind(term) {
120 } else if let Some(TermKind::MVar(_mvar)) = arena.kind(term) {
122 self.add_constraint(Constraint::IsSort(term));
124 } else {
125 return Err(crate::Error::UnificationError(
126 "Expected sort".to_string(),
127 ));
128 }
129 }
130 Constraint::HasType(mvar, ty) => {
131 self.mvar_types.insert(mvar, ty);
133 }
134 }
135 }
136
137 Ok(())
138 }
139
140 fn solve_unify(
142 &mut self,
143 arena: &mut Arena,
144 _env: &Environment,
145 _ctx: &Context,
146 t1: TermId,
147 t2: TermId,
148 ) -> crate::Result<()> {
149 if t1 == t2 {
151 return Ok(());
152 }
153
154 let t1 = self.apply_subst(arena, t1)?;
156 let t2 = self.apply_subst(arena, t2)?;
157
158 if t1 == t2 {
159 return Ok(());
160 }
161
162 let kind1 = arena.kind(t1).ok_or_else(|| {
163 crate::Error::Internal(format!("Invalid term ID: {:?}", t1))
164 })?.clone();
165
166 let kind2 = arena.kind(t2).ok_or_else(|| {
167 crate::Error::Internal(format!("Invalid term ID: {:?}", t2))
168 })?.clone();
169
170 match (kind1, kind2) {
171 (TermKind::MVar(m), _) => {
173 if !self.subst.is_assigned(m) {
174 if self.occurs_check(m, t2, arena)? {
175 return Err(crate::Error::UnificationError(
176 "Occurs check failed".to_string(),
177 ));
178 }
179 self.subst.assign(m, t2);
180 Ok(())
181 } else {
182 let assigned = self.subst.lookup(m).unwrap();
183 self.solve_unify(arena, _env, _ctx, assigned, t2)
184 }
185 }
186
187 (_, TermKind::MVar(m)) => {
188 if !self.subst.is_assigned(m) {
189 if self.occurs_check(m, t1, arena)? {
190 return Err(crate::Error::UnificationError(
191 "Occurs check failed".to_string(),
192 ));
193 }
194 self.subst.assign(m, t1);
195 Ok(())
196 } else {
197 let assigned = self.subst.lookup(m).unwrap();
198 self.solve_unify(arena, _env, _ctx, t1, assigned)
199 }
200 }
201
202 (TermKind::App(f1, a1), TermKind::App(f2, a2)) => {
204 self.solve_unify(arena, _env, _ctx, f1, f2)?;
205 self.solve_unify(arena, _env, _ctx, a1, a2)?;
206 Ok(())
207 }
208
209 (TermKind::Lam(b1, body1), TermKind::Lam(b2, body2)) => {
210 self.solve_unify(arena, _env, _ctx, b1.ty, b2.ty)?;
211 self.solve_unify(arena, _env, _ctx, body1, body2)?;
212 Ok(())
213 }
214
215 (TermKind::Pi(b1, body1), TermKind::Pi(b2, body2)) => {
216 self.solve_unify(arena, _env, _ctx, b1.ty, b2.ty)?;
217 self.solve_unify(arena, _env, _ctx, body1, body2)?;
218 Ok(())
219 }
220
221 (TermKind::Sort(l1), TermKind::Sort(l2)) if l1 == l2 => Ok(()),
222
223 (TermKind::Var(i1), TermKind::Var(i2)) if i1 == i2 => Ok(()),
224
225 (TermKind::Const(n1, lvls1), TermKind::Const(n2, lvls2))
226 if n1 == n2 && lvls1 == lvls2 =>
227 {
228 Ok(())
229 }
230
231 _ => Err(crate::Error::UnificationError(format!(
233 "Cannot unify {:?} with {:?}",
234 t1, t2
235 ))),
236 }
237 }
238
239 fn occurs_check(
241 &self,
242 mvar: MetaVarId,
243 term: TermId,
244 arena: &Arena,
245 ) -> crate::Result<bool> {
246 let kind = arena.kind(term).ok_or_else(|| {
247 crate::Error::Internal(format!("Invalid term ID: {:?}", term))
248 })?;
249
250 match kind {
251 TermKind::MVar(m) if *m == mvar => Ok(true),
252
253 TermKind::MVar(m) => {
254 if let Some(assigned) = self.subst.lookup(*m) {
255 self.occurs_check(mvar, assigned, arena)
256 } else {
257 Ok(false)
258 }
259 }
260
261 TermKind::App(f, a) => {
262 let in_func = self.occurs_check(mvar, *f, arena)?;
263 let in_arg = self.occurs_check(mvar, *a, arena)?;
264 Ok(in_func || in_arg)
265 }
266
267 TermKind::Lam(b, body) | TermKind::Pi(b, body) => {
268 let in_ty = self.occurs_check(mvar, b.ty, arena)?;
269 let in_body = self.occurs_check(mvar, *body, arena)?;
270 Ok(in_ty || in_body)
271 }
272
273 TermKind::Let(b, val, body) => {
274 let in_ty = self.occurs_check(mvar, b.ty, arena)?;
275 let in_val = self.occurs_check(mvar, *val, arena)?;
276 let in_body = self.occurs_check(mvar, *body, arena)?;
277 Ok(in_ty || in_val || in_body)
278 }
279
280 TermKind::Sort(_) | TermKind::Const(_, _) | TermKind::Var(_) | TermKind::Lit(_) => {
281 Ok(false)
282 }
283 }
284 }
285
286 fn apply_subst(&self, arena: &Arena, term: TermId) -> crate::Result<TermId> {
288 let kind = arena.kind(term).ok_or_else(|| {
289 crate::Error::Internal(format!("Invalid term ID: {:?}", term))
290 })?;
291
292 match kind {
293 TermKind::MVar(m) => {
294 if let Some(assigned) = self.subst.lookup(*m) {
295 self.apply_subst(arena, assigned)
297 } else {
298 Ok(term)
299 }
300 }
301 _ => Ok(term),
302 }
303 }
304
305 pub fn substitution(&self) -> &Substitution {
307 &self.subst
308 }
309
310 pub fn is_solved(&self) -> bool {
312 self.constraints.is_empty()
313 }
314
315 pub fn num_constraints(&self) -> usize {
317 self.constraints.len()
318 }
319}
320
321impl Default for Unifier {
322 fn default() -> Self {
323 Self::new()
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_basic_unification() {
333 let mut arena = Arena::new();
334 let env = Environment::new();
335 let ctx = Context::new();
336 let mut unifier = Unifier::new();
337
338 let var0 = arena.mk_var(0);
339 let var1 = arena.mk_var(1);
340
341 let mvar0 = arena.mk_mvar(MetaVarId::new(0));
343 unifier.unify(mvar0, var0);
344
345 unifier.solve(&mut arena, &env, &ctx).unwrap();
346
347 assert!(unifier.is_solved());
348 assert!(unifier.substitution().is_assigned(MetaVarId::new(0)));
349 }
350
351 #[test]
352 fn test_occurs_check() {
353 let mut arena = Arena::new();
354 let env = Environment::new();
355 let ctx = Context::new();
356 let mut unifier = Unifier::new();
357
358 let mvar0_id = MetaVarId::new(0);
360 let mvar0 = arena.mk_mvar(mvar0_id);
361 let x = arena.mk_var(0);
362 let app = arena.mk_app(mvar0, x);
363
364 unifier.unify(mvar0, app);
365
366 let result = unifier.solve(&mut arena, &env, &ctx);
367 assert!(result.is_err());
368 }
369
370 #[test]
371 fn test_structural_unification() {
372 let mut arena = Arena::new();
373 let env = Environment::new();
374 let ctx = Context::new();
375 let mut unifier = Unifier::new();
376
377 let mvar0 = arena.mk_mvar(MetaVarId::new(0));
379 let x = arena.mk_var(0);
380 let y = arena.mk_var(1);
381
382 let app1 = arena.mk_app(mvar0, x);
383 let app2 = arena.mk_app(y, x);
384
385 unifier.unify(app1, app2);
386
387 unifier.solve(&mut arena, &env, &ctx).unwrap();
388
389 let assignment = unifier.substitution().lookup(MetaVarId::new(0)).unwrap();
390 assert_eq!(assignment, y);
391 }
392}