1use crate::chc::{ChcSystem, PredId, Rule};
10use std::collections::{HashMap, HashSet};
11use thiserror::Error;
12use tracing::{debug, trace};
13
14#[derive(Error, Debug)]
16pub enum RecursiveError {
17 #[error("invalid recursion pattern: {0}")]
19 InvalidPattern(String),
20 #[error("cyclic dependency in non-recursive context")]
22 CyclicDependency,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum RecursionKind {
28 NonRecursive,
30 DirectRecursive,
32 MutuallyRecursive,
34 NestedRecursive,
36}
37
38#[derive(Debug, Clone)]
40pub struct RecursiveInfo {
41 pub pred: PredId,
43 pub kind: RecursionKind,
45 pub dependencies: HashSet<PredId>,
47 pub dependents: HashSet<PredId>,
49 pub recursive_rules: Vec<usize>, pub base_rules: Vec<usize>,
53}
54
55impl RecursiveInfo {
56 pub fn new(pred: PredId) -> Self {
58 Self {
59 pred,
60 kind: RecursionKind::NonRecursive,
61 dependencies: HashSet::new(),
62 dependents: HashSet::new(),
63 recursive_rules: Vec::new(),
64 base_rules: Vec::new(),
65 }
66 }
67
68 pub fn is_recursive(&self) -> bool {
70 self.kind != RecursionKind::NonRecursive
71 }
72
73 pub fn has_base_cases(&self) -> bool {
75 !self.base_rules.is_empty()
76 }
77
78 pub fn recursion_depth(&self) -> usize {
80 match self.kind {
81 RecursionKind::NonRecursive => 0,
82 RecursionKind::DirectRecursive => 1,
83 RecursionKind::MutuallyRecursive => self.dependencies.len(),
84 RecursionKind::NestedRecursive => self.dependencies.len() + 1,
85 }
86 }
87}
88
89pub struct RecursiveAnalyzer<'a> {
91 system: &'a ChcSystem,
93 info: HashMap<PredId, RecursiveInfo>,
95}
96
97impl<'a> RecursiveAnalyzer<'a> {
98 pub fn new(system: &'a ChcSystem) -> Self {
100 Self {
101 system,
102 info: HashMap::new(),
103 }
104 }
105
106 pub fn analyze(&mut self) -> Result<(), RecursiveError> {
108 debug!("Analyzing CHC system for recursion");
109
110 for pred in self.system.predicates() {
112 self.info.insert(pred.id, RecursiveInfo::new(pred.id));
113 }
114
115 self.build_dependency_graph()?;
117
118 self.detect_recursion_kinds()?;
120
121 self.classify_rules()?;
123
124 debug!(
125 "Found {} recursive predicates",
126 self.info
127 .values()
128 .filter(|info| info.is_recursive())
129 .count()
130 );
131
132 Ok(())
133 }
134
135 fn build_dependency_graph(&mut self) -> Result<(), RecursiveError> {
137 for rule in self.system.rules() {
138 if let Some(head_pred) = rule.head_predicate() {
139 let body_preds: Vec<PredId> =
141 rule.body.predicates.iter().map(|app| app.pred).collect();
142
143 let head_info = self
145 .info
146 .entry(head_pred)
147 .or_insert_with(|| RecursiveInfo::new(head_pred));
148
149 for body_pred in &body_preds {
151 head_info.dependencies.insert(*body_pred);
152 }
153
154 for body_pred in body_preds {
156 let body_info = self
157 .info
158 .entry(body_pred)
159 .or_insert_with(|| RecursiveInfo::new(body_pred));
160 body_info.dependents.insert(head_pred);
161 }
162 }
163 }
164
165 Ok(())
166 }
167
168 fn detect_recursion_kinds(&mut self) -> Result<(), RecursiveError> {
170 let pred_ids: Vec<PredId> = self.info.keys().copied().collect();
172
173 for pred_id in pred_ids {
174 let kind = self.detect_predicate_recursion(pred_id)?;
175
176 if let Some(info) = self.info.get_mut(&pred_id) {
177 info.kind = kind;
178 trace!("Predicate {:?} has recursion kind {:?}", pred_id, kind);
179 }
180 }
181
182 Ok(())
183 }
184
185 fn detect_predicate_recursion(&self, pred: PredId) -> Result<RecursionKind, RecursiveError> {
187 let info = self
188 .info
189 .get(&pred)
190 .ok_or_else(|| RecursiveError::InvalidPattern("predicate not found".to_string()))?;
191
192 if info.dependencies.contains(&pred) {
194 let has_recursive_deps = info.dependencies.iter().any(|dep| {
196 if let Some(dep_info) = self.info.get(dep) {
197 dep_info.dependencies.contains(&pred) || dep_info.dependencies.contains(dep)
198 } else {
199 false
200 }
201 });
202
203 if has_recursive_deps {
204 return Ok(RecursionKind::NestedRecursive);
205 } else {
206 return Ok(RecursionKind::DirectRecursive);
207 }
208 }
209
210 for dep in &info.dependencies {
212 if let Some(dep_info) = self.info.get(dep)
213 && dep_info.dependencies.contains(&pred)
214 {
215 return Ok(RecursionKind::MutuallyRecursive);
216 }
217 }
218
219 Ok(RecursionKind::NonRecursive)
220 }
221
222 fn classify_rules(&mut self) -> Result<(), RecursiveError> {
224 for (rule_idx, rule) in self.system.rules().enumerate() {
225 if let Some(head_pred) = rule.head_predicate() {
226 let is_recursive = self.is_rule_recursive(rule);
227
228 if let Some(info) = self.info.get_mut(&head_pred) {
229 if is_recursive {
230 info.recursive_rules.push(rule_idx);
231 } else {
232 info.base_rules.push(rule_idx);
233 }
234 }
235 }
236 }
237
238 Ok(())
239 }
240
241 fn is_rule_recursive(&self, rule: &Rule) -> bool {
243 if let Some(head_pred) = rule.head_predicate() {
244 rule.body
246 .predicates
247 .iter()
248 .any(|body_app| body_app.pred == head_pred)
249 } else {
250 false
251 }
252 }
253
254 pub fn get_info(&self, pred: PredId) -> Option<&RecursiveInfo> {
256 self.info.get(&pred)
257 }
258
259 pub fn recursive_predicates(&self) -> impl Iterator<Item = &RecursiveInfo> {
261 self.info.values().filter(|info| info.is_recursive())
262 }
263
264 pub fn strongly_connected_components(&self) -> Vec<Vec<PredId>> {
266 let mut sccs = Vec::new();
267 let mut visited = HashSet::new();
268 let mut stack = Vec::new();
269
270 for pred_id in self.info.keys() {
271 if !visited.contains(pred_id) {
272 self.tarjan_scc(
273 *pred_id,
274 &mut visited,
275 &mut stack,
276 &mut sccs,
277 &mut HashMap::new(),
278 &mut 0,
279 );
280 }
281 }
282
283 sccs
284 }
285
286 #[allow(clippy::too_many_arguments)]
288 fn tarjan_scc(
289 &self,
290 pred: PredId,
291 visited: &mut HashSet<PredId>,
292 stack: &mut Vec<PredId>,
293 sccs: &mut Vec<Vec<PredId>>,
294 indices: &mut HashMap<PredId, usize>,
295 index_counter: &mut usize,
296 ) {
297 visited.insert(pred);
298 indices.insert(pred, *index_counter);
299 let mut low_link = *index_counter;
300 *index_counter += 1;
301 stack.push(pred);
302
303 if let Some(info) = self.info.get(&pred) {
304 for &dep in &info.dependencies {
305 if !visited.contains(&dep) {
306 self.tarjan_scc(dep, visited, stack, sccs, indices, index_counter);
307 if let Some(&dep_low) = indices.get(&dep) {
308 low_link = low_link.min(dep_low);
309 }
310 } else if stack.contains(&dep)
311 && let Some(&dep_idx) = indices.get(&dep)
312 {
313 low_link = low_link.min(dep_idx);
314 }
315 }
316 }
317
318 if low_link == indices[&pred] {
319 let mut scc = Vec::new();
320 while let Some(node) = stack.pop() {
321 scc.push(node);
322 if node == pred {
323 break;
324 }
325 }
326 if scc.len() > 1
327 || (scc.len() == 1
328 && self
329 .info
330 .get(&scc[0])
331 .map(|i| i.dependencies.contains(&scc[0]))
332 .unwrap_or(false))
333 {
334 sccs.push(scc);
335 }
336 }
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use oxiz_core::TermManager;
344
345 #[test]
346 fn test_recursion_kind() {
347 let info = RecursiveInfo::new(PredId(0));
348 assert_eq!(info.kind, RecursionKind::NonRecursive);
349 assert!(!info.is_recursive());
350 }
351
352 #[test]
353 fn test_recursive_info() {
354 let mut info = RecursiveInfo::new(PredId(0));
355 info.kind = RecursionKind::DirectRecursive;
356 info.dependencies.insert(PredId(0));
357 info.recursive_rules.push(0);
358 info.base_rules.push(1);
359
360 assert!(info.is_recursive());
361 assert!(info.has_base_cases());
362 assert_eq!(info.recursion_depth(), 1);
363 }
364
365 #[test]
366 fn test_analyzer_empty_system() {
367 let system = ChcSystem::new();
368 let mut analyzer = RecursiveAnalyzer::new(&system);
369 assert!(analyzer.analyze().is_ok());
370 }
371
372 #[test]
373 fn test_analyzer_simple_system() {
374 let mut terms = TermManager::new();
375 let mut system = ChcSystem::new();
376
377 let inv = system.declare_predicate("Inv", [terms.sorts.int_sort]);
379 let x = terms.mk_var("x", terms.sorts.int_sort);
380 let zero = terms.mk_int(0);
381 let init_constraint = terms.mk_eq(x, zero);
382
383 system.add_init_rule(
384 [("x".to_string(), terms.sorts.int_sort)],
385 init_constraint,
386 inv,
387 [x],
388 );
389
390 let mut analyzer = RecursiveAnalyzer::new(&system);
391 assert!(analyzer.analyze().is_ok());
392
393 let info = analyzer.get_info(inv);
395 assert!(info.is_some());
396 let info = info.expect("test operation should succeed");
397 assert_eq!(info.kind, RecursionKind::NonRecursive);
398 }
399
400 #[test]
401 fn test_scc_computation() {
402 let system = ChcSystem::new();
403 let analyzer = RecursiveAnalyzer::new(&system);
404 let sccs = analyzer.strongly_connected_components();
405 assert!(sccs.is_empty());
406 }
407}