1use super::functions::*;
6use oxilean_kernel::{Declaration, Expr, Name, ReducibilityHint};
7use oxilean_parse::AttributeKind;
8use std::collections::{HashMap, HashSet};
9
10#[derive(Clone, Debug, PartialEq, Eq)]
12pub enum ArgRelation {
13 Equal,
15 Smaller,
17 Unknown,
19}
20#[derive(Clone, Debug, Default)]
22pub struct CallGraph {
23 calls: HashMap<Name, Vec<RecursiveCall>>,
25 names: Vec<Name>,
27}
28impl CallGraph {
29 #[allow(dead_code)]
34 pub fn build_from_block(block: &MutualBlock) -> Self {
35 let mut calls: HashMap<Name, Vec<RecursiveCall>> = HashMap::new();
36 let block_names: HashSet<Name> = block.names.iter().cloned().collect();
37 for name in &block.names {
38 let mut func_calls = Vec::new();
39 if let Some(body) = block.get_body(name) {
40 Self::collect_calls(name, body, &block_names, &mut func_calls);
41 }
42 calls.insert(name.clone(), func_calls);
43 }
44 Self {
45 calls,
46 names: block.names.clone(),
47 }
48 }
49 fn peel_app(expr: &Expr) -> (&Expr, Vec<&Expr>) {
52 let mut args = Vec::new();
53 let mut cur = expr;
54 while let Expr::App(f, a) = cur {
55 args.push(a.as_ref());
56 cur = f.as_ref();
57 }
58 args.reverse();
59 (cur, args)
60 }
61 fn collect_calls(
62 caller: &Name,
63 expr: &Expr,
64 block_names: &HashSet<Name>,
65 out: &mut Vec<RecursiveCall>,
66 ) {
67 match expr {
68 Expr::App(func, arg) => {
69 let (head, all_args) = Self::peel_app(expr);
70 if let Some(callee_name) = Self::get_const_head(head) {
71 if block_names.contains(&callee_name) {
72 let relations: Vec<ArgRelation> =
73 all_args.iter().map(|a| Self::classify_arg(a)).collect();
74 out.push(RecursiveCall {
75 caller: caller.clone(),
76 callee: callee_name,
77 args: relations,
78 });
79 for a in &all_args {
80 Self::collect_calls(caller, a, block_names, out);
81 }
82 return;
83 }
84 }
85 Self::collect_calls(caller, func, block_names, out);
86 Self::collect_calls(caller, arg, block_names, out);
87 }
88 Expr::Lam(_, _, ty, body) => {
89 Self::collect_calls(caller, ty, block_names, out);
90 Self::collect_calls(caller, body, block_names, out);
91 }
92 Expr::Pi(_, _, ty, body) => {
93 Self::collect_calls(caller, ty, block_names, out);
94 Self::collect_calls(caller, body, block_names, out);
95 }
96 Expr::Let(_, ty, val, body) => {
97 Self::collect_calls(caller, ty, block_names, out);
98 Self::collect_calls(caller, val, block_names, out);
99 Self::collect_calls(caller, body, block_names, out);
100 }
101 Expr::Proj(_, _, base) => {
102 Self::collect_calls(caller, base, block_names, out);
103 }
104 _ => {}
105 }
106 }
107 fn get_const_head(expr: &Expr) -> Option<Name> {
109 match expr {
110 Expr::Const(name, _) => Some(name.clone()),
111 Expr::App(func, _) => Self::get_const_head(func),
112 _ => None,
113 }
114 }
115 fn classify_arg(expr: &Expr) -> ArgRelation {
117 match expr {
118 Expr::BVar(_) => ArgRelation::Equal,
119 Expr::Proj(_, _, base) => {
120 if matches!(base.as_ref(), Expr::BVar(_)) {
121 ArgRelation::Smaller
122 } else {
123 ArgRelation::Unknown
124 }
125 }
126 Expr::App(func, _) => {
127 if matches!(func.as_ref(), Expr::BVar(_)) {
128 ArgRelation::Smaller
129 } else {
130 ArgRelation::Unknown
131 }
132 }
133 _ => ArgRelation::Unknown,
134 }
135 }
136 #[allow(dead_code)]
139 pub fn is_structurally_decreasing(&self, name: &Name, arg_idx: usize) -> bool {
140 if let Some(func_calls) = self.calls.get(name) {
141 if func_calls.is_empty() {
142 return true;
143 }
144 let mut has_smaller = false;
145 for call in func_calls {
146 if call.callee == *name {
147 match call.args.get(arg_idx) {
148 Some(ArgRelation::Smaller) => has_smaller = true,
149 Some(ArgRelation::Equal) => {}
150 _ => return false,
151 }
152 }
153 }
154 has_smaller || func_calls.iter().all(|c| c.callee != *name)
155 } else {
156 false
157 }
158 }
159 #[allow(dead_code)]
161 pub fn find_decreasing_arg(&self, name: &Name) -> Option<usize> {
162 let max_args = self
163 .calls
164 .get(name)
165 .map(|calls| {
166 calls
167 .iter()
168 .filter(|c| &c.callee == name)
169 .map(|c| c.args.len())
170 .max()
171 .unwrap_or(0)
172 })
173 .unwrap_or(0);
174 let n = if max_args == 0 { 1 } else { max_args };
175 (0..n).find(|&idx| self.is_structurally_decreasing(name, idx))
176 }
177 #[allow(dead_code)]
179 pub fn is_mutually_recursive(&self) -> bool {
180 for (caller, func_calls) in &self.calls {
181 for call in func_calls {
182 if &call.callee != caller {
183 return true;
184 }
185 }
186 }
187 false
188 }
189 #[allow(dead_code)]
191 pub fn is_self_recursive(&self, name: &Name) -> bool {
192 self.calls
193 .get(name)
194 .map(|cs| cs.iter().any(|c| c.callee == *name))
195 .unwrap_or(false)
196 }
197 #[allow(dead_code)]
199 pub fn is_recursive(&self) -> bool {
200 for func_calls in self.calls.values() {
201 if !func_calls.is_empty() {
202 return true;
203 }
204 }
205 false
206 }
207 #[allow(dead_code)]
209 pub fn get_calls(&self, name: &Name) -> &[RecursiveCall] {
210 self.calls.get(name).map(|v| v.as_slice()).unwrap_or(&[])
211 }
212 #[allow(dead_code)]
214 pub fn strongly_connected_components(&self) -> Vec<Vec<Name>> {
215 let n = self.names.len();
216 if n == 0 {
217 return Vec::new();
218 }
219 let name_to_idx: HashMap<Name, usize> = self
220 .names
221 .iter()
222 .enumerate()
223 .map(|(i, name)| (name.clone(), i))
224 .collect();
225 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
226 for (caller, func_calls) in &self.calls {
227 if let Some(&ci) = name_to_idx.get(caller) {
228 for call in func_calls {
229 if let Some(&cj) = name_to_idx.get(&call.callee) {
230 if !adj[ci].contains(&cj) {
231 adj[ci].push(cj);
232 }
233 }
234 }
235 }
236 }
237 let mut index_counter: usize = 0;
238 let mut stack: Vec<usize> = Vec::new();
239 let mut on_stack = vec![false; n];
240 let mut indices: Vec<Option<usize>> = vec![None; n];
241 let mut lowlinks = vec![0usize; n];
242 let mut result: Vec<Vec<Name>> = Vec::new();
243 for v in 0..n {
244 if indices[v].is_none() {
245 Self::tarjan_visit(
246 v,
247 &adj,
248 &mut index_counter,
249 &mut stack,
250 &mut on_stack,
251 &mut indices,
252 &mut lowlinks,
253 &mut result,
254 &self.names,
255 );
256 }
257 }
258 result
259 }
260 #[allow(clippy::too_many_arguments)]
262 fn tarjan_visit(
263 v: usize,
264 adj: &[Vec<usize>],
265 index_counter: &mut usize,
266 stack: &mut Vec<usize>,
267 on_stack: &mut Vec<bool>,
268 indices: &mut Vec<Option<usize>>,
269 lowlinks: &mut Vec<usize>,
270 result: &mut Vec<Vec<Name>>,
271 names: &[Name],
272 ) {
273 indices[v] = Some(*index_counter);
274 lowlinks[v] = *index_counter;
275 *index_counter += 1;
276 stack.push(v);
277 on_stack[v] = true;
278 for &w in &adj[v] {
279 if indices[w].is_none() {
280 Self::tarjan_visit(
281 w,
282 adj,
283 index_counter,
284 stack,
285 on_stack,
286 indices,
287 lowlinks,
288 result,
289 names,
290 );
291 lowlinks[v] = lowlinks[v].min(lowlinks[w]);
292 } else if on_stack[w] {
293 lowlinks[v] =
294 lowlinks[v].min(indices[w].expect("w is on stack so indices[w] is set"));
295 }
296 }
297 if lowlinks[v] == indices[v].expect("v was just assigned an index above") {
298 let mut component = Vec::new();
299 loop {
300 let w = stack
301 .pop()
302 .expect("stack is non-empty: v is always on it when we reach the SCC root");
303 on_stack[w] = false;
304 component.push(names[w].clone());
305 if w == v {
306 break;
307 }
308 }
309 result.push(component);
310 }
311 }
312}
313#[derive(Clone, Debug)]
318pub struct WellFoundedRecursion {
319 pub block: MutualBlock,
321 pub measure: Option<Name>,
323 pub rel: Option<Expr>,
325 pub decreasing_args: HashMap<Name, Vec<usize>>,
327}
328impl WellFoundedRecursion {
329 #[allow(dead_code)]
331 pub fn new(block: MutualBlock) -> Self {
332 Self {
333 block,
334 measure: None,
335 rel: None,
336 decreasing_args: HashMap::new(),
337 }
338 }
339 #[allow(dead_code)]
341 pub fn set_measure(&mut self, name: Name) {
342 self.measure = Some(name);
343 }
344 #[allow(dead_code)]
346 pub fn set_relation(&mut self, rel: Expr) {
347 self.rel = Some(rel);
348 }
349 #[allow(dead_code)]
355 pub fn detect_decreasing_args(&mut self) -> Result<(), MutualElabError> {
356 let call_graph = CallGraph::build_from_block(&self.block);
357 for name in &self.block.names {
358 if call_graph.is_self_recursive(name) || call_graph.is_mutually_recursive() {
359 let dec_idx = call_graph.find_decreasing_arg(name).unwrap_or(0);
360 self.decreasing_args
361 .entry(name.clone())
362 .or_default()
363 .push(dec_idx);
364 }
365 }
366 Ok(())
367 }
368 #[allow(dead_code)]
373 pub fn encode_as_wf_recursion(&self) -> Result<MutualBlock, MutualElabError> {
374 if self.measure.is_none() && self.rel.is_none() {
375 return Err(MutualElabError::TerminationFailure(
376 "well-founded recursion requires a measure or relation".to_string(),
377 ));
378 }
379 let mut result = self.block.clone();
380 let wf_rel: Expr = match (&self.measure, &self.rel) {
381 (Some(m), _) => Expr::App(
382 Box::new(Expr::Const(Name::str("Measure"), vec![])),
383 Box::new(Expr::Const(m.clone(), vec![])),
384 ),
385 (None, Some(r)) => r.clone(),
386 (None, None) => unreachable!("checked above"),
387 };
388 let wf_proof = Expr::App(
389 Box::new(Expr::Const(Name::str("WellFounded.wf"), vec![])),
390 Box::new(wf_rel.clone()),
391 );
392 let call_graph = CallGraph::build_from_block(&self.block);
393 for name in &self.block.names {
394 if !call_graph.is_self_recursive(name) {
395 continue;
396 }
397 if let Some(body) = self.block.get_body(name) {
398 let dec_idx = self
399 .decreasing_args
400 .get(name)
401 .and_then(|v| v.first())
402 .copied()
403 .unwrap_or(0);
404 let rec_ty = self
405 .block
406 .types
407 .get(name)
408 .cloned()
409 .unwrap_or(Expr::Const(Name::str("_"), vec![]));
410 let step = Expr::Lam(
411 oxilean_kernel::BinderInfo::Default,
412 name.clone(),
413 Box::new(rec_ty),
414 Box::new(body.clone()),
415 );
416 let init_arg = Expr::BVar(dec_idx as u32);
417 let wrapped = Expr::App(
418 Box::new(Expr::App(
419 Box::new(Expr::App(
420 Box::new(Expr::Const(Name::str("WellFounded.fix"), vec![])),
421 Box::new(wf_proof.clone()),
422 )),
423 Box::new(step),
424 )),
425 Box::new(init_arg),
426 );
427 result.bodies.insert(name.clone(), wrapped);
428 }
429 result
430 .attrs
431 .entry(name.clone())
432 .or_default()
433 .push(AttributeKind::Custom("_wf_rec".to_string()));
434 }
435 Ok(result)
436 }
437 #[allow(dead_code)]
442 pub fn generate_termination_proof(&self) -> Result<Expr, MutualElabError> {
443 if self.measure.is_some() || self.rel.is_some() {
444 Ok(Expr::Const(Name::str("sorry"), vec![]))
445 } else {
446 Err(MutualElabError::TerminationFailure(
447 "no measure or relation provided".to_string(),
448 ))
449 }
450 }
451}
452#[derive(Clone, Debug, Default)]
454pub struct DeclDependencyGraph {
455 names: Vec<Name>,
457 edges: Vec<Vec<usize>>,
459}
460impl DeclDependencyGraph {
461 pub fn new() -> Self {
463 Self::default()
464 }
465 pub fn add_node(&mut self, name: Name) -> usize {
467 let idx = self.names.len();
468 self.names.push(name);
469 self.edges.push(Vec::new());
470 idx
471 }
472 pub fn add_edge(&mut self, from: usize, to: usize) {
474 if !self.edges[from].contains(&to) {
475 self.edges[from].push(to);
476 }
477 }
478 pub fn index_of(&self, name: &Name) -> Option<usize> {
480 self.names.iter().position(|n| n == name)
481 }
482 pub fn sccs(&self) -> Vec<Vec<Name>> {
484 let raw = tarjan_scc(self.names.len(), &self.edges);
485 raw.into_iter()
486 .map(|scc| scc.iter().map(|&i| self.names[i].clone()).collect())
487 .collect()
488 }
489 pub fn has_cycle(&self) -> bool {
491 self.sccs().iter().any(|scc| scc.len() > 1)
492 }
493 pub fn cyclic_sccs(&self) -> Vec<Vec<Name>> {
495 self.sccs().into_iter().filter(|s| s.len() > 1).collect()
496 }
497 pub fn topological_order(&self) -> Vec<Name> {
499 let sccs = self.sccs();
500 sccs.into_iter().flatten().collect()
501 }
502 pub fn num_nodes(&self) -> usize {
504 self.names.len()
505 }
506}
507#[derive(Debug, Default)]
513pub struct MutualDefCycleDetector {
514 graph: DeclDependencyGraph,
515}
516impl MutualDefCycleDetector {
517 pub fn new() -> Self {
519 Self::default()
520 }
521 pub fn register(&mut self, name: Name) -> usize {
523 self.graph.add_node(name)
524 }
525 pub fn add_dependency(&mut self, caller: &Name, callee: &Name) -> bool {
529 match (self.graph.index_of(caller), self.graph.index_of(callee)) {
530 (Some(from), Some(to)) => {
531 self.graph.add_edge(from, to);
532 true
533 }
534 _ => false,
535 }
536 }
537 pub fn has_mutual_recursion(&self) -> bool {
539 self.graph.has_cycle()
540 }
541 pub fn mutual_groups(&self) -> Vec<Vec<Name>> {
543 self.graph.cyclic_sccs()
544 }
545 pub fn elaboration_order(&self) -> Vec<Name> {
548 self.graph.topological_order()
549 }
550 pub fn num_decls(&self) -> usize {
552 self.graph.num_nodes()
553 }
554}
555#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
557pub enum MutualElabStage {
558 SigCollection,
560 DependencyAnalysis,
562 BodyElab,
564 TerminationCheck,
566 PostProcess,
568 Done,
570}
571#[allow(dead_code)]
573#[derive(Clone, Debug)]
574pub struct PartialSig {
575 pub name: Name,
577 pub declared_type: Option<Expr>,
579 pub inferred_type: Option<Expr>,
581 pub resolved: bool,
583}
584#[allow(dead_code)]
585impl PartialSig {
586 pub fn new(name: Name) -> Self {
588 Self {
589 name,
590 declared_type: None,
591 inferred_type: None,
592 resolved: false,
593 }
594 }
595 pub fn resolve(&mut self, ty: Expr) {
597 self.inferred_type = Some(ty);
598 self.resolved = true;
599 }
600 pub fn best_type(&self) -> Option<&Expr> {
602 self.inferred_type.as_ref().or(self.declared_type.as_ref())
603 }
604}
605#[allow(dead_code)]
607#[derive(Clone, Debug)]
608pub struct MutualElabBudget {
609 pub max_scc_size: usize,
611 pub max_termination_depth: usize,
613 pub max_structural_args: usize,
615 pub max_refinements: usize,
617}
618#[allow(dead_code)]
619impl MutualElabBudget {
620 pub fn new() -> Self {
622 Self::default()
623 }
624 pub fn liberal() -> Self {
626 Self {
627 max_scc_size: 256,
628 max_termination_depth: 1024,
629 max_structural_args: 64,
630 max_refinements: 32,
631 }
632 }
633 pub fn strict() -> Self {
635 Self {
636 max_scc_size: 8,
637 max_termination_depth: 32,
638 max_structural_args: 4,
639 max_refinements: 2,
640 }
641 }
642 pub fn allows_scc_size(&self, n: usize) -> bool {
644 n <= self.max_scc_size
645 }
646 pub fn allows_termination_depth(&self, d: usize) -> bool {
648 d <= self.max_termination_depth
649 }
650}
651#[derive(Clone, Debug)]
653pub struct RecursiveCall {
654 pub caller: Name,
656 pub callee: Name,
658 pub args: Vec<ArgRelation>,
660}
661#[derive(Debug, Clone, Default)]
663pub struct TarjanNode {
664 pub index: usize,
665 pub lowlink: usize,
666 pub on_stack: bool,
667 pub discovered: bool,
668}
669#[allow(dead_code)]
671#[derive(Clone, Debug, Default)]
672pub struct MutualSigCollection {
673 sigs: Vec<PartialSig>,
674}
675#[allow(dead_code)]
676impl MutualSigCollection {
677 pub fn new() -> Self {
679 Self::default()
680 }
681 pub fn add(&mut self, sig: PartialSig) {
683 self.sigs.push(sig);
684 }
685 pub fn len(&self) -> usize {
687 self.sigs.len()
688 }
689 pub fn is_empty(&self) -> bool {
691 self.sigs.is_empty()
692 }
693 pub fn num_resolved(&self) -> usize {
695 self.sigs.iter().filter(|s| s.resolved).count()
696 }
697 pub fn all_resolved(&self) -> bool {
699 self.sigs.iter().all(|s| s.resolved)
700 }
701 pub fn get(&self, name: &Name) -> Option<&PartialSig> {
703 self.sigs.iter().find(|s| &s.name == name)
704 }
705 pub fn get_mut(&mut self, name: &Name) -> Option<&mut PartialSig> {
707 self.sigs.iter_mut().find(|s| &s.name == name)
708 }
709 pub fn iter(&self) -> impl Iterator<Item = &PartialSig> {
711 self.sigs.iter()
712 }
713}
714pub struct MutualChecker {
716 current_block: Option<MutualBlock>,
718}
719impl MutualChecker {
720 pub fn new() -> Self {
722 Self {
723 current_block: None,
724 }
725 }
726 pub fn start_block(&mut self) {
728 self.current_block = Some(MutualBlock::new());
729 }
730 pub fn add_def(&mut self, name: Name, ty: Expr, body: Expr) -> Result<(), String> {
732 if let Some(block) = &mut self.current_block {
733 block.add(name, ty, body);
734 Ok(())
735 } else {
736 Err("No mutual block started".to_string())
737 }
738 }
739 pub fn finish_block(&mut self) -> Result<MutualBlock, String> {
741 self.current_block
742 .take()
743 .ok_or_else(|| "No mutual block to finish".to_string())
744 }
745 pub fn current_block(&self) -> Option<&MutualBlock> {
747 self.current_block.as_ref()
748 }
749 #[allow(dead_code)]
756 pub fn check_well_formedness(block: &MutualBlock) -> Result<(), MutualElabError> {
757 block.validate()?;
758 let block_names: HashSet<Name> = block.names.iter().cloned().collect();
759 for name in &block.names {
760 if let Some(ty) = block.get_type(name) {
761 Self::check_no_external_forward_refs(ty, &block_names)?;
762 }
763 }
764 Ok(())
765 }
766 fn check_no_external_forward_refs(
769 expr: &Expr,
770 block_names: &HashSet<Name>,
771 ) -> Result<(), MutualElabError> {
772 match expr {
773 Expr::Const(name, _) => {
774 let _ = block_names.contains(name);
775 Ok(())
776 }
777 Expr::App(f, a) => {
778 Self::check_no_external_forward_refs(f, block_names)?;
779 Self::check_no_external_forward_refs(a, block_names)?;
780 Ok(())
781 }
782 Expr::Lam(_, _, ty, body) | Expr::Pi(_, _, ty, body) => {
783 Self::check_no_external_forward_refs(ty, block_names)?;
784 Self::check_no_external_forward_refs(body, block_names)?;
785 Ok(())
786 }
787 Expr::Let(_, ty, val, body) => {
788 Self::check_no_external_forward_refs(ty, block_names)?;
789 Self::check_no_external_forward_refs(val, block_names)?;
790 Self::check_no_external_forward_refs(body, block_names)?;
791 Ok(())
792 }
793 Expr::Proj(_, _, base) => {
794 Self::check_no_external_forward_refs(base, block_names)?;
795 Ok(())
796 }
797 _ => Ok(()),
798 }
799 }
800 #[allow(dead_code)]
807 pub fn check_termination(block: &MutualBlock) -> Result<TerminationKind, MutualElabError> {
808 let call_graph = CallGraph::build_from_block(block);
809 if !call_graph.is_recursive() {
810 return Ok(TerminationKind::NonRecursive);
811 }
812 let mut structural_args = HashMap::new();
813 let mut all_structural = true;
814 for name in &block.names {
815 if call_graph.is_self_recursive(name) || call_graph.is_mutually_recursive() {
816 match call_graph.find_decreasing_arg(name) {
817 Some(idx) => {
818 structural_args.insert(name.clone(), idx);
819 }
820 None => {
821 all_structural = false;
822 break;
823 }
824 }
825 }
826 }
827 if all_structural && !structural_args.is_empty() {
828 return Ok(TerminationKind::Structural(structural_args));
829 }
830 Ok(TerminationKind::WellFounded)
831 }
832 #[allow(dead_code)]
840 pub fn elaborate_mutual_defs(
841 names: &[Name],
842 types: &[Expr],
843 bodies: &[Expr],
844 ) -> Result<MutualBlock, MutualElabError> {
845 if names.len() != types.len() || names.len() != bodies.len() {
846 return Err(MutualElabError::Other(
847 "mismatched lengths for names, types, and bodies".to_string(),
848 ));
849 }
850 if names.is_empty() {
851 return Err(MutualElabError::Other("empty mutual block".to_string()));
852 }
853 let mut block = MutualBlock::new();
854 for i in 0..names.len() {
855 block.add(names[i].clone(), types[i].clone(), bodies[i].clone());
856 }
857 block.validate()?;
858 Ok(block)
859 }
860 #[allow(dead_code)]
862 pub fn encode_recursion(
863 block: MutualBlock,
864 kind: &TerminationKind,
865 ) -> Result<MutualBlock, MutualElabError> {
866 match kind {
867 TerminationKind::NonRecursive => Ok(block),
868 TerminationKind::Structural(_args) => {
869 let mut sr = StructuralRecursion::new(block);
870 sr.detect_structural_recursion()?;
871 sr.encode_as_recursor_application()
872 }
873 TerminationKind::WellFounded => {
874 let mut wfr = WellFoundedRecursion::new(block);
875 wfr.detect_decreasing_args()?;
876 if wfr.measure.is_none() && wfr.rel.is_none() {
877 wfr.set_measure(Name::str("Nat.lt"));
878 }
879 wfr.encode_as_wf_recursion()
880 }
881 }
882 }
883 #[allow(dead_code)]
885 pub fn split_mutual_block(block: &MutualBlock) -> Vec<Declaration> {
886 let mut decls = Vec::new();
887 for name in &block.names {
888 if let (Some(ty), Some(val)) = (block.get_type(name), block.get_body(name)) {
889 decls.push(Declaration::Definition {
890 name: name.clone(),
891 univ_params: block.univ_params.clone(),
892 ty: ty.clone(),
893 val: val.clone(),
894 hint: ReducibilityHint::Regular(100),
895 });
896 }
897 }
898 decls
899 }
900}
901#[derive(Clone, Debug, PartialEq, Eq)]
903pub enum MutualElabError {
904 TypeMismatch(String),
906 InvalidRecursion(String),
908 MissingDefinition(String),
910 CyclicType(String),
912 TerminationFailure(String),
914 Other(String),
916}
917#[derive(Clone, Debug, PartialEq, Eq)]
919pub enum WellFoundedOrder {
920 Lexicographic(Vec<usize>),
922 Measure(usize),
924 Structural(usize),
926 Multiset(Vec<usize>),
928 Unknown,
930}
931#[derive(Debug, Clone)]
933pub struct MutualElabProgress {
934 pub names: Vec<Name>,
936 pub stage: MutualElabStage,
938 pub completed: Vec<MutualElabStage>,
940 pub error: Option<MutualElabError>,
942}
943impl MutualElabProgress {
944 pub fn new(names: Vec<Name>) -> Self {
946 Self {
947 names,
948 stage: MutualElabStage::SigCollection,
949 completed: Vec::new(),
950 error: None,
951 }
952 }
953 pub fn advance(&mut self) {
955 let next = match self.stage {
956 MutualElabStage::SigCollection => MutualElabStage::DependencyAnalysis,
957 MutualElabStage::DependencyAnalysis => MutualElabStage::BodyElab,
958 MutualElabStage::BodyElab => MutualElabStage::TerminationCheck,
959 MutualElabStage::TerminationCheck => MutualElabStage::PostProcess,
960 MutualElabStage::PostProcess => MutualElabStage::Done,
961 MutualElabStage::Done => MutualElabStage::Done,
962 };
963 self.completed.push(self.stage);
964 self.stage = next;
965 }
966 pub fn fail(&mut self, err: MutualElabError) {
968 self.error = Some(err);
969 self.stage = MutualElabStage::Done;
970 }
971 pub fn is_done(&self) -> bool {
973 self.stage == MutualElabStage::Done
974 }
975 pub fn is_success(&self) -> bool {
977 self.is_done() && self.error.is_none()
978 }
979}
980#[derive(Clone, Debug)]
982pub struct MutualRecursionSummary {
983 pub names: Vec<Name>,
985 pub is_mutually_recursive: bool,
987 pub mutual_groups: Vec<Vec<Name>>,
989 pub termination_measure: Option<TerminationMeasure>,
991 pub diagnostics: Vec<String>,
993}
994impl MutualRecursionSummary {
995 pub fn from_detector(
997 detector: &MutualDefCycleDetector,
998 measure: Option<TerminationMeasure>,
999 ) -> Self {
1000 let groups = detector.mutual_groups();
1001 let is_mutual = !groups.is_empty();
1002 Self {
1003 names: (0..detector.num_decls())
1004 .filter_map(|i| detector.graph.names.get(i).cloned())
1005 .collect(),
1006 is_mutually_recursive: is_mutual,
1007 mutual_groups: groups,
1008 termination_measure: measure,
1009 diagnostics: Vec::new(),
1010 }
1011 }
1012 pub fn add_diagnostic(&mut self, msg: impl Into<String>) {
1014 self.diagnostics.push(msg.into());
1015 }
1016 pub fn has_diagnostics(&self) -> bool {
1018 !self.diagnostics.is_empty()
1019 }
1020}
1021#[derive(Clone, Debug)]
1023pub struct TerminationMeasure {
1024 pub order: WellFoundedOrder,
1026 pub confidence: f64,
1028 pub justification: String,
1030}
1031impl TerminationMeasure {
1032 pub fn certain(order: WellFoundedOrder, justification: impl Into<String>) -> Self {
1034 Self {
1035 order,
1036 confidence: 1.0,
1037 justification: justification.into(),
1038 }
1039 }
1040 pub fn heuristic(
1042 order: WellFoundedOrder,
1043 confidence: f64,
1044 justification: impl Into<String>,
1045 ) -> Self {
1046 Self {
1047 order,
1048 confidence,
1049 justification: justification.into(),
1050 }
1051 }
1052 pub fn is_reliable(&self) -> bool {
1054 self.confidence >= 0.8
1055 }
1056}
1057#[derive(Clone, Debug, PartialEq, Eq)]
1059pub enum TerminationKind {
1060 Structural(HashMap<Name, usize>),
1062 WellFounded,
1064 NonRecursive,
1066}
1067#[derive(Debug, Clone)]
1069pub struct MutualBlock {
1070 pub names: Vec<Name>,
1072 pub types: HashMap<Name, Expr>,
1074 pub bodies: HashMap<Name, Expr>,
1076 pub univ_params: Vec<Name>,
1078 pub attrs: HashMap<Name, Vec<AttributeKind>>,
1080 pub is_noncomputable: HashMap<Name, bool>,
1082}
1083impl MutualBlock {
1084 pub fn new() -> Self {
1086 Self {
1087 names: Vec::new(),
1088 types: HashMap::new(),
1089 bodies: HashMap::new(),
1090 univ_params: Vec::new(),
1091 attrs: HashMap::new(),
1092 is_noncomputable: HashMap::new(),
1093 }
1094 }
1095 pub fn add(&mut self, name: Name, ty: Expr, body: Expr) {
1097 self.names.push(name.clone());
1098 self.types.insert(name.clone(), ty);
1099 self.bodies.insert(name, body);
1100 }
1101 #[allow(dead_code)]
1103 pub fn add_with_attrs(
1104 &mut self,
1105 name: Name,
1106 ty: Expr,
1107 body: Expr,
1108 attrs: Vec<AttributeKind>,
1109 noncomputable: bool,
1110 ) {
1111 self.names.push(name.clone());
1112 self.types.insert(name.clone(), ty);
1113 self.bodies.insert(name.clone(), body);
1114 self.attrs.insert(name.clone(), attrs);
1115 self.is_noncomputable.insert(name, noncomputable);
1116 }
1117 pub fn get_type(&self, name: &Name) -> Option<&Expr> {
1119 self.types.get(name)
1120 }
1121 pub fn get_body(&self, name: &Name) -> Option<&Expr> {
1123 self.bodies.get(name)
1124 }
1125 pub fn size(&self) -> usize {
1127 self.names.len()
1128 }
1129 pub fn contains(&self, name: &Name) -> bool {
1131 self.names.contains(name)
1132 }
1133 #[allow(dead_code)]
1135 pub fn names_in_order(&self) -> &[Name] {
1136 &self.names
1137 }
1138 #[allow(dead_code)]
1140 pub fn get_all_bodies(&self) -> Vec<(&Name, &Expr)> {
1141 self.names
1142 .iter()
1143 .filter_map(|name| self.bodies.get(name).map(|body| (name, body)))
1144 .collect()
1145 }
1146 #[allow(dead_code)]
1148 pub fn validate(&self) -> Result<(), MutualElabError> {
1149 if self.names.is_empty() {
1150 return Err(MutualElabError::Other("empty mutual block".to_string()));
1151 }
1152 let mut seen = HashSet::new();
1153 for name in &self.names {
1154 if !seen.insert(name.clone()) {
1155 return Err(MutualElabError::Other(format!(
1156 "duplicate name in mutual block: {:?}",
1157 name
1158 )));
1159 }
1160 }
1161 for name in &self.names {
1162 if !self.types.contains_key(name) {
1163 return Err(MutualElabError::MissingDefinition(format!(
1164 "no type for '{:?}'",
1165 name
1166 )));
1167 }
1168 }
1169 for name in &self.names {
1170 if !self.bodies.contains_key(name) {
1171 return Err(MutualElabError::MissingDefinition(format!(
1172 "no body for '{:?}'",
1173 name
1174 )));
1175 }
1176 }
1177 Ok(())
1178 }
1179 #[allow(dead_code)]
1181 pub fn set_univ_params(&mut self, params: Vec<Name>) {
1182 self.univ_params = params;
1183 }
1184 #[allow(dead_code)]
1186 pub fn set_attrs(&mut self, name: &Name, attrs: Vec<AttributeKind>) {
1187 self.attrs.insert(name.clone(), attrs);
1188 }
1189 #[allow(dead_code)]
1191 pub fn set_noncomputable(&mut self, name: &Name, noncomputable: bool) {
1192 self.is_noncomputable.insert(name.clone(), noncomputable);
1193 }
1194 #[allow(dead_code)]
1196 pub fn is_def_noncomputable(&self, name: &Name) -> bool {
1197 self.is_noncomputable.get(name).copied().unwrap_or(false)
1198 }
1199 #[allow(dead_code)]
1201 pub fn get_attrs(&self, name: &Name) -> &[AttributeKind] {
1202 self.attrs.get(name).map(|v| v.as_slice()).unwrap_or(&[])
1203 }
1204}
1205#[derive(Clone, Debug)]
1210pub struct StructuralRecursion {
1211 pub block: MutualBlock,
1213 pub recursive_args: HashMap<Name, Vec<usize>>,
1215}
1216impl StructuralRecursion {
1217 #[allow(dead_code)]
1219 pub fn new(block: MutualBlock) -> Self {
1220 Self {
1221 block,
1222 recursive_args: HashMap::new(),
1223 }
1224 }
1225 #[allow(dead_code)]
1227 pub fn detect_structural_recursion(&mut self) -> Result<(), MutualElabError> {
1228 let call_graph = CallGraph::build_from_block(&self.block);
1229 for name in &self.block.names {
1230 if call_graph.is_self_recursive(name) {
1231 match call_graph.find_decreasing_arg(name) {
1232 Some(idx) => {
1233 self.recursive_args
1234 .entry(name.clone())
1235 .or_default()
1236 .push(idx);
1237 }
1238 None => {
1239 return Err(MutualElabError::TerminationFailure(format!(
1240 "could not find structurally decreasing argument for '{:?}'",
1241 name
1242 )));
1243 }
1244 }
1245 }
1246 }
1247 Ok(())
1248 }
1249 #[allow(dead_code)]
1254 pub fn encode_as_recursor_application(&self) -> Result<MutualBlock, MutualElabError> {
1255 let mut result = self.block.clone();
1256 let call_graph = CallGraph::build_from_block(&self.block);
1257 for name in &self.block.names {
1258 if call_graph.is_self_recursive(name) && !self.recursive_args.contains_key(name) {
1259 return Err(MutualElabError::TerminationFailure(format!(
1260 "no structural recursion info for '{:?}'",
1261 name
1262 )));
1263 }
1264 }
1265 for (name, args) in &self.recursive_args {
1266 let attr_name = format!(
1267 "_rec_arg_{}",
1268 args.iter()
1269 .map(|i| i.to_string())
1270 .collect::<Vec<_>>()
1271 .join("_")
1272 );
1273 result
1274 .attrs
1275 .entry(name.clone())
1276 .or_default()
1277 .push(AttributeKind::Custom(attr_name));
1278 }
1279 Ok(result)
1280 }
1281 #[allow(dead_code)]
1283 pub fn get_recursive_args(&self) -> &HashMap<Name, Vec<usize>> {
1284 &self.recursive_args
1285 }
1286}