1use indexmap::IndexMap;
10use oxiz_core::{TermId, TermManager};
11use rustc_hash::FxHashMap;
12use smallvec::SmallVec;
13use std::sync::atomic::{AtomicU32, Ordering};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
17pub struct PredId(pub u32);
18
19impl PredId {
20 #[inline]
22 #[must_use]
23 pub const fn new(id: u32) -> Self {
24 Self(id)
25 }
26
27 #[inline]
29 #[must_use]
30 pub const fn raw(self) -> u32 {
31 self.0
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
37pub struct RuleId(pub u32);
38
39impl RuleId {
40 #[inline]
42 #[must_use]
43 pub const fn new(id: u32) -> Self {
44 Self(id)
45 }
46
47 #[inline]
49 #[must_use]
50 pub const fn raw(self) -> u32 {
51 self.0
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct Predicate {
58 pub id: PredId,
60 pub name: String,
62 pub params: SmallVec<[oxiz_core::SortId; 4]>,
64}
65
66impl Predicate {
67 #[inline]
69 #[must_use]
70 pub fn arity(&self) -> usize {
71 self.params.len()
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct PredicateApp {
78 pub pred: PredId,
80 pub args: SmallVec<[TermId; 4]>,
82}
83
84impl PredicateApp {
85 pub fn new(pred: PredId, args: impl IntoIterator<Item = TermId>) -> Self {
87 Self {
88 pred,
89 args: args.into_iter().collect(),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub enum RuleHead {
97 Predicate(PredicateApp),
99 Query,
101}
102
103impl RuleHead {
104 #[inline]
106 #[must_use]
107 pub fn is_query(&self) -> bool {
108 matches!(self, RuleHead::Query)
109 }
110
111 #[inline]
113 #[must_use]
114 pub fn as_predicate(&self) -> Option<&PredicateApp> {
115 match self {
116 RuleHead::Predicate(app) => Some(app),
117 RuleHead::Query => None,
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct RuleBody {
125 pub predicates: SmallVec<[PredicateApp; 2]>,
127 pub constraint: TermId,
129}
130
131impl RuleBody {
132 pub fn init(constraint: TermId) -> Self {
134 Self {
135 predicates: SmallVec::new(),
136 constraint,
137 }
138 }
139
140 pub fn new(predicates: impl IntoIterator<Item = PredicateApp>, constraint: TermId) -> Self {
142 Self {
143 predicates: predicates.into_iter().collect(),
144 constraint,
145 }
146 }
147
148 #[inline]
150 #[must_use]
151 pub fn is_init(&self) -> bool {
152 self.predicates.is_empty()
153 }
154
155 #[inline]
157 #[must_use]
158 pub fn uninterpreted_tail_size(&self) -> usize {
159 self.predicates.len()
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct Rule {
166 pub id: RuleId,
168 pub vars: SmallVec<[(String, oxiz_core::SortId); 4]>,
170 pub body: RuleBody,
172 pub head: RuleHead,
174 pub name: Option<String>,
176}
177
178impl Rule {
179 #[inline]
181 #[must_use]
182 pub fn is_init(&self) -> bool {
183 self.body.is_init()
184 }
185
186 #[inline]
188 #[must_use]
189 pub fn is_query(&self) -> bool {
190 self.head.is_query()
191 }
192
193 #[inline]
195 #[must_use]
196 pub fn head_predicate(&self) -> Option<PredId> {
197 match &self.head {
198 RuleHead::Predicate(app) => Some(app.pred),
199 RuleHead::Query => None,
200 }
201 }
202
203 pub fn body_predicates(&self) -> impl Iterator<Item = PredId> + '_ {
205 self.body.predicates.iter().map(|app| app.pred)
206 }
207}
208
209#[derive(Debug)]
211pub struct ChcSystem {
212 predicates: Vec<Predicate>,
214 pred_by_name: FxHashMap<String, PredId>,
216 next_pred_id: AtomicU32,
218
219 rules: Vec<Rule>,
221 next_rule_id: AtomicU32,
223
224 rules_by_head: IndexMap<PredId, SmallVec<[RuleId; 4]>>,
226 rules_by_body: IndexMap<PredId, SmallVec<[RuleId; 4]>>,
228
229 queries: SmallVec<[RuleId; 2]>,
231 entries: SmallVec<[RuleId; 2]>,
233}
234
235impl Default for ChcSystem {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241impl ChcSystem {
242 pub fn new() -> Self {
244 Self {
245 predicates: Vec::new(),
246 pred_by_name: FxHashMap::default(),
247 next_pred_id: AtomicU32::new(0),
248 rules: Vec::new(),
249 next_rule_id: AtomicU32::new(0),
250 rules_by_head: IndexMap::new(),
251 rules_by_body: IndexMap::new(),
252 queries: SmallVec::new(),
253 entries: SmallVec::new(),
254 }
255 }
256
257 pub fn declare_predicate(
259 &mut self,
260 name: impl Into<String>,
261 params: impl IntoIterator<Item = oxiz_core::SortId>,
262 ) -> PredId {
263 let name = name.into();
264 if let Some(&id) = self.pred_by_name.get(&name) {
265 return id;
266 }
267
268 let id = PredId(self.next_pred_id.fetch_add(1, Ordering::Relaxed));
269 let pred = Predicate {
270 id,
271 name: name.clone(),
272 params: params.into_iter().collect(),
273 };
274
275 self.pred_by_name.insert(name, id);
276 self.predicates.push(pred);
277 id
278 }
279
280 #[must_use]
282 pub fn get_predicate(&self, id: PredId) -> Option<&Predicate> {
283 self.predicates.get(id.0 as usize)
284 }
285
286 #[must_use]
288 pub fn get_predicate_by_name(&self, name: &str) -> Option<&Predicate> {
289 self.pred_by_name
290 .get(name)
291 .and_then(|&id| self.get_predicate(id))
292 }
293
294 #[must_use]
296 pub fn get_predicate_id(&self, name: &str) -> Option<PredId> {
297 self.pred_by_name.get(name).copied()
298 }
299
300 pub fn add_rule(
302 &mut self,
303 vars: impl IntoIterator<Item = (String, oxiz_core::SortId)>,
304 body: RuleBody,
305 head: RuleHead,
306 name: Option<String>,
307 ) -> RuleId {
308 let id = RuleId(self.next_rule_id.fetch_add(1, Ordering::Relaxed));
309
310 if head.is_query() {
312 self.queries.push(id);
313 }
314 if body.is_init() {
315 self.entries.push(id);
316 }
317
318 if let Some(pred_id) = head.as_predicate().map(|a| a.pred) {
320 self.rules_by_head.entry(pred_id).or_default().push(id);
321 }
322
323 for app in &body.predicates {
325 self.rules_by_body.entry(app.pred).or_default().push(id);
326 }
327
328 let rule = Rule {
329 id,
330 vars: vars.into_iter().collect(),
331 body,
332 head,
333 name,
334 };
335
336 self.rules.push(rule);
337 id
338 }
339
340 pub fn add_init_rule(
342 &mut self,
343 vars: impl IntoIterator<Item = (String, oxiz_core::SortId)>,
344 constraint: TermId,
345 head_pred: PredId,
346 head_args: impl IntoIterator<Item = TermId>,
347 ) -> RuleId {
348 let body = RuleBody::init(constraint);
349 let head = RuleHead::Predicate(PredicateApp::new(head_pred, head_args));
350 self.add_rule(vars, body, head, None)
351 }
352
353 pub fn add_transition_rule(
355 &mut self,
356 vars: impl IntoIterator<Item = (String, oxiz_core::SortId)>,
357 body_preds: impl IntoIterator<Item = PredicateApp>,
358 constraint: TermId,
359 head_pred: PredId,
360 head_args: impl IntoIterator<Item = TermId>,
361 ) -> RuleId {
362 let body = RuleBody::new(body_preds, constraint);
363 let head = RuleHead::Predicate(PredicateApp::new(head_pred, head_args));
364 self.add_rule(vars, body, head, None)
365 }
366
367 pub fn add_query(
369 &mut self,
370 vars: impl IntoIterator<Item = (String, oxiz_core::SortId)>,
371 body_preds: impl IntoIterator<Item = PredicateApp>,
372 constraint: TermId,
373 ) -> RuleId {
374 let body = RuleBody::new(body_preds, constraint);
375 self.add_rule(vars, body, RuleHead::Query, None)
376 }
377
378 #[must_use]
380 pub fn get_rule(&self, id: RuleId) -> Option<&Rule> {
381 self.rules.get(id.0 as usize)
382 }
383
384 pub fn rules(&self) -> impl Iterator<Item = &Rule> {
386 self.rules.iter()
387 }
388
389 pub fn predicates(&self) -> impl Iterator<Item = &Predicate> {
391 self.predicates.iter()
392 }
393
394 pub fn queries(&self) -> impl Iterator<Item = &Rule> {
396 self.queries.iter().filter_map(|&id| self.get_rule(id))
397 }
398
399 pub fn entries(&self) -> impl Iterator<Item = &Rule> {
401 self.entries.iter().filter_map(|&id| self.get_rule(id))
402 }
403
404 pub fn rules_by_head(&self, pred: PredId) -> impl Iterator<Item = &Rule> {
406 self.rules_by_head
407 .get(&pred)
408 .into_iter()
409 .flat_map(|ids| ids.iter())
410 .filter_map(|&id| self.get_rule(id))
411 }
412
413 pub fn rules_using(&self, pred: PredId) -> impl Iterator<Item = &Rule> {
415 self.rules_by_body
416 .get(&pred)
417 .into_iter()
418 .flat_map(|ids| ids.iter())
419 .filter_map(|&id| self.get_rule(id))
420 }
421
422 #[must_use]
424 pub fn num_predicates(&self) -> usize {
425 self.predicates.len()
426 }
427
428 #[must_use]
430 pub fn num_rules(&self) -> usize {
431 self.rules.len()
432 }
433
434 #[must_use]
436 pub fn is_empty(&self) -> bool {
437 self.rules.is_empty()
438 }
439
440 pub fn topological_order(&self) -> Option<Vec<PredId>> {
442 let mut in_degree: FxHashMap<PredId, usize> = FxHashMap::default();
443 let mut result = Vec::new();
444
445 for pred in &self.predicates {
447 in_degree.insert(pred.id, 0);
448 }
449
450 for rule in &self.rules {
452 if let Some(head_pred) = rule.head_predicate() {
453 for body_pred in rule.body_predicates() {
454 if body_pred != head_pred {
455 *in_degree.entry(head_pred).or_default() += 1;
456 }
457 }
458 }
459 }
460
461 let mut queue: Vec<PredId> = in_degree
463 .iter()
464 .filter(|&(_, deg)| *deg == 0)
465 .map(|(&id, _)| id)
466 .collect();
467
468 while let Some(pred) = queue.pop() {
469 result.push(pred);
470
471 for rule in self.rules_by_body.get(&pred).into_iter().flatten() {
472 if let Some(head_pred) = self.get_rule(*rule).and_then(|r| r.head_predicate())
473 && let Some(deg) = in_degree.get_mut(&head_pred)
474 {
475 *deg = deg.saturating_sub(1);
476 if *deg == 0 {
477 queue.push(head_pred);
478 }
479 }
480 }
481 }
482
483 if result.len() == self.predicates.len() {
484 Some(result)
485 } else {
486 None }
488 }
489}
490
491pub struct ChcBuilder<'a> {
493 system: ChcSystem,
494 terms: &'a mut TermManager,
495}
496
497impl<'a> ChcBuilder<'a> {
498 pub fn new(terms: &'a mut TermManager) -> Self {
500 Self {
501 system: ChcSystem::new(),
502 terms,
503 }
504 }
505
506 pub fn declare_pred(
508 &mut self,
509 name: impl Into<String>,
510 params: impl IntoIterator<Item = oxiz_core::SortId>,
511 ) -> PredId {
512 self.system.declare_predicate(name, params)
513 }
514
515 pub fn terms(&mut self) -> &mut TermManager {
517 self.terms
518 }
519
520 pub fn build(self) -> ChcSystem {
522 self.system
523 }
524
525 pub fn system_mut(&mut self) -> &mut ChcSystem {
527 &mut self.system
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534
535 #[test]
536 fn test_chc_system_creation() {
537 let terms = TermManager::new();
538 let mut system = ChcSystem::new();
539
540 let inv = system.declare_predicate("Inv", [terms.sorts.int_sort]);
542 let err = system.declare_predicate("Err", []);
543
544 assert_eq!(system.num_predicates(), 2);
545 assert_eq!(
546 system
547 .get_predicate(inv)
548 .expect("test operation should succeed")
549 .name,
550 "Inv"
551 );
552 assert_eq!(
553 system
554 .get_predicate(err)
555 .expect("test operation should succeed")
556 .arity(),
557 0
558 );
559 }
560
561 #[test]
562 fn test_chc_rules() {
563 let mut terms = TermManager::new();
564 let mut system = ChcSystem::new();
565
566 let inv = system.declare_predicate("Inv", [terms.sorts.int_sort]);
567
568 let x = terms.mk_var("x", terms.sorts.int_sort);
570 let zero = terms.mk_int(0);
571 let init_constraint = terms.mk_eq(x, zero);
572
573 system.add_init_rule(
574 [("x".to_string(), terms.sorts.int_sort)],
575 init_constraint,
576 inv,
577 [x],
578 );
579
580 let x_prime = terms.mk_var("x'", terms.sorts.int_sort);
582 let one = terms.mk_int(1);
583 let x_plus_one = terms.mk_add([x, one]);
584 let trans_constraint = terms.mk_eq(x_prime, x_plus_one);
585
586 system.add_transition_rule(
587 [
588 ("x".to_string(), terms.sorts.int_sort),
589 ("x'".to_string(), terms.sorts.int_sort),
590 ],
591 [PredicateApp::new(inv, [x])],
592 trans_constraint,
593 inv,
594 [x_prime],
595 );
596
597 let neg_constraint = terms.mk_lt(x, zero);
599 system.add_query(
600 [("x".to_string(), terms.sorts.int_sort)],
601 [PredicateApp::new(inv, [x])],
602 neg_constraint,
603 );
604
605 assert_eq!(system.num_rules(), 3);
606 assert_eq!(system.entries().count(), 1);
607 assert_eq!(system.queries().count(), 1);
608 }
609
610 #[test]
611 fn test_rule_indexing() {
612 let mut terms = TermManager::new();
613 let mut system = ChcSystem::new();
614
615 let p = system.declare_predicate("P", [terms.sorts.int_sort]);
616 let q = system.declare_predicate("Q", [terms.sorts.int_sort]);
617
618 let x = terms.mk_var("x", terms.sorts.int_sort);
619 let constraint = terms.mk_true();
620
621 system.add_transition_rule(
623 [("x".to_string(), terms.sorts.int_sort)],
624 [PredicateApp::new(p, [x])],
625 constraint,
626 q,
627 [x],
628 );
629
630 assert_eq!(system.rules_by_head(q).count(), 1);
632 assert_eq!(system.rules_by_head(p).count(), 0);
633
634 assert_eq!(system.rules_using(p).count(), 1);
636 assert_eq!(system.rules_using(q).count(), 0);
637 }
638
639 #[test]
640 fn test_topological_order() {
641 let mut terms = TermManager::new();
642 let mut system = ChcSystem::new();
643
644 let p1 = system.declare_predicate("P1", [terms.sorts.int_sort]);
645 let p2 = system.declare_predicate("P2", [terms.sorts.int_sort]);
646 let p3 = system.declare_predicate("P3", [terms.sorts.int_sort]);
647
648 let x = terms.mk_var("x", terms.sorts.int_sort);
649 let constraint = terms.mk_true();
650
651 system.add_transition_rule(
653 [("x".to_string(), terms.sorts.int_sort)],
654 [PredicateApp::new(p1, [x])],
655 constraint,
656 p2,
657 [x],
658 );
659 system.add_transition_rule(
660 [("x".to_string(), terms.sorts.int_sort)],
661 [PredicateApp::new(p2, [x])],
662 constraint,
663 p3,
664 [x],
665 );
666
667 let order = system.topological_order();
668 assert!(order.is_some());
669
670 let order = order.expect("test operation should succeed");
671 let p1_pos = order
672 .iter()
673 .position(|&id| id == p1)
674 .expect("element should be found");
675 let p2_pos = order
676 .iter()
677 .position(|&id| id == p2)
678 .expect("element should be found");
679 let p3_pos = order
680 .iter()
681 .position(|&id| id == p3)
682 .expect("element should be found");
683
684 assert!(p1_pos < p2_pos);
686 assert!(p2_pos < p3_pos);
687 }
688}