1use anyhow::{Result, anyhow};
27use fxhash::{FxHashMap, FxHashSet};
28use mangle_ir::{Inst, InstId, Ir, NameId};
29
30use crate::name_trie::NameTrie;
31use crate::type_expr::{self, TypeContext};
32
33pub struct BoundsChecker<'a> {
35 ir: &'a mut Ir,
36 name_trie: NameTrie,
37 rel_type_map: FxHashMap<NameId, Vec<Vec<InstId>>>,
40 rules_map: FxHashMap<NameId, Vec<(InstId, Vec<InstId>, Vec<InstId>)>>,
42 inferred: FxHashMap<NameId, Vec<Vec<InstId>>>,
44 visiting: FxHashSet<NameId>,
46 fresh_var_counter: usize,
48}
49
50impl<'a> BoundsChecker<'a> {
51 pub fn new(ir: &'a mut Ir) -> Self {
52 Self {
53 ir,
54 name_trie: NameTrie::new(),
55 rel_type_map: FxHashMap::default(),
56 rules_map: FxHashMap::default(),
57 inferred: FxHashMap::default(),
58 visiting: FxHashSet::default(),
59 fresh_var_counter: 0,
60 }
61 }
62
63 pub fn check(&mut self) -> Result<()> {
65 self.collect_declarations()?;
66 self.build_rules_map();
67 self.check_all_clauses()
68 }
69
70 fn fresh_var(&mut self) -> NameId {
72 let name = format!("?X{}", self.fresh_var_counter);
73 self.fresh_var_counter += 1;
74 self.ir.intern_name(&name)
75 }
76
77 fn collect_declarations(&mut self) -> Result<()> {
79 let insts: Vec<Inst> = self.ir.insts.clone();
80 for inst in &insts {
81 if let Inst::Decl { atom, bounds, .. } = inst {
82 let pred_name = self.atom_predicate(*atom);
83 if let Some(pred) = pred_name {
84 let mut alternatives = Vec::new();
85 for bound_id in bounds {
86 if let Inst::BoundDecl { base_terms } = self.ir.get(*bound_id) {
87 let base_terms = base_terms.clone();
88 for term in &base_terms {
90 self.name_trie.collect(self.ir, *term);
91 }
92 let any = type_expr::find_or_create_name(self.ir, "/any");
94 let mut ctx = TypeContext::default();
95 for term in &base_terms {
96 let mut vars = FxHashSet::default();
97 type_expr::collect_vars(self.ir, *term, &mut vars);
98 for v in vars {
99 ctx.entry(v).or_insert(any);
100 }
101 }
102 for term in &base_terms {
104 type_expr::wellformed_type(self.ir, &ctx, *term)?;
105 }
106 alternatives.push(base_terms);
107 }
108 }
109 if !alternatives.is_empty() {
110 self.rel_type_map.insert(pred, alternatives);
111 }
112 }
113 }
114 }
115 Ok(())
116 }
117
118 fn build_rules_map(&mut self) {
120 let insts: Vec<Inst> = self.ir.insts.clone();
121 for inst in &insts {
122 if let Inst::Rule {
123 head,
124 premises,
125 transform,
126 } = inst
127 {
128 if !premises.is_empty() || !transform.is_empty() {
130 if let Some(pred) = self.atom_predicate(*head) {
131 self.rules_map
132 .entry(pred)
133 .or_default()
134 .push((*head, premises.clone(), transform.clone()));
135 }
136 }
137 }
138 }
139 }
140
141 fn check_all_clauses(&mut self) -> Result<()> {
143 let insts: Vec<Inst> = self.ir.insts.clone();
144 for inst in &insts {
145 match inst {
146 Inst::Rule {
147 head,
148 premises,
149 transform,
150 } => {
151 let head = *head;
152 let premises = premises.clone();
153 let transform = transform.clone();
154 if let Some(pred) = self.atom_predicate(head) {
155 if let Some(alternatives) = self.rel_type_map.get(&pred).cloned() {
156 if premises.is_empty() && transform.is_empty() {
157 self.check_fact(head, &alternatives)?;
158 } else {
159 self.check_rule(head, &premises, &transform, &alternatives)?;
160 }
161 }
162 }
163 }
164 _ => {}
165 }
166 }
167 Ok(())
168 }
169
170 fn check_fact(&self, head: InstId, alternatives: &[Vec<InstId>]) -> Result<()> {
172 let args = self.atom_args(head);
173 let pred = self.atom_predicate(head).unwrap();
174 if args.is_empty() && alternatives.is_empty() {
175 return Ok(());
176 }
177
178 let mut errors = Vec::new();
179 for alt in alternatives {
180 match self.check_fact_against_bound(pred, &args, alt) {
181 Ok(()) => return Ok(()),
182 Err(e) => errors.push(e.to_string()),
183 }
184 }
185
186 if errors.is_empty() {
187 return Ok(());
188 }
189
190 let pred_name = self
191 .atom_predicate(head)
192 .map(|p| self.ir.resolve_name(p).to_string())
193 .unwrap_or_else(|| "?".to_string());
194 Err(anyhow!(
195 "fact {}(...) matches none of the bound decls: {}",
196 pred_name,
197 errors.join("; ")
198 ))
199 }
200
201 fn check_fact_against_bound(
203 &self,
204 pred: NameId,
205 args: &[InstId],
206 bound: &[InstId],
207 ) -> Result<()> {
208 let is_temporal = self.ir.temporal_predicates.contains(&pred);
209 let expected_args = if is_temporal {
210 bound.len() + 2
211 } else {
212 bound.len()
213 };
214 if args.len() != expected_args {
215 return Err(anyhow!(
216 "arity mismatch: fact has {} args, bound has {}{}",
217 args.len(),
218 bound.len(),
219 if is_temporal { " (+2 temporal)" } else { "" }
220 ));
221 }
222 for (i, (arg, type_expr)) in args.iter().zip(bound.iter()).enumerate() {
223 if !type_expr::has_type(self.ir, *type_expr, *arg) {
224 let arg_desc = self.describe_inst(*arg);
225 let type_desc = self.describe_inst(*type_expr);
226 return Err(anyhow!(
227 "argument {} ({}) does not have type {}",
228 i,
229 arg_desc,
230 type_desc
231 ));
232 }
233 }
234 Ok(())
235 }
236
237 fn check_rule(
242 &mut self,
243 head: InstId,
244 premises: &[InstId],
245 transforms: &[InstId],
246 alternatives: &[Vec<InstId>],
247 ) -> Result<()> {
248 let head_args = self.atom_args(head);
249 let pred = self.atom_predicate(head).unwrap();
250 let is_temporal = self.ir.temporal_predicates.contains(&pred);
251
252 let mut state = InferState::new();
254 for premise_id in premises {
255 state = self.infer_from_premise(*premise_id, state)?;
256 }
257
258 for transform_id in transforms {
260 if let Inst::Transform { var, app } = self.ir.get(*transform_id) {
261 let var = *var;
262 let app = *app;
263 if let Some(v) = var {
264 let tpe = self.bound_of_arg(app, &state.as_map());
265 state.add_or_refine_with_ir(self.ir, v, tpe);
266 }
267 }
268 }
269
270 let var_ranges = state.as_map();
272 let inferred: Vec<InstId> = head_args
273 .iter()
274 .map(|arg| self.bound_of_arg(*arg, &var_ranges))
275 .collect();
276
277 let check_len = if is_temporal && inferred.len() >= 2 {
279 inferred.len() - 2
280 } else {
281 inferred.len()
282 };
283 let inferred_trimmed = &inferred[..check_len];
284
285 let mut errors = Vec::new();
287 for alt in alternatives {
288 if alt.len() != inferred_trimmed.len() {
289 errors.push(format!(
290 "arity mismatch: head has {} args, bound has {}",
291 inferred_trimmed.len(),
292 alt.len()
293 ));
294 continue;
295 }
296 let any = type_expr::find_or_create_name(self.ir, "/any");
298 let mut ctx = TypeContext::default();
299 for t in alt.iter() {
300 let mut vars = FxHashSet::default();
301 type_expr::collect_vars(self.ir, *t, &mut vars);
302 for v in vars {
303 ctx.entry(v).or_insert(any);
304 }
305 }
306 let all_conform = inferred_trimmed
307 .iter()
308 .zip(alt.iter())
309 .all(|(inf, decl)| type_expr::set_conforms(self.ir, &ctx, *inf, *decl));
310 if all_conform {
311 return Ok(());
312 }
313 errors.push(format!(
314 "inferred [{}] does not conform to declared [{}]",
315 inferred_trimmed
316 .iter()
317 .map(|i| self.describe_inst(*i))
318 .collect::<Vec<_>>()
319 .join(", "),
320 alt.iter()
321 .map(|i| self.describe_inst(*i))
322 .collect::<Vec<_>>()
323 .join(", "),
324 ));
325 }
326
327 if errors.is_empty() {
328 return Ok(());
329 }
330
331 let pred_name = self
332 .atom_predicate(head)
333 .map(|p| self.ir.resolve_name(p).to_string())
334 .unwrap_or_else(|| "?".to_string());
335 Err(anyhow!(
336 "rule for {}(...) does not conform to declared bounds: {}",
337 pred_name,
338 errors.join("; ")
339 ))
340 }
341
342 fn infer_from_premise(
344 &mut self,
345 premise_id: InstId,
346 mut state: InferState,
347 ) -> Result<InferState> {
348 match self.ir.get(premise_id) {
349 Inst::Atom { predicate, args } => {
350 let pred = *predicate;
351 let args = args.clone();
352
353 let pred_name = self.ir.resolve_name(pred).to_string();
355 if pred_name == ":match_prefix" {
356 return self.infer_match_prefix(&args, state);
357 }
358 if pred_name == ":match_field" {
359 return self.infer_match_field(&args, state);
360 }
361 if pred_name == ":match_entry" {
362 return self.infer_match_entry(&args, state);
363 }
364 if pred_name == ":list:member" {
365 return self.infer_list_member(&args, state);
366 }
367
368 let var_ranges = state.as_map();
370 let feasible =
371 self.get_or_infer_alternatives(pred, &args, &var_ranges);
372
373 if !feasible.is_empty() {
374 let first = &feasible[0].clone();
376 for (arg, type_id) in args.iter().zip(first.iter()) {
377 if let Inst::Var(v) = self.ir.get(*arg) {
378 let v = *v;
379 state.add_or_refine_with_ir(self.ir, v, *type_id);
380 }
381 }
382 } else if let Some(alternatives) = self.rel_type_map.get(&pred).cloned() {
383 if let Some(first_alt) = alternatives.first() {
385 for (arg, type_id) in args.iter().zip(first_alt.iter()) {
386 if let Inst::Var(v) = self.ir.get(*arg) {
387 let v = *v;
388 state.add_or_refine_with_ir(self.ir, v, *type_id);
389 }
390 }
391 }
392 }
393 Ok(state)
394 }
395 Inst::NegAtom(inner) => {
396 let inner = *inner;
397 if let Inst::Atom { predicate, args } = self.ir.get(inner) {
400 let pred = *predicate;
401 let args = args.clone();
402 let pred_name = self.ir.resolve_name(pred).to_string();
403
404 if pred_name == ":match_prefix" && args.len() >= 2 {
405 if let Inst::Var(v) = self.ir.get(args[0]) {
407 let v = *v;
408 let bound = self.bound_of_arg(args[1], &state.as_map());
409 if let Some(existing) = state.as_map().get(&v).copied() {
410 if type_expr::is_union_type(self.ir, existing) {
411 let refined =
412 type_expr::remove_from_union_type(self.ir, bound, existing);
413 if !type_expr::is_empty_type(self.ir, refined) {
414 state.set_var(v, refined);
415 }
416 }
417 }
418 }
419 }
420 }
422 Ok(state)
423 }
424 Inst::Eq(left, right) => {
425 let left = *left;
426 let right = *right;
427 let var_ranges = state.as_map();
428
429 if let Inst::Var(lv) = self.ir.get(left) {
430 let lv = *lv;
431 let tpe = self.bound_of_arg(right, &var_ranges);
432 state.add_or_refine_with_ir(self.ir, lv, tpe);
433 }
434 if let Inst::Var(rv) = self.ir.get(right) {
435 let rv = *rv;
436 let tpe = self.bound_of_arg(left, &state.as_map());
437 state.add_or_refine_with_ir(self.ir, rv, tpe);
438 }
439 Ok(state)
440 }
441 Inst::Ineq(left, right) => {
442 let left = *left;
443 let right = *right;
444 let var_ranges = state.as_map();
445
446 let left_tpe = self.bound_of_arg(left, &var_ranges);
448 let right_tpe = self.bound_of_arg(right, &var_ranges);
449 let ctx = TypeContext::default();
450 let meet = type_expr::lower_bound(self.ir, &ctx, &[left_tpe, right_tpe]);
451 if !type_expr::is_empty_type(self.ir, meet) {
452 if let Inst::Var(lv) = self.ir.get(left) {
453 let lv = *lv;
454 state.add_or_refine_with_ir(self.ir, lv, meet);
455 }
456 if let Inst::Var(rv) = self.ir.get(right) {
457 let rv = *rv;
458 state.add_or_refine_with_ir(self.ir, rv, meet);
459 }
460 }
461 Ok(state)
462 }
463 _ => Ok(state),
464 }
465 }
466
467 fn feasible_alternatives(
475 &mut self,
476 alternatives: &[Vec<InstId>],
477 args: &[InstId],
478 var_ranges: &FxHashMap<NameId, InstId>,
479 ) -> Vec<Vec<InstId>> {
480 let mut feasible = Vec::new();
481
482 for alt in alternatives {
483 if alt.len() != args.len() {
484 continue;
485 }
486
487 let mut arg_bound = Vec::new();
491 for (i, arg) in args.iter().enumerate() {
492 if let Inst::Var(v) = self.ir.get(*arg) {
493 let v = *v;
494 if let Some(&range) = var_ranges.get(&v) {
495 arg_bound.push(range);
496 } else {
497 arg_bound.push(alt[i]);
499 }
500 } else {
501 arg_bound.push(self.bound_of_arg(*arg, var_ranges));
502 }
503 }
504
505 let mut type_vars = FxHashSet::default();
507 for t in alt {
508 type_expr::collect_vars(self.ir, *t, &mut type_vars);
509 }
510
511 let mut subst: FxHashMap<NameId, InstId> = FxHashMap::default();
513 if !type_vars.is_empty() {
514 for v in &type_vars {
515 let fresh = self.fresh_var();
516 let fresh_id = self.ir.add_inst(Inst::Var(fresh));
517 subst.insert(*v, fresh_id);
518 }
519 }
520
521 let arg_bound_subst: Vec<InstId> = arg_bound
523 .iter()
524 .map(|t| type_expr::apply_subst(self.ir, *t, &subst))
525 .collect();
526 let alt_subst: Vec<InstId> = alt
527 .iter()
528 .map(|t| type_expr::apply_subst(self.ir, *t, &subst))
529 .collect();
530
531 let any = type_expr::find_or_create_name(self.ir, "/any");
533 let mut ctx = TypeContext::default();
534 for fresh_id in subst.values() {
535 if let Inst::Var(v) = self.ir.get(*fresh_id) {
536 ctx.insert(*v, any);
537 }
538 }
539
540 let mut is_feasible = true;
542 let mut result_types = Vec::new();
543 for (ab, at) in arg_bound_subst.iter().zip(alt_subst.iter()) {
544 let meet = type_expr::lower_bound(self.ir, &ctx, &[*ab, *at]);
545 if type_expr::is_empty_type(self.ir, meet) {
546 is_feasible = false;
547 break;
548 }
549 result_types.push(meet);
550 }
551
552 if is_feasible {
553 feasible.push(result_types);
554 }
555 }
556 feasible
557 }
558
559 fn get_or_infer_alternatives(
564 &mut self,
565 pred: NameId,
566 args: &[InstId],
567 var_ranges: &FxHashMap<NameId, InstId>,
568 ) -> Vec<Vec<InstId>> {
569 if let Some(alts) = self.rel_type_map.get(&pred).cloned() {
571 return self.feasible_alternatives(&alts, args, var_ranges);
572 }
573
574 if let Some(alts) = self.inferred.get(&pred).cloned() {
576 return self.feasible_alternatives(&alts, args, var_ranges);
577 }
578
579 if self.visiting.contains(&pred) {
582 let any = type_expr::find_or_create_name(self.ir, "/any");
583 return vec![vec![any; args.len()]];
584 }
585
586 self.visiting.insert(pred);
588 let inferred = self.infer_rel_types(pred);
589 self.visiting.remove(&pred);
590
591 if !inferred.is_empty() {
592 self.inferred.insert(pred, inferred.clone());
593 return self.feasible_alternatives(&inferred, args, var_ranges);
594 }
595
596 Vec::new()
597 }
598
599 fn infer_rel_types(&mut self, pred: NameId) -> Vec<Vec<InstId>> {
604 let rules = match self.rules_map.get(&pred) {
605 Some(r) => r.clone(),
606 None => return Vec::new(),
607 };
608
609 let mut alternatives: Vec<Vec<InstId>> = Vec::new();
610
611 for (head, premises, transforms) in &rules {
612 if let Some(inferred) = self.infer_clause(*head, premises, transforms) {
614 alternatives.push(inferred);
615 }
616 }
617
618 alternatives
619 }
620
621 fn infer_clause(
623 &mut self,
624 head: InstId,
625 premises: &[InstId],
626 transforms: &[InstId],
627 ) -> Option<Vec<InstId>> {
628 let head_args = self.atom_args(head);
629 let mut state = InferState::new();
630
631 for premise_id in premises {
632 match self.infer_from_premise(*premise_id, state) {
633 Ok(new_state) => state = new_state,
634 Err(_) => return None,
635 }
636 }
637
638 for transform_id in transforms {
640 if let Inst::Transform { var, app } = self.ir.get(*transform_id) {
641 let var = *var;
642 let app = *app;
643 if let Some(v) = var {
644 let tpe = self.bound_of_arg(app, &state.as_map());
645 state.add_or_refine_with_ir(self.ir, v, tpe);
646 }
647 }
648 }
649
650 let var_ranges = state.as_map();
652 let inferred: Vec<InstId> = head_args
653 .iter()
654 .map(|arg| self.bound_of_arg(*arg, &var_ranges))
655 .collect();
656
657 Some(inferred)
658 }
659
660 fn infer_match_prefix(
662 &mut self,
663 args: &[InstId],
664 mut state: InferState,
665 ) -> Result<InferState> {
666 if args.len() != 2 {
667 return Ok(state);
668 }
669 let var_ranges = state.as_map();
670 let tpe = self.bound_of_arg(args[0], &var_ranges);
671 let prefix = self.bound_of_arg(args[1], &var_ranges);
672
673 let ctx = TypeContext::default();
674 let meet = type_expr::lower_bound(self.ir, &ctx, &[tpe, prefix]);
675 if !type_expr::is_empty_type(self.ir, meet) {
676 if let Inst::Var(v) = self.ir.get(args[0]) {
677 let v = *v;
678 state.add_or_refine_with_ir(self.ir, v, meet);
679 }
680 let name_type = type_expr::find_or_create_name(self.ir, "/name");
682 if let Inst::Var(v) = self.ir.get(args[1]) {
683 let v = *v;
684 state.add_or_refine_with_ir(self.ir, v, name_type);
685 }
686 }
687 Ok(state)
688 }
689
690 fn infer_match_field(
692 &mut self,
693 args: &[InstId],
694 mut state: InferState,
695 ) -> Result<InferState> {
696 if args.len() != 3 {
697 return Ok(state);
698 }
699 let var_ranges = state.as_map();
700 let scrutinee_type = self.bound_of_arg(args[0], &var_ranges);
701
702 let field_name_id = match self.ir.get(args[1]) {
704 Inst::Name(n) => Some(*n),
705 _ => None,
706 };
707
708 if let Some(field) = field_name_id {
709 if type_expr::is_struct_type(self.ir, scrutinee_type)
710 || type_expr::is_tagged_union_type(self.ir, scrutinee_type)
711 || type_expr::is_union_type(self.ir, scrutinee_type)
712 {
713 if let Some(field_type) =
714 type_expr::struct_type_field_deep(self.ir, scrutinee_type, field)
715 {
716 let ctx = TypeContext::default();
718 let value_bound = self.bound_of_arg(args[2], &state.as_map());
719 let meet =
720 type_expr::lower_bound(self.ir, &ctx, &[value_bound, field_type]);
721 if !type_expr::is_empty_type(self.ir, meet) {
722 if let Inst::Var(v) = self.ir.get(args[2]) {
723 let v = *v;
724 state.add_or_refine_with_ir(self.ir, v, meet);
725 }
726 }
727 }
728 }
729 }
730 let any = type_expr::find_or_create_name(self.ir, "/any");
732 if let Inst::Var(v) = self.ir.get(args[0]) {
733 let v = *v;
734 state.add_or_refine_with_ir(self.ir, v, any);
735 }
736 let name_type = type_expr::find_or_create_name(self.ir, "/name");
738 if let Inst::Var(v) = self.ir.get(args[1]) {
739 let v = *v;
740 state.add_or_refine_with_ir(self.ir, v, name_type);
741 }
742 Ok(state)
743 }
744
745 fn infer_match_entry(
747 &mut self,
748 args: &[InstId],
749 mut state: InferState,
750 ) -> Result<InferState> {
751 if args.len() != 3 {
752 return Ok(state);
753 }
754 let var_ranges = state.as_map();
755 let map_type = self.bound_of_arg(args[0], &var_ranges);
756
757 if type_expr::is_map_type(self.ir, map_type) {
758 if let Some((key_type, val_type)) = type_expr::map_type_args(self.ir, map_type) {
759 let ctx = TypeContext::default();
760
761 let key_bound = self.bound_of_arg(args[1], &state.as_map());
763 let key_meet =
764 type_expr::lower_bound(self.ir, &ctx, &[key_bound, key_type]);
765 if !type_expr::is_empty_type(self.ir, key_meet) {
766 if let Inst::Var(v) = self.ir.get(args[1]) {
767 let v = *v;
768 state.add_or_refine_with_ir(self.ir, v, key_meet);
769 }
770 }
771
772 let val_bound = self.bound_of_arg(args[2], &state.as_map());
774 let val_meet =
775 type_expr::lower_bound(self.ir, &ctx, &[val_bound, val_type]);
776 if !type_expr::is_empty_type(self.ir, val_meet) {
777 if let Inst::Var(v) = self.ir.get(args[2]) {
778 let v = *v;
779 state.add_or_refine_with_ir(self.ir, v, val_meet);
780 }
781 }
782 }
783 }
784 Ok(state)
785 }
786
787 fn infer_list_member(
789 &mut self,
790 args: &[InstId],
791 mut state: InferState,
792 ) -> Result<InferState> {
793 if args.len() != 2 {
794 return Ok(state);
795 }
796 let var_ranges = state.as_map();
797 let list_type = self.bound_of_arg(args[1], &var_ranges);
798
799 if type_expr::is_list_type(self.ir, list_type) {
800 if let Some(elem_type) = type_expr::list_type_arg(self.ir, list_type) {
801 let ctx = TypeContext::default();
802 let elem_bound = self.bound_of_arg(args[0], &state.as_map());
803 let meet =
804 type_expr::lower_bound(self.ir, &ctx, &[elem_bound, elem_type]);
805 if !type_expr::is_empty_type(self.ir, meet) {
806 if let Inst::Var(v) = self.ir.get(args[0]) {
807 let v = *v;
808 state.add_or_refine_with_ir(self.ir, v, meet);
809 }
810 }
811 }
812 }
813 Ok(state)
814 }
815
816 fn bound_of_arg(
818 &mut self,
819 arg: InstId,
820 var_ranges: &FxHashMap<NameId, InstId>,
821 ) -> InstId {
822 match self.ir.get(arg) {
823 Inst::Var(v) => {
824 let v = *v;
825 if let Some(&range) = var_ranges.get(&v) {
826 range
827 } else {
828 type_expr::find_or_create_name(self.ir, "/any")
829 }
830 }
831 Inst::Number(_) => type_expr::find_or_create_name(self.ir, "/number"),
832 Inst::Float(_) => type_expr::find_or_create_name(self.ir, "/float64"),
833 Inst::String(_) => type_expr::find_or_create_name(self.ir, "/string"),
834 Inst::Bool(_) => type_expr::find_or_create_name(self.ir, "/bool"),
835 Inst::Time(_) => type_expr::find_or_create_name(self.ir, "/time"),
836 Inst::Duration(_) => type_expr::find_or_create_name(self.ir, "/duration"),
837 Inst::Bytes(_) => type_expr::find_or_create_name(self.ir, "/bytes"),
838 Inst::Name(n) => {
839 let name = self.ir.resolve_name(*n).to_string();
840 let prefix = self.name_trie.prefix_name(&name);
841 type_expr::find_or_create_name(self.ir, &prefix)
842 }
843 Inst::List(elems) => {
844 let elems = elems.clone();
845 if elems.is_empty() {
846 let bot = type_expr::find_or_create_name(self.ir, "/bot");
847 return type_expr::new_list_type(self.ir, bot);
848 }
849 let ctx = TypeContext::default();
850 let elem_types: Vec<InstId> = elems
851 .iter()
852 .map(|e| self.bound_of_arg(*e, var_ranges))
853 .collect();
854 let elem_type = type_expr::upper_bound(self.ir, &ctx, &elem_types);
855 type_expr::new_list_type(self.ir, elem_type)
856 }
857 Inst::Map { keys, values } => {
858 let keys = keys.clone();
859 let values = values.clone();
860 let ctx = TypeContext::default();
861 let key_types: Vec<InstId> = keys
862 .iter()
863 .map(|k| self.bound_of_arg(*k, var_ranges))
864 .collect();
865 let val_types: Vec<InstId> = values
866 .iter()
867 .map(|v| self.bound_of_arg(*v, var_ranges))
868 .collect();
869 let kt = type_expr::upper_bound(self.ir, &ctx, &key_types);
870 let vt = type_expr::upper_bound(self.ir, &ctx, &val_types);
871 type_expr::new_map_type(self.ir, kt, vt)
872 }
873 Inst::Struct { fields, values } => {
874 let fields = fields.clone();
875 let values = values.clone();
876 let mut args = Vec::new();
877 for (f, v) in fields.iter().zip(values.iter()) {
878 let fname = self.ir.resolve_name(*f).to_string();
879 let fname_id = type_expr::find_or_create_name(self.ir, &fname);
880 let vtype = self.bound_of_arg(*v, var_ranges);
881 args.push(fname_id);
882 args.push(vtype);
883 }
884 type_expr::new_struct_type(self.ir, args)
885 }
886 Inst::ApplyFn { function, args } => {
887 let fname = self.ir.resolve_name(*function).to_string();
888 let args = args.clone();
889 self.bound_of_apply_fn(&fname, &args, var_ranges)
890 }
891 _ => type_expr::find_or_create_name(self.ir, "/any"),
892 }
893 }
894
895 fn bound_of_apply_fn(
897 &mut self,
898 fname: &str,
899 args: &[InstId],
900 var_ranges: &FxHashMap<NameId, InstId>,
901 ) -> InstId {
902 match fname {
903 "fn:list" => {
904 if args.is_empty() {
905 let bot = type_expr::find_or_create_name(self.ir, "/bot");
906 return type_expr::new_list_type(self.ir, bot);
907 }
908 let ctx = TypeContext::default();
909 let arg_types: Vec<InstId> = args
910 .iter()
911 .map(|a| self.bound_of_arg(*a, var_ranges))
912 .collect();
913 let elem = type_expr::upper_bound(self.ir, &ctx, &arg_types);
914 type_expr::new_list_type(self.ir, elem)
915 }
916 "fn:map" => {
917 let ctx = TypeContext::default();
918 let mut key_types = Vec::new();
919 let mut val_types = Vec::new();
920 let mut i = 0;
921 while i + 1 < args.len() {
922 key_types.push(self.bound_of_arg(args[i], var_ranges));
923 val_types.push(self.bound_of_arg(args[i + 1], var_ranges));
924 i += 2;
925 }
926 let kt = type_expr::upper_bound(self.ir, &ctx, &key_types);
927 let vt = type_expr::upper_bound(self.ir, &ctx, &val_types);
928 type_expr::new_map_type(self.ir, kt, vt)
929 }
930 "fn:struct" => {
931 let mut struct_args = Vec::new();
932 let mut i = 0;
933 while i + 1 < args.len() {
934 struct_args.push(args[i]); struct_args.push(self.bound_of_arg(args[i + 1], var_ranges));
936 i += 2;
937 }
938 type_expr::new_struct_type(self.ir, struct_args)
939 }
940 "fn:tuple" => {
941 let arg_types: Vec<InstId> = args
942 .iter()
943 .map(|a| self.bound_of_arg(*a, var_ranges))
944 .collect();
945 type_expr::new_tuple_type(self.ir, arg_types)
946 }
947 "fn:struct_get" if args.len() == 2 => {
948 let struct_type = self.bound_of_arg(args[0], var_ranges);
949 if let Inst::Name(n) = self.ir.get(args[1]) {
950 let field = *n;
951 if let Some(ft) =
952 type_expr::struct_type_field_deep(self.ir, struct_type, field)
953 {
954 return ft;
955 }
956 }
957 type_expr::find_or_create_name(self.ir, "/any")
958 }
959 "fn:plus" | "fn:minus" | "fn:mult" | "fn:div" => {
960 type_expr::find_or_create_name(self.ir, "/number")
961 }
962 "fn:float_plus" | "fn:float_mult" | "fn:float_div" => {
963 type_expr::find_or_create_name(self.ir, "/float64")
964 }
965 "fn:string:concat" | "fn:string:replace" => {
966 type_expr::find_or_create_name(self.ir, "/string")
967 }
968 "fn:count" | "fn:sum" | "fn:max" | "fn:min" => {
969 type_expr::find_or_create_name(self.ir, "/number")
970 }
971 "fn:collect" | "fn:collect_distinct" => {
972 if args.len() == 1 {
973 let elem_type = self.bound_of_arg(args[0], var_ranges);
974 type_expr::new_list_type(self.ir, elem_type)
975 } else {
976 let any = type_expr::find_or_create_name(self.ir, "/any");
977 type_expr::new_list_type(self.ir, any)
978 }
979 }
980 _ => type_expr::find_or_create_name(self.ir, "/any"),
981 }
982 }
983
984 fn atom_predicate(&self, atom_id: InstId) -> Option<NameId> {
987 if let Inst::Atom { predicate, .. } = self.ir.get(atom_id) {
988 Some(*predicate)
989 } else {
990 None
991 }
992 }
993
994 fn atom_args(&self, atom_id: InstId) -> Vec<InstId> {
995 if let Inst::Atom { args, .. } = self.ir.get(atom_id) {
996 args.clone()
997 } else {
998 Vec::new()
999 }
1000 }
1001
1002 fn describe_inst(&self, id: InstId) -> String {
1004 match self.ir.get(id) {
1005 Inst::Name(n) => self.ir.resolve_name(*n).to_string(),
1006 Inst::Number(n) => n.to_string(),
1007 Inst::Float(f) => f.to_string(),
1008 Inst::String(s) => format!("{:?}", self.ir.resolve_string(*s)),
1009 Inst::Bool(b) => b.to_string(),
1010 Inst::Var(v) => self.ir.resolve_name(*v).to_string(),
1011 Inst::ApplyFn { function, args } => {
1012 let fname = self.ir.resolve_name(*function);
1013 let arg_strs: Vec<String> =
1014 args.iter().map(|a| self.describe_inst(*a)).collect();
1015 format!("{}({})", fname, arg_strs.join(", "))
1016 }
1017 _ => format!("inst#{}", id.index()),
1018 }
1019 }
1020}
1021
1022struct InferState {
1030 used_vars: Vec<NameId>,
1032 var_types: Vec<InstId>,
1034}
1035
1036impl InferState {
1037 fn new() -> Self {
1038 Self {
1039 used_vars: Vec::new(),
1040 var_types: Vec::new(),
1041 }
1042 }
1043
1044 fn add_or_refine_with_ir(&mut self, ir: &mut Ir, var: NameId, tpe: InstId) {
1046 if let Some(idx) = self.used_vars.iter().position(|v| *v == var) {
1047 let existing = self.var_types[idx];
1049 let ctx = TypeContext::default();
1050 let meet = type_expr::lower_bound(ir, &ctx, &[existing, tpe]);
1051 if !type_expr::is_empty_type(ir, meet) {
1052 self.var_types[idx] = meet;
1053 }
1054 } else {
1056 self.used_vars.push(var);
1057 self.var_types.push(tpe);
1058 }
1059 }
1060
1061 fn set_var(&mut self, var: NameId, tpe: InstId) {
1063 if let Some(idx) = self.used_vars.iter().position(|v| *v == var) {
1064 self.var_types[idx] = tpe;
1065 }
1066 }
1067
1068 fn as_map(&self) -> FxHashMap<NameId, InstId> {
1070 self.used_vars
1071 .iter()
1072 .zip(self.var_types.iter())
1073 .map(|(v, t)| (*v, *t))
1074 .collect()
1075 }
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080 use super::*;
1081 use crate::LoweringContext;
1082 use mangle_ast as ast;
1083 use mangle_parse::Parser;
1084
1085 fn check(source: &str) -> Result<()> {
1087 let arena = ast::Arena::new_with_global_interner();
1088 let mut parser = Parser::new(&arena, source.as_bytes(), "test");
1089 parser.next_token().unwrap();
1090 let unit = parser.parse_unit().unwrap();
1091 let ctx = LoweringContext::new(&arena);
1092 let mut ir = ctx.lower_unit(&unit);
1093 let mut checker = BoundsChecker::new(&mut ir);
1094 checker.check()
1095 }
1096
1097 #[test]
1102 fn check_valid_fact() {
1103 let arena = ast::Arena::new_with_global_interner();
1104
1105 let foo_sym = arena.predicate_sym("foo", Some(1));
1107 let var_x = arena.variable("X");
1108 let atom_foo_x = arena.atom(foo_sym, &[var_x]);
1109 let num_type = arena.const_(arena.name("/number"));
1110 let bound_decl = ast::BoundDecl {
1111 base_terms: arena.alloc_slice_copy(&[num_type]),
1112 };
1113 let decl = ast::Decl {
1114 atom: atom_foo_x,
1115 descr: &[],
1116 bounds: Some(arena.alloc_slice_copy(&[arena.alloc(bound_decl)])),
1117 constraints: None,
1118 is_temporal: false,
1119 };
1120
1121 let const_42 = arena.const_(ast::Const::Number(42));
1123 let atom_foo_42 = arena.atom(foo_sym, &[const_42]);
1124 let clause = ast::Clause {
1125 head: atom_foo_42,
1126 head_time: None,
1127 premises: &[],
1128 transform: &[],
1129 };
1130
1131 let unit = ast::Unit {
1132 decls: arena.alloc_slice_copy(&[&decl]),
1133 clauses: arena.alloc_slice_copy(&[&clause]),
1134 };
1135
1136 let ctx = LoweringContext::new(&arena);
1137 let mut ir = ctx.lower_unit(&unit);
1138 let mut checker = BoundsChecker::new(&mut ir);
1139 assert!(checker.check().is_ok());
1140 }
1141
1142 #[test]
1143 fn check_invalid_fact_type_mismatch() {
1144 let arena = ast::Arena::new_with_global_interner();
1145
1146 let foo_sym = arena.predicate_sym("foo", Some(1));
1148 let var_x = arena.variable("X");
1149 let atom_foo_x = arena.atom(foo_sym, &[var_x]);
1150 let num_type = arena.const_(arena.name("/number"));
1151 let bound_decl = ast::BoundDecl {
1152 base_terms: arena.alloc_slice_copy(&[num_type]),
1153 };
1154 let decl = ast::Decl {
1155 atom: atom_foo_x,
1156 descr: &[],
1157 bounds: Some(arena.alloc_slice_copy(&[arena.alloc(bound_decl)])),
1158 constraints: None,
1159 is_temporal: false,
1160 };
1161
1162 let const_str = arena.const_(ast::Const::String("hello"));
1164 let atom_foo_bad = arena.atom(foo_sym, &[const_str]);
1165 let clause = ast::Clause {
1166 head: atom_foo_bad,
1167 head_time: None,
1168 premises: &[],
1169 transform: &[],
1170 };
1171
1172 let unit = ast::Unit {
1173 decls: arena.alloc_slice_copy(&[&decl]),
1174 clauses: arena.alloc_slice_copy(&[&clause]),
1175 };
1176
1177 let ctx = LoweringContext::new(&arena);
1178 let mut ir = ctx.lower_unit(&unit);
1179 let mut checker = BoundsChecker::new(&mut ir);
1180 let result = checker.check();
1181 assert!(result.is_err(), "expected type mismatch error");
1182 }
1183
1184 #[test]
1185 fn check_valid_rule() {
1186 let arena = ast::Arena::new_with_global_interner();
1187
1188 let src_sym = arena.predicate_sym("src", Some(1));
1190 let var_x = arena.variable("X");
1191 let atom_src_x = arena.atom(src_sym, &[var_x]);
1192 let num_type = arena.const_(arena.name("/number"));
1193 let bound_decl = ast::BoundDecl {
1194 base_terms: arena.alloc_slice_copy(&[num_type]),
1195 };
1196 let decl_src = ast::Decl {
1197 atom: atom_src_x,
1198 descr: &[],
1199 bounds: Some(arena.alloc_slice_copy(&[arena.alloc(bound_decl)])),
1200 constraints: None,
1201 is_temporal: false,
1202 };
1203
1204 let dst_sym = arena.predicate_sym("dst", Some(1));
1206 let var_y = arena.variable("Y");
1207 let atom_dst_y = arena.atom(dst_sym, &[var_y]);
1208 let num_type2 = arena.const_(arena.name("/number"));
1209 let bound_decl2 = ast::BoundDecl {
1210 base_terms: arena.alloc_slice_copy(&[num_type2]),
1211 };
1212 let decl_dst = ast::Decl {
1213 atom: atom_dst_y,
1214 descr: &[],
1215 bounds: Some(arena.alloc_slice_copy(&[arena.alloc(bound_decl2)])),
1216 constraints: None,
1217 is_temporal: false,
1218 };
1219
1220 let var_x2 = arena.variable("X");
1222 let head = arena.atom(dst_sym, &[var_x2]);
1223 let var_x3 = arena.variable("X");
1224 let body = arena.atom(src_sym, &[var_x3]);
1225 let clause = ast::Clause {
1226 head,
1227 head_time: None,
1228 premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(body))]),
1229 transform: &[],
1230 };
1231
1232 let unit = ast::Unit {
1233 decls: arena.alloc_slice_copy(&[&decl_src, &decl_dst]),
1234 clauses: arena.alloc_slice_copy(&[&clause]),
1235 };
1236
1237 let ctx = LoweringContext::new(&arena);
1238 let mut ir = ctx.lower_unit(&unit);
1239 let mut checker = BoundsChecker::new(&mut ir);
1240 assert!(checker.check().is_ok());
1241 }
1242
1243 #[test]
1244 fn check_arity_mismatch() {
1245 let arena = ast::Arena::new_with_global_interner();
1246
1247 let foo_sym = arena.predicate_sym("foo", Some(1));
1249 let var_x = arena.variable("X");
1250 let atom_foo_x = arena.atom(foo_sym, &[var_x]);
1251 let num_type = arena.const_(arena.name("/number"));
1252 let bound_decl = ast::BoundDecl {
1253 base_terms: arena.alloc_slice_copy(&[num_type]),
1254 };
1255 let decl = ast::Decl {
1256 atom: atom_foo_x,
1257 descr: &[],
1258 bounds: Some(arena.alloc_slice_copy(&[arena.alloc(bound_decl)])),
1259 constraints: None,
1260 is_temporal: false,
1261 };
1262
1263 let const_42 = arena.const_(ast::Const::Number(42));
1265 let const_43 = arena.const_(ast::Const::Number(43));
1266 let atom_foo_bad = arena.atom(foo_sym, &[const_42, const_43]);
1267 let clause = ast::Clause {
1268 head: atom_foo_bad,
1269 head_time: None,
1270 premises: &[],
1271 transform: &[],
1272 };
1273
1274 let unit = ast::Unit {
1275 decls: arena.alloc_slice_copy(&[&decl]),
1276 clauses: arena.alloc_slice_copy(&[&clause]),
1277 };
1278
1279 let ctx = LoweringContext::new(&arena);
1280 let mut ir = ctx.lower_unit(&unit);
1281 let mut checker = BoundsChecker::new(&mut ir);
1282 let result = checker.check();
1283 assert!(result.is_err());
1284 }
1285
1286 #[test]
1291 fn multiple_alternatives_first_matches() {
1292 assert!(check(r#"
1294 Decl pair(X, Y) bound [/number, /number] bound [/string, /string].
1295 pair(42, 99).
1296 "#).is_ok());
1297 }
1298
1299 #[test]
1300 fn multiple_alternatives_second_matches() {
1301 assert!(check(r#"
1303 Decl pair(X, Y) bound [/number, /number] bound [/string, /string].
1304 pair("a", "b").
1305 "#).is_ok());
1306 }
1307
1308 #[test]
1309 fn multiple_alternatives_none_matches() {
1310 assert!(check(r#"
1312 Decl pair(X, Y) bound [/number, /number] bound [/string, /string].
1313 pair(42, "b").
1314 "#).is_err());
1315 }
1316
1317 #[test]
1322 fn rule_infers_type_from_premise() {
1323 assert!(check(r#"
1325 Decl src(X) bound [/number].
1326 Decl dst(X) bound [/number].
1327 dst(X) :- src(X).
1328 "#).is_ok());
1329 }
1330
1331 #[test]
1332 fn rule_type_mismatch_from_premise() {
1333 assert!(check(r#"
1335 Decl src(X) bound [/string].
1336 Decl dst(X) bound [/number].
1337 dst(X) :- src(X).
1338 "#).is_err());
1339 }
1340
1341 #[test]
1346 fn two_premises_refine_variable() {
1347 assert!(check(r#"
1350 Decl wide(X) bound [fn:Union(/number, /string)].
1351 Decl narrow(X) bound [/number].
1352 Decl result(X) bound [/number].
1353 result(X) :- wide(X), narrow(X).
1354 "#).is_ok());
1355 }
1356
1357 #[test]
1358 fn two_premises_refine_to_incompatible() {
1359 assert!(check(r#"
1363 Decl src1(X) bound [/string].
1364 Decl src2(X) bound [/number].
1365 Decl dst(X) bound [/number].
1366 dst(X) :- src1(X), src2(X).
1367 "#).is_err());
1368 }
1369
1370 #[test]
1375 fn polymorphic_identity_number() {
1376 assert!(check(r#"
1378 Decl pair(X, Y) bound [T, T].
1379 pair(42, 99).
1380 "#).is_ok());
1381 }
1382
1383 #[test]
1384 fn polymorphic_identity_string() {
1385 assert!(check(r#"
1387 Decl pair(X, Y) bound [T, T].
1388 pair("a", "b").
1389 "#).is_ok());
1390 }
1391
1392 #[test]
1393 fn polymorphic_rule_with_inferred_type() {
1394 assert!(check(r#"
1397 Decl src(X) bound [/number].
1398 Decl dst(X) bound [T].
1399 dst(X) :- src(X).
1400 "#).is_ok());
1401 }
1402
1403 #[test]
1408 fn cross_predicate_inference_basic() {
1409 assert!(check(r#"
1412 Decl src(X) bound [/number].
1413 Decl dst(X) bound [/number].
1414 helper(X) :- src(X).
1415 dst(X) :- helper(X).
1416 "#).is_ok());
1417 }
1418
1419 #[test]
1420 fn cross_predicate_inference_type_mismatch() {
1421 assert!(check(r#"
1423 Decl src(X) bound [/string].
1424 Decl dst(X) bound [/number].
1425 helper(X) :- src(X).
1426 dst(X) :- helper(X).
1427 "#).is_err());
1428 }
1429
1430 #[test]
1431 fn cross_predicate_inference_chain() {
1432 assert!(check(r#"
1434 Decl src(X) bound [/number].
1435 Decl dst(X) bound [/number].
1436 mid(X) :- src(X).
1437 dst(X) :- mid(X).
1438 "#).is_ok());
1439 }
1440
1441 #[test]
1446 fn equality_binds_variable() {
1447 assert!(check(r#"
1449 Decl src(X) bound [/string].
1450 Decl dst(X) bound [/string].
1451 dst(X) :- src(X), X = "hello".
1452 "#).is_ok());
1453 }
1454
1455 #[test]
1456 fn inequality_refines_variable() {
1457 assert!(check(r#"
1459 Decl src(X) bound [/string].
1460 Decl dst(X) bound [/string].
1461 dst(X) :- src(X), X != "bad".
1462 "#).is_ok());
1463 }
1464
1465 #[test]
1470 fn transform_arithmetic() {
1471 assert!(check(r#"
1473 Decl src(X) bound [/number].
1474 Decl dst(X, Y) bound [/number, /number].
1475 dst(X, Y) :- src(X) |> let Y = fn:plus(X, 1).
1476 "#).is_ok());
1477 }
1478
1479 #[test]
1480 fn transform_string_concat() {
1481 assert!(check(r#"
1483 Decl src(X) bound [/string].
1484 Decl dst(X, Y) bound [/string, /string].
1485 dst(X, Y) :- src(X) |> let Y = fn:string:concat(X, "!").
1486 "#).is_ok());
1487 }
1488
1489 #[test]
1490 fn transform_type_mismatch() {
1491 assert!(check(r#"
1493 Decl src(X) bound [/number].
1494 Decl dst(X, Y) bound [/number, /string].
1495 dst(X, Y) :- src(X) |> let Y = fn:plus(X, 1).
1496 "#).is_err());
1497 }
1498
1499 #[test]
1504 fn undeclared_predicate_passes() {
1505 assert!(check(r#"
1507 foo(1).
1508 bar(X) :- foo(X).
1509 "#).is_ok());
1510 }
1511}