Skip to main content

mangle_analysis/
bounds_check.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Bounds checker for Mangle.
16//!
17//! Validates that facts and rule derivations conform to declared type bounds.
18//! Implements the Go-equivalent bounds analysis with:
19//!
20//! - Inference state tracking with per-variable type accumulation
21//! - Feasible alternatives analysis with special cases for built-in predicates
22//! - Skolemization of polymorphic type variables
23//! - Cross-predicate type inference
24//! - UpperBound/LowerBound for type intersection/union
25
26use anyhow::{Result, anyhow};
27use fxhash::{FxHashMap, FxHashSet};
28use mangle_ir::{Inst, InstId, Ir, NameId};
29
30use crate::name_trie::NameTrie;
31use crate::type_expr::{self, TypeContext};
32
33/// Bounds checker state.
34pub struct BoundsChecker<'a> {
35    ir: &'a mut Ir,
36    name_trie: NameTrie,
37    /// Predicate NameId -> declared type alternatives.
38    /// Each alternative is a Vec<InstId> of argument types.
39    rel_type_map: FxHashMap<NameId, Vec<Vec<InstId>>>,
40    /// Predicate NameId -> rules defining it: (head, premises, transforms).
41    rules_map: FxHashMap<NameId, Vec<(InstId, Vec<InstId>, Vec<InstId>)>>,
42    /// Cross-predicate inference: inferred types for predicates without declarations.
43    inferred: FxHashMap<NameId, Vec<Vec<InstId>>>,
44    /// Cycle detection for cross-predicate inference.
45    visiting: FxHashSet<NameId>,
46    /// Counter for generating fresh type variable names.
47    fresh_var_counter: usize,
48}
49
50impl<'a> BoundsChecker<'a> {
51    pub fn new(ir: &'a mut Ir) -> Self {
52        Self {
53            ir,
54            name_trie: NameTrie::new(),
55            rel_type_map: FxHashMap::default(),
56            rules_map: FxHashMap::default(),
57            inferred: FxHashMap::default(),
58            visiting: FxHashSet::default(),
59            fresh_var_counter: 0,
60        }
61    }
62
63    /// Main entry point: collect declarations, build rules map, check all clauses.
64    pub fn check(&mut self) -> Result<()> {
65        self.collect_declarations()?;
66        self.build_rules_map();
67        self.check_all_clauses()
68    }
69
70    /// Generates a fresh type variable NameId (e.g., `?X0`, `?X1`, ...).
71    fn fresh_var(&mut self) -> NameId {
72        let name = format!("?X{}", self.fresh_var_counter);
73        self.fresh_var_counter += 1;
74        self.ir.intern_name(&name)
75    }
76
77    /// Pass 1: Collect declared types from Decl instructions and build name trie.
78    fn collect_declarations(&mut self) -> Result<()> {
79        let insts: Vec<Inst> = self.ir.insts.clone();
80        for inst in &insts {
81            if let Inst::Decl { atom, bounds, .. } = inst {
82                let pred_name = self.atom_predicate(*atom);
83                if let Some(pred) = pred_name {
84                    let mut alternatives = Vec::new();
85                    for bound_id in bounds {
86                        if let Inst::BoundDecl { base_terms } = self.ir.get(*bound_id) {
87                            let base_terms = base_terms.clone();
88                            // Collect name constants into trie.
89                            for term in &base_terms {
90                                self.name_trie.collect(self.ir, *term);
91                            }
92                            // Build type context with any type variables in this bound.
93                            let any = type_expr::find_or_create_name(self.ir, "/any");
94                            let mut ctx = TypeContext::default();
95                            for term in &base_terms {
96                                let mut vars = FxHashSet::default();
97                                type_expr::collect_vars(self.ir, *term, &mut vars);
98                                for v in vars {
99                                    ctx.entry(v).or_insert(any);
100                                }
101                            }
102                            // Validate wellformedness of each type expression.
103                            for term in &base_terms {
104                                type_expr::wellformed_type(self.ir, &ctx, *term)?;
105                            }
106                            alternatives.push(base_terms);
107                        }
108                    }
109                    if !alternatives.is_empty() {
110                        self.rel_type_map.insert(pred, alternatives);
111                    }
112                }
113            }
114        }
115        Ok(())
116    }
117
118    /// Build a map from predicate NameId to rules (head, premises, transforms).
119    fn build_rules_map(&mut self) {
120        let insts: Vec<Inst> = self.ir.insts.clone();
121        for inst in &insts {
122            if let Inst::Rule {
123                head,
124                premises,
125                transform,
126            } = inst
127            {
128                // Only non-unit clauses (actual rules with premises or transforms).
129                if !premises.is_empty() || !transform.is_empty() {
130                    if let Some(pred) = self.atom_predicate(*head) {
131                        self.rules_map
132                            .entry(pred)
133                            .or_default()
134                            .push((*head, premises.clone(), transform.clone()));
135                    }
136                }
137            }
138        }
139    }
140
141    /// Pass 2: Check all unit clauses and rules against declared bounds.
142    fn check_all_clauses(&mut self) -> Result<()> {
143        let insts: Vec<Inst> = self.ir.insts.clone();
144        for inst in &insts {
145            match inst {
146                Inst::Rule {
147                    head,
148                    premises,
149                    transform,
150                } => {
151                    let head = *head;
152                    let premises = premises.clone();
153                    let transform = transform.clone();
154                    if let Some(pred) = self.atom_predicate(head) {
155                        if let Some(alternatives) = self.rel_type_map.get(&pred).cloned() {
156                            if premises.is_empty() && transform.is_empty() {
157                                self.check_fact(head, &alternatives)?;
158                            } else {
159                                self.check_rule(head, &premises, &transform, &alternatives)?;
160                            }
161                        }
162                    }
163                }
164                _ => {}
165            }
166        }
167        Ok(())
168    }
169
170    /// Check a fact (unit clause head) against declared bound alternatives.
171    fn check_fact(&self, head: InstId, alternatives: &[Vec<InstId>]) -> Result<()> {
172        let args = self.atom_args(head);
173        let pred = self.atom_predicate(head).unwrap();
174        if args.is_empty() && alternatives.is_empty() {
175            return Ok(());
176        }
177
178        let mut errors = Vec::new();
179        for alt in alternatives {
180            match self.check_fact_against_bound(pred, &args, alt) {
181                Ok(()) => return Ok(()),
182                Err(e) => errors.push(e.to_string()),
183            }
184        }
185
186        if errors.is_empty() {
187            return Ok(());
188        }
189
190        let pred_name = self
191            .atom_predicate(head)
192            .map(|p| self.ir.resolve_name(p).to_string())
193            .unwrap_or_else(|| "?".to_string());
194        Err(anyhow!(
195            "fact {}(...) matches none of the bound decls: {}",
196            pred_name,
197            errors.join("; ")
198        ))
199    }
200
201    /// Check a single fact against one bound alternative.
202    fn check_fact_against_bound(
203        &self,
204        pred: NameId,
205        args: &[InstId],
206        bound: &[InstId],
207    ) -> Result<()> {
208        let is_temporal = self.ir.temporal_predicates.contains(&pred);
209        let expected_args = if is_temporal {
210            bound.len() + 2
211        } else {
212            bound.len()
213        };
214        if args.len() != expected_args {
215            return Err(anyhow!(
216                "arity mismatch: fact has {} args, bound has {}{}",
217                args.len(),
218                bound.len(),
219                if is_temporal { " (+2 temporal)" } else { "" }
220            ));
221        }
222        for (i, (arg, type_expr)) in args.iter().zip(bound.iter()).enumerate() {
223            if !type_expr::has_type(self.ir, *type_expr, *arg) {
224                let arg_desc = self.describe_inst(*arg);
225                let type_desc = self.describe_inst(*type_expr);
226                return Err(anyhow!(
227                    "argument {} ({}) does not have type {}",
228                    i,
229                    arg_desc,
230                    type_desc
231                ));
232            }
233        }
234        Ok(())
235    }
236
237    /// Check a rule against declared bound alternatives.
238    ///
239    /// Uses the inference pipeline: for each premise, infer variable types
240    /// via feasible alternatives, then check that head args conform.
241    fn check_rule(
242        &mut self,
243        head: InstId,
244        premises: &[InstId],
245        transforms: &[InstId],
246        alternatives: &[Vec<InstId>],
247    ) -> Result<()> {
248        let head_args = self.atom_args(head);
249        let pred = self.atom_predicate(head).unwrap();
250        let is_temporal = self.ir.temporal_predicates.contains(&pred);
251
252        // Run inference pipeline.
253        let mut state = InferState::new();
254        for premise_id in premises {
255            state = self.infer_from_premise(*premise_id, state)?;
256        }
257
258        // Process transforms.
259        for transform_id in transforms {
260            if let Inst::Transform { var, app } = self.ir.get(*transform_id) {
261                let var = *var;
262                let app = *app;
263                if let Some(v) = var {
264                    let tpe = self.bound_of_arg(app, &state.as_map());
265                    state.add_or_refine_with_ir(self.ir, v, tpe);
266                }
267            }
268        }
269
270        // Compute head tuple types.
271        let var_ranges = state.as_map();
272        let inferred: Vec<InstId> = head_args
273            .iter()
274            .map(|arg| self.bound_of_arg(*arg, &var_ranges))
275            .collect();
276
277        // For temporal predicates, trim synthetic time columns.
278        let check_len = if is_temporal && inferred.len() >= 2 {
279            inferred.len() - 2
280        } else {
281            inferred.len()
282        };
283        let inferred_trimmed = &inferred[..check_len];
284
285        // Check inferred types against each declared alternative.
286        let mut errors = Vec::new();
287        for alt in alternatives {
288            if alt.len() != inferred_trimmed.len() {
289                errors.push(format!(
290                    "arity mismatch: head has {} args, bound has {}",
291                    inferred_trimmed.len(),
292                    alt.len()
293                ));
294                continue;
295            }
296            // Build type context: map any type variables in the alt to /any.
297            let any = type_expr::find_or_create_name(self.ir, "/any");
298            let mut ctx = TypeContext::default();
299            for t in alt.iter() {
300                let mut vars = FxHashSet::default();
301                type_expr::collect_vars(self.ir, *t, &mut vars);
302                for v in vars {
303                    ctx.entry(v).or_insert(any);
304                }
305            }
306            let all_conform = inferred_trimmed
307                .iter()
308                .zip(alt.iter())
309                .all(|(inf, decl)| type_expr::set_conforms(self.ir, &ctx, *inf, *decl));
310            if all_conform {
311                return Ok(());
312            }
313            errors.push(format!(
314                "inferred [{}] does not conform to declared [{}]",
315                inferred_trimmed
316                    .iter()
317                    .map(|i| self.describe_inst(*i))
318                    .collect::<Vec<_>>()
319                    .join(", "),
320                alt.iter()
321                    .map(|i| self.describe_inst(*i))
322                    .collect::<Vec<_>>()
323                    .join(", "),
324            ));
325        }
326
327        if errors.is_empty() {
328            return Ok(());
329        }
330
331        let pred_name = self
332            .atom_predicate(head)
333            .map(|p| self.ir.resolve_name(p).to_string())
334            .unwrap_or_else(|| "?".to_string());
335        Err(anyhow!(
336            "rule for {}(...) does not conform to declared bounds: {}",
337            pred_name,
338            errors.join("; ")
339        ))
340    }
341
342    /// Infer variable types from a single premise, updating the state.
343    fn infer_from_premise(
344        &mut self,
345        premise_id: InstId,
346        mut state: InferState,
347    ) -> Result<InferState> {
348        match self.ir.get(premise_id) {
349            Inst::Atom { predicate, args } => {
350                let pred = *predicate;
351                let args = args.clone();
352
353                // Special case: :match_prefix
354                let pred_name = self.ir.resolve_name(pred).to_string();
355                if pred_name == ":match_prefix" {
356                    return self.infer_match_prefix(&args, state);
357                }
358                if pred_name == ":match_field" {
359                    return self.infer_match_field(&args, state);
360                }
361                if pred_name == ":match_entry" {
362                    return self.infer_match_entry(&args, state);
363                }
364                if pred_name == ":list:member" {
365                    return self.infer_list_member(&args, state);
366                }
367
368                // Regular atom: look up or infer alternatives.
369                let var_ranges = state.as_map();
370                let feasible =
371                    self.get_or_infer_alternatives(pred, &args, &var_ranges);
372
373                if !feasible.is_empty() {
374                    // Use the first feasible alternative to bind variables.
375                    let first = &feasible[0].clone();
376                    for (arg, type_id) in args.iter().zip(first.iter()) {
377                        if let Inst::Var(v) = self.ir.get(*arg) {
378                            let v = *v;
379                            state.add_or_refine_with_ir(self.ir, v, *type_id);
380                        }
381                    }
382                } else if let Some(alternatives) = self.rel_type_map.get(&pred).cloned() {
383                    // Fallback: no feasible alternative, use first declared alt.
384                    if let Some(first_alt) = alternatives.first() {
385                        for (arg, type_id) in args.iter().zip(first_alt.iter()) {
386                            if let Inst::Var(v) = self.ir.get(*arg) {
387                                let v = *v;
388                                state.add_or_refine_with_ir(self.ir, v, *type_id);
389                            }
390                        }
391                    }
392                }
393                Ok(state)
394            }
395            Inst::NegAtom(inner) => {
396                let inner = *inner;
397                // Negated atoms: we can refine types via negative information,
398                // but don't add new bindings.
399                if let Inst::Atom { predicate, args } = self.ir.get(inner) {
400                    let pred = *predicate;
401                    let args = args.clone();
402                    let pred_name = self.ir.resolve_name(pred).to_string();
403
404                    if pred_name == ":match_prefix" && args.len() >= 2 {
405                        // Negative :match_prefix: refine away the prefix type.
406                        if let Inst::Var(v) = self.ir.get(args[0]) {
407                            let v = *v;
408                            let bound = self.bound_of_arg(args[1], &state.as_map());
409                            if let Some(existing) = state.as_map().get(&v).copied() {
410                                if type_expr::is_union_type(self.ir, existing) {
411                                    let refined =
412                                        type_expr::remove_from_union_type(self.ir, bound, existing);
413                                    if !type_expr::is_empty_type(self.ir, refined) {
414                                        state.set_var(v, refined);
415                                    }
416                                }
417                            }
418                        }
419                    }
420                    // Other negated atoms: no type refinement.
421                }
422                Ok(state)
423            }
424            Inst::Eq(left, right) => {
425                let left = *left;
426                let right = *right;
427                let var_ranges = state.as_map();
428
429                if let Inst::Var(lv) = self.ir.get(left) {
430                    let lv = *lv;
431                    let tpe = self.bound_of_arg(right, &var_ranges);
432                    state.add_or_refine_with_ir(self.ir, lv, tpe);
433                }
434                if let Inst::Var(rv) = self.ir.get(right) {
435                    let rv = *rv;
436                    let tpe = self.bound_of_arg(left, &state.as_map());
437                    state.add_or_refine_with_ir(self.ir, rv, tpe);
438                }
439                Ok(state)
440            }
441            Inst::Ineq(left, right) => {
442                let left = *left;
443                let right = *right;
444                let var_ranges = state.as_map();
445
446                // For inequality, both sides must have compatible types.
447                let left_tpe = self.bound_of_arg(left, &var_ranges);
448                let right_tpe = self.bound_of_arg(right, &var_ranges);
449                let ctx = TypeContext::default();
450                let meet = type_expr::lower_bound(self.ir, &ctx, &[left_tpe, right_tpe]);
451                if !type_expr::is_empty_type(self.ir, meet) {
452                    if let Inst::Var(lv) = self.ir.get(left) {
453                        let lv = *lv;
454                        state.add_or_refine_with_ir(self.ir, lv, meet);
455                    }
456                    if let Inst::Var(rv) = self.ir.get(right) {
457                        let rv = *rv;
458                        state.add_or_refine_with_ir(self.ir, rv, meet);
459                    }
460                }
461                Ok(state)
462            }
463            _ => Ok(state),
464        }
465    }
466
467    /// Finds feasible alternatives for a subgoal p(e1...eN) with skolemization.
468    ///
469    /// For each declared alternative:
470    /// 1. Builds argument bounds (uses var_ranges for bound vars, declared type for unbound)
471    /// 2. Collects type variables from the alternative, creates fresh substitution
472    /// 3. Applies substitution to both arg bounds and alternative types
473    /// 4. Checks that LowerBound (with extended type context) is non-empty per position
474    fn feasible_alternatives(
475        &mut self,
476        alternatives: &[Vec<InstId>],
477        args: &[InstId],
478        var_ranges: &FxHashMap<NameId, InstId>,
479    ) -> Vec<Vec<InstId>> {
480        let mut feasible = Vec::new();
481
482        for alt in alternatives {
483            if alt.len() != args.len() {
484                continue;
485            }
486
487            // Step 1: Build argument bounds.
488            // For bound vars: use var_ranges. For unbound vars: use declared type.
489            // For constants: use bound_of_arg.
490            let mut arg_bound = Vec::new();
491            for (i, arg) in args.iter().enumerate() {
492                if let Inst::Var(v) = self.ir.get(*arg) {
493                    let v = *v;
494                    if let Some(&range) = var_ranges.get(&v) {
495                        arg_bound.push(range);
496                    } else {
497                        // Unbound variable: use declared type from this alternative.
498                        arg_bound.push(alt[i]);
499                    }
500                } else {
501                    arg_bound.push(self.bound_of_arg(*arg, var_ranges));
502                }
503            }
504
505            // Step 2: Collect type variables from the alternative.
506            let mut type_vars = FxHashSet::default();
507            for t in alt {
508                type_expr::collect_vars(self.ir, *t, &mut type_vars);
509            }
510
511            // Step 3: Skolemize — create fresh variables for each type variable.
512            let mut subst: FxHashMap<NameId, InstId> = FxHashMap::default();
513            if !type_vars.is_empty() {
514                for v in &type_vars {
515                    let fresh = self.fresh_var();
516                    let fresh_id = self.ir.add_inst(Inst::Var(fresh));
517                    subst.insert(*v, fresh_id);
518                }
519            }
520
521            // Step 4: Apply substitution to arg bounds and alternative.
522            let arg_bound_subst: Vec<InstId> = arg_bound
523                .iter()
524                .map(|t| type_expr::apply_subst(self.ir, *t, &subst))
525                .collect();
526            let alt_subst: Vec<InstId> = alt
527                .iter()
528                .map(|t| type_expr::apply_subst(self.ir, *t, &subst))
529                .collect();
530
531            // Step 5: Build extended type context with fresh vars -> /any.
532            let any = type_expr::find_or_create_name(self.ir, "/any");
533            let mut ctx = TypeContext::default();
534            for fresh_id in subst.values() {
535                if let Inst::Var(v) = self.ir.get(*fresh_id) {
536                    ctx.insert(*v, any);
537                }
538            }
539
540            // Step 6: Per-position feasibility check.
541            let mut is_feasible = true;
542            let mut result_types = Vec::new();
543            for (ab, at) in arg_bound_subst.iter().zip(alt_subst.iter()) {
544                let meet = type_expr::lower_bound(self.ir, &ctx, &[*ab, *at]);
545                if type_expr::is_empty_type(self.ir, meet) {
546                    is_feasible = false;
547                    break;
548                }
549                result_types.push(meet);
550            }
551
552            if is_feasible {
553                feasible.push(result_types);
554            }
555        }
556        feasible
557    }
558
559    /// Looks up or infers type alternatives for a predicate.
560    ///
561    /// Checks declared types first, then already-inferred types, then infers
562    /// from rules. Uses cycle detection to handle recursive predicates.
563    fn get_or_infer_alternatives(
564        &mut self,
565        pred: NameId,
566        args: &[InstId],
567        var_ranges: &FxHashMap<NameId, InstId>,
568    ) -> Vec<Vec<InstId>> {
569        // 1. Check declared types.
570        if let Some(alts) = self.rel_type_map.get(&pred).cloned() {
571            return self.feasible_alternatives(&alts, args, var_ranges);
572        }
573
574        // 2. Check already-inferred types.
575        if let Some(alts) = self.inferred.get(&pred).cloned() {
576            return self.feasible_alternatives(&alts, args, var_ranges);
577        }
578
579        // 3. Cycle detection: if we're already visiting this predicate,
580        // return [/any ... /any] to break the cycle.
581        if self.visiting.contains(&pred) {
582            let any = type_expr::find_or_create_name(self.ir, "/any");
583            return vec![vec![any; args.len()]];
584        }
585
586        // 4. Infer from rules defining this predicate.
587        self.visiting.insert(pred);
588        let inferred = self.infer_rel_types(pred);
589        self.visiting.remove(&pred);
590
591        if !inferred.is_empty() {
592            self.inferred.insert(pred, inferred.clone());
593            return self.feasible_alternatives(&inferred, args, var_ranges);
594        }
595
596        Vec::new()
597    }
598
599    /// Infers relation type alternatives for a predicate from its defining rules.
600    ///
601    /// For each rule defining the predicate, runs inference to determine
602    /// the head tuple types, then collects all alternatives.
603    fn infer_rel_types(&mut self, pred: NameId) -> Vec<Vec<InstId>> {
604        let rules = match self.rules_map.get(&pred) {
605            Some(r) => r.clone(),
606            None => return Vec::new(),
607        };
608
609        let mut alternatives: Vec<Vec<InstId>> = Vec::new();
610
611        for (head, premises, transforms) in &rules {
612            // Run inference pipeline on this clause.
613            if let Some(inferred) = self.infer_clause(*head, premises, transforms) {
614                alternatives.push(inferred);
615            }
616        }
617
618        alternatives
619    }
620
621    /// Runs inference on a single clause, returning inferred head tuple types.
622    fn infer_clause(
623        &mut self,
624        head: InstId,
625        premises: &[InstId],
626        transforms: &[InstId],
627    ) -> Option<Vec<InstId>> {
628        let head_args = self.atom_args(head);
629        let mut state = InferState::new();
630
631        for premise_id in premises {
632            match self.infer_from_premise(*premise_id, state) {
633                Ok(new_state) => state = new_state,
634                Err(_) => return None,
635            }
636        }
637
638        // Process transforms.
639        for transform_id in transforms {
640            if let Inst::Transform { var, app } = self.ir.get(*transform_id) {
641                let var = *var;
642                let app = *app;
643                if let Some(v) = var {
644                    let tpe = self.bound_of_arg(app, &state.as_map());
645                    state.add_or_refine_with_ir(self.ir, v, tpe);
646                }
647            }
648        }
649
650        // Compute head tuple types.
651        let var_ranges = state.as_map();
652        let inferred: Vec<InstId> = head_args
653            .iter()
654            .map(|arg| self.bound_of_arg(*arg, &var_ranges))
655            .collect();
656
657        Some(inferred)
658    }
659
660    /// Special case inference for `:match_prefix(Name, Prefix)`.
661    fn infer_match_prefix(
662        &mut self,
663        args: &[InstId],
664        mut state: InferState,
665    ) -> Result<InferState> {
666        if args.len() != 2 {
667            return Ok(state);
668        }
669        let var_ranges = state.as_map();
670        let tpe = self.bound_of_arg(args[0], &var_ranges);
671        let prefix = self.bound_of_arg(args[1], &var_ranges);
672
673        let ctx = TypeContext::default();
674        let meet = type_expr::lower_bound(self.ir, &ctx, &[tpe, prefix]);
675        if !type_expr::is_empty_type(self.ir, meet) {
676            if let Inst::Var(v) = self.ir.get(args[0]) {
677                let v = *v;
678                state.add_or_refine_with_ir(self.ir, v, meet);
679            }
680            // Second arg (prefix) is typically a constant.
681            let name_type = type_expr::find_or_create_name(self.ir, "/name");
682            if let Inst::Var(v) = self.ir.get(args[1]) {
683                let v = *v;
684                state.add_or_refine_with_ir(self.ir, v, name_type);
685            }
686        }
687        Ok(state)
688    }
689
690    /// Special case inference for `:match_field(Struct, FieldName, Value)`.
691    fn infer_match_field(
692        &mut self,
693        args: &[InstId],
694        mut state: InferState,
695    ) -> Result<InferState> {
696        if args.len() != 3 {
697            return Ok(state);
698        }
699        let var_ranges = state.as_map();
700        let scrutinee_type = self.bound_of_arg(args[0], &var_ranges);
701
702        // Get field name from args[1] (must be a name constant).
703        let field_name_id = match self.ir.get(args[1]) {
704            Inst::Name(n) => Some(*n),
705            _ => None,
706        };
707
708        if let Some(field) = field_name_id {
709            if type_expr::is_struct_type(self.ir, scrutinee_type)
710                || type_expr::is_tagged_union_type(self.ir, scrutinee_type)
711                || type_expr::is_union_type(self.ir, scrutinee_type)
712            {
713                if let Some(field_type) =
714                    type_expr::struct_type_field_deep(self.ir, scrutinee_type, field)
715                {
716                    // Bind the value variable.
717                    let ctx = TypeContext::default();
718                    let value_bound = self.bound_of_arg(args[2], &state.as_map());
719                    let meet =
720                        type_expr::lower_bound(self.ir, &ctx, &[value_bound, field_type]);
721                    if !type_expr::is_empty_type(self.ir, meet) {
722                        if let Inst::Var(v) = self.ir.get(args[2]) {
723                            let v = *v;
724                            state.add_or_refine_with_ir(self.ir, v, meet);
725                        }
726                    }
727                }
728            }
729        }
730        // Bind first arg if variable.
731        let any = type_expr::find_or_create_name(self.ir, "/any");
732        if let Inst::Var(v) = self.ir.get(args[0]) {
733            let v = *v;
734            state.add_or_refine_with_ir(self.ir, v, any);
735        }
736        // Bind second arg (field name) if variable.
737        let name_type = type_expr::find_or_create_name(self.ir, "/name");
738        if let Inst::Var(v) = self.ir.get(args[1]) {
739            let v = *v;
740            state.add_or_refine_with_ir(self.ir, v, name_type);
741        }
742        Ok(state)
743    }
744
745    /// Special case inference for `:match_entry(Map, Key, Value)`.
746    fn infer_match_entry(
747        &mut self,
748        args: &[InstId],
749        mut state: InferState,
750    ) -> Result<InferState> {
751        if args.len() != 3 {
752            return Ok(state);
753        }
754        let var_ranges = state.as_map();
755        let map_type = self.bound_of_arg(args[0], &var_ranges);
756
757        if type_expr::is_map_type(self.ir, map_type) {
758            if let Some((key_type, val_type)) = type_expr::map_type_args(self.ir, map_type) {
759                let ctx = TypeContext::default();
760
761                // Bind key.
762                let key_bound = self.bound_of_arg(args[1], &state.as_map());
763                let key_meet =
764                    type_expr::lower_bound(self.ir, &ctx, &[key_bound, key_type]);
765                if !type_expr::is_empty_type(self.ir, key_meet) {
766                    if let Inst::Var(v) = self.ir.get(args[1]) {
767                        let v = *v;
768                        state.add_or_refine_with_ir(self.ir, v, key_meet);
769                    }
770                }
771
772                // Bind value.
773                let val_bound = self.bound_of_arg(args[2], &state.as_map());
774                let val_meet =
775                    type_expr::lower_bound(self.ir, &ctx, &[val_bound, val_type]);
776                if !type_expr::is_empty_type(self.ir, val_meet) {
777                    if let Inst::Var(v) = self.ir.get(args[2]) {
778                        let v = *v;
779                        state.add_or_refine_with_ir(self.ir, v, val_meet);
780                    }
781                }
782            }
783        }
784        Ok(state)
785    }
786
787    /// Special case inference for `:list:member(Elem, List)`.
788    fn infer_list_member(
789        &mut self,
790        args: &[InstId],
791        mut state: InferState,
792    ) -> Result<InferState> {
793        if args.len() != 2 {
794            return Ok(state);
795        }
796        let var_ranges = state.as_map();
797        let list_type = self.bound_of_arg(args[1], &var_ranges);
798
799        if type_expr::is_list_type(self.ir, list_type) {
800            if let Some(elem_type) = type_expr::list_type_arg(self.ir, list_type) {
801                let ctx = TypeContext::default();
802                let elem_bound = self.bound_of_arg(args[0], &state.as_map());
803                let meet =
804                    type_expr::lower_bound(self.ir, &ctx, &[elem_bound, elem_type]);
805                if !type_expr::is_empty_type(self.ir, meet) {
806                    if let Inst::Var(v) = self.ir.get(args[0]) {
807                        let v = *v;
808                        state.add_or_refine_with_ir(self.ir, v, meet);
809                    }
810                }
811            }
812        }
813        Ok(state)
814    }
815
816    /// Infers the type bound for a single argument.
817    fn bound_of_arg(
818        &mut self,
819        arg: InstId,
820        var_ranges: &FxHashMap<NameId, InstId>,
821    ) -> InstId {
822        match self.ir.get(arg) {
823            Inst::Var(v) => {
824                let v = *v;
825                if let Some(&range) = var_ranges.get(&v) {
826                    range
827                } else {
828                    type_expr::find_or_create_name(self.ir, "/any")
829                }
830            }
831            Inst::Number(_) => type_expr::find_or_create_name(self.ir, "/number"),
832            Inst::Float(_) => type_expr::find_or_create_name(self.ir, "/float64"),
833            Inst::String(_) => type_expr::find_or_create_name(self.ir, "/string"),
834            Inst::Bool(_) => type_expr::find_or_create_name(self.ir, "/bool"),
835            Inst::Time(_) => type_expr::find_or_create_name(self.ir, "/time"),
836            Inst::Duration(_) => type_expr::find_or_create_name(self.ir, "/duration"),
837            Inst::Bytes(_) => type_expr::find_or_create_name(self.ir, "/bytes"),
838            Inst::Name(n) => {
839                let name = self.ir.resolve_name(*n).to_string();
840                let prefix = self.name_trie.prefix_name(&name);
841                type_expr::find_or_create_name(self.ir, &prefix)
842            }
843            Inst::List(elems) => {
844                let elems = elems.clone();
845                if elems.is_empty() {
846                    let bot = type_expr::find_or_create_name(self.ir, "/bot");
847                    return type_expr::new_list_type(self.ir, bot);
848                }
849                let ctx = TypeContext::default();
850                let elem_types: Vec<InstId> = elems
851                    .iter()
852                    .map(|e| self.bound_of_arg(*e, var_ranges))
853                    .collect();
854                let elem_type = type_expr::upper_bound(self.ir, &ctx, &elem_types);
855                type_expr::new_list_type(self.ir, elem_type)
856            }
857            Inst::Map { keys, values } => {
858                let keys = keys.clone();
859                let values = values.clone();
860                let ctx = TypeContext::default();
861                let key_types: Vec<InstId> = keys
862                    .iter()
863                    .map(|k| self.bound_of_arg(*k, var_ranges))
864                    .collect();
865                let val_types: Vec<InstId> = values
866                    .iter()
867                    .map(|v| self.bound_of_arg(*v, var_ranges))
868                    .collect();
869                let kt = type_expr::upper_bound(self.ir, &ctx, &key_types);
870                let vt = type_expr::upper_bound(self.ir, &ctx, &val_types);
871                type_expr::new_map_type(self.ir, kt, vt)
872            }
873            Inst::Struct { fields, values } => {
874                let fields = fields.clone();
875                let values = values.clone();
876                let mut args = Vec::new();
877                for (f, v) in fields.iter().zip(values.iter()) {
878                    let fname = self.ir.resolve_name(*f).to_string();
879                    let fname_id = type_expr::find_or_create_name(self.ir, &fname);
880                    let vtype = self.bound_of_arg(*v, var_ranges);
881                    args.push(fname_id);
882                    args.push(vtype);
883                }
884                type_expr::new_struct_type(self.ir, args)
885            }
886            Inst::ApplyFn { function, args } => {
887                let fname = self.ir.resolve_name(*function).to_string();
888                let args = args.clone();
889                self.bound_of_apply_fn(&fname, &args, var_ranges)
890            }
891            _ => type_expr::find_or_create_name(self.ir, "/any"),
892        }
893    }
894
895    /// Infers a type for a function application expression.
896    fn bound_of_apply_fn(
897        &mut self,
898        fname: &str,
899        args: &[InstId],
900        var_ranges: &FxHashMap<NameId, InstId>,
901    ) -> InstId {
902        match fname {
903            "fn:list" => {
904                if args.is_empty() {
905                    let bot = type_expr::find_or_create_name(self.ir, "/bot");
906                    return type_expr::new_list_type(self.ir, bot);
907                }
908                let ctx = TypeContext::default();
909                let arg_types: Vec<InstId> = args
910                    .iter()
911                    .map(|a| self.bound_of_arg(*a, var_ranges))
912                    .collect();
913                let elem = type_expr::upper_bound(self.ir, &ctx, &arg_types);
914                type_expr::new_list_type(self.ir, elem)
915            }
916            "fn:map" => {
917                let ctx = TypeContext::default();
918                let mut key_types = Vec::new();
919                let mut val_types = Vec::new();
920                let mut i = 0;
921                while i + 1 < args.len() {
922                    key_types.push(self.bound_of_arg(args[i], var_ranges));
923                    val_types.push(self.bound_of_arg(args[i + 1], var_ranges));
924                    i += 2;
925                }
926                let kt = type_expr::upper_bound(self.ir, &ctx, &key_types);
927                let vt = type_expr::upper_bound(self.ir, &ctx, &val_types);
928                type_expr::new_map_type(self.ir, kt, vt)
929            }
930            "fn:struct" => {
931                let mut struct_args = Vec::new();
932                let mut i = 0;
933                while i + 1 < args.len() {
934                    struct_args.push(args[i]); // field name
935                    struct_args.push(self.bound_of_arg(args[i + 1], var_ranges));
936                    i += 2;
937                }
938                type_expr::new_struct_type(self.ir, struct_args)
939            }
940            "fn:tuple" => {
941                let arg_types: Vec<InstId> = args
942                    .iter()
943                    .map(|a| self.bound_of_arg(*a, var_ranges))
944                    .collect();
945                type_expr::new_tuple_type(self.ir, arg_types)
946            }
947            "fn:struct_get" if args.len() == 2 => {
948                let struct_type = self.bound_of_arg(args[0], var_ranges);
949                if let Inst::Name(n) = self.ir.get(args[1]) {
950                    let field = *n;
951                    if let Some(ft) =
952                        type_expr::struct_type_field_deep(self.ir, struct_type, field)
953                    {
954                        return ft;
955                    }
956                }
957                type_expr::find_or_create_name(self.ir, "/any")
958            }
959            "fn:plus" | "fn:minus" | "fn:mult" | "fn:div" => {
960                type_expr::find_or_create_name(self.ir, "/number")
961            }
962            "fn:float_plus" | "fn:float_mult" | "fn:float_div" => {
963                type_expr::find_or_create_name(self.ir, "/float64")
964            }
965            "fn:string:concat" | "fn:string:replace" => {
966                type_expr::find_or_create_name(self.ir, "/string")
967            }
968            "fn:count" | "fn:sum" | "fn:max" | "fn:min" => {
969                type_expr::find_or_create_name(self.ir, "/number")
970            }
971            "fn:collect" | "fn:collect_distinct" => {
972                if args.len() == 1 {
973                    let elem_type = self.bound_of_arg(args[0], var_ranges);
974                    type_expr::new_list_type(self.ir, elem_type)
975                } else {
976                    let any = type_expr::find_or_create_name(self.ir, "/any");
977                    type_expr::new_list_type(self.ir, any)
978                }
979            }
980            _ => type_expr::find_or_create_name(self.ir, "/any"),
981        }
982    }
983
984    // -- Helpers --
985
986    fn atom_predicate(&self, atom_id: InstId) -> Option<NameId> {
987        if let Inst::Atom { predicate, .. } = self.ir.get(atom_id) {
988            Some(*predicate)
989        } else {
990            None
991        }
992    }
993
994    fn atom_args(&self, atom_id: InstId) -> Vec<InstId> {
995        if let Inst::Atom { args, .. } = self.ir.get(atom_id) {
996            args.clone()
997        } else {
998            Vec::new()
999        }
1000    }
1001
1002    /// Simple textual description of an IR instruction for error messages.
1003    fn describe_inst(&self, id: InstId) -> String {
1004        match self.ir.get(id) {
1005            Inst::Name(n) => self.ir.resolve_name(*n).to_string(),
1006            Inst::Number(n) => n.to_string(),
1007            Inst::Float(f) => f.to_string(),
1008            Inst::String(s) => format!("{:?}", self.ir.resolve_string(*s)),
1009            Inst::Bool(b) => b.to_string(),
1010            Inst::Var(v) => self.ir.resolve_name(*v).to_string(),
1011            Inst::ApplyFn { function, args } => {
1012                let fname = self.ir.resolve_name(*function);
1013                let arg_strs: Vec<String> =
1014                    args.iter().map(|a| self.describe_inst(*a)).collect();
1015                format!("{}({})", fname, arg_strs.join(", "))
1016            }
1017            _ => format!("inst#{}", id.index()),
1018        }
1019    }
1020}
1021
1022// ---------------------------------------------------------------------------
1023// InferState
1024// ---------------------------------------------------------------------------
1025
1026/// State of type inference while iterating over premises.
1027///
1028/// Tracks variable bindings with their inferred types.
1029struct InferState {
1030    /// Variable names (parallel with `var_types`).
1031    used_vars: Vec<NameId>,
1032    /// Type bounds for each variable.
1033    var_types: Vec<InstId>,
1034}
1035
1036impl InferState {
1037    fn new() -> Self {
1038        Self {
1039            used_vars: Vec::new(),
1040            var_types: Vec::new(),
1041        }
1042    }
1043
1044    /// Adds a new variable binding or refines an existing one via LowerBound.
1045    fn add_or_refine_with_ir(&mut self, ir: &mut Ir, var: NameId, tpe: InstId) {
1046        if let Some(idx) = self.used_vars.iter().position(|v| *v == var) {
1047            // Variable already bound: intersect existing type with new type.
1048            let existing = self.var_types[idx];
1049            let ctx = TypeContext::default();
1050            let meet = type_expr::lower_bound(ir, &ctx, &[existing, tpe]);
1051            if !type_expr::is_empty_type(ir, meet) {
1052                self.var_types[idx] = meet;
1053            }
1054            // If intersection is empty, keep the existing type (conservative).
1055        } else {
1056            self.used_vars.push(var);
1057            self.var_types.push(tpe);
1058        }
1059    }
1060
1061    /// Sets a variable's type directly (for negative refinement).
1062    fn set_var(&mut self, var: NameId, tpe: InstId) {
1063        if let Some(idx) = self.used_vars.iter().position(|v| *v == var) {
1064            self.var_types[idx] = tpe;
1065        }
1066    }
1067
1068    /// Converts the state to a HashMap for lookups.
1069    fn as_map(&self) -> FxHashMap<NameId, InstId> {
1070        self.used_vars
1071            .iter()
1072            .zip(self.var_types.iter())
1073            .map(|(v, t)| (*v, *t))
1074            .collect()
1075    }
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080    use super::*;
1081    use crate::LoweringContext;
1082    use mangle_ast as ast;
1083    use mangle_parse::Parser;
1084
1085    /// Helper: parse source, lower, run bounds checker.
1086    fn check(source: &str) -> Result<()> {
1087        let arena = ast::Arena::new_with_global_interner();
1088        let mut parser = Parser::new(&arena, source.as_bytes(), "test");
1089        parser.next_token().unwrap();
1090        let unit = parser.parse_unit().unwrap();
1091        let ctx = LoweringContext::new(&arena);
1092        let mut ir = ctx.lower_unit(&unit);
1093        let mut checker = BoundsChecker::new(&mut ir);
1094        checker.check()
1095    }
1096
1097    // -----------------------------------------------------------------------
1098    // Basic facts and rules (existing tests, now parser-based)
1099    // -----------------------------------------------------------------------
1100
1101    #[test]
1102    fn check_valid_fact() {
1103        let arena = ast::Arena::new_with_global_interner();
1104
1105        // Decl foo(X) bound [/number].
1106        let foo_sym = arena.predicate_sym("foo", Some(1));
1107        let var_x = arena.variable("X");
1108        let atom_foo_x = arena.atom(foo_sym, &[var_x]);
1109        let num_type = arena.const_(arena.name("/number"));
1110        let bound_decl = ast::BoundDecl {
1111            base_terms: arena.alloc_slice_copy(&[num_type]),
1112        };
1113        let decl = ast::Decl {
1114            atom: atom_foo_x,
1115            descr: &[],
1116            bounds: Some(arena.alloc_slice_copy(&[arena.alloc(bound_decl)])),
1117            constraints: None,
1118            is_temporal: false,
1119        };
1120
1121        // foo(42).
1122        let const_42 = arena.const_(ast::Const::Number(42));
1123        let atom_foo_42 = arena.atom(foo_sym, &[const_42]);
1124        let clause = ast::Clause {
1125            head: atom_foo_42,
1126            head_time: None,
1127            premises: &[],
1128            transform: &[],
1129        };
1130
1131        let unit = ast::Unit {
1132            decls: arena.alloc_slice_copy(&[&decl]),
1133            clauses: arena.alloc_slice_copy(&[&clause]),
1134        };
1135
1136        let ctx = LoweringContext::new(&arena);
1137        let mut ir = ctx.lower_unit(&unit);
1138        let mut checker = BoundsChecker::new(&mut ir);
1139        assert!(checker.check().is_ok());
1140    }
1141
1142    #[test]
1143    fn check_invalid_fact_type_mismatch() {
1144        let arena = ast::Arena::new_with_global_interner();
1145
1146        // Decl foo(X) bound [/number].
1147        let foo_sym = arena.predicate_sym("foo", Some(1));
1148        let var_x = arena.variable("X");
1149        let atom_foo_x = arena.atom(foo_sym, &[var_x]);
1150        let num_type = arena.const_(arena.name("/number"));
1151        let bound_decl = ast::BoundDecl {
1152            base_terms: arena.alloc_slice_copy(&[num_type]),
1153        };
1154        let decl = ast::Decl {
1155            atom: atom_foo_x,
1156            descr: &[],
1157            bounds: Some(arena.alloc_slice_copy(&[arena.alloc(bound_decl)])),
1158            constraints: None,
1159            is_temporal: false,
1160        };
1161
1162        // foo("hello"). -> Type mismatch.
1163        let const_str = arena.const_(ast::Const::String("hello"));
1164        let atom_foo_bad = arena.atom(foo_sym, &[const_str]);
1165        let clause = ast::Clause {
1166            head: atom_foo_bad,
1167            head_time: None,
1168            premises: &[],
1169            transform: &[],
1170        };
1171
1172        let unit = ast::Unit {
1173            decls: arena.alloc_slice_copy(&[&decl]),
1174            clauses: arena.alloc_slice_copy(&[&clause]),
1175        };
1176
1177        let ctx = LoweringContext::new(&arena);
1178        let mut ir = ctx.lower_unit(&unit);
1179        let mut checker = BoundsChecker::new(&mut ir);
1180        let result = checker.check();
1181        assert!(result.is_err(), "expected type mismatch error");
1182    }
1183
1184    #[test]
1185    fn check_valid_rule() {
1186        let arena = ast::Arena::new_with_global_interner();
1187
1188        // Decl src(X) bound [/number].
1189        let src_sym = arena.predicate_sym("src", Some(1));
1190        let var_x = arena.variable("X");
1191        let atom_src_x = arena.atom(src_sym, &[var_x]);
1192        let num_type = arena.const_(arena.name("/number"));
1193        let bound_decl = ast::BoundDecl {
1194            base_terms: arena.alloc_slice_copy(&[num_type]),
1195        };
1196        let decl_src = ast::Decl {
1197            atom: atom_src_x,
1198            descr: &[],
1199            bounds: Some(arena.alloc_slice_copy(&[arena.alloc(bound_decl)])),
1200            constraints: None,
1201            is_temporal: false,
1202        };
1203
1204        // Decl dst(X) bound [/number].
1205        let dst_sym = arena.predicate_sym("dst", Some(1));
1206        let var_y = arena.variable("Y");
1207        let atom_dst_y = arena.atom(dst_sym, &[var_y]);
1208        let num_type2 = arena.const_(arena.name("/number"));
1209        let bound_decl2 = ast::BoundDecl {
1210            base_terms: arena.alloc_slice_copy(&[num_type2]),
1211        };
1212        let decl_dst = ast::Decl {
1213            atom: atom_dst_y,
1214            descr: &[],
1215            bounds: Some(arena.alloc_slice_copy(&[arena.alloc(bound_decl2)])),
1216            constraints: None,
1217            is_temporal: false,
1218        };
1219
1220        // dst(X) :- src(X).
1221        let var_x2 = arena.variable("X");
1222        let head = arena.atom(dst_sym, &[var_x2]);
1223        let var_x3 = arena.variable("X");
1224        let body = arena.atom(src_sym, &[var_x3]);
1225        let clause = ast::Clause {
1226            head,
1227            head_time: None,
1228            premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(body))]),
1229            transform: &[],
1230        };
1231
1232        let unit = ast::Unit {
1233            decls: arena.alloc_slice_copy(&[&decl_src, &decl_dst]),
1234            clauses: arena.alloc_slice_copy(&[&clause]),
1235        };
1236
1237        let ctx = LoweringContext::new(&arena);
1238        let mut ir = ctx.lower_unit(&unit);
1239        let mut checker = BoundsChecker::new(&mut ir);
1240        assert!(checker.check().is_ok());
1241    }
1242
1243    #[test]
1244    fn check_arity_mismatch() {
1245        let arena = ast::Arena::new_with_global_interner();
1246
1247        // Decl foo(X) bound [/number].
1248        let foo_sym = arena.predicate_sym("foo", Some(1));
1249        let var_x = arena.variable("X");
1250        let atom_foo_x = arena.atom(foo_sym, &[var_x]);
1251        let num_type = arena.const_(arena.name("/number"));
1252        let bound_decl = ast::BoundDecl {
1253            base_terms: arena.alloc_slice_copy(&[num_type]),
1254        };
1255        let decl = ast::Decl {
1256            atom: atom_foo_x,
1257            descr: &[],
1258            bounds: Some(arena.alloc_slice_copy(&[arena.alloc(bound_decl)])),
1259            constraints: None,
1260            is_temporal: false,
1261        };
1262
1263        // foo(42, 43). -> Arity mismatch.
1264        let const_42 = arena.const_(ast::Const::Number(42));
1265        let const_43 = arena.const_(ast::Const::Number(43));
1266        let atom_foo_bad = arena.atom(foo_sym, &[const_42, const_43]);
1267        let clause = ast::Clause {
1268            head: atom_foo_bad,
1269            head_time: None,
1270            premises: &[],
1271            transform: &[],
1272        };
1273
1274        let unit = ast::Unit {
1275            decls: arena.alloc_slice_copy(&[&decl]),
1276            clauses: arena.alloc_slice_copy(&[&clause]),
1277        };
1278
1279        let ctx = LoweringContext::new(&arena);
1280        let mut ir = ctx.lower_unit(&unit);
1281        let mut checker = BoundsChecker::new(&mut ir);
1282        let result = checker.check();
1283        assert!(result.is_err());
1284    }
1285
1286    // -----------------------------------------------------------------------
1287    // Parser-based tests: multiple bound alternatives
1288    // -----------------------------------------------------------------------
1289
1290    #[test]
1291    fn multiple_alternatives_first_matches() {
1292        // pair(42, 99) matches first alternative [/number, /number].
1293        assert!(check(r#"
1294            Decl pair(X, Y) bound [/number, /number] bound [/string, /string].
1295            pair(42, 99).
1296        "#).is_ok());
1297    }
1298
1299    #[test]
1300    fn multiple_alternatives_second_matches() {
1301        // pair("a", "b") matches second alternative [/string, /string].
1302        assert!(check(r#"
1303            Decl pair(X, Y) bound [/number, /number] bound [/string, /string].
1304            pair("a", "b").
1305        "#).is_ok());
1306    }
1307
1308    #[test]
1309    fn multiple_alternatives_none_matches() {
1310        // pair(42, "b") matches neither alternative.
1311        assert!(check(r#"
1312            Decl pair(X, Y) bound [/number, /number] bound [/string, /string].
1313            pair(42, "b").
1314        "#).is_err());
1315    }
1316
1317    // -----------------------------------------------------------------------
1318    // Rule type inference: variable binding from premises
1319    // -----------------------------------------------------------------------
1320
1321    #[test]
1322    fn rule_infers_type_from_premise() {
1323        // X gets type /number from src, which conforms to dst's bound.
1324        assert!(check(r#"
1325            Decl src(X) bound [/number].
1326            Decl dst(X) bound [/number].
1327            dst(X) :- src(X).
1328        "#).is_ok());
1329    }
1330
1331    #[test]
1332    fn rule_type_mismatch_from_premise() {
1333        // X inferred as /string from src, but dst expects /number.
1334        assert!(check(r#"
1335            Decl src(X) bound [/string].
1336            Decl dst(X) bound [/number].
1337            dst(X) :- src(X).
1338        "#).is_err());
1339    }
1340
1341    // -----------------------------------------------------------------------
1342    // Multiple body atoms refining the same variable (LowerBound)
1343    // -----------------------------------------------------------------------
1344
1345    #[test]
1346    fn two_premises_refine_variable() {
1347        // X starts as fn:Union(/number, /string) from 'wide',
1348        // then refined to /number from 'narrow'. Should conform to /number.
1349        assert!(check(r#"
1350            Decl wide(X) bound [fn:Union(/number, /string)].
1351            Decl narrow(X) bound [/number].
1352            Decl result(X) bound [/number].
1353            result(X) :- wide(X), narrow(X).
1354        "#).is_ok());
1355    }
1356
1357    #[test]
1358    fn two_premises_refine_to_incompatible() {
1359        // X inferred as /string from src1, then /number from src2.
1360        // Intersection is empty, so X keeps /string (conservative).
1361        // /string does not conform to /number → error.
1362        assert!(check(r#"
1363            Decl src1(X) bound [/string].
1364            Decl src2(X) bound [/number].
1365            Decl dst(X) bound [/number].
1366            dst(X) :- src1(X), src2(X).
1367        "#).is_err());
1368    }
1369
1370    // -----------------------------------------------------------------------
1371    // Polymorphic type declarations (skolemization)
1372    // -----------------------------------------------------------------------
1373
1374    #[test]
1375    fn polymorphic_identity_number() {
1376        // T is a type variable. pair(42, 99) should pass: T can be /number.
1377        assert!(check(r#"
1378            Decl pair(X, Y) bound [T, T].
1379            pair(42, 99).
1380        "#).is_ok());
1381    }
1382
1383    #[test]
1384    fn polymorphic_identity_string() {
1385        // T is a type variable. pair("a", "b") should pass: T can be /string.
1386        assert!(check(r#"
1387            Decl pair(X, Y) bound [T, T].
1388            pair("a", "b").
1389        "#).is_ok());
1390    }
1391
1392    #[test]
1393    fn polymorphic_rule_with_inferred_type() {
1394        // T skolemized to fresh var. X inferred as /number from src.
1395        // /number conforms to ?X0 (mapped to /any in context) → passes.
1396        assert!(check(r#"
1397            Decl src(X) bound [/number].
1398            Decl dst(X) bound [T].
1399            dst(X) :- src(X).
1400        "#).is_ok());
1401    }
1402
1403    // -----------------------------------------------------------------------
1404    // Cross-predicate inference
1405    // -----------------------------------------------------------------------
1406
1407    #[test]
1408    fn cross_predicate_inference_basic() {
1409        // 'helper' has no declaration. Its type is inferred from its rule
1410        // (which uses 'src' with bound [/number]). Then 'dst' uses 'helper'.
1411        assert!(check(r#"
1412            Decl src(X) bound [/number].
1413            Decl dst(X) bound [/number].
1414            helper(X) :- src(X).
1415            dst(X) :- helper(X).
1416        "#).is_ok());
1417    }
1418
1419    #[test]
1420    fn cross_predicate_inference_type_mismatch() {
1421        // 'helper' inferred as /string from src. dst expects /number → error.
1422        assert!(check(r#"
1423            Decl src(X) bound [/string].
1424            Decl dst(X) bound [/number].
1425            helper(X) :- src(X).
1426            dst(X) :- helper(X).
1427        "#).is_err());
1428    }
1429
1430    #[test]
1431    fn cross_predicate_inference_chain() {
1432        // Chain: src → mid → dst, only src and dst declared.
1433        assert!(check(r#"
1434            Decl src(X) bound [/number].
1435            Decl dst(X) bound [/number].
1436            mid(X) :- src(X).
1437            dst(X) :- mid(X).
1438        "#).is_ok());
1439    }
1440
1441    // -----------------------------------------------------------------------
1442    // Equality and inequality premises
1443    // -----------------------------------------------------------------------
1444
1445    #[test]
1446    fn equality_binds_variable() {
1447        // X = "hello" gives X type /string.
1448        assert!(check(r#"
1449            Decl src(X) bound [/string].
1450            Decl dst(X) bound [/string].
1451            dst(X) :- src(X), X = "hello".
1452        "#).is_ok());
1453    }
1454
1455    #[test]
1456    fn inequality_refines_variable() {
1457        // X from src is /string, X != "bad" should still be /string.
1458        assert!(check(r#"
1459            Decl src(X) bound [/string].
1460            Decl dst(X) bound [/string].
1461            dst(X) :- src(X), X != "bad".
1462        "#).is_ok());
1463    }
1464
1465    // -----------------------------------------------------------------------
1466    // Transform (let) expressions
1467    // -----------------------------------------------------------------------
1468
1469    #[test]
1470    fn transform_arithmetic() {
1471        // let Y = fn:plus(X, 1) → Y inferred as /number.
1472        assert!(check(r#"
1473            Decl src(X) bound [/number].
1474            Decl dst(X, Y) bound [/number, /number].
1475            dst(X, Y) :- src(X) |> let Y = fn:plus(X, 1).
1476        "#).is_ok());
1477    }
1478
1479    #[test]
1480    fn transform_string_concat() {
1481        // let Y = fn:string:concat(X, "!") → Y inferred as /string.
1482        assert!(check(r#"
1483            Decl src(X) bound [/string].
1484            Decl dst(X, Y) bound [/string, /string].
1485            dst(X, Y) :- src(X) |> let Y = fn:string:concat(X, "!").
1486        "#).is_ok());
1487    }
1488
1489    #[test]
1490    fn transform_type_mismatch() {
1491        // Y = fn:plus(X, 1) → /number, but dst expects /string for Y.
1492        assert!(check(r#"
1493            Decl src(X) bound [/number].
1494            Decl dst(X, Y) bound [/number, /string].
1495            dst(X, Y) :- src(X) |> let Y = fn:plus(X, 1).
1496        "#).is_err());
1497    }
1498
1499    // -----------------------------------------------------------------------
1500    // No-declaration predicates (no bounds checking needed)
1501    // -----------------------------------------------------------------------
1502
1503    #[test]
1504    fn undeclared_predicate_passes() {
1505        // Rules with no declarations should pass without error.
1506        assert!(check(r#"
1507            foo(1).
1508            bar(X) :- foo(X).
1509        "#).is_ok());
1510    }
1511}