1use anyhow::{Result, anyhow};
16use fxhash::FxHashMap;
17use mangle_ir::{Inst, InstId, Ir, NameId};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum Type {
21 Any,
22 Bool,
23 Number,
24 Float,
25 String,
26 Bytes,
27 List(Box<Type>),
28 #[allow(dead_code)]
29 Map(Box<Type>, Box<Type>),
30 #[allow(dead_code)]
31 Struct, }
34
35pub struct TypeChecker<'a> {
36 ir: &'a Ir,
37 signatures: FxHashMap<NameId, Vec<Type>>,
39}
40
41impl<'a> TypeChecker<'a> {
42 pub fn new(ir: &'a Ir) -> Self {
43 Self {
44 ir,
45 signatures: FxHashMap::default(),
46 }
47 }
48
49 pub fn check(&mut self) -> Result<()> {
50 for inst in &self.ir.insts {
52 if let Inst::Decl { atom, bounds, .. } = inst {
53 self.collect_signature(*atom, bounds)?;
54 }
55 }
56
57 for inst in &self.ir.insts {
59 if let Inst::Rule {
60 head,
61 premises,
62 transform,
63 } = inst
64 {
65 self.check_rule(*head, premises, transform)?;
66 }
67 }
68 Ok(())
69 }
70
71 fn collect_signature(&mut self, atom_id: InstId, bounds: &[InstId]) -> Result<()> {
72 let atom = self.ir.get(atom_id);
73 if let Inst::Atom { predicate, args } = atom {
74 let mut types = Vec::new();
75 if !bounds.is_empty() {
76 if let Some(first_bound_id) = bounds.first()
82 && let Inst::BoundDecl { base_terms } = self.ir.get(*first_bound_id)
83 {
84 for term_id in base_terms {
85 types.push(self.resolve_type(*term_id)?);
86 }
87 }
88 } else {
89 for _ in args {
91 types.push(Type::Any);
92 }
93 }
94 self.signatures.insert(*predicate, types);
95 }
96 Ok(())
97 }
98
99 fn resolve_type(&self, type_term_id: InstId) -> Result<Type> {
100 let inst = self.ir.get(type_term_id);
101 match inst {
102 Inst::Name(s) => match self.ir.resolve_name(*s) {
103 "/string" => Ok(Type::String),
104 "/number" => Ok(Type::Number),
105 "/float" => Ok(Type::Float),
106 "/bool" => Ok(Type::Bool),
107 "/bytes" => Ok(Type::Bytes),
108 _ => Ok(Type::Any), },
110 Inst::ApplyFn { function, args } => {
111 match self.ir.resolve_name(*function) {
112 "fn:List" | "fn:list" => {
113 let inner = if let Some(arg) = args.first() {
114 self.resolve_type(*arg)?
115 } else {
116 Type::Any
117 };
118 Ok(Type::List(Box::new(inner)))
119 }
120 _ => Ok(Type::Any),
122 }
123 }
124 _ => Ok(Type::Any),
125 }
126 }
127
128 fn check_rule(&self, head: InstId, premises: &[InstId], _transform: &[InstId]) -> Result<()> {
129 let mut var_types: FxHashMap<NameId, Type> = FxHashMap::default();
130
131 for premise in premises {
133 self.check_premise(*premise, &mut var_types)?;
134 }
135
136 self.check_atom(head, &mut var_types)?;
138
139 Ok(())
140 }
141
142 fn check_premise(
143 &self,
144 premise: InstId,
145 var_types: &mut FxHashMap<NameId, Type>,
146 ) -> Result<()> {
147 match self.ir.get(premise) {
148 Inst::Atom { .. } => self.check_atom(premise, var_types),
149 Inst::NegAtom(a) => self.check_atom(*a, var_types),
150 Inst::Eq(l, r) => {
151 let t_l = self.infer_type(*l, var_types)?;
153 let t_r = self.infer_type(*r, var_types)?;
154 self.unify(t_l, t_r).map(|_| ())
155 }
156 _ => Ok(()),
158 }
159 }
160
161 fn check_atom(&self, atom_id: InstId, var_types: &mut FxHashMap<NameId, Type>) -> Result<()> {
162 if let Inst::Atom { predicate, args } = self.ir.get(atom_id)
163 && let Some(sig) = self.signatures.get(predicate)
164 {
165 if sig.len() != args.len() {
166 return Err(anyhow!(
167 "Arity mismatch for {}: expected {}, got {}",
168 self.ir.resolve_name(*predicate),
169 sig.len(),
170 args.len()
171 ));
172 }
173 for (i, arg) in args.iter().enumerate() {
174 let expected_type = &sig[i];
175 self.unify_arg(*arg, expected_type.clone(), var_types)?;
178 }
179 }
180 Ok(())
181 }
182
183 fn infer_type(&self, term: InstId, var_types: &FxHashMap<NameId, Type>) -> Result<Type> {
184 match self.ir.get(term) {
185 Inst::Var(name) => Ok(var_types.get(name).cloned().unwrap_or(Type::Any)),
186 Inst::Number(_) => Ok(Type::Number),
187 Inst::String(_) => Ok(Type::String),
188 Inst::Bool(_) => Ok(Type::Bool),
189 Inst::Float(_) => Ok(Type::Float),
190 Inst::Bytes(_) => Ok(Type::Bytes),
191 _ => Ok(Type::Any),
193 }
194 }
195
196 fn unify_arg(
197 &self,
198 term: InstId,
199 expected: Type,
200 var_types: &mut FxHashMap<NameId, Type>,
201 ) -> Result<()> {
202 if let Inst::Var(name) = self.ir.get(term) {
203 if let Some(current) = var_types.get(name) {
204 let new_type = self.unify(current.clone(), expected)?;
205 var_types.insert(*name, new_type);
206 } else {
207 var_types.insert(*name, expected);
208 }
209 } else {
210 let actual = self.infer_type(term, var_types)?;
211 self.unify(actual, expected)?;
212 }
213 Ok(())
214 }
215
216 fn unify(&self, t1: Type, t2: Type) -> Result<Type> {
217 match (t1, t2) {
218 (Type::Any, t) => Ok(t),
219 (t, Type::Any) => Ok(t),
220 (t1, t2) if t1 == t2 => Ok(t1),
221 (t1, t2) => Err(anyhow!("Type mismatch: {:?} vs {:?}", t1, t2)),
222 }
223 }
224}