Skip to main content

oxilean_elab/mutual/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::functions::*;
6use oxilean_kernel::{Declaration, Expr, Name, ReducibilityHint};
7use oxilean_parse::AttributeKind;
8use std::collections::{HashMap, HashSet};
9
10/// Relation of an argument in a recursive call relative to the caller.
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub enum ArgRelation {
13    /// Argument is syntactically equal to the corresponding parameter
14    Equal,
15    /// Argument is structurally smaller
16    Smaller,
17    /// Unknown relation
18    Unknown,
19}
20/// Call graph tracking recursive calls in a mutual block.
21#[derive(Clone, Debug, Default)]
22pub struct CallGraph {
23    /// Map from caller name to list of recursive calls it makes
24    calls: HashMap<Name, Vec<RecursiveCall>>,
25    /// All function names in the mutual block
26    names: Vec<Name>,
27}
28impl CallGraph {
29    /// Build a call graph from a mutual block.
30    ///
31    /// This performs a conservative syntactic analysis of function bodies
32    /// to detect recursive calls to functions in the mutual block.
33    #[allow(dead_code)]
34    pub fn build_from_block(block: &MutualBlock) -> Self {
35        let mut calls: HashMap<Name, Vec<RecursiveCall>> = HashMap::new();
36        let block_names: HashSet<Name> = block.names.iter().cloned().collect();
37        for name in &block.names {
38            let mut func_calls = Vec::new();
39            if let Some(body) = block.get_body(name) {
40                Self::collect_calls(name, body, &block_names, &mut func_calls);
41            }
42            calls.insert(name.clone(), func_calls);
43        }
44        Self {
45            calls,
46            names: block.names.clone(),
47        }
48    }
49    /// Recursively collect calls from an expression.
50    /// Peel a curried application into (head, [arg0, arg1, ...]).
51    fn peel_app(expr: &Expr) -> (&Expr, Vec<&Expr>) {
52        let mut args = Vec::new();
53        let mut cur = expr;
54        while let Expr::App(f, a) = cur {
55            args.push(a.as_ref());
56            cur = f.as_ref();
57        }
58        args.reverse();
59        (cur, args)
60    }
61    fn collect_calls(
62        caller: &Name,
63        expr: &Expr,
64        block_names: &HashSet<Name>,
65        out: &mut Vec<RecursiveCall>,
66    ) {
67        match expr {
68            Expr::App(func, arg) => {
69                let (head, all_args) = Self::peel_app(expr);
70                if let Some(callee_name) = Self::get_const_head(head) {
71                    if block_names.contains(&callee_name) {
72                        let relations: Vec<ArgRelation> =
73                            all_args.iter().map(|a| Self::classify_arg(a)).collect();
74                        out.push(RecursiveCall {
75                            caller: caller.clone(),
76                            callee: callee_name,
77                            args: relations,
78                        });
79                        for a in &all_args {
80                            Self::collect_calls(caller, a, block_names, out);
81                        }
82                        return;
83                    }
84                }
85                Self::collect_calls(caller, func, block_names, out);
86                Self::collect_calls(caller, arg, block_names, out);
87            }
88            Expr::Lam(_, _, ty, body) => {
89                Self::collect_calls(caller, ty, block_names, out);
90                Self::collect_calls(caller, body, block_names, out);
91            }
92            Expr::Pi(_, _, ty, body) => {
93                Self::collect_calls(caller, ty, block_names, out);
94                Self::collect_calls(caller, body, block_names, out);
95            }
96            Expr::Let(_, ty, val, body) => {
97                Self::collect_calls(caller, ty, block_names, out);
98                Self::collect_calls(caller, val, block_names, out);
99                Self::collect_calls(caller, body, block_names, out);
100            }
101            Expr::Proj(_, _, base) => {
102                Self::collect_calls(caller, base, block_names, out);
103            }
104            _ => {}
105        }
106    }
107    /// Extract the constant name from the head of a (possibly nested) application.
108    fn get_const_head(expr: &Expr) -> Option<Name> {
109        match expr {
110            Expr::Const(name, _) => Some(name.clone()),
111            Expr::App(func, _) => Self::get_const_head(func),
112            _ => None,
113        }
114    }
115    /// Classify an argument expression as Equal, Smaller, or Unknown.
116    fn classify_arg(expr: &Expr) -> ArgRelation {
117        match expr {
118            Expr::BVar(_) => ArgRelation::Equal,
119            Expr::Proj(_, _, base) => {
120                if matches!(base.as_ref(), Expr::BVar(_)) {
121                    ArgRelation::Smaller
122                } else {
123                    ArgRelation::Unknown
124                }
125            }
126            Expr::App(func, _) => {
127                if matches!(func.as_ref(), Expr::BVar(_)) {
128                    ArgRelation::Smaller
129                } else {
130                    ArgRelation::Unknown
131                }
132            }
133            _ => ArgRelation::Unknown,
134        }
135    }
136    /// Check if a specific argument position is structurally decreasing
137    /// for a given function.
138    #[allow(dead_code)]
139    pub fn is_structurally_decreasing(&self, name: &Name, arg_idx: usize) -> bool {
140        if let Some(func_calls) = self.calls.get(name) {
141            if func_calls.is_empty() {
142                return true;
143            }
144            let mut has_smaller = false;
145            for call in func_calls {
146                if call.callee == *name {
147                    match call.args.get(arg_idx) {
148                        Some(ArgRelation::Smaller) => has_smaller = true,
149                        Some(ArgRelation::Equal) => {}
150                        _ => return false,
151                    }
152                }
153            }
154            has_smaller || func_calls.iter().all(|c| c.callee != *name)
155        } else {
156            false
157        }
158    }
159    /// Find the first argument index that is structurally decreasing.
160    #[allow(dead_code)]
161    pub fn find_decreasing_arg(&self, name: &Name) -> Option<usize> {
162        let max_args = self
163            .calls
164            .get(name)
165            .map(|calls| {
166                calls
167                    .iter()
168                    .filter(|c| &c.callee == name)
169                    .map(|c| c.args.len())
170                    .max()
171                    .unwrap_or(0)
172            })
173            .unwrap_or(0);
174        let n = if max_args == 0 { 1 } else { max_args };
175        (0..n).find(|&idx| self.is_structurally_decreasing(name, idx))
176    }
177    /// Check if the block is mutually recursive (cross-function calls exist).
178    #[allow(dead_code)]
179    pub fn is_mutually_recursive(&self) -> bool {
180        for (caller, func_calls) in &self.calls {
181            for call in func_calls {
182                if &call.callee != caller {
183                    return true;
184                }
185            }
186        }
187        false
188    }
189    /// Check if any function in the block is self-recursive.
190    #[allow(dead_code)]
191    pub fn is_self_recursive(&self, name: &Name) -> bool {
192        self.calls
193            .get(name)
194            .map(|cs| cs.iter().any(|c| c.callee == *name))
195            .unwrap_or(false)
196    }
197    /// Check if any function in the block is recursive at all.
198    #[allow(dead_code)]
199    pub fn is_recursive(&self) -> bool {
200        for func_calls in self.calls.values() {
201            if !func_calls.is_empty() {
202                return true;
203            }
204        }
205        false
206    }
207    /// Get all calls made by a specific function.
208    #[allow(dead_code)]
209    pub fn get_calls(&self, name: &Name) -> &[RecursiveCall] {
210        self.calls.get(name).map(|v| v.as_slice()).unwrap_or(&[])
211    }
212    /// Compute strongly connected components using Tarjan's algorithm.
213    #[allow(dead_code)]
214    pub fn strongly_connected_components(&self) -> Vec<Vec<Name>> {
215        let n = self.names.len();
216        if n == 0 {
217            return Vec::new();
218        }
219        let name_to_idx: HashMap<Name, usize> = self
220            .names
221            .iter()
222            .enumerate()
223            .map(|(i, name)| (name.clone(), i))
224            .collect();
225        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
226        for (caller, func_calls) in &self.calls {
227            if let Some(&ci) = name_to_idx.get(caller) {
228                for call in func_calls {
229                    if let Some(&cj) = name_to_idx.get(&call.callee) {
230                        if !adj[ci].contains(&cj) {
231                            adj[ci].push(cj);
232                        }
233                    }
234                }
235            }
236        }
237        let mut index_counter: usize = 0;
238        let mut stack: Vec<usize> = Vec::new();
239        let mut on_stack = vec![false; n];
240        let mut indices: Vec<Option<usize>> = vec![None; n];
241        let mut lowlinks = vec![0usize; n];
242        let mut result: Vec<Vec<Name>> = Vec::new();
243        for v in 0..n {
244            if indices[v].is_none() {
245                Self::tarjan_visit(
246                    v,
247                    &adj,
248                    &mut index_counter,
249                    &mut stack,
250                    &mut on_stack,
251                    &mut indices,
252                    &mut lowlinks,
253                    &mut result,
254                    &self.names,
255                );
256            }
257        }
258        result
259    }
260    /// Tarjan DFS visit helper.
261    #[allow(clippy::too_many_arguments)]
262    fn tarjan_visit(
263        v: usize,
264        adj: &[Vec<usize>],
265        index_counter: &mut usize,
266        stack: &mut Vec<usize>,
267        on_stack: &mut Vec<bool>,
268        indices: &mut Vec<Option<usize>>,
269        lowlinks: &mut Vec<usize>,
270        result: &mut Vec<Vec<Name>>,
271        names: &[Name],
272    ) {
273        indices[v] = Some(*index_counter);
274        lowlinks[v] = *index_counter;
275        *index_counter += 1;
276        stack.push(v);
277        on_stack[v] = true;
278        for &w in &adj[v] {
279            if indices[w].is_none() {
280                Self::tarjan_visit(
281                    w,
282                    adj,
283                    index_counter,
284                    stack,
285                    on_stack,
286                    indices,
287                    lowlinks,
288                    result,
289                    names,
290                );
291                lowlinks[v] = lowlinks[v].min(lowlinks[w]);
292            } else if on_stack[w] {
293                lowlinks[v] =
294                    lowlinks[v].min(indices[w].expect("w is on stack so indices[w] is set"));
295            }
296        }
297        if lowlinks[v] == indices[v].expect("v was just assigned an index above") {
298            let mut component = Vec::new();
299            loop {
300                let w = stack
301                    .pop()
302                    .expect("stack is non-empty: v is always on it when we reach the SCC root");
303                on_stack[w] = false;
304                component.push(names[w].clone());
305                if w == v {
306                    break;
307                }
308            }
309            result.push(component);
310        }
311    }
312}
313/// Encoder for well-founded recursive definitions.
314///
315/// Used when structural recursion cannot be established. Requires
316/// a well-founded relation and a measure function.
317#[derive(Clone, Debug)]
318pub struct WellFoundedRecursion {
319    /// The mutual block being processed
320    pub block: MutualBlock,
321    /// Optional measure function name
322    pub measure: Option<Name>,
323    /// Optional well-founded relation expression
324    pub rel: Option<Expr>,
325    /// For each function, which argument indices decrease
326    pub decreasing_args: HashMap<Name, Vec<usize>>,
327}
328impl WellFoundedRecursion {
329    /// Create a new well-founded recursion encoder.
330    #[allow(dead_code)]
331    pub fn new(block: MutualBlock) -> Self {
332        Self {
333            block,
334            measure: None,
335            rel: None,
336            decreasing_args: HashMap::new(),
337        }
338    }
339    /// Set the measure function.
340    #[allow(dead_code)]
341    pub fn set_measure(&mut self, name: Name) {
342        self.measure = Some(name);
343    }
344    /// Set the well-founded relation.
345    #[allow(dead_code)]
346    pub fn set_relation(&mut self, rel: Expr) {
347        self.rel = Some(rel);
348    }
349    /// Detect which arguments are decreasing under the given measure.
350    ///
351    /// Uses structural analysis via `CallGraph` to find which argument positions
352    /// are structurally decreasing. Falls back to argument 0 if none can be
353    /// determined (e.g. for well-founded recursion with an explicit measure).
354    #[allow(dead_code)]
355    pub fn detect_decreasing_args(&mut self) -> Result<(), MutualElabError> {
356        let call_graph = CallGraph::build_from_block(&self.block);
357        for name in &self.block.names {
358            if call_graph.is_self_recursive(name) || call_graph.is_mutually_recursive() {
359                let dec_idx = call_graph.find_decreasing_arg(name).unwrap_or(0);
360                self.decreasing_args
361                    .entry(name.clone())
362                    .or_default()
363                    .push(dec_idx);
364            }
365        }
366        Ok(())
367    }
368    /// Encode the definitions using well-founded recursion.
369    ///
370    /// Transforms the mutual block into a form that uses `WellFounded.fix`
371    /// (or equivalent) to justify termination.
372    #[allow(dead_code)]
373    pub fn encode_as_wf_recursion(&self) -> Result<MutualBlock, MutualElabError> {
374        if self.measure.is_none() && self.rel.is_none() {
375            return Err(MutualElabError::TerminationFailure(
376                "well-founded recursion requires a measure or relation".to_string(),
377            ));
378        }
379        let mut result = self.block.clone();
380        let wf_rel: Expr = match (&self.measure, &self.rel) {
381            (Some(m), _) => Expr::App(
382                Box::new(Expr::Const(Name::str("Measure"), vec![])),
383                Box::new(Expr::Const(m.clone(), vec![])),
384            ),
385            (None, Some(r)) => r.clone(),
386            (None, None) => unreachable!("checked above"),
387        };
388        let wf_proof = Expr::App(
389            Box::new(Expr::Const(Name::str("WellFounded.wf"), vec![])),
390            Box::new(wf_rel.clone()),
391        );
392        let call_graph = CallGraph::build_from_block(&self.block);
393        for name in &self.block.names {
394            if !call_graph.is_self_recursive(name) {
395                continue;
396            }
397            if let Some(body) = self.block.get_body(name) {
398                let dec_idx = self
399                    .decreasing_args
400                    .get(name)
401                    .and_then(|v| v.first())
402                    .copied()
403                    .unwrap_or(0);
404                let rec_ty = self
405                    .block
406                    .types
407                    .get(name)
408                    .cloned()
409                    .unwrap_or(Expr::Const(Name::str("_"), vec![]));
410                let step = Expr::Lam(
411                    oxilean_kernel::BinderInfo::Default,
412                    name.clone(),
413                    Box::new(rec_ty),
414                    Box::new(body.clone()),
415                );
416                let init_arg = Expr::BVar(dec_idx as u32);
417                let wrapped = Expr::App(
418                    Box::new(Expr::App(
419                        Box::new(Expr::App(
420                            Box::new(Expr::Const(Name::str("WellFounded.fix"), vec![])),
421                            Box::new(wf_proof.clone()),
422                        )),
423                        Box::new(step),
424                    )),
425                    Box::new(init_arg),
426                );
427                result.bodies.insert(name.clone(), wrapped);
428            }
429            result
430                .attrs
431                .entry(name.clone())
432                .or_default()
433                .push(AttributeKind::Custom("_wf_rec".to_string()));
434        }
435        Ok(result)
436    }
437    /// Generate a termination proof obligation.
438    ///
439    /// Returns an expression representing the proof obligation
440    /// that all recursive calls decrease under the measure.
441    #[allow(dead_code)]
442    pub fn generate_termination_proof(&self) -> Result<Expr, MutualElabError> {
443        if self.measure.is_some() || self.rel.is_some() {
444            Ok(Expr::Const(Name::str("sorry"), vec![]))
445        } else {
446            Err(MutualElabError::TerminationFailure(
447                "no measure or relation provided".to_string(),
448            ))
449        }
450    }
451}
452/// A dependency graph over declaration names for cycle detection.
453#[derive(Clone, Debug, Default)]
454pub struct DeclDependencyGraph {
455    /// Ordered list of declaration names.
456    names: Vec<Name>,
457    /// Adjacency list: `edges[i]` contains indices j such that decl[i] calls decl[j].
458    edges: Vec<Vec<usize>>,
459}
460impl DeclDependencyGraph {
461    /// Create an empty graph.
462    pub fn new() -> Self {
463        Self::default()
464    }
465    /// Add a declaration node.  Returns its index.
466    pub fn add_node(&mut self, name: Name) -> usize {
467        let idx = self.names.len();
468        self.names.push(name);
469        self.edges.push(Vec::new());
470        idx
471    }
472    /// Add a directed dependency edge `from -> to`.
473    pub fn add_edge(&mut self, from: usize, to: usize) {
474        if !self.edges[from].contains(&to) {
475            self.edges[from].push(to);
476        }
477    }
478    /// Return the index of a declaration by name.
479    pub fn index_of(&self, name: &Name) -> Option<usize> {
480        self.names.iter().position(|n| n == name)
481    }
482    /// Compute all SCCs.
483    pub fn sccs(&self) -> Vec<Vec<Name>> {
484        let raw = tarjan_scc(self.names.len(), &self.edges);
485        raw.into_iter()
486            .map(|scc| scc.iter().map(|&i| self.names[i].clone()).collect())
487            .collect()
488    }
489    /// Return `true` if any SCC has more than one node (true cycle).
490    pub fn has_cycle(&self) -> bool {
491        self.sccs().iter().any(|scc| scc.len() > 1)
492    }
493    /// Return all cyclic SCCs (those with more than one node).
494    pub fn cyclic_sccs(&self) -> Vec<Vec<Name>> {
495        self.sccs().into_iter().filter(|s| s.len() > 1).collect()
496    }
497    /// Topological order of nodes (only valid when `has_cycle()` is false).
498    pub fn topological_order(&self) -> Vec<Name> {
499        let sccs = self.sccs();
500        sccs.into_iter().flatten().collect()
501    }
502    /// Number of nodes in the graph.
503    pub fn num_nodes(&self) -> usize {
504        self.names.len()
505    }
506}
507/// High-level cycle detector for mutual definition groups.
508///
509/// Given a set of definitions and their direct call dependencies, this
510/// struct computes which definitions form genuine mutual recursion cycles
511/// and which are merely co-defined but not mutually recursive.
512#[derive(Debug, Default)]
513pub struct MutualDefCycleDetector {
514    graph: DeclDependencyGraph,
515}
516impl MutualDefCycleDetector {
517    /// Create an empty detector.
518    pub fn new() -> Self {
519        Self::default()
520    }
521    /// Register a declaration name.  Returns its graph index.
522    pub fn register(&mut self, name: Name) -> usize {
523        self.graph.add_node(name)
524    }
525    /// Declare that `caller` directly uses `callee`.
526    ///
527    /// Both names must have been registered first.
528    pub fn add_dependency(&mut self, caller: &Name, callee: &Name) -> bool {
529        match (self.graph.index_of(caller), self.graph.index_of(callee)) {
530            (Some(from), Some(to)) => {
531                self.graph.add_edge(from, to);
532                true
533            }
534            _ => false,
535        }
536    }
537    /// Check whether there are any non-trivial mutual recursion cycles.
538    pub fn has_mutual_recursion(&self) -> bool {
539        self.graph.has_cycle()
540    }
541    /// Return all groups of mutually recursive declarations.
542    pub fn mutual_groups(&self) -> Vec<Vec<Name>> {
543        self.graph.cyclic_sccs()
544    }
545    /// Return the topological order for the declarations (deepest dependencies
546    /// first), only meaningful if no cycles exist.
547    pub fn elaboration_order(&self) -> Vec<Name> {
548        self.graph.topological_order()
549    }
550    /// Number of registered declarations.
551    pub fn num_decls(&self) -> usize {
552        self.graph.num_nodes()
553    }
554}
555/// Stages of the mutual-definition elaboration pipeline.
556#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
557pub enum MutualElabStage {
558    /// Initial collection of signatures.
559    SigCollection,
560    /// Dependency analysis and SCC computation.
561    DependencyAnalysis,
562    /// Body elaboration.
563    BodyElab,
564    /// Termination checking.
565    TerminationCheck,
566    /// Post-processing (wf-encode, add to env).
567    PostProcess,
568    /// Complete.
569    Done,
570}
571/// A partially-known signature for a mutually recursive function.
572#[allow(dead_code)]
573#[derive(Clone, Debug)]
574pub struct PartialSig {
575    /// Function name.
576    pub name: Name,
577    /// Declared type (if any).
578    pub declared_type: Option<Expr>,
579    /// Inferred type (filled in during elaboration).
580    pub inferred_type: Option<Expr>,
581    /// Whether the signature is fully resolved.
582    pub resolved: bool,
583}
584#[allow(dead_code)]
585impl PartialSig {
586    /// Create a new partial signature with only a name.
587    pub fn new(name: Name) -> Self {
588        Self {
589            name,
590            declared_type: None,
591            inferred_type: None,
592            resolved: false,
593        }
594    }
595    /// Mark the signature as resolved.
596    pub fn resolve(&mut self, ty: Expr) {
597        self.inferred_type = Some(ty);
598        self.resolved = true;
599    }
600    /// Return the best available type (inferred > declared > None).
601    pub fn best_type(&self) -> Option<&Expr> {
602        self.inferred_type.as_ref().or(self.declared_type.as_ref())
603    }
604}
605/// A budget for mutual elaboration (limits recursion/unfolding).
606#[allow(dead_code)]
607#[derive(Clone, Debug)]
608pub struct MutualElabBudget {
609    /// Maximum allowed SCC size.
610    pub max_scc_size: usize,
611    /// Maximum depth for termination checking.
612    pub max_termination_depth: usize,
613    /// Maximum number of structural recursion arguments to check.
614    pub max_structural_args: usize,
615    /// Maximum number of refinement iterations.
616    pub max_refinements: usize,
617}
618#[allow(dead_code)]
619impl MutualElabBudget {
620    /// Create a budget with default limits.
621    pub fn new() -> Self {
622        Self::default()
623    }
624    /// Create a liberal budget (for debugging).
625    pub fn liberal() -> Self {
626        Self {
627            max_scc_size: 256,
628            max_termination_depth: 1024,
629            max_structural_args: 64,
630            max_refinements: 32,
631        }
632    }
633    /// Create a strict budget (for fast pre-checks).
634    pub fn strict() -> Self {
635        Self {
636            max_scc_size: 8,
637            max_termination_depth: 32,
638            max_structural_args: 4,
639            max_refinements: 2,
640        }
641    }
642    /// Check if an SCC of size `n` is within budget.
643    pub fn allows_scc_size(&self, n: usize) -> bool {
644        n <= self.max_scc_size
645    }
646    /// Check if a termination depth `d` is within budget.
647    pub fn allows_termination_depth(&self, d: usize) -> bool {
648        d <= self.max_termination_depth
649    }
650}
651/// A single recursive call found in a function body.
652#[derive(Clone, Debug)]
653pub struct RecursiveCall {
654    /// Name of the calling function
655    pub caller: Name,
656    /// Name of the called function
657    pub callee: Name,
658    /// Relation of each argument to the corresponding parameter
659    pub args: Vec<ArgRelation>,
660}
661/// Per-node metadata for Tarjan's algorithm.
662#[derive(Debug, Clone, Default)]
663pub struct TarjanNode {
664    pub index: usize,
665    pub lowlink: usize,
666    pub on_stack: bool,
667    pub discovered: bool,
668}
669/// Collection of partial signatures for a mutually recursive group.
670#[allow(dead_code)]
671#[derive(Clone, Debug, Default)]
672pub struct MutualSigCollection {
673    sigs: Vec<PartialSig>,
674}
675#[allow(dead_code)]
676impl MutualSigCollection {
677    /// Create an empty collection.
678    pub fn new() -> Self {
679        Self::default()
680    }
681    /// Add a partial signature.
682    pub fn add(&mut self, sig: PartialSig) {
683        self.sigs.push(sig);
684    }
685    /// Return the number of signatures.
686    pub fn len(&self) -> usize {
687        self.sigs.len()
688    }
689    /// Return true if the collection is empty.
690    pub fn is_empty(&self) -> bool {
691        self.sigs.is_empty()
692    }
693    /// Return the number of resolved signatures.
694    pub fn num_resolved(&self) -> usize {
695        self.sigs.iter().filter(|s| s.resolved).count()
696    }
697    /// Return true if all signatures are resolved.
698    pub fn all_resolved(&self) -> bool {
699        self.sigs.iter().all(|s| s.resolved)
700    }
701    /// Look up a signature by name.
702    pub fn get(&self, name: &Name) -> Option<&PartialSig> {
703        self.sigs.iter().find(|s| &s.name == name)
704    }
705    /// Mutably look up a signature by name.
706    pub fn get_mut(&mut self, name: &Name) -> Option<&mut PartialSig> {
707        self.sigs.iter_mut().find(|s| &s.name == name)
708    }
709    /// Iterate over all signatures.
710    pub fn iter(&self) -> impl Iterator<Item = &PartialSig> {
711        self.sigs.iter()
712    }
713}
714/// Mutual recursion checker and elaborator.
715pub struct MutualChecker {
716    /// Current mutual block being checked
717    current_block: Option<MutualBlock>,
718}
719impl MutualChecker {
720    /// Create a new mutual checker.
721    pub fn new() -> Self {
722        Self {
723            current_block: None,
724        }
725    }
726    /// Start a new mutual block.
727    pub fn start_block(&mut self) {
728        self.current_block = Some(MutualBlock::new());
729    }
730    /// Add a definition to the current block.
731    pub fn add_def(&mut self, name: Name, ty: Expr, body: Expr) -> Result<(), String> {
732        if let Some(block) = &mut self.current_block {
733            block.add(name, ty, body);
734            Ok(())
735        } else {
736            Err("No mutual block started".to_string())
737        }
738    }
739    /// Finish the current mutual block.
740    pub fn finish_block(&mut self) -> Result<MutualBlock, String> {
741        self.current_block
742            .take()
743            .ok_or_else(|| "No mutual block to finish".to_string())
744    }
745    /// Get the current block (if any).
746    pub fn current_block(&self) -> Option<&MutualBlock> {
747        self.current_block.as_ref()
748    }
749    /// Check well-formedness of a mutual block.
750    ///
751    /// Validates:
752    /// - All types are present
753    /// - All bodies are present
754    /// - No duplicate names
755    #[allow(dead_code)]
756    pub fn check_well_formedness(block: &MutualBlock) -> Result<(), MutualElabError> {
757        block.validate()?;
758        let block_names: HashSet<Name> = block.names.iter().cloned().collect();
759        for name in &block.names {
760            if let Some(ty) = block.get_type(name) {
761                Self::check_no_external_forward_refs(ty, &block_names)?;
762            }
763        }
764        Ok(())
765    }
766    /// Check that an expression does not reference undefined names
767    /// outside the mutual block.
768    fn check_no_external_forward_refs(
769        expr: &Expr,
770        block_names: &HashSet<Name>,
771    ) -> Result<(), MutualElabError> {
772        match expr {
773            Expr::Const(name, _) => {
774                let _ = block_names.contains(name);
775                Ok(())
776            }
777            Expr::App(f, a) => {
778                Self::check_no_external_forward_refs(f, block_names)?;
779                Self::check_no_external_forward_refs(a, block_names)?;
780                Ok(())
781            }
782            Expr::Lam(_, _, ty, body) | Expr::Pi(_, _, ty, body) => {
783                Self::check_no_external_forward_refs(ty, block_names)?;
784                Self::check_no_external_forward_refs(body, block_names)?;
785                Ok(())
786            }
787            Expr::Let(_, ty, val, body) => {
788                Self::check_no_external_forward_refs(ty, block_names)?;
789                Self::check_no_external_forward_refs(val, block_names)?;
790                Self::check_no_external_forward_refs(body, block_names)?;
791                Ok(())
792            }
793            Expr::Proj(_, _, base) => {
794                Self::check_no_external_forward_refs(base, block_names)?;
795                Ok(())
796            }
797            _ => Ok(()),
798        }
799    }
800    /// Check termination of a mutual block.
801    ///
802    /// Determines if the definitions are:
803    /// - Non-recursive
804    /// - Structurally recursive
805    /// - Requiring well-founded recursion
806    #[allow(dead_code)]
807    pub fn check_termination(block: &MutualBlock) -> Result<TerminationKind, MutualElabError> {
808        let call_graph = CallGraph::build_from_block(block);
809        if !call_graph.is_recursive() {
810            return Ok(TerminationKind::NonRecursive);
811        }
812        let mut structural_args = HashMap::new();
813        let mut all_structural = true;
814        for name in &block.names {
815            if call_graph.is_self_recursive(name) || call_graph.is_mutually_recursive() {
816                match call_graph.find_decreasing_arg(name) {
817                    Some(idx) => {
818                        structural_args.insert(name.clone(), idx);
819                    }
820                    None => {
821                        all_structural = false;
822                        break;
823                    }
824                }
825            }
826        }
827        if all_structural && !structural_args.is_empty() {
828            return Ok(TerminationKind::Structural(structural_args));
829        }
830        Ok(TerminationKind::WellFounded)
831    }
832    /// Elaborate a set of mutual definitions.
833    ///
834    /// This is the main entry point for mutual definition elaboration:
835    /// 1. Forward-declare all names
836    /// 2. Elaborate types
837    /// 3. Elaborate bodies in extended context
838    /// 4. Check termination
839    #[allow(dead_code)]
840    pub fn elaborate_mutual_defs(
841        names: &[Name],
842        types: &[Expr],
843        bodies: &[Expr],
844    ) -> Result<MutualBlock, MutualElabError> {
845        if names.len() != types.len() || names.len() != bodies.len() {
846            return Err(MutualElabError::Other(
847                "mismatched lengths for names, types, and bodies".to_string(),
848            ));
849        }
850        if names.is_empty() {
851            return Err(MutualElabError::Other("empty mutual block".to_string()));
852        }
853        let mut block = MutualBlock::new();
854        for i in 0..names.len() {
855            block.add(names[i].clone(), types[i].clone(), bodies[i].clone());
856        }
857        block.validate()?;
858        Ok(block)
859    }
860    /// Encode recursion in a mutual block based on the termination kind.
861    #[allow(dead_code)]
862    pub fn encode_recursion(
863        block: MutualBlock,
864        kind: &TerminationKind,
865    ) -> Result<MutualBlock, MutualElabError> {
866        match kind {
867            TerminationKind::NonRecursive => Ok(block),
868            TerminationKind::Structural(_args) => {
869                let mut sr = StructuralRecursion::new(block);
870                sr.detect_structural_recursion()?;
871                sr.encode_as_recursor_application()
872            }
873            TerminationKind::WellFounded => {
874                let mut wfr = WellFoundedRecursion::new(block);
875                wfr.detect_decreasing_args()?;
876                if wfr.measure.is_none() && wfr.rel.is_none() {
877                    wfr.set_measure(Name::str("Nat.lt"));
878                }
879                wfr.encode_as_wf_recursion()
880            }
881        }
882    }
883    /// Split a mutual block into individual declarations.
884    #[allow(dead_code)]
885    pub fn split_mutual_block(block: &MutualBlock) -> Vec<Declaration> {
886        let mut decls = Vec::new();
887        for name in &block.names {
888            if let (Some(ty), Some(val)) = (block.get_type(name), block.get_body(name)) {
889                decls.push(Declaration::Definition {
890                    name: name.clone(),
891                    univ_params: block.univ_params.clone(),
892                    ty: ty.clone(),
893                    val: val.clone(),
894                    hint: ReducibilityHint::Regular(100),
895                });
896            }
897        }
898        decls
899    }
900}
901/// Error during mutual definition elaboration.
902#[derive(Clone, Debug, PartialEq, Eq)]
903pub enum MutualElabError {
904    /// Type mismatch between declared and inferred type
905    TypeMismatch(String),
906    /// Invalid recursion pattern
907    InvalidRecursion(String),
908    /// A definition in the mutual block is missing
909    MissingDefinition(String),
910    /// Types form a cycle (not allowed without inductive)
911    CyclicType(String),
912    /// Failed to prove termination
913    TerminationFailure(String),
914    /// Other error
915    Other(String),
916}
917/// Describes a well-founded ordering on terms used for termination proofs.
918#[derive(Clone, Debug, PartialEq, Eq)]
919pub enum WellFoundedOrder {
920    /// Lexicographic ordering on a tuple of arguments.
921    Lexicographic(Vec<usize>),
922    /// Measure function applied to a single argument.
923    Measure(usize),
924    /// Structural recursion on a specific argument position.
925    Structural(usize),
926    /// Multiset ordering.
927    Multiset(Vec<usize>),
928    /// Unknown / unresolved ordering.
929    Unknown,
930}
931/// Tracks the progress of the mutual-elaboration pipeline.
932#[derive(Debug, Clone)]
933pub struct MutualElabProgress {
934    /// Names being elaborated.
935    pub names: Vec<Name>,
936    /// Current stage.
937    pub stage: MutualElabStage,
938    /// Stages that have been completed.
939    pub completed: Vec<MutualElabStage>,
940    /// Any error encountered.
941    pub error: Option<MutualElabError>,
942}
943impl MutualElabProgress {
944    /// Create progress for the given names starting at `SigCollection`.
945    pub fn new(names: Vec<Name>) -> Self {
946        Self {
947            names,
948            stage: MutualElabStage::SigCollection,
949            completed: Vec::new(),
950            error: None,
951        }
952    }
953    /// Advance to the next stage.
954    pub fn advance(&mut self) {
955        let next = match self.stage {
956            MutualElabStage::SigCollection => MutualElabStage::DependencyAnalysis,
957            MutualElabStage::DependencyAnalysis => MutualElabStage::BodyElab,
958            MutualElabStage::BodyElab => MutualElabStage::TerminationCheck,
959            MutualElabStage::TerminationCheck => MutualElabStage::PostProcess,
960            MutualElabStage::PostProcess => MutualElabStage::Done,
961            MutualElabStage::Done => MutualElabStage::Done,
962        };
963        self.completed.push(self.stage);
964        self.stage = next;
965    }
966    /// Mark the elaboration as failed with the given error.
967    pub fn fail(&mut self, err: MutualElabError) {
968        self.error = Some(err);
969        self.stage = MutualElabStage::Done;
970    }
971    /// Return `true` if elaboration has completed (either successfully or with error).
972    pub fn is_done(&self) -> bool {
973        self.stage == MutualElabStage::Done
974    }
975    /// Return `true` if elaboration succeeded (done, no error).
976    pub fn is_success(&self) -> bool {
977        self.is_done() && self.error.is_none()
978    }
979}
980/// A summary of the mutual recursion analysis for a block of definitions.
981#[derive(Clone, Debug)]
982pub struct MutualRecursionSummary {
983    /// Names in this block.
984    pub names: Vec<Name>,
985    /// Whether the block contains genuine mutual recursion.
986    pub is_mutually_recursive: bool,
987    /// Detected mutual groups.
988    pub mutual_groups: Vec<Vec<Name>>,
989    /// Inferred termination measure, if any.
990    pub termination_measure: Option<TerminationMeasure>,
991    /// Diagnostics accumulated during analysis.
992    pub diagnostics: Vec<String>,
993}
994impl MutualRecursionSummary {
995    /// Create a summary from a cycle detector and a termination measure.
996    pub fn from_detector(
997        detector: &MutualDefCycleDetector,
998        measure: Option<TerminationMeasure>,
999    ) -> Self {
1000        let groups = detector.mutual_groups();
1001        let is_mutual = !groups.is_empty();
1002        Self {
1003            names: (0..detector.num_decls())
1004                .filter_map(|i| detector.graph.names.get(i).cloned())
1005                .collect(),
1006            is_mutually_recursive: is_mutual,
1007            mutual_groups: groups,
1008            termination_measure: measure,
1009            diagnostics: Vec::new(),
1010        }
1011    }
1012    /// Add a diagnostic message.
1013    pub fn add_diagnostic(&mut self, msg: impl Into<String>) {
1014        self.diagnostics.push(msg.into());
1015    }
1016    /// Return `true` if the analysis produced any diagnostics.
1017    pub fn has_diagnostics(&self) -> bool {
1018        !self.diagnostics.is_empty()
1019    }
1020}
1021/// A heuristic termination measure inference result.
1022#[derive(Clone, Debug)]
1023pub struct TerminationMeasure {
1024    /// The inferred well-founded ordering.
1025    pub order: WellFoundedOrder,
1026    /// Confidence score in [0, 1].
1027    pub confidence: f64,
1028    /// Human-readable justification.
1029    pub justification: String,
1030}
1031impl TerminationMeasure {
1032    /// Create a measure with full confidence.
1033    pub fn certain(order: WellFoundedOrder, justification: impl Into<String>) -> Self {
1034        Self {
1035            order,
1036            confidence: 1.0,
1037            justification: justification.into(),
1038        }
1039    }
1040    /// Create a measure with partial confidence.
1041    pub fn heuristic(
1042        order: WellFoundedOrder,
1043        confidence: f64,
1044        justification: impl Into<String>,
1045    ) -> Self {
1046        Self {
1047            order,
1048            confidence,
1049            justification: justification.into(),
1050        }
1051    }
1052    /// Return `true` if this measure is judged reliable.
1053    pub fn is_reliable(&self) -> bool {
1054        self.confidence >= 0.8
1055    }
1056}
1057/// How a recursive definition terminates.
1058#[derive(Clone, Debug, PartialEq, Eq)]
1059pub enum TerminationKind {
1060    /// Structurally recursive on the given argument index per function
1061    Structural(HashMap<Name, usize>),
1062    /// Well-founded recursion (requires a measure/relation)
1063    WellFounded,
1064    /// Not recursive at all
1065    NonRecursive,
1066}
1067/// A mutually recursive block of definitions.
1068#[derive(Debug, Clone)]
1069pub struct MutualBlock {
1070    /// Names of all definitions in this mutual block (in declaration order)
1071    pub names: Vec<Name>,
1072    /// Types of all definitions
1073    pub types: HashMap<Name, Expr>,
1074    /// Bodies of all definitions
1075    pub bodies: HashMap<Name, Expr>,
1076    /// Universe parameters shared by all definitions
1077    pub univ_params: Vec<Name>,
1078    /// Attributes per definition
1079    pub attrs: HashMap<Name, Vec<AttributeKind>>,
1080    /// Whether each definition is noncomputable
1081    pub is_noncomputable: HashMap<Name, bool>,
1082}
1083impl MutualBlock {
1084    /// Create a new mutual block.
1085    pub fn new() -> Self {
1086        Self {
1087            names: Vec::new(),
1088            types: HashMap::new(),
1089            bodies: HashMap::new(),
1090            univ_params: Vec::new(),
1091            attrs: HashMap::new(),
1092            is_noncomputable: HashMap::new(),
1093        }
1094    }
1095    /// Add a definition to the mutual block.
1096    pub fn add(&mut self, name: Name, ty: Expr, body: Expr) {
1097        self.names.push(name.clone());
1098        self.types.insert(name.clone(), ty);
1099        self.bodies.insert(name, body);
1100    }
1101    /// Add a definition with attributes and noncomputable flag.
1102    #[allow(dead_code)]
1103    pub fn add_with_attrs(
1104        &mut self,
1105        name: Name,
1106        ty: Expr,
1107        body: Expr,
1108        attrs: Vec<AttributeKind>,
1109        noncomputable: bool,
1110    ) {
1111        self.names.push(name.clone());
1112        self.types.insert(name.clone(), ty);
1113        self.bodies.insert(name.clone(), body);
1114        self.attrs.insert(name.clone(), attrs);
1115        self.is_noncomputable.insert(name, noncomputable);
1116    }
1117    /// Get the type of a definition.
1118    pub fn get_type(&self, name: &Name) -> Option<&Expr> {
1119        self.types.get(name)
1120    }
1121    /// Get the body of a definition.
1122    pub fn get_body(&self, name: &Name) -> Option<&Expr> {
1123        self.bodies.get(name)
1124    }
1125    /// Get the number of definitions in this block.
1126    pub fn size(&self) -> usize {
1127        self.names.len()
1128    }
1129    /// Check if a name is in this mutual block.
1130    pub fn contains(&self, name: &Name) -> bool {
1131        self.names.contains(name)
1132    }
1133    /// Get names in declaration order.
1134    #[allow(dead_code)]
1135    pub fn names_in_order(&self) -> &[Name] {
1136        &self.names
1137    }
1138    /// Get all (name, body) pairs.
1139    #[allow(dead_code)]
1140    pub fn get_all_bodies(&self) -> Vec<(&Name, &Expr)> {
1141        self.names
1142            .iter()
1143            .filter_map(|name| self.bodies.get(name).map(|body| (name, body)))
1144            .collect()
1145    }
1146    /// Validate the mutual block: every name must have both a type and a body.
1147    #[allow(dead_code)]
1148    pub fn validate(&self) -> Result<(), MutualElabError> {
1149        if self.names.is_empty() {
1150            return Err(MutualElabError::Other("empty mutual block".to_string()));
1151        }
1152        let mut seen = HashSet::new();
1153        for name in &self.names {
1154            if !seen.insert(name.clone()) {
1155                return Err(MutualElabError::Other(format!(
1156                    "duplicate name in mutual block: {:?}",
1157                    name
1158                )));
1159            }
1160        }
1161        for name in &self.names {
1162            if !self.types.contains_key(name) {
1163                return Err(MutualElabError::MissingDefinition(format!(
1164                    "no type for '{:?}'",
1165                    name
1166                )));
1167            }
1168        }
1169        for name in &self.names {
1170            if !self.bodies.contains_key(name) {
1171                return Err(MutualElabError::MissingDefinition(format!(
1172                    "no body for '{:?}'",
1173                    name
1174                )));
1175            }
1176        }
1177        Ok(())
1178    }
1179    /// Set universe parameters for the entire block.
1180    #[allow(dead_code)]
1181    pub fn set_univ_params(&mut self, params: Vec<Name>) {
1182        self.univ_params = params;
1183    }
1184    /// Set attributes for a specific definition.
1185    #[allow(dead_code)]
1186    pub fn set_attrs(&mut self, name: &Name, attrs: Vec<AttributeKind>) {
1187        self.attrs.insert(name.clone(), attrs);
1188    }
1189    /// Mark a definition as noncomputable.
1190    #[allow(dead_code)]
1191    pub fn set_noncomputable(&mut self, name: &Name, noncomputable: bool) {
1192        self.is_noncomputable.insert(name.clone(), noncomputable);
1193    }
1194    /// Check if a definition is noncomputable.
1195    #[allow(dead_code)]
1196    pub fn is_def_noncomputable(&self, name: &Name) -> bool {
1197        self.is_noncomputable.get(name).copied().unwrap_or(false)
1198    }
1199    /// Get the attributes for a definition.
1200    #[allow(dead_code)]
1201    pub fn get_attrs(&self, name: &Name) -> &[AttributeKind] {
1202        self.attrs.get(name).map(|v| v.as_slice()).unwrap_or(&[])
1203    }
1204}
1205/// Encoder for structurally recursive definitions.
1206///
1207/// Transforms structurally recursive definitions into applications
1208/// of the recursor (eliminator) for the decreasing argument's type.
1209#[derive(Clone, Debug)]
1210pub struct StructuralRecursion {
1211    /// The mutual block being processed
1212    pub block: MutualBlock,
1213    /// For each function, which argument indices are recursive
1214    pub recursive_args: HashMap<Name, Vec<usize>>,
1215}
1216impl StructuralRecursion {
1217    /// Create a new structural recursion encoder.
1218    #[allow(dead_code)]
1219    pub fn new(block: MutualBlock) -> Self {
1220        Self {
1221            block,
1222            recursive_args: HashMap::new(),
1223        }
1224    }
1225    /// Detect which arguments are structurally decreasing.
1226    #[allow(dead_code)]
1227    pub fn detect_structural_recursion(&mut self) -> Result<(), MutualElabError> {
1228        let call_graph = CallGraph::build_from_block(&self.block);
1229        for name in &self.block.names {
1230            if call_graph.is_self_recursive(name) {
1231                match call_graph.find_decreasing_arg(name) {
1232                    Some(idx) => {
1233                        self.recursive_args
1234                            .entry(name.clone())
1235                            .or_default()
1236                            .push(idx);
1237                    }
1238                    None => {
1239                        return Err(MutualElabError::TerminationFailure(format!(
1240                            "could not find structurally decreasing argument for '{:?}'",
1241                            name
1242                        )));
1243                    }
1244                }
1245            }
1246        }
1247        Ok(())
1248    }
1249    /// Encode the structural recursion as recursor applications.
1250    ///
1251    /// In a full implementation, this would replace each recursive function
1252    /// with an application of the appropriate recursor/eliminator.
1253    #[allow(dead_code)]
1254    pub fn encode_as_recursor_application(&self) -> Result<MutualBlock, MutualElabError> {
1255        let mut result = self.block.clone();
1256        let call_graph = CallGraph::build_from_block(&self.block);
1257        for name in &self.block.names {
1258            if call_graph.is_self_recursive(name) && !self.recursive_args.contains_key(name) {
1259                return Err(MutualElabError::TerminationFailure(format!(
1260                    "no structural recursion info for '{:?}'",
1261                    name
1262                )));
1263            }
1264        }
1265        for (name, args) in &self.recursive_args {
1266            let attr_name = format!(
1267                "_rec_arg_{}",
1268                args.iter()
1269                    .map(|i| i.to_string())
1270                    .collect::<Vec<_>>()
1271                    .join("_")
1272            );
1273            result
1274                .attrs
1275                .entry(name.clone())
1276                .or_default()
1277                .push(AttributeKind::Custom(attr_name));
1278        }
1279        Ok(result)
1280    }
1281    /// Get the detected recursive arguments.
1282    #[allow(dead_code)]
1283    pub fn get_recursive_args(&self) -> &HashMap<Name, Vec<usize>> {
1284        &self.recursive_args
1285    }
1286}