Skip to main content

oxiz_spacer/
chc.rs

1//! Constrained Horn Clause (CHC) representation.
2//!
3//! CHC format: `forall X. (body => head)`
4//! where body is a conjunction of constraints and predicate applications,
5//! and head is a predicate application or false (query).
6//!
7//! Reference: Z3's `muz/spacer/` implementation.
8
9use indexmap::IndexMap;
10use oxiz_core::{TermId, TermManager};
11use rustc_hash::FxHashMap;
12use smallvec::SmallVec;
13use std::sync::atomic::{AtomicU32, Ordering};
14
15/// Unique identifier for a predicate declaration
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
17pub struct PredId(pub u32);
18
19impl PredId {
20    /// Create a new predicate ID
21    #[inline]
22    #[must_use]
23    pub const fn new(id: u32) -> Self {
24        Self(id)
25    }
26
27    /// Get the raw ID value
28    #[inline]
29    #[must_use]
30    pub const fn raw(self) -> u32 {
31        self.0
32    }
33}
34
35/// Unique identifier for a CHC rule
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
37pub struct RuleId(pub u32);
38
39impl RuleId {
40    /// Create a new rule ID
41    #[inline]
42    #[must_use]
43    pub const fn new(id: u32) -> Self {
44        Self(id)
45    }
46
47    /// Get the raw ID value
48    #[inline]
49    #[must_use]
50    pub const fn raw(self) -> u32 {
51        self.0
52    }
53}
54
55/// A predicate declaration with arity and parameter sorts
56#[derive(Debug, Clone)]
57pub struct Predicate {
58    /// Unique identifier
59    pub id: PredId,
60    /// Name of the predicate
61    pub name: String,
62    /// Parameter sorts (arity = params.len())
63    pub params: SmallVec<[oxiz_core::SortId; 4]>,
64}
65
66impl Predicate {
67    /// Get the arity of this predicate
68    #[inline]
69    #[must_use]
70    pub fn arity(&self) -> usize {
71        self.params.len()
72    }
73}
74
75/// An application of a predicate to arguments
76#[derive(Debug, Clone)]
77pub struct PredicateApp {
78    /// The predicate being applied
79    pub pred: PredId,
80    /// Arguments (must match arity)
81    pub args: SmallVec<[TermId; 4]>,
82}
83
84impl PredicateApp {
85    /// Create a new predicate application
86    pub fn new(pred: PredId, args: impl IntoIterator<Item = TermId>) -> Self {
87        Self {
88            pred,
89            args: args.into_iter().collect(),
90        }
91    }
92}
93
94/// The head of a CHC rule: either a predicate application or `false` (query)
95#[derive(Debug, Clone)]
96pub enum RuleHead {
97    /// Predicate application (non-query rule)
98    Predicate(PredicateApp),
99    /// Query rule (head is false)
100    Query,
101}
102
103impl RuleHead {
104    /// Check if this is a query (false head)
105    #[inline]
106    #[must_use]
107    pub fn is_query(&self) -> bool {
108        matches!(self, RuleHead::Query)
109    }
110
111    /// Get the predicate application if not a query
112    #[inline]
113    #[must_use]
114    pub fn as_predicate(&self) -> Option<&PredicateApp> {
115        match self {
116            RuleHead::Predicate(app) => Some(app),
117            RuleHead::Query => None,
118        }
119    }
120}
121
122/// The body of a CHC rule: conjunction of predicate applications and constraints
123#[derive(Debug, Clone)]
124pub struct RuleBody {
125    /// Predicate applications in the body (uninterpreted tail)
126    pub predicates: SmallVec<[PredicateApp; 2]>,
127    /// Interpreted constraint (conjunction)
128    pub constraint: TermId,
129}
130
131impl RuleBody {
132    /// Create a body with no predicate applications (init rule)
133    pub fn init(constraint: TermId) -> Self {
134        Self {
135            predicates: SmallVec::new(),
136            constraint,
137        }
138    }
139
140    /// Create a body with predicate applications
141    pub fn new(predicates: impl IntoIterator<Item = PredicateApp>, constraint: TermId) -> Self {
142        Self {
143            predicates: predicates.into_iter().collect(),
144            constraint,
145        }
146    }
147
148    /// Check if this is an init body (no predicate applications)
149    #[inline]
150    #[must_use]
151    pub fn is_init(&self) -> bool {
152        self.predicates.is_empty()
153    }
154
155    /// Get the number of predicate applications (uninterpreted tail size)
156    #[inline]
157    #[must_use]
158    pub fn uninterpreted_tail_size(&self) -> usize {
159        self.predicates.len()
160    }
161}
162
163/// A CHC rule: `forall vars. (body => head)`
164#[derive(Debug, Clone)]
165pub struct Rule {
166    /// Unique identifier
167    pub id: RuleId,
168    /// Universally quantified variables
169    pub vars: SmallVec<[(String, oxiz_core::SortId); 4]>,
170    /// Body of the rule
171    pub body: RuleBody,
172    /// Head of the rule
173    pub head: RuleHead,
174    /// Optional name/label for the rule
175    pub name: Option<String>,
176}
177
178impl Rule {
179    /// Check if this is an init rule (no predicates in body)
180    #[inline]
181    #[must_use]
182    pub fn is_init(&self) -> bool {
183        self.body.is_init()
184    }
185
186    /// Check if this is a query rule (head is false)
187    #[inline]
188    #[must_use]
189    pub fn is_query(&self) -> bool {
190        self.head.is_query()
191    }
192
193    /// Get the head predicate if not a query
194    #[inline]
195    #[must_use]
196    pub fn head_predicate(&self) -> Option<PredId> {
197        match &self.head {
198            RuleHead::Predicate(app) => Some(app.pred),
199            RuleHead::Query => None,
200        }
201    }
202
203    /// Get all body predicate IDs
204    pub fn body_predicates(&self) -> impl Iterator<Item = PredId> + '_ {
205        self.body.predicates.iter().map(|app| app.pred)
206    }
207}
208
209/// A complete CHC system containing predicates and rules
210#[derive(Debug)]
211pub struct ChcSystem {
212    /// Predicate declarations indexed by ID
213    predicates: Vec<Predicate>,
214    /// Predicate lookup by name
215    pred_by_name: FxHashMap<String, PredId>,
216    /// Next predicate ID
217    next_pred_id: AtomicU32,
218
219    /// Rules indexed by ID
220    rules: Vec<Rule>,
221    /// Next rule ID
222    next_rule_id: AtomicU32,
223
224    /// Rules grouped by head predicate (for forward analysis)
225    rules_by_head: IndexMap<PredId, SmallVec<[RuleId; 4]>>,
226    /// Rules that use a predicate in the body (for backward analysis)
227    rules_by_body: IndexMap<PredId, SmallVec<[RuleId; 4]>>,
228
229    /// Query rules (head is false)
230    queries: SmallVec<[RuleId; 2]>,
231    /// Entry rules (no predicates in body)
232    entries: SmallVec<[RuleId; 2]>,
233}
234
235impl Default for ChcSystem {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241impl ChcSystem {
242    /// Create a new empty CHC system
243    pub fn new() -> Self {
244        Self {
245            predicates: Vec::new(),
246            pred_by_name: FxHashMap::default(),
247            next_pred_id: AtomicU32::new(0),
248            rules: Vec::new(),
249            next_rule_id: AtomicU32::new(0),
250            rules_by_head: IndexMap::new(),
251            rules_by_body: IndexMap::new(),
252            queries: SmallVec::new(),
253            entries: SmallVec::new(),
254        }
255    }
256
257    /// Declare a new predicate
258    pub fn declare_predicate(
259        &mut self,
260        name: impl Into<String>,
261        params: impl IntoIterator<Item = oxiz_core::SortId>,
262    ) -> PredId {
263        let name = name.into();
264        if let Some(&id) = self.pred_by_name.get(&name) {
265            return id;
266        }
267
268        let id = PredId(self.next_pred_id.fetch_add(1, Ordering::Relaxed));
269        let pred = Predicate {
270            id,
271            name: name.clone(),
272            params: params.into_iter().collect(),
273        };
274
275        self.pred_by_name.insert(name, id);
276        self.predicates.push(pred);
277        id
278    }
279
280    /// Get a predicate by ID
281    #[must_use]
282    pub fn get_predicate(&self, id: PredId) -> Option<&Predicate> {
283        self.predicates.get(id.0 as usize)
284    }
285
286    /// Get a predicate by name
287    #[must_use]
288    pub fn get_predicate_by_name(&self, name: &str) -> Option<&Predicate> {
289        self.pred_by_name
290            .get(name)
291            .and_then(|&id| self.get_predicate(id))
292    }
293
294    /// Get a predicate ID by name
295    #[must_use]
296    pub fn get_predicate_id(&self, name: &str) -> Option<PredId> {
297        self.pred_by_name.get(name).copied()
298    }
299
300    /// Add a rule to the system
301    pub fn add_rule(
302        &mut self,
303        vars: impl IntoIterator<Item = (String, oxiz_core::SortId)>,
304        body: RuleBody,
305        head: RuleHead,
306        name: Option<String>,
307    ) -> RuleId {
308        let id = RuleId(self.next_rule_id.fetch_add(1, Ordering::Relaxed));
309
310        // Track queries and entries
311        if head.is_query() {
312            self.queries.push(id);
313        }
314        if body.is_init() {
315            self.entries.push(id);
316        }
317
318        // Index by head predicate
319        if let Some(pred_id) = head.as_predicate().map(|a| a.pred) {
320            self.rules_by_head.entry(pred_id).or_default().push(id);
321        }
322
323        // Index by body predicates
324        for app in &body.predicates {
325            self.rules_by_body.entry(app.pred).or_default().push(id);
326        }
327
328        let rule = Rule {
329            id,
330            vars: vars.into_iter().collect(),
331            body,
332            head,
333            name,
334        };
335
336        self.rules.push(rule);
337        id
338    }
339
340    /// Add a simple init rule: `constraint => P(args)`
341    pub fn add_init_rule(
342        &mut self,
343        vars: impl IntoIterator<Item = (String, oxiz_core::SortId)>,
344        constraint: TermId,
345        head_pred: PredId,
346        head_args: impl IntoIterator<Item = TermId>,
347    ) -> RuleId {
348        let body = RuleBody::init(constraint);
349        let head = RuleHead::Predicate(PredicateApp::new(head_pred, head_args));
350        self.add_rule(vars, body, head, None)
351    }
352
353    /// Add a transition rule: `P1(args1) /\ ... /\ constraint => P(args)`
354    pub fn add_transition_rule(
355        &mut self,
356        vars: impl IntoIterator<Item = (String, oxiz_core::SortId)>,
357        body_preds: impl IntoIterator<Item = PredicateApp>,
358        constraint: TermId,
359        head_pred: PredId,
360        head_args: impl IntoIterator<Item = TermId>,
361    ) -> RuleId {
362        let body = RuleBody::new(body_preds, constraint);
363        let head = RuleHead::Predicate(PredicateApp::new(head_pred, head_args));
364        self.add_rule(vars, body, head, None)
365    }
366
367    /// Add a query rule: `P(args) /\ constraint => false`
368    pub fn add_query(
369        &mut self,
370        vars: impl IntoIterator<Item = (String, oxiz_core::SortId)>,
371        body_preds: impl IntoIterator<Item = PredicateApp>,
372        constraint: TermId,
373    ) -> RuleId {
374        let body = RuleBody::new(body_preds, constraint);
375        self.add_rule(vars, body, RuleHead::Query, None)
376    }
377
378    /// Get a rule by ID
379    #[must_use]
380    pub fn get_rule(&self, id: RuleId) -> Option<&Rule> {
381        self.rules.get(id.0 as usize)
382    }
383
384    /// Get all rules
385    pub fn rules(&self) -> impl Iterator<Item = &Rule> {
386        self.rules.iter()
387    }
388
389    /// Get all predicates
390    pub fn predicates(&self) -> impl Iterator<Item = &Predicate> {
391        self.predicates.iter()
392    }
393
394    /// Get query rules
395    pub fn queries(&self) -> impl Iterator<Item = &Rule> {
396        self.queries.iter().filter_map(|&id| self.get_rule(id))
397    }
398
399    /// Get entry rules (init rules)
400    pub fn entries(&self) -> impl Iterator<Item = &Rule> {
401        self.entries.iter().filter_map(|&id| self.get_rule(id))
402    }
403
404    /// Get rules with a given head predicate
405    pub fn rules_by_head(&self, pred: PredId) -> impl Iterator<Item = &Rule> {
406        self.rules_by_head
407            .get(&pred)
408            .into_iter()
409            .flat_map(|ids| ids.iter())
410            .filter_map(|&id| self.get_rule(id))
411    }
412
413    /// Get rules that use a predicate in the body
414    pub fn rules_using(&self, pred: PredId) -> impl Iterator<Item = &Rule> {
415        self.rules_by_body
416            .get(&pred)
417            .into_iter()
418            .flat_map(|ids| ids.iter())
419            .filter_map(|&id| self.get_rule(id))
420    }
421
422    /// Get the number of predicates
423    #[must_use]
424    pub fn num_predicates(&self) -> usize {
425        self.predicates.len()
426    }
427
428    /// Get the number of rules
429    #[must_use]
430    pub fn num_rules(&self) -> usize {
431        self.rules.len()
432    }
433
434    /// Check if the system is empty
435    #[must_use]
436    pub fn is_empty(&self) -> bool {
437        self.rules.is_empty()
438    }
439
440    /// Get predicates in topological order (if acyclic)
441    pub fn topological_order(&self) -> Option<Vec<PredId>> {
442        let mut in_degree: FxHashMap<PredId, usize> = FxHashMap::default();
443        let mut result = Vec::new();
444
445        // Initialize in-degrees
446        for pred in &self.predicates {
447            in_degree.insert(pred.id, 0);
448        }
449
450        // Count dependencies
451        for rule in &self.rules {
452            if let Some(head_pred) = rule.head_predicate() {
453                for body_pred in rule.body_predicates() {
454                    if body_pred != head_pred {
455                        *in_degree.entry(head_pred).or_default() += 1;
456                    }
457                }
458            }
459        }
460
461        // Kahn's algorithm
462        let mut queue: Vec<PredId> = in_degree
463            .iter()
464            .filter(|&(_, deg)| *deg == 0)
465            .map(|(&id, _)| id)
466            .collect();
467
468        while let Some(pred) = queue.pop() {
469            result.push(pred);
470
471            for rule in self.rules_by_body.get(&pred).into_iter().flatten() {
472                if let Some(head_pred) = self.get_rule(*rule).and_then(|r| r.head_predicate())
473                    && let Some(deg) = in_degree.get_mut(&head_pred)
474                {
475                    *deg = deg.saturating_sub(1);
476                    if *deg == 0 {
477                        queue.push(head_pred);
478                    }
479                }
480            }
481        }
482
483        if result.len() == self.predicates.len() {
484            Some(result)
485        } else {
486            None // Cycle detected
487        }
488    }
489}
490
491/// Builder for constructing CHC systems conveniently
492pub struct ChcBuilder<'a> {
493    system: ChcSystem,
494    terms: &'a mut TermManager,
495}
496
497impl<'a> ChcBuilder<'a> {
498    /// Create a new CHC builder
499    pub fn new(terms: &'a mut TermManager) -> Self {
500        Self {
501            system: ChcSystem::new(),
502            terms,
503        }
504    }
505
506    /// Declare a predicate
507    pub fn declare_pred(
508        &mut self,
509        name: impl Into<String>,
510        params: impl IntoIterator<Item = oxiz_core::SortId>,
511    ) -> PredId {
512        self.system.declare_predicate(name, params)
513    }
514
515    /// Get the term manager
516    pub fn terms(&mut self) -> &mut TermManager {
517        self.terms
518    }
519
520    /// Build the CHC system
521    pub fn build(self) -> ChcSystem {
522        self.system
523    }
524
525    /// Get mutable access to the system
526    pub fn system_mut(&mut self) -> &mut ChcSystem {
527        &mut self.system
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_chc_system_creation() {
537        let terms = TermManager::new();
538        let mut system = ChcSystem::new();
539
540        // Declare predicates
541        let inv = system.declare_predicate("Inv", [terms.sorts.int_sort]);
542        let err = system.declare_predicate("Err", []);
543
544        assert_eq!(system.num_predicates(), 2);
545        assert_eq!(
546            system
547                .get_predicate(inv)
548                .expect("test operation should succeed")
549                .name,
550            "Inv"
551        );
552        assert_eq!(
553            system
554                .get_predicate(err)
555                .expect("test operation should succeed")
556                .arity(),
557            0
558        );
559    }
560
561    #[test]
562    fn test_chc_rules() {
563        let mut terms = TermManager::new();
564        let mut system = ChcSystem::new();
565
566        let inv = system.declare_predicate("Inv", [terms.sorts.int_sort]);
567
568        // Init rule: x = 0 => Inv(x)
569        let x = terms.mk_var("x", terms.sorts.int_sort);
570        let zero = terms.mk_int(0);
571        let init_constraint = terms.mk_eq(x, zero);
572
573        system.add_init_rule(
574            [("x".to_string(), terms.sorts.int_sort)],
575            init_constraint,
576            inv,
577            [x],
578        );
579
580        // Transition rule: Inv(x) /\ x' = x + 1 => Inv(x')
581        let x_prime = terms.mk_var("x'", terms.sorts.int_sort);
582        let one = terms.mk_int(1);
583        let x_plus_one = terms.mk_add([x, one]);
584        let trans_constraint = terms.mk_eq(x_prime, x_plus_one);
585
586        system.add_transition_rule(
587            [
588                ("x".to_string(), terms.sorts.int_sort),
589                ("x'".to_string(), terms.sorts.int_sort),
590            ],
591            [PredicateApp::new(inv, [x])],
592            trans_constraint,
593            inv,
594            [x_prime],
595        );
596
597        // Query: Inv(x) /\ x < 0 => false
598        let neg_constraint = terms.mk_lt(x, zero);
599        system.add_query(
600            [("x".to_string(), terms.sorts.int_sort)],
601            [PredicateApp::new(inv, [x])],
602            neg_constraint,
603        );
604
605        assert_eq!(system.num_rules(), 3);
606        assert_eq!(system.entries().count(), 1);
607        assert_eq!(system.queries().count(), 1);
608    }
609
610    #[test]
611    fn test_rule_indexing() {
612        let mut terms = TermManager::new();
613        let mut system = ChcSystem::new();
614
615        let p = system.declare_predicate("P", [terms.sorts.int_sort]);
616        let q = system.declare_predicate("Q", [terms.sorts.int_sort]);
617
618        let x = terms.mk_var("x", terms.sorts.int_sort);
619        let constraint = terms.mk_true();
620
621        // P(x) => Q(x)
622        system.add_transition_rule(
623            [("x".to_string(), terms.sorts.int_sort)],
624            [PredicateApp::new(p, [x])],
625            constraint,
626            q,
627            [x],
628        );
629
630        // Rules with head Q
631        assert_eq!(system.rules_by_head(q).count(), 1);
632        assert_eq!(system.rules_by_head(p).count(), 0);
633
634        // Rules using P in body
635        assert_eq!(system.rules_using(p).count(), 1);
636        assert_eq!(system.rules_using(q).count(), 0);
637    }
638
639    #[test]
640    fn test_topological_order() {
641        let mut terms = TermManager::new();
642        let mut system = ChcSystem::new();
643
644        let p1 = system.declare_predicate("P1", [terms.sorts.int_sort]);
645        let p2 = system.declare_predicate("P2", [terms.sorts.int_sort]);
646        let p3 = system.declare_predicate("P3", [terms.sorts.int_sort]);
647
648        let x = terms.mk_var("x", terms.sorts.int_sort);
649        let constraint = terms.mk_true();
650
651        // P1 => P2, P2 => P3 (acyclic)
652        system.add_transition_rule(
653            [("x".to_string(), terms.sorts.int_sort)],
654            [PredicateApp::new(p1, [x])],
655            constraint,
656            p2,
657            [x],
658        );
659        system.add_transition_rule(
660            [("x".to_string(), terms.sorts.int_sort)],
661            [PredicateApp::new(p2, [x])],
662            constraint,
663            p3,
664            [x],
665        );
666
667        let order = system.topological_order();
668        assert!(order.is_some());
669
670        let order = order.expect("test operation should succeed");
671        let p1_pos = order
672            .iter()
673            .position(|&id| id == p1)
674            .expect("element should be found");
675        let p2_pos = order
676            .iter()
677            .position(|&id| id == p2)
678            .expect("element should be found");
679        let p3_pos = order
680            .iter()
681            .position(|&id| id == p3)
682            .expect("element should be found");
683
684        // P1 should come before P2, P2 before P3
685        assert!(p1_pos < p2_pos);
686        assert!(p2_pos < p3_pos);
687    }
688}