1use crate::solver::{Solver, SolverResult};
4use oxiz_core::ast::{TermId, TermKind, TermManager};
5use oxiz_core::error::Result;
6use oxiz_core::smtlib::{Command, parse_script};
7use oxiz_core::sort::SortId;
8
9#[derive(Debug, Clone)]
11struct DeclaredConst {
12 term: TermId,
14 sort: SortId,
16 name: String,
18}
19
20#[derive(Debug, Clone)]
22struct DeclaredFun {
23 name: String,
25 arg_sorts: Vec<SortId>,
27 ret_sort: SortId,
29}
30
31#[derive(Debug)]
76pub struct Context {
77 pub terms: TermManager,
79 solver: Solver,
81 logic: Option<String>,
83 assertions: Vec<TermId>,
85 assertion_stack: Vec<usize>,
87 declared_consts: Vec<DeclaredConst>,
89 const_stack: Vec<usize>,
91 const_name_to_index: std::collections::HashMap<String, usize>,
93 declared_funs: Vec<DeclaredFun>,
95 fun_stack: Vec<usize>,
97 fun_name_to_index: std::collections::HashMap<String, usize>,
99 last_result: Option<SolverResult>,
101 options: std::collections::HashMap<String, String>,
103}
104
105impl Default for Context {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111impl Context {
112 #[must_use]
114 pub fn new() -> Self {
115 Self {
116 terms: TermManager::new(),
117 solver: Solver::new(),
118 logic: None,
119 assertions: Vec::new(),
120 assertion_stack: Vec::new(),
121 declared_consts: Vec::new(),
122 const_stack: Vec::new(),
123 const_name_to_index: std::collections::HashMap::new(),
124 declared_funs: Vec::new(),
125 fun_stack: Vec::new(),
126 fun_name_to_index: std::collections::HashMap::new(),
127 last_result: None,
128 options: std::collections::HashMap::new(),
129 }
130 }
131
132 pub fn declare_const(&mut self, name: &str, sort: SortId) -> TermId {
134 let term = self.terms.mk_var(name, sort);
135 let index = self.declared_consts.len();
136 self.declared_consts.push(DeclaredConst {
137 term,
138 sort,
139 name: name.to_string(),
140 });
141 self.const_name_to_index.insert(name.to_string(), index);
142 term
143 }
144
145 pub fn declare_fun(&mut self, name: &str, arg_sorts: Vec<SortId>, ret_sort: SortId) {
150 let index = self.declared_funs.len();
151 self.declared_funs.push(DeclaredFun {
152 name: name.to_string(),
153 arg_sorts,
154 ret_sort,
155 });
156 self.fun_name_to_index.insert(name.to_string(), index);
157 }
158
159 pub fn get_fun_signature(&self, name: &str) -> Option<(Vec<SortId>, SortId)> {
161 self.fun_name_to_index.get(name).and_then(|&idx| {
162 self.declared_funs
163 .get(idx)
164 .map(|f| (f.arg_sorts.clone(), f.ret_sort))
165 })
166 }
167
168 pub fn set_logic(&mut self, logic: &str) {
170 self.logic = Some(logic.to_string());
171 self.solver.set_logic(logic);
172 }
173
174 #[must_use]
176 pub fn logic(&self) -> Option<&str> {
177 self.logic.as_deref()
178 }
179
180 pub fn assert(&mut self, term: TermId) {
182 self.assertions.push(term);
183 self.solver.assert(term, &mut self.terms);
184 }
185
186 pub fn check_sat(&mut self) -> SolverResult {
188 let result = self.solver.check(&mut self.terms);
189 self.last_result = Some(result);
190 result
191 }
192
193 pub fn get_model(&self) -> Option<Vec<(String, String, String)>> {
196 if self.last_result != Some(SolverResult::Sat) {
197 return None;
198 }
199
200 let mut model = Vec::new();
201 let solver_model = self.solver.model()?;
202
203 for decl in &self.declared_consts {
204 let value = if let Some(val) = solver_model.get(decl.term) {
205 self.format_value(val)
206 } else {
207 self.default_value(decl.sort)
209 };
210 let sort_name = self.format_sort_name(decl.sort);
211 model.push((decl.name.clone(), sort_name, value));
212 }
213
214 Some(model)
215 }
216
217 fn format_sort_name(&self, sort: SortId) -> String {
219 if sort == self.terms.sorts.bool_sort {
220 "Bool".to_string()
221 } else if sort == self.terms.sorts.int_sort {
222 "Int".to_string()
223 } else if sort == self.terms.sorts.real_sort {
224 "Real".to_string()
225 } else if let Some(s) = self.terms.sorts.get(sort) {
226 if let Some(w) = s.bitvec_width() {
227 format!("(_ BitVec {})", w)
228 } else {
229 "Unknown".to_string()
230 }
231 } else {
232 "Unknown".to_string()
233 }
234 }
235
236 fn format_value(&self, term: TermId) -> String {
238 match self.terms.get(term).map(|t| &t.kind) {
239 Some(TermKind::True) => "true".to_string(),
240 Some(TermKind::False) => "false".to_string(),
241 Some(TermKind::IntConst(n)) => n.to_string(),
242 Some(TermKind::RealConst(r)) => {
243 if *r.denom() == 1 {
244 format!("{}.0", r.numer())
245 } else {
246 format!("(/ {} {})", r.numer(), r.denom())
247 }
248 }
249 Some(TermKind::BitVecConst { value, width }) => {
250 format!(
251 "#b{:0>width$}",
252 format!("{:b}", value),
253 width = *width as usize
254 )
255 }
256 _ => "?".to_string(),
257 }
258 }
259
260 fn default_value(&self, sort: SortId) -> String {
262 if sort == self.terms.sorts.bool_sort {
263 "false".to_string()
264 } else if sort == self.terms.sorts.int_sort {
265 "0".to_string()
266 } else if sort == self.terms.sorts.real_sort {
267 "0.0".to_string()
268 } else if let Some(s) = self.terms.sorts.get(sort) {
269 if let Some(w) = s.bitvec_width() {
270 format!("#b{:0>width$}", "0", width = w as usize)
271 } else {
272 "?".to_string()
273 }
274 } else {
275 "?".to_string()
276 }
277 }
278
279 pub fn format_model(&self) -> String {
281 match self.get_model() {
282 None => "(error \"No model available\")".to_string(),
283 Some(model) if model.is_empty() => "(model)".to_string(),
284 Some(model) => {
285 let mut lines = vec!["(model".to_string()];
286 for (name, sort, value) in model {
287 lines.push(format!(" (define-fun {} () {} {})", name, sort, value));
288 }
289 lines.push(")".to_string());
290 lines.join("\n")
291 }
292 }
293 }
294
295 pub fn push(&mut self) {
297 self.assertion_stack.push(self.assertions.len());
298 self.const_stack.push(self.declared_consts.len());
299 self.fun_stack.push(self.declared_funs.len());
300 self.solver.push();
301 }
302
303 pub fn pop(&mut self) {
305 if let Some(len) = self.assertion_stack.pop() {
306 self.assertions.truncate(len);
307 if let Some(const_len) = self.const_stack.pop() {
308 while self.declared_consts.len() > const_len {
310 if let Some(decl) = self.declared_consts.pop() {
311 self.const_name_to_index.remove(&decl.name);
312 }
313 }
314 }
315 if let Some(fun_len) = self.fun_stack.pop() {
316 while self.declared_funs.len() > fun_len {
318 if let Some(decl) = self.declared_funs.pop() {
319 self.fun_name_to_index.remove(&decl.name);
320 }
321 }
322 }
323 self.solver.pop();
324 }
325 }
326
327 pub fn reset(&mut self) {
329 self.solver.reset();
330 self.assertions.clear();
331 self.assertion_stack.clear();
332 self.declared_consts.clear();
333 self.const_stack.clear();
334 self.const_name_to_index.clear();
335 self.declared_funs.clear();
336 self.fun_stack.clear();
337 self.fun_name_to_index.clear();
338 self.logic = None;
339 self.last_result = None;
340 self.options.clear();
341 }
342
343 pub fn reset_assertions(&mut self) {
345 self.solver.reset();
346 self.assertions.clear();
347 self.assertion_stack.clear();
348 self.last_result = None;
352 }
353
354 #[must_use]
356 pub fn get_assertions(&self) -> &[TermId] {
357 &self.assertions
358 }
359
360 pub fn format_assertions(&self) -> String {
362 if self.assertions.is_empty() {
363 return "()".to_string();
364 }
365 let printer = oxiz_core::smtlib::Printer::new(&self.terms);
366 let mut parts = Vec::new();
367 for &term in &self.assertions {
368 parts.push(printer.print_term(term));
369 }
370 format!("({})", parts.join("\n "))
371 }
372
373 pub fn set_option(&mut self, key: &str, value: &str) {
375 self.options.insert(key.to_string(), value.to_string());
376
377 match key {
379 "produce-proofs" => {
380 let mut config = self.solver.config().clone();
381 config.proof = value == "true";
382 self.solver.set_config(config);
383 }
384 "produce-unsat-cores" => {
385 self.solver.set_produce_unsat_cores(value == "true");
386 }
387 _ => {}
388 }
389 }
390
391 #[must_use]
393 pub fn get_option(&self, key: &str) -> Option<&str> {
394 self.options.get(key).map(String::as_str)
395 }
396
397 fn format_option(&self, key: &str) -> String {
399 match self.get_option(key) {
400 Some(val) => val.to_string(),
401 None => {
402 match key {
404 "produce-models" => "false".to_string(),
405 "produce-unsat-cores" => "false".to_string(),
406 "produce-proofs" => "false".to_string(),
407 "produce-assignments" => "false".to_string(),
408 "print-success" => "true".to_string(),
409 _ => "unsupported".to_string(),
410 }
411 }
412 }
413 }
414
415 pub fn get_assignment(&self) -> String {
418 "()".to_string()
419 }
420
421 pub fn get_proof(&self) -> String {
423 if self.last_result != Some(SolverResult::Unsat) {
424 return "(error \"Proof is only available after unsat result\")".to_string();
425 }
426
427 match self.solver.get_proof() {
428 Some(proof) => proof.format(),
429 None => {
430 "(error \"Proof generation not enabled. Set :produce-proofs to true\")".to_string()
431 }
432 }
433 }
434
435 pub fn get_statistics(&self) -> String {
438 let stats = self.solver.get_statistics();
439 format!(
440 "(:decisions {} :conflicts {} :propagations {} :restarts {} :learned-clauses {} :theory-propagations {} :theory-conflicts {})",
441 stats.decisions,
442 stats.conflicts,
443 stats.propagations,
444 stats.restarts,
445 stats.learned_clauses,
446 stats.theory_propagations,
447 stats.theory_conflicts
448 )
449 }
450
451 fn parse_sort_name(&mut self, name: &str) -> SortId {
453 match name {
454 "Bool" => self.terms.sorts.bool_sort,
455 "Int" => self.terms.sorts.int_sort,
456 "Real" => self.terms.sorts.real_sort,
457 _ => {
458 if let Some(width_str) = name.strip_prefix("BitVec")
460 && let Ok(width) = width_str.trim().parse::<u32>()
461 {
462 return self.terms.sorts.bitvec(width);
463 }
464 self.terms.sorts.bool_sort
466 }
467 }
468 }
469
470 pub fn execute_script(&mut self, script: &str) -> Result<Vec<String>> {
472 let commands = parse_script(script, &mut self.terms)?;
473 let mut output = Vec::new();
474
475 for cmd in commands {
476 match cmd {
477 Command::SetLogic(logic) => {
478 self.set_logic(&logic);
479 }
480 Command::DeclareConst(name, sort_name) => {
481 let sort = self.parse_sort_name(&sort_name);
482 self.declare_const(&name, sort);
483 }
484 Command::DeclareFun(name, arg_sorts, ret_sort) => {
485 if arg_sorts.is_empty() {
487 let sort = self.parse_sort_name(&ret_sort);
488 self.declare_const(&name, sort);
489 } else {
490 let parsed_arg_sorts: Vec<SortId> =
492 arg_sorts.iter().map(|s| self.parse_sort_name(s)).collect();
493 let parsed_ret_sort = self.parse_sort_name(&ret_sort);
494 self.declare_fun(&name, parsed_arg_sorts, parsed_ret_sort);
495 }
496 }
497 Command::Assert(term) => {
498 self.assert(term);
499 }
500 Command::CheckSat => {
501 let result = self.check_sat();
502 output.push(match result {
503 SolverResult::Sat => "sat".to_string(),
504 SolverResult::Unsat => "unsat".to_string(),
505 SolverResult::Unknown => "unknown".to_string(),
506 });
507 }
508 Command::Push(n) => {
509 for _ in 0..n {
510 self.push();
511 }
512 }
513 Command::Pop(n) => {
514 for _ in 0..n {
515 self.pop();
516 }
517 }
518 Command::Reset => {
519 self.reset();
520 }
521 Command::ResetAssertions => {
522 self.reset_assertions();
523 }
524 Command::Exit => {
525 break;
526 }
527 Command::Echo(msg) => {
528 output.push(msg);
529 }
530 Command::GetModel => {
531 output.push(self.format_model());
532 }
533 Command::GetAssertions => {
534 output.push(self.format_assertions());
535 }
536 Command::GetAssignment => {
537 output.push(self.get_assignment());
538 }
539 Command::GetProof => {
540 output.push(self.get_proof());
541 }
542 Command::GetOption(key) => {
543 output.push(self.format_option(&key));
544 }
545 Command::SetOption(key, value) => {
546 self.set_option(&key, &value);
547 }
548 Command::CheckSatAssuming(assumptions) => {
549 self.push();
551 for assumption in assumptions {
552 self.assert(assumption);
553 }
554 let result = self.check_sat();
555 self.pop();
556 output.push(match result {
557 SolverResult::Sat => "sat".to_string(),
558 SolverResult::Unsat => "unsat".to_string(),
559 SolverResult::Unknown => "unknown".to_string(),
560 });
561 }
562 Command::Simplify(term) => {
563 let simplified = self.terms.simplify(term);
565 let printer = oxiz_core::smtlib::Printer::new(&self.terms);
566 output.push(printer.print_term(simplified));
567 }
568 Command::GetUnsatCore => {
569 if let Some(core) = self.solver.get_unsat_core() {
570 if core.names.is_empty() {
571 output.push("()".to_string());
572 } else {
573 output.push(format!("({})", core.names.join(" ")));
574 }
575 } else {
576 output.push("(error \"No unsat core available\")".to_string());
577 }
578 }
579 Command::GetValue(terms) => {
580 if self.last_result != Some(SolverResult::Sat) {
581 output.push("(error \"No model available\")".to_string());
582 } else if let Some(model) = self.solver.model() {
583 let mut values = Vec::new();
584 for term in terms {
585 let value = model.eval(term, &mut self.terms);
587 let printer = oxiz_core::smtlib::Printer::new(&self.terms);
589 let term_str = printer.print_term(term);
590 let value_str = printer.print_term(value);
591 values.push(format!("({} {})", term_str, value_str));
592 }
593 output.push(format!("({})", values.join("\n ")));
594 } else {
595 output.push("(error \"No model available\")".to_string());
596 }
597 }
598 Command::GetInfo(keyword) => {
599 if keyword == ":all-statistics" {
601 output.push(self.get_statistics());
602 } else {
603 output.push(format!("(error \"Unsupported info keyword: {}\")", keyword));
604 }
605 }
606 Command::SetInfo(_, _)
607 | Command::DeclareSort(_, _)
608 | Command::DefineSort(_, _, _)
609 | Command::DefineFun(_, _, _, _)
610 | Command::DeclareDatatype { .. } => {
611 }
613 }
614 }
615
616 Ok(output)
617 }
618
619 #[must_use]
621 pub fn stats(&self) -> &oxiz_sat::SolverStats {
622 self.solver.stats()
623 }
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 #[test]
631 fn test_context_basic() {
632 let mut ctx = Context::new();
633
634 ctx.set_logic("QF_UF");
635 assert_eq!(ctx.logic(), Some("QF_UF"));
636
637 let t = ctx.terms.mk_true();
638 ctx.assert(t);
639
640 let result = ctx.check_sat();
641 assert_eq!(result, SolverResult::Sat);
642 }
643
644 #[test]
645 fn test_context_push_pop() {
646 let mut ctx = Context::new();
647
648 let t = ctx.terms.mk_true();
649 ctx.assert(t);
650 ctx.push();
651
652 let f = ctx.terms.mk_false();
653 ctx.assert(f);
654
655 let result = ctx.check_sat();
657 assert_eq!(result, SolverResult::Unsat);
658
659 ctx.pop();
660
661 let result = ctx.check_sat();
663 assert_eq!(result, SolverResult::Sat);
664 }
665
666 #[test]
667 fn test_execute_script() {
668 let mut ctx = Context::new();
669
670 let script = r#"
671 (set-logic QF_UF)
672 (declare-const p Bool)
673 (assert p)
674 (check-sat)
675 "#;
676
677 let output = ctx.execute_script(script).unwrap();
678 assert_eq!(output, vec!["sat"]);
679 }
680
681 #[test]
682 fn test_declare_const() {
683 let mut ctx = Context::new();
684
685 let bool_sort = ctx.terms.sorts.bool_sort;
686 let int_sort = ctx.terms.sorts.int_sort;
687
688 ctx.declare_const("x", bool_sort);
689 ctx.declare_const("y", int_sort);
690
691 let t = ctx.terms.mk_true();
692 ctx.assert(t);
693 let result = ctx.check_sat();
694 assert_eq!(result, SolverResult::Sat);
695
696 let model = ctx.get_model();
698 assert!(model.is_some());
699 let model = model.unwrap();
700 assert_eq!(model.len(), 2);
701 }
702
703 #[test]
704 fn test_format_model() {
705 let mut ctx = Context::new();
706
707 let bool_sort = ctx.terms.sorts.bool_sort;
708 ctx.declare_const("p", bool_sort);
709
710 let t = ctx.terms.mk_true();
711 ctx.assert(t);
712 let _ = ctx.check_sat();
713
714 let model_str = ctx.format_model();
715 assert!(model_str.contains("(model"));
716 assert!(model_str.contains("define-fun p () Bool"));
717 }
718
719 #[test]
720 fn test_get_model_script() {
721 let mut ctx = Context::new();
722
723 let script = r#"
724 (set-logic QF_LIA)
725 (declare-const x Int)
726 (declare-const y Bool)
727 (assert true)
728 (check-sat)
729 (get-model)
730 "#;
731
732 let output = ctx.execute_script(script).unwrap();
733 assert_eq!(output.len(), 2);
734 assert_eq!(output[0], "sat");
735 assert!(output[1].contains("(model"));
736 assert!(output[1].contains("Int"));
737 assert!(output[1].contains("Bool"));
738 }
739
740 #[test]
741 fn test_push_pop_consts() {
742 let mut ctx = Context::new();
743
744 let bool_sort = ctx.terms.sorts.bool_sort;
745 ctx.declare_const("a", bool_sort);
746 ctx.push();
747 ctx.declare_const("b", bool_sort);
748
749 let t = ctx.terms.mk_true();
750 ctx.assert(t);
751 let _ = ctx.check_sat();
752
753 let model = ctx.get_model().unwrap();
754 assert_eq!(model.len(), 2);
755
756 ctx.pop();
757 let _ = ctx.check_sat();
758
759 let model = ctx.get_model().unwrap();
760 assert_eq!(model.len(), 1);
761 assert_eq!(model[0].0, "a");
762 }
763
764 #[test]
765 fn test_get_assertions() {
766 let mut ctx = Context::new();
767
768 let script = r#"
769 (set-logic QF_UF)
770 (declare-const p Bool)
771 (assert p)
772 (assert (not p))
773 (get-assertions)
774 "#;
775
776 let output = ctx.execute_script(script).unwrap();
777 assert_eq!(output.len(), 1);
778 assert!(output[0].starts_with('('));
779 assert!(output[0].contains("p"));
781 }
782
783 #[test]
784 fn test_check_sat_assuming_script() {
785 let mut ctx = Context::new();
786
787 let script = r#"
788 (set-logic QF_UF)
789 (declare-const p Bool)
790 (declare-const q Bool)
791 (assert p)
792 (check-sat-assuming (q))
793 "#;
794
795 let output = ctx.execute_script(script).unwrap();
796 assert_eq!(output.len(), 1);
797 assert_eq!(output[0], "sat");
798 }
799
800 #[test]
801 fn test_get_option_script() {
802 let mut ctx = Context::new();
803
804 let script = r#"
805 (set-option :produce-models true)
806 (get-option :produce-models)
807 "#;
808
809 let output = ctx.execute_script(script).unwrap();
810 assert_eq!(output.len(), 1);
811 assert_eq!(output[0], "true");
812 }
813
814 #[test]
815 fn test_reset_assertions() {
816 let mut ctx = Context::new();
817
818 let script = r#"
819 (set-logic QF_UF)
820 (declare-const p Bool)
821 (assert p)
822 (reset-assertions)
823 (get-assertions)
824 (check-sat)
825 "#;
826
827 let output = ctx.execute_script(script).unwrap();
828 assert_eq!(output.len(), 2);
829 assert_eq!(output[0], "()"); assert_eq!(output[1], "sat"); }
832
833 #[test]
834 fn test_simplify_command() {
835 let mut ctx = Context::new();
836
837 let script = r#"
838 (simplify (+ 1 2))
839 "#;
840
841 let output = ctx.execute_script(script).unwrap();
842 assert_eq!(output.len(), 1);
843 assert_eq!(output[0], "3");
845 }
846
847 #[test]
848 fn test_simplify_complex() {
849 let mut ctx = Context::new();
850
851 let script = r#"
852 (simplify (* 2 3 4))
853 "#;
854
855 let output = ctx.execute_script(script).unwrap();
856 assert_eq!(output.len(), 1);
857 assert_eq!(output[0], "24");
859 }
860
861 #[test]
862 fn test_get_value() {
863 let mut ctx = Context::new();
864
865 let script = r#"
866 (set-logic QF_UF)
867 (declare-const p Bool)
868 (declare-const q Bool)
869 (assert p)
870 (assert (not q))
871 (check-sat)
872 (get-value (p q (and p q) (or p q)))
873 "#;
874
875 let output = ctx.execute_script(script).unwrap();
876 assert_eq!(output.len(), 2);
877 assert_eq!(output[0], "sat");
878
879 let value_output = &output[1];
881 assert!(value_output.contains("p"));
882 assert!(value_output.contains("q"));
883 assert!(value_output.contains("true"));
885 assert!(value_output.contains("false"));
887 }
888
889 #[test]
890 fn test_get_value_no_model() {
891 let mut ctx = Context::new();
892
893 let script = r#"
894 (set-logic QF_UF)
895 (declare-const p Bool)
896 (get-value (p))
897 "#;
898
899 let output = ctx.execute_script(script).unwrap();
900 assert_eq!(output.len(), 1);
901 assert!(output[0].contains("error") || output[0].contains("No model"));
902 }
903
904 #[test]
905 fn test_get_value_after_unsat() {
906 let mut ctx = Context::new();
907
908 let script = r#"
909 (set-logic QF_UF)
910 (declare-const p Bool)
911 (assert p)
912 (assert (not p))
913 (check-sat)
914 (get-value (p))
915 "#;
916
917 let output = ctx.execute_script(script).unwrap();
918 assert_eq!(output.len(), 2);
919 assert_eq!(output[0], "unsat");
920 assert!(output[1].contains("error") || output[1].contains("No model"));
921 }
922}