Skip to main content

mangle_analysis/
type_check.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{Result, anyhow};
16use fxhash::FxHashMap;
17use mangle_ir::{Inst, InstId, Ir, NameId};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum Type {
21    Any,
22    Bool,
23    Number,
24    Float,
25    String,
26    Bytes,
27    List(Box<Type>),
28    #[allow(dead_code)]
29    Map(Box<Type>, Box<Type>),
30    #[allow(dead_code)]
31    Struct, // Simplified
32            // TODO: More precise types
33}
34
35pub struct TypeChecker<'a> {
36    ir: &'a Ir,
37    // Predicate name -> Arg types
38    signatures: FxHashMap<NameId, Vec<Type>>,
39}
40
41impl<'a> TypeChecker<'a> {
42    pub fn new(ir: &'a Ir) -> Self {
43        Self {
44            ir,
45            signatures: FxHashMap::default(),
46        }
47    }
48
49    pub fn check(&mut self) -> Result<()> {
50        // Pass 1: Collect signatures from Decls
51        for inst in &self.ir.insts {
52            if let Inst::Decl { atom, bounds, .. } = inst {
53                self.collect_signature(*atom, bounds)?;
54            }
55        }
56
57        // Pass 2: Check Rules
58        for inst in &self.ir.insts {
59            if let Inst::Rule {
60                head,
61                premises,
62                transform,
63            } = inst
64            {
65                self.check_rule(*head, premises, transform)?;
66            }
67        }
68        Ok(())
69    }
70
71    fn collect_signature(&mut self, atom_id: InstId, bounds: &[InstId]) -> Result<()> {
72        let atom = self.ir.get(atom_id);
73        if let Inst::Atom { predicate, args } = atom {
74            let mut types = Vec::new();
75            if !bounds.is_empty() {
76                // Bounds map to args?
77                // Mangle syntax: bound [/type1, /type2]
78                // If there are multiple bound decls, it's intersection or union?
79                // Usually one bound decl per rule/pred?
80                // Assuming first bound decl defines signature for now.
81                if let Some(first_bound_id) = bounds.first()
82                    && let Inst::BoundDecl { base_terms } = self.ir.get(*first_bound_id)
83                {
84                    for term_id in base_terms {
85                        types.push(self.resolve_type(*term_id)?);
86                    }
87                }
88            } else {
89                // Default to Any
90                for _ in args {
91                    types.push(Type::Any);
92                }
93            }
94            self.signatures.insert(*predicate, types);
95        }
96        Ok(())
97    }
98
99    fn resolve_type(&self, type_term_id: InstId) -> Result<Type> {
100        let inst = self.ir.get(type_term_id);
101        match inst {
102            Inst::Name(s) => match self.ir.resolve_name(*s) {
103                "/string" => Ok(Type::String),
104                "/number" => Ok(Type::Number),
105                "/float" => Ok(Type::Float),
106                "/bool" => Ok(Type::Bool),
107                "/bytes" => Ok(Type::Bytes),
108                _ => Ok(Type::Any), // Unknown type name
109            },
110            Inst::ApplyFn { function, args } => {
111                match self.ir.resolve_name(*function) {
112                    "fn:List" | "fn:list" => {
113                        let inner = if let Some(arg) = args.first() {
114                            self.resolve_type(*arg)?
115                        } else {
116                            Type::Any
117                        };
118                        Ok(Type::List(Box::new(inner)))
119                    }
120                    // TODO: Map, Struct
121                    _ => Ok(Type::Any),
122                }
123            }
124            _ => Ok(Type::Any),
125        }
126    }
127
128    fn check_rule(&self, head: InstId, premises: &[InstId], _transform: &[InstId]) -> Result<()> {
129        let mut var_types: FxHashMap<NameId, Type> = FxHashMap::default();
130
131        // Check premises
132        for premise in premises {
133            self.check_premise(*premise, &mut var_types)?;
134        }
135
136        // Check head
137        self.check_atom(head, &mut var_types)?;
138
139        Ok(())
140    }
141
142    fn check_premise(
143        &self,
144        premise: InstId,
145        var_types: &mut FxHashMap<NameId, Type>,
146    ) -> Result<()> {
147        match self.ir.get(premise) {
148            Inst::Atom { .. } => self.check_atom(premise, var_types),
149            Inst::NegAtom(a) => self.check_atom(*a, var_types),
150            Inst::Eq(l, r) => {
151                // Unify types of l and r
152                let t_l = self.infer_type(*l, var_types)?;
153                let t_r = self.infer_type(*r, var_types)?;
154                self.unify(t_l, t_r).map(|_| ())
155            }
156            // ...
157            _ => Ok(()),
158        }
159    }
160
161    fn check_atom(&self, atom_id: InstId, var_types: &mut FxHashMap<NameId, Type>) -> Result<()> {
162        if let Inst::Atom { predicate, args } = self.ir.get(atom_id)
163            && let Some(sig) = self.signatures.get(predicate)
164        {
165            if sig.len() != args.len() {
166                return Err(anyhow!(
167                    "Arity mismatch for {}: expected {}, got {}",
168                    self.ir.resolve_name(*predicate),
169                    sig.len(),
170                    args.len()
171                ));
172            }
173            for (i, arg) in args.iter().enumerate() {
174                let expected_type = &sig[i];
175                // Update variable type if it was Any/Unknown
176                // Or check if it matches
177                self.unify_arg(*arg, expected_type.clone(), var_types)?;
178            }
179        }
180        Ok(())
181    }
182
183    fn infer_type(&self, term: InstId, var_types: &FxHashMap<NameId, Type>) -> Result<Type> {
184        match self.ir.get(term) {
185            Inst::Var(name) => Ok(var_types.get(name).cloned().unwrap_or(Type::Any)),
186            Inst::Number(_) => Ok(Type::Number),
187            Inst::String(_) => Ok(Type::String),
188            Inst::Bool(_) => Ok(Type::Bool),
189            Inst::Float(_) => Ok(Type::Float),
190            Inst::Bytes(_) => Ok(Type::Bytes),
191            // ...
192            _ => Ok(Type::Any),
193        }
194    }
195
196    fn unify_arg(
197        &self,
198        term: InstId,
199        expected: Type,
200        var_types: &mut FxHashMap<NameId, Type>,
201    ) -> Result<()> {
202        if let Inst::Var(name) = self.ir.get(term) {
203            if let Some(current) = var_types.get(name) {
204                let new_type = self.unify(current.clone(), expected)?;
205                var_types.insert(*name, new_type);
206            } else {
207                var_types.insert(*name, expected);
208            }
209        } else {
210            let actual = self.infer_type(term, var_types)?;
211            self.unify(actual, expected)?;
212        }
213        Ok(())
214    }
215
216    fn unify(&self, t1: Type, t2: Type) -> Result<Type> {
217        match (t1, t2) {
218            (Type::Any, t) => Ok(t),
219            (t, Type::Any) => Ok(t),
220            (t1, t2) if t1 == t2 => Ok(t1),
221            (t1, t2) => Err(anyhow!("Type mismatch: {:?} vs {:?}", t1, t2)),
222        }
223    }
224}