Skip to main content

oxiz_spacer/
recursive.rs

1//! Recursive CHC support and analysis.
2//!
3//! This module provides utilities for detecting and handling recursive
4//! predicates in CHC systems, which are common in verification of
5//! recursive functions and data structures.
6//!
7//! Reference: Z3's recursive predicate handling in Spacer
8
9use crate::chc::{ChcSystem, PredId, Rule};
10use std::collections::{HashMap, HashSet};
11use thiserror::Error;
12use tracing::{debug, trace};
13
14/// Errors in recursive CHC analysis
15#[derive(Error, Debug)]
16pub enum RecursiveError {
17    /// Invalid recursion pattern
18    #[error("invalid recursion pattern: {0}")]
19    InvalidPattern(String),
20    /// Cyclic dependency detected
21    #[error("cyclic dependency in non-recursive context")]
22    CyclicDependency,
23}
24
25/// Type of recursion in a predicate
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum RecursionKind {
28    /// Not recursive
29    NonRecursive,
30    /// Directly recursive (predicate appears in its own rules)
31    DirectRecursive,
32    /// Mutually recursive with other predicates
33    MutuallyRecursive,
34    /// Nested recursion (recursive calls within recursive calls)
35    NestedRecursive,
36}
37
38/// Information about a recursive predicate
39#[derive(Debug, Clone)]
40pub struct RecursiveInfo {
41    /// The predicate ID
42    pub pred: PredId,
43    /// Kind of recursion
44    pub kind: RecursionKind,
45    /// Predicates this one depends on
46    pub dependencies: HashSet<PredId>,
47    /// Predicates that depend on this one
48    pub dependents: HashSet<PredId>,
49    /// Recursive rules (rules that contain the predicate in both head and body)
50    pub recursive_rules: Vec<usize>, // Rule indices
51    /// Base case rules (non-recursive rules)
52    pub base_rules: Vec<usize>,
53}
54
55impl RecursiveInfo {
56    /// Create new recursive info
57    pub fn new(pred: PredId) -> Self {
58        Self {
59            pred,
60            kind: RecursionKind::NonRecursive,
61            dependencies: HashSet::new(),
62            dependents: HashSet::new(),
63            recursive_rules: Vec::new(),
64            base_rules: Vec::new(),
65        }
66    }
67
68    /// Check if predicate is recursive
69    pub fn is_recursive(&self) -> bool {
70        self.kind != RecursionKind::NonRecursive
71    }
72
73    /// Check if predicate has base cases
74    pub fn has_base_cases(&self) -> bool {
75        !self.base_rules.is_empty()
76    }
77
78    /// Get recursion depth (number of predicates in mutual recursion)
79    pub fn recursion_depth(&self) -> usize {
80        match self.kind {
81            RecursionKind::NonRecursive => 0,
82            RecursionKind::DirectRecursive => 1,
83            RecursionKind::MutuallyRecursive => self.dependencies.len(),
84            RecursionKind::NestedRecursive => self.dependencies.len() + 1,
85        }
86    }
87}
88
89/// Analyzer for recursive CHC systems
90pub struct RecursiveAnalyzer<'a> {
91    /// The CHC system to analyze
92    system: &'a ChcSystem,
93    /// Recursive information for each predicate
94    info: HashMap<PredId, RecursiveInfo>,
95}
96
97impl<'a> RecursiveAnalyzer<'a> {
98    /// Create a new recursive analyzer
99    pub fn new(system: &'a ChcSystem) -> Self {
100        Self {
101            system,
102            info: HashMap::new(),
103        }
104    }
105
106    /// Analyze the CHC system for recursion
107    pub fn analyze(&mut self) -> Result<(), RecursiveError> {
108        debug!("Analyzing CHC system for recursion");
109
110        // Initialize info for all predicates
111        for pred in self.system.predicates() {
112            self.info.insert(pred.id, RecursiveInfo::new(pred.id));
113        }
114
115        // Build dependency graph
116        self.build_dependency_graph()?;
117
118        // Detect recursion kinds
119        self.detect_recursion_kinds()?;
120
121        // Classify rules
122        self.classify_rules()?;
123
124        debug!(
125            "Found {} recursive predicates",
126            self.info
127                .values()
128                .filter(|info| info.is_recursive())
129                .count()
130        );
131
132        Ok(())
133    }
134
135    /// Build the dependency graph between predicates
136    fn build_dependency_graph(&mut self) -> Result<(), RecursiveError> {
137        for rule in self.system.rules() {
138            if let Some(head_pred) = rule.head_predicate() {
139                // Collect body predicates first
140                let body_preds: Vec<PredId> =
141                    rule.body.predicates.iter().map(|app| app.pred).collect();
142
143                // Get or create info for head predicate
144                let head_info = self
145                    .info
146                    .entry(head_pred)
147                    .or_insert_with(|| RecursiveInfo::new(head_pred));
148
149                // Add dependencies from body predicates
150                for body_pred in &body_preds {
151                    head_info.dependencies.insert(*body_pred);
152                }
153
154                // Now update body predicates (separate borrow)
155                for body_pred in body_preds {
156                    let body_info = self
157                        .info
158                        .entry(body_pred)
159                        .or_insert_with(|| RecursiveInfo::new(body_pred));
160                    body_info.dependents.insert(head_pred);
161                }
162            }
163        }
164
165        Ok(())
166    }
167
168    /// Detect recursion kinds for each predicate
169    fn detect_recursion_kinds(&mut self) -> Result<(), RecursiveError> {
170        // Clone the info to avoid borrow issues
171        let pred_ids: Vec<PredId> = self.info.keys().copied().collect();
172
173        for pred_id in pred_ids {
174            let kind = self.detect_predicate_recursion(pred_id)?;
175
176            if let Some(info) = self.info.get_mut(&pred_id) {
177                info.kind = kind;
178                trace!("Predicate {:?} has recursion kind {:?}", pred_id, kind);
179            }
180        }
181
182        Ok(())
183    }
184
185    /// Detect recursion kind for a specific predicate
186    fn detect_predicate_recursion(&self, pred: PredId) -> Result<RecursionKind, RecursiveError> {
187        let info = self
188            .info
189            .get(&pred)
190            .ok_or_else(|| RecursiveError::InvalidPattern("predicate not found".to_string()))?;
191
192        // Check for direct recursion
193        if info.dependencies.contains(&pred) {
194            // Check for nested recursion (depends on other recursive predicates)
195            let has_recursive_deps = info.dependencies.iter().any(|dep| {
196                if let Some(dep_info) = self.info.get(dep) {
197                    dep_info.dependencies.contains(&pred) || dep_info.dependencies.contains(dep)
198                } else {
199                    false
200                }
201            });
202
203            if has_recursive_deps {
204                return Ok(RecursionKind::NestedRecursive);
205            } else {
206                return Ok(RecursionKind::DirectRecursive);
207            }
208        }
209
210        // Check for mutual recursion
211        for dep in &info.dependencies {
212            if let Some(dep_info) = self.info.get(dep)
213                && dep_info.dependencies.contains(&pred)
214            {
215                return Ok(RecursionKind::MutuallyRecursive);
216            }
217        }
218
219        Ok(RecursionKind::NonRecursive)
220    }
221
222    /// Classify rules as recursive or base cases
223    fn classify_rules(&mut self) -> Result<(), RecursiveError> {
224        for (rule_idx, rule) in self.system.rules().enumerate() {
225            if let Some(head_pred) = rule.head_predicate() {
226                let is_recursive = self.is_rule_recursive(rule);
227
228                if let Some(info) = self.info.get_mut(&head_pred) {
229                    if is_recursive {
230                        info.recursive_rules.push(rule_idx);
231                    } else {
232                        info.base_rules.push(rule_idx);
233                    }
234                }
235            }
236        }
237
238        Ok(())
239    }
240
241    /// Check if a rule is recursive
242    fn is_rule_recursive(&self, rule: &Rule) -> bool {
243        if let Some(head_pred) = rule.head_predicate() {
244            // Check if head predicate appears in body
245            rule.body
246                .predicates
247                .iter()
248                .any(|body_app| body_app.pred == head_pred)
249        } else {
250            false
251        }
252    }
253
254    /// Get recursive info for a predicate
255    pub fn get_info(&self, pred: PredId) -> Option<&RecursiveInfo> {
256        self.info.get(&pred)
257    }
258
259    /// Get all recursive predicates
260    pub fn recursive_predicates(&self) -> impl Iterator<Item = &RecursiveInfo> {
261        self.info.values().filter(|info| info.is_recursive())
262    }
263
264    /// Get strongly connected components (mutual recursion groups)
265    pub fn strongly_connected_components(&self) -> Vec<Vec<PredId>> {
266        let mut sccs = Vec::new();
267        let mut visited = HashSet::new();
268        let mut stack = Vec::new();
269
270        for pred_id in self.info.keys() {
271            if !visited.contains(pred_id) {
272                self.tarjan_scc(
273                    *pred_id,
274                    &mut visited,
275                    &mut stack,
276                    &mut sccs,
277                    &mut HashMap::new(),
278                    &mut 0,
279                );
280            }
281        }
282
283        sccs
284    }
285
286    /// Tarjan's algorithm for finding SCCs
287    #[allow(clippy::too_many_arguments)]
288    fn tarjan_scc(
289        &self,
290        pred: PredId,
291        visited: &mut HashSet<PredId>,
292        stack: &mut Vec<PredId>,
293        sccs: &mut Vec<Vec<PredId>>,
294        indices: &mut HashMap<PredId, usize>,
295        index_counter: &mut usize,
296    ) {
297        visited.insert(pred);
298        indices.insert(pred, *index_counter);
299        let mut low_link = *index_counter;
300        *index_counter += 1;
301        stack.push(pred);
302
303        if let Some(info) = self.info.get(&pred) {
304            for &dep in &info.dependencies {
305                if !visited.contains(&dep) {
306                    self.tarjan_scc(dep, visited, stack, sccs, indices, index_counter);
307                    if let Some(&dep_low) = indices.get(&dep) {
308                        low_link = low_link.min(dep_low);
309                    }
310                } else if stack.contains(&dep)
311                    && let Some(&dep_idx) = indices.get(&dep)
312                {
313                    low_link = low_link.min(dep_idx);
314                }
315            }
316        }
317
318        if low_link == indices[&pred] {
319            let mut scc = Vec::new();
320            while let Some(node) = stack.pop() {
321                scc.push(node);
322                if node == pred {
323                    break;
324                }
325            }
326            if scc.len() > 1
327                || (scc.len() == 1
328                    && self
329                        .info
330                        .get(&scc[0])
331                        .map(|i| i.dependencies.contains(&scc[0]))
332                        .unwrap_or(false))
333            {
334                sccs.push(scc);
335            }
336        }
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use oxiz_core::TermManager;
344
345    #[test]
346    fn test_recursion_kind() {
347        let info = RecursiveInfo::new(PredId(0));
348        assert_eq!(info.kind, RecursionKind::NonRecursive);
349        assert!(!info.is_recursive());
350    }
351
352    #[test]
353    fn test_recursive_info() {
354        let mut info = RecursiveInfo::new(PredId(0));
355        info.kind = RecursionKind::DirectRecursive;
356        info.dependencies.insert(PredId(0));
357        info.recursive_rules.push(0);
358        info.base_rules.push(1);
359
360        assert!(info.is_recursive());
361        assert!(info.has_base_cases());
362        assert_eq!(info.recursion_depth(), 1);
363    }
364
365    #[test]
366    fn test_analyzer_empty_system() {
367        let system = ChcSystem::new();
368        let mut analyzer = RecursiveAnalyzer::new(&system);
369        assert!(analyzer.analyze().is_ok());
370    }
371
372    #[test]
373    fn test_analyzer_simple_system() {
374        let mut terms = TermManager::new();
375        let mut system = ChcSystem::new();
376
377        // Create a simple non-recursive system
378        let inv = system.declare_predicate("Inv", [terms.sorts.int_sort]);
379        let x = terms.mk_var("x", terms.sorts.int_sort);
380        let zero = terms.mk_int(0);
381        let init_constraint = terms.mk_eq(x, zero);
382
383        system.add_init_rule(
384            [("x".to_string(), terms.sorts.int_sort)],
385            init_constraint,
386            inv,
387            [x],
388        );
389
390        let mut analyzer = RecursiveAnalyzer::new(&system);
391        assert!(analyzer.analyze().is_ok());
392
393        // Check that predicate is non-recursive
394        let info = analyzer.get_info(inv);
395        assert!(info.is_some());
396        let info = info.expect("test operation should succeed");
397        assert_eq!(info.kind, RecursionKind::NonRecursive);
398    }
399
400    #[test]
401    fn test_scc_computation() {
402        let system = ChcSystem::new();
403        let analyzer = RecursiveAnalyzer::new(&system);
404        let sccs = analyzer.strongly_connected_components();
405        assert!(sccs.is_empty());
406    }
407}