Skip to main content

oxiz_solver/
context.rs

1//! Solver context
2
3#[allow(unused_imports)]
4use crate::prelude::*;
5use crate::solver::{Solver, SolverResult};
6use oxiz_core::ast::{TermId, TermKind, TermManager};
7#[cfg(feature = "std")]
8use oxiz_core::error::Result;
9#[cfg(feature = "std")]
10use oxiz_core::smtlib::{Command, parse_script};
11use oxiz_core::sort::SortId;
12#[cfg(feature = "std")]
13use std::path::{Path, PathBuf};
14
15/// Raw function interpretation: a list of `(arg_strings, value_string)` entries
16/// together with an `else_value` string and the function arity.
17///
18/// Used as the return type of [`Context::get_func_interp_raw`] to avoid pulling
19/// `oxiz_core::model` types into the public API of this file.
20pub type RawFuncInterp = (Vec<(Vec<String>, String)>, String, usize);
21
22/// A declared constant
23#[derive(Debug, Clone)]
24struct DeclaredConst {
25    /// The term ID for this constant
26    term: TermId,
27    /// The sort of this constant
28    sort: SortId,
29    /// The name of this constant
30    name: String,
31}
32
33/// A declared function
34#[derive(Debug, Clone)]
35struct DeclaredFun {
36    /// The function name
37    name: String,
38    /// Argument sorts
39    arg_sorts: Vec<SortId>,
40    /// Return sort
41    ret_sort: SortId,
42}
43
44/// Solver context for managing the solving process
45///
46/// The `Context` provides a high-level API for SMT solving, similar to
47/// the SMT-LIB2 standard. It manages declarations, assertions, and solver state.
48///
49/// # Examples
50///
51/// ## Basic Usage
52///
53/// ```
54/// use oxiz_solver::Context;
55///
56/// let mut ctx = Context::new();
57/// ctx.set_logic("QF_UF");
58///
59/// // Declare boolean constants
60/// let p = ctx.declare_const("p", ctx.terms.sorts.bool_sort);
61/// let q = ctx.declare_const("q", ctx.terms.sorts.bool_sort);
62///
63/// // Assert p AND q
64/// let formula = ctx.terms.mk_and(vec![p, q]);
65/// ctx.assert(formula);
66///
67/// // Check satisfiability
68/// ctx.check_sat();
69/// ```
70///
71/// ## SMT-LIB2 Script Execution
72///
73/// ```
74/// use oxiz_solver::Context;
75///
76/// let mut ctx = Context::new();
77///
78/// let script = r#"
79/// (set-logic QF_LIA)
80/// (declare-const x Int)
81/// (assert (>= x 0))
82/// (assert (<= x 10))
83/// (check-sat)
84/// "#;
85///
86/// let _ = ctx.execute_script(script);
87/// ```
88#[derive(Debug)]
89pub struct Context {
90    /// Term manager
91    pub terms: TermManager,
92    /// Solver instance
93    solver: Solver,
94    /// Current logic
95    logic: Option<String>,
96    /// Assertions
97    assertions: Vec<TermId>,
98    /// Assertion stack for push/pop
99    assertion_stack: Vec<usize>,
100    /// Declared constants
101    declared_consts: Vec<DeclaredConst>,
102    /// Declared constants stack for push/pop
103    const_stack: Vec<usize>,
104    /// Mapping from constant names to indices (for efficient removal)
105    const_name_to_index: crate::prelude::HashMap<String, usize>,
106    /// Declared functions
107    declared_funs: Vec<DeclaredFun>,
108    /// Declared functions stack for push/pop
109    fun_stack: Vec<usize>,
110    /// Mapping from function names to indices
111    fun_name_to_index: crate::prelude::HashMap<String, usize>,
112    /// Last check-sat result
113    last_result: Option<SolverResult>,
114    /// Options
115    options: crate::prelude::HashMap<String, String>,
116    /// Optional path for binary proof logging.
117    ///
118    /// When set, `check_sat` creates a `ProofLogger` at this path, records
119    /// proof steps derived from the solver result, and flushes/closes the log
120    /// before returning.
121    #[cfg(feature = "std")]
122    proof_log_path: Option<PathBuf>,
123}
124
125impl Default for Context {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl Context {
132    /// Create a new context
133    #[must_use]
134    pub fn new() -> Self {
135        Self {
136            terms: TermManager::new(),
137            solver: Solver::new(),
138            logic: None,
139            assertions: Vec::new(),
140            assertion_stack: Vec::new(),
141            declared_consts: Vec::new(),
142            const_stack: Vec::new(),
143            const_name_to_index: crate::prelude::HashMap::new(),
144            declared_funs: Vec::new(),
145            fun_stack: Vec::new(),
146            fun_name_to_index: crate::prelude::HashMap::new(),
147            last_result: None,
148            options: crate::prelude::HashMap::new(),
149            #[cfg(feature = "std")]
150            proof_log_path: None,
151        }
152    }
153
154    /// Configure a path for binary proof logging.
155    ///
156    /// When a path is configured, every subsequent call to `check_sat` opens a
157    /// [`oxiz_proof::logging::ProofLogger`] at that path, writes a structural
158    /// summary of the proof, and flushes/closes the log before returning.
159    /// Pass `None` to disable proof logging.
160    #[cfg(feature = "std")]
161    pub fn set_proof_log_path(&mut self, path: Option<PathBuf>) {
162        self.proof_log_path = path;
163    }
164
165    /// Return the currently configured proof log path, if any.
166    #[cfg(feature = "std")]
167    #[must_use]
168    pub fn proof_log_path(&self) -> Option<&Path> {
169        self.proof_log_path.as_deref()
170    }
171
172    /// Verify a binary proof log produced by a previous `check_sat` call with
173    /// proof logging enabled.
174    ///
175    /// Delegates to [`oxiz_proof::replay::ProofReplayer::replay_from_file`].
176    ///
177    /// # Errors
178    ///
179    /// Returns `Err` only for hard I/O or binary-format failures; logical
180    /// invalidity is encoded as `Ok(VerificationResult::Invalid(_))`.
181    #[cfg(feature = "std")]
182    pub fn verify_proof_log(
183        path: &Path,
184    ) -> std::result::Result<oxiz_proof::replay::VerificationResult, oxiz_proof::replay::ProofError>
185    {
186        oxiz_proof::replay::ProofReplayer::replay_from_file(path)
187    }
188
189    /// Declare a constant
190    pub fn declare_const(&mut self, name: &str, sort: SortId) -> TermId {
191        let term = self.terms.mk_var(name, sort);
192        let index = self.declared_consts.len();
193        self.declared_consts.push(DeclaredConst {
194            term,
195            sort,
196            name: name.to_string(),
197        });
198        self.const_name_to_index.insert(name.to_string(), index);
199        term
200    }
201
202    /// Declare a function
203    ///
204    /// Registers a function signature in the context. For nullary functions (constants),
205    /// use `declare_const` instead.
206    pub fn declare_fun(&mut self, name: &str, arg_sorts: Vec<SortId>, ret_sort: SortId) {
207        let index = self.declared_funs.len();
208        self.declared_funs.push(DeclaredFun {
209            name: name.to_string(),
210            arg_sorts,
211            ret_sort,
212        });
213        self.fun_name_to_index.insert(name.to_string(), index);
214    }
215
216    /// Get function signature if it exists
217    pub fn get_fun_signature(&self, name: &str) -> Option<(Vec<SortId>, SortId)> {
218        self.fun_name_to_index.get(name).and_then(|&idx| {
219            self.declared_funs
220                .get(idx)
221                .map(|f| (f.arg_sorts.clone(), f.ret_sort))
222        })
223    }
224
225    /// Iterate over the names of all currently declared uninterpreted functions.
226    pub fn declared_function_names(&self) -> impl Iterator<Item = &str> {
227        self.declared_funs.iter().map(|d| d.name.as_str())
228    }
229
230    /// Set the logic
231    pub fn set_logic(&mut self, logic: &str) {
232        self.logic = Some(logic.to_string());
233        self.solver.set_logic(logic);
234    }
235
236    /// Get the current logic
237    #[must_use]
238    pub fn logic(&self) -> Option<&str> {
239        self.logic.as_deref()
240    }
241
242    /// Add an assertion
243    pub fn assert(&mut self, term: TermId) {
244        self.assertions.push(term);
245        self.solver.assert(term, &mut self.terms);
246    }
247
248    /// Check satisfiability
249    pub fn check_sat(&mut self) -> SolverResult {
250        let result = self.solver.check(&mut self.terms);
251        self.last_result = Some(result);
252
253        // Write a binary proof log if a path is configured (std-only).
254        #[cfg(feature = "std")]
255        if let Some(ref path) = self.proof_log_path.clone() {
256            if let Err(e) = self.write_proof_log(path, result) {
257                // Non-fatal: warn but do not abort the solve.
258                #[cfg(feature = "tracing")]
259                tracing::warn!("proof log write failed for {:?}: {}", path, e);
260                let _ = e;
261            }
262        }
263
264        result
265    }
266
267    /// Serialise a proof log entry for the given result.
268    ///
269    /// For `Unsat`, resolution proof steps are emitted when available;
270    /// for `Sat` and `Unknown`, a single axiom node is written so the log is
271    /// never empty and can be cleanly replayed.
272    #[cfg(feature = "std")]
273    fn write_proof_log(
274        &self,
275        path: &Path,
276        result: SolverResult,
277    ) -> std::result::Result<(), oxiz_proof::logging::LoggingError> {
278        use oxiz_proof::logging::ProofLogger;
279        use oxiz_proof::proof::{ProofNodeId, ProofStep};
280        use smallvec::SmallVec;
281
282        let mut logger = ProofLogger::create(path)?;
283
284        match result {
285            SolverResult::Unsat => {
286                if let Some(proof) = self.solver.get_proof() {
287                    let mut counter: u32 = 0;
288                    for step in proof.steps() {
289                        let entry = match step {
290                            crate::solver::ProofStep::Input { index, .. } => ProofStep::Axiom {
291                                conclusion: format!("input-clause-{}", index),
292                            },
293                            crate::solver::ProofStep::Resolution {
294                                index,
295                                left,
296                                right,
297                                pivot,
298                                ..
299                            } => {
300                                let mut premises: SmallVec<[ProofNodeId; 4]> = SmallVec::new();
301                                premises.push(ProofNodeId(*left));
302                                premises.push(ProofNodeId(*right));
303                                let mut args: SmallVec<[String; 2]> = SmallVec::new();
304                                args.push(format!("{:?}", pivot));
305                                ProofStep::Inference {
306                                    rule: "resolution".to_string(),
307                                    premises,
308                                    conclusion: format!("resolution-{}", index),
309                                    args,
310                                }
311                            }
312                            crate::solver::ProofStep::TheoryLemma { index, theory, .. } => {
313                                ProofStep::Axiom {
314                                    conclusion: format!("theory-lemma-{}-{}", theory, index),
315                                }
316                            }
317                        };
318                        logger.log_step(ProofNodeId(counter), &entry)?;
319                        counter += 1;
320                    }
321                    if counter == 0 {
322                        // Proof object present but empty — emit minimal witness.
323                        logger.log_step(
324                            ProofNodeId(0),
325                            &ProofStep::Axiom {
326                                conclusion: "unsat".to_string(),
327                            },
328                        )?;
329                    }
330                } else {
331                    logger.log_step(
332                        ProofNodeId(0),
333                        &ProofStep::Axiom {
334                            conclusion: "unsat".to_string(),
335                        },
336                    )?;
337                }
338            }
339            SolverResult::Sat => {
340                logger.log_step(
341                    ProofNodeId(0),
342                    &ProofStep::Axiom {
343                        conclusion: "sat".to_string(),
344                    },
345                )?;
346            }
347            SolverResult::Unknown => {
348                logger.log_step(
349                    ProofNodeId(0),
350                    &ProofStep::Axiom {
351                        conclusion: "unknown".to_string(),
352                    },
353                )?;
354            }
355        }
356
357        logger.flush()?;
358        logger.close()
359    }
360
361    /// Get the model (if SAT)
362    /// Returns a list of (name, sort, value) tuples
363    pub fn get_model(&self) -> Option<Vec<(String, String, String)>> {
364        if self.last_result != Some(SolverResult::Sat) {
365            return None;
366        }
367
368        let mut model = Vec::new();
369        let solver_model = self.solver.model()?;
370
371        for decl in &self.declared_consts {
372            let value = if let Some(val) = solver_model.get(decl.term) {
373                self.format_value(val)
374            } else {
375                // Default value based on sort
376                self.default_value(decl.sort)
377            };
378            let sort_name = self.format_sort_name(decl.sort);
379            model.push((decl.name.clone(), sort_name, value));
380        }
381
382        Some(model)
383    }
384
385    /// Build a raw function interpretation for a declared uninterpreted function.
386    ///
387    /// Derives entries from the EUF congruence closure rather than from raw
388    /// `Apply` terms alone.  For every application `f(a1, …, an)` interned in the
389    /// E-graph, the arguments and the result are canonicalized through their
390    /// equivalence-class representatives, so:
391    ///
392    /// - Two applications whose arguments are pairwise congruent (e.g. `f(a)` and
393    ///   `f(b)` when `a = b` is implied by the assertions) collapse to a **single**
394    ///   entry keyed by the shared argument class.
395    /// - The reported argument and result strings are **model values** taken from
396    ///   the class (resolving through the representative), not raw term ids.
397    /// - When an application has no direct model value, the value of any congruent
398    ///   member of its class is used.
399    ///
400    /// `else_value` is chosen as the most frequently occurring entry value (ties
401    /// broken by first occurrence), mirroring how Z3 selects a default; if there
402    /// are no entries it falls back to the return sort's default value.
403    ///
404    /// Returns `None` when:
405    /// - the last check was not `Sat`, or
406    /// - no model is available, or
407    /// - `func_name` is not a declared function.
408    ///
409    /// The return type is `(entries, else_value_string, arity)` to avoid
410    /// pulling `oxiz_core::model` types into this file.
411    pub fn get_func_interp_raw(&self, func_name: &str) -> Option<RawFuncInterp> {
412        if self.last_result != Some(SolverResult::Sat) {
413            return None;
414        }
415        let solver_model = self.solver.model()?;
416
417        // Find the declared function so we know its arity and default sort.
418        let decl = self.declared_funs.iter().find(|d| d.name == func_name)?;
419        let arity = decl.arg_sorts.len();
420        let default_else = self.default_value(decl.ret_sort);
421
422        // Resolve `func_name` to the EUF function-symbol id.  For an `Apply`
423        // term the EUF id is the underlying value of the function-name `Spur`,
424        // so we recover it from any matching application term (read-only — no
425        // mutable interner access required).
426        let mut func_id: Option<u32> = None;
427        for idx in 0..(self.terms.len() as u32) {
428            let tid = TermId(idx);
429            let Some(term) = self.terms.get(tid) else {
430                continue;
431            };
432            if let TermKind::Apply {
433                func: func_spur, ..
434            } = &term.kind
435                && self.terms.resolve_str(*func_spur) == func_name
436            {
437                func_id = Some(func_spur.into_inner().get());
438                break;
439            }
440        }
441
442        // No application of this function exists in the E-graph: the function is
443        // declared but never applied, so its interpretation is purely the default.
444        let Some(func_id) = func_id else {
445            return Some((Vec::new(), default_else, arity));
446        };
447
448        // Pull congruence-closed application entries from the EUF solver.  Each
449        // entry already has its argument and result classes canonicalized, so
450        // congruence (e.g. f(a) == f(b) when a == b) is applied for us.
451        let euf_entries = self.solver.euf_function_entries(func_id);
452
453        // Deduplicate on the canonical argument-class representative tuple so
454        // congruent applications produce exactly one entry.  Because congruence
455        // forces congruent applications into the same result class, the values
456        // agree in a consistent model.
457        let mut seen_arg_keys: crate::prelude::HashSet<smallvec::SmallVec<[u32; 4]>> =
458            crate::prelude::HashSet::new();
459        let mut entries: Vec<(Vec<String>, String)> = Vec::new();
460        for entry in &euf_entries {
461            // Resolve the result value first: skip applications whose class has
462            // no concrete model value (an unconstrained application contributes
463            // nothing observable beyond the else-branch).
464            let Some(val_str) = self.class_value_string(&entry.result_class_terms, solver_model)
465            else {
466                continue;
467            };
468
469            if !seen_arg_keys.insert(entry.arg_reps.clone()) {
470                continue; // already emitted this congruence class of arguments
471            }
472
473            // Resolve each argument to its canonical model value.  Falls back to
474            // the default value for the corresponding argument sort when the
475            // class carries no concrete value (rare: an unconstrained argument).
476            let arg_strs: Vec<String> = entry
477                .arg_class_terms
478                .iter()
479                .enumerate()
480                .map(|(i, members)| {
481                    self.class_value_string(members, solver_model)
482                        .unwrap_or_else(|| {
483                            decl.arg_sorts
484                                .get(i)
485                                .map_or_else(|| "?".to_string(), |&s| self.default_value(s))
486                        })
487                })
488                .collect();
489            entries.push((arg_strs, val_str));
490        }
491
492        // Pick `else_value`: the most common entry value (ties → first seen),
493        // matching Z3's habit of reusing an existing value as the default.
494        let else_value = Self::most_common_value(&entries).unwrap_or(default_else);
495
496        Some((entries, else_value, arity))
497    }
498
499    /// Resolve an equivalence class (its member `TermId`s) to a formatted model
500    /// value string, by finding the first member that carries either a direct
501    /// model assignment or is itself a literal constant.
502    ///
503    /// Returns `None` when no member of the class has an observable value.
504    fn class_value_string(
505        &self,
506        members: &[TermId],
507        solver_model: &crate::solver::Model,
508    ) -> Option<String> {
509        for &member in members {
510            // Direct model assignment (covers variables and applications whose
511            // value was extracted from an equality constraint).
512            if let Some(val_term) = solver_model.get(member) {
513                return Some(self.format_value(val_term));
514            }
515            // The member may itself be a literal constant (e.g. the term `5` in
516            // `f(a) = 5`), which has no separate model entry but is its own value.
517            if let Some(term) = self.terms.get(member)
518                && matches!(
519                    term.kind,
520                    TermKind::True
521                        | TermKind::False
522                        | TermKind::IntConst(_)
523                        | TermKind::RealConst(_)
524                        | TermKind::BitVecConst { .. }
525                )
526            {
527                return Some(self.format_value(member));
528            }
529        }
530        None
531    }
532
533    /// Choose the most frequently occurring value among the interpretation
534    /// entries, breaking ties in favour of the earliest occurrence.  Returns
535    /// `None` for an empty entry list.
536    fn most_common_value(entries: &[(Vec<String>, String)]) -> Option<String> {
537        let mut counts: crate::prelude::HashMap<&str, (usize, usize)> =
538            crate::prelude::HashMap::new();
539        for (order, (_, value)) in entries.iter().enumerate() {
540            let slot = counts.entry(value.as_str()).or_insert((0, order));
541            slot.0 += 1;
542        }
543        counts
544            .into_iter()
545            .max_by(|(_, (count_a, order_a)), (_, (count_b, order_b))| {
546                // Higher count wins; on a tie the smaller insertion order wins,
547                // so we reverse the order comparison.
548                count_a.cmp(count_b).then_with(|| order_b.cmp(order_a))
549            })
550            .map(|(value, _)| value.to_string())
551    }
552
553    /// Format a sort ID to its SMT-LIB2 name
554    fn format_sort_name(&self, sort: SortId) -> String {
555        if sort == self.terms.sorts.bool_sort {
556            "Bool".to_string()
557        } else if sort == self.terms.sorts.int_sort {
558            "Int".to_string()
559        } else if sort == self.terms.sorts.real_sort {
560            "Real".to_string()
561        } else if let Some(s) = self.terms.sorts.get(sort) {
562            if let Some(w) = s.bitvec_width() {
563                format!("(_ BitVec {})", w)
564            } else {
565                "Unknown".to_string()
566            }
567        } else {
568            "Unknown".to_string()
569        }
570    }
571
572    /// Format a model value
573    fn format_value(&self, term: TermId) -> String {
574        match self.terms.get(term).map(|t| &t.kind) {
575            Some(TermKind::True) => "true".to_string(),
576            Some(TermKind::False) => "false".to_string(),
577            Some(TermKind::IntConst(n)) => n.to_string(),
578            Some(TermKind::RealConst(r)) => {
579                if *r.denom() == 1 {
580                    format!("{}.0", r.numer())
581                } else {
582                    format!("(/ {} {})", r.numer(), r.denom())
583                }
584            }
585            Some(TermKind::BitVecConst { value, width }) => {
586                format!(
587                    "#b{:0>width$}",
588                    format!("{:b}", value),
589                    width = *width as usize
590                )
591            }
592            _ => "?".to_string(),
593        }
594    }
595
596    /// Get a default value for a sort
597    fn default_value(&self, sort: SortId) -> String {
598        if sort == self.terms.sorts.bool_sort {
599            "false".to_string()
600        } else if sort == self.terms.sorts.int_sort {
601            "0".to_string()
602        } else if sort == self.terms.sorts.real_sort {
603            "0.0".to_string()
604        } else if let Some(s) = self.terms.sorts.get(sort) {
605            if let Some(w) = s.bitvec_width() {
606                format!("#b{:0>width$}", "0", width = w as usize)
607            } else {
608                "?".to_string()
609            }
610        } else {
611            "?".to_string()
612        }
613    }
614
615    /// Format the model as SMT-LIB2
616    pub fn format_model(&self) -> String {
617        match self.get_model() {
618            None => "(error \"No model available\")".to_string(),
619            Some(model) if model.is_empty() => "(model)".to_string(),
620            Some(model) => {
621                let mut lines = vec!["(model".to_string()];
622                for (name, sort, value) in model {
623                    lines.push(format!("  (define-fun {} () {} {})", name, sort, value));
624                }
625                lines.push(")".to_string());
626                lines.join("\n")
627            }
628        }
629    }
630
631    /// Push a context level
632    pub fn push(&mut self) {
633        self.assertion_stack.push(self.assertions.len());
634        self.const_stack.push(self.declared_consts.len());
635        self.fun_stack.push(self.declared_funs.len());
636        self.solver.push();
637    }
638
639    /// Pop a context level with incremental declaration removal
640    pub fn pop(&mut self) {
641        if let Some(len) = self.assertion_stack.pop() {
642            self.assertions.truncate(len);
643            if let Some(const_len) = self.const_stack.pop() {
644                // Remove constants from the name-to-index mapping
645                while self.declared_consts.len() > const_len {
646                    if let Some(decl) = self.declared_consts.pop() {
647                        self.const_name_to_index.remove(&decl.name);
648                    }
649                }
650            }
651            if let Some(fun_len) = self.fun_stack.pop() {
652                // Remove functions from the name-to-index mapping
653                while self.declared_funs.len() > fun_len {
654                    if let Some(decl) = self.declared_funs.pop() {
655                        self.fun_name_to_index.remove(&decl.name);
656                    }
657                }
658            }
659            self.solver.pop();
660        }
661    }
662
663    /// Reset the context
664    pub fn reset(&mut self) {
665        self.solver.reset();
666        self.assertions.clear();
667        self.assertion_stack.clear();
668        self.declared_consts.clear();
669        self.const_stack.clear();
670        self.const_name_to_index.clear();
671        self.declared_funs.clear();
672        self.fun_stack.clear();
673        self.fun_name_to_index.clear();
674        self.logic = None;
675        self.last_result = None;
676        self.options.clear();
677    }
678
679    /// Reset assertions (keep declarations and options)
680    pub fn reset_assertions(&mut self) {
681        self.solver.reset();
682        self.assertions.clear();
683        self.assertion_stack.clear();
684        // Keep declared_consts, const_stack, const_name_to_index,
685        // declared_funs, fun_stack, and fun_name_to_index
686        // Re-assert nothing - solver is fresh
687        self.last_result = None;
688    }
689
690    /// Get all current assertions
691    #[must_use]
692    pub fn get_assertions(&self) -> &[TermId] {
693        &self.assertions
694    }
695
696    /// Format assertions as SMT-LIB2
697    #[cfg(feature = "std")]
698    pub fn format_assertions(&self) -> String {
699        if self.assertions.is_empty() {
700            return "()".to_string();
701        }
702        let printer = oxiz_core::smtlib::Printer::new(&self.terms);
703        let mut parts = Vec::new();
704        for &term in &self.assertions {
705            parts.push(printer.print_term(term));
706        }
707        format!("({})", parts.join("\n "))
708    }
709
710    /// Set an option
711    pub fn set_option(&mut self, key: &str, value: &str) {
712        self.options.insert(key.to_string(), value.to_string());
713
714        // Handle special options that affect the solver
715        match key {
716            "produce-proofs" => {
717                let mut config = self.solver.config().clone();
718                config.proof = value == "true";
719                self.solver.set_config(config);
720            }
721            "produce-unsat-cores" => {
722                self.solver.set_produce_unsat_cores(value == "true");
723            }
724            _ => {}
725        }
726    }
727
728    /// Get an option
729    #[must_use]
730    pub fn get_option(&self, key: &str) -> Option<&str> {
731        self.options.get(key).map(String::as_str)
732    }
733
734    /// Format an option value
735    fn format_option(&self, key: &str) -> String {
736        match self.get_option(key) {
737            Some(val) => val.to_string(),
738            None => {
739                // Return default values for well-known options
740                match key {
741                    "produce-models" => "false".to_string(),
742                    "produce-unsat-cores" => "false".to_string(),
743                    "produce-proofs" => "false".to_string(),
744                    "produce-assignments" => "false".to_string(),
745                    "print-success" => "true".to_string(),
746                    _ => "unsupported".to_string(),
747                }
748            }
749        }
750    }
751
752    /// Get assignment (for propositional variables with :named attribute)
753    /// Returns an empty list as we don't track named literals yet
754    pub fn get_assignment(&self) -> String {
755        "()".to_string()
756    }
757
758    /// Get proof (if proof generation is enabled and result is unsat)
759    pub fn get_proof(&self) -> String {
760        if self.last_result != Some(SolverResult::Unsat) {
761            return "(error \"Proof is only available after unsat result\")".to_string();
762        }
763
764        match self.solver.get_proof() {
765            Some(proof) => proof.format(),
766            None => {
767                "(error \"Proof generation not enabled. Set :produce-proofs to true\")".to_string()
768            }
769        }
770    }
771
772    /// Get solver statistics
773    /// Returns statistics about the last solving run
774    pub fn get_statistics(&self) -> String {
775        let stats = self.solver.get_statistics();
776        format!(
777            "(:decisions {} :conflicts {} :propagations {} :restarts {} :learned-clauses {} :theory-propagations {} :theory-conflicts {})",
778            stats.decisions,
779            stats.conflicts,
780            stats.propagations,
781            stats.restarts,
782            stats.learned_clauses,
783            stats.theory_propagations,
784            stats.theory_conflicts
785        )
786    }
787
788    /// Return the raw solver statistics (crate-internal use only).
789    #[must_use]
790    pub(crate) fn raw_statistics(&self) -> &crate::solver::Statistics {
791        self.solver.get_statistics()
792    }
793
794    /// Return the current solver configuration (crate-internal use only).
795    #[must_use]
796    pub(crate) fn solver_config(&self) -> &crate::solver::SolverConfig {
797        self.solver.config()
798    }
799
800    /// Update the solver configuration (crate-internal use only).
801    pub(crate) fn set_solver_config(&mut self, config: crate::solver::SolverConfig) {
802        self.solver.set_config(config);
803    }
804
805    /// Check satisfiability under temporary assumptions (crate-internal use only).
806    pub(crate) fn check_with_assumptions_raw(
807        &mut self,
808        assumptions: &[oxiz_core::ast::TermId],
809    ) -> crate::solver::SolverResult {
810        self.solver
811            .check_with_assumptions(assumptions, &mut self.terms)
812    }
813
814    /// Return the unsat core from the last check (crate-internal use only).
815    #[must_use]
816    pub(crate) fn get_unsat_core_raw(&self) -> Option<&crate::solver::UnsatCore> {
817        self.solver.get_unsat_core()
818    }
819
820    /// Parse a sort name and return its SortId
821    fn parse_sort_name(&mut self, name: &str) -> SortId {
822        match name {
823            "Bool" => self.terms.sorts.bool_sort,
824            "Int" => self.terms.sorts.int_sort,
825            "Real" => self.terms.sorts.real_sort,
826            _ => {
827                // Check for BitVec
828                if let Some(width_str) = name.strip_prefix("BitVec")
829                    && let Ok(width) = width_str.trim().parse::<u32>()
830                {
831                    return self.terms.sorts.bitvec(width);
832                }
833                // Default to Bool for unknown sorts
834                self.terms.sorts.bool_sort
835            }
836        }
837    }
838
839    /// Execute an SMT-LIB2 script
840    #[cfg(feature = "std")]
841    pub fn execute_script(&mut self, script: &str) -> Result<Vec<String>> {
842        let commands = parse_script(script, &mut self.terms)?;
843        let mut output = Vec::new();
844
845        for cmd in commands {
846            match cmd {
847                Command::SetLogic(logic) => {
848                    self.set_logic(&logic);
849                }
850                Command::DeclareConst(name, sort_name) => {
851                    let sort = self.parse_sort_name(&sort_name);
852                    self.declare_const(&name, sort);
853                }
854                Command::DeclareFun(name, arg_sorts, ret_sort) => {
855                    // Treat nullary functions as constants
856                    if arg_sorts.is_empty() {
857                        let sort = self.parse_sort_name(&ret_sort);
858                        self.declare_const(&name, sort);
859                    } else {
860                        // Parse argument sorts and return sort
861                        let parsed_arg_sorts: Vec<SortId> =
862                            arg_sorts.iter().map(|s| self.parse_sort_name(s)).collect();
863                        let parsed_ret_sort = self.parse_sort_name(&ret_sort);
864                        self.declare_fun(&name, parsed_arg_sorts, parsed_ret_sort);
865                    }
866                }
867                Command::Assert(term) => {
868                    self.assert(term);
869                }
870                Command::CheckSat => {
871                    let result = self.check_sat();
872                    output.push(match result {
873                        SolverResult::Sat => "sat".to_string(),
874                        SolverResult::Unsat => "unsat".to_string(),
875                        SolverResult::Unknown => "unknown".to_string(),
876                    });
877                }
878                Command::Push(n) => {
879                    for _ in 0..n {
880                        self.push();
881                    }
882                }
883                Command::Pop(n) => {
884                    for _ in 0..n {
885                        self.pop();
886                    }
887                }
888                Command::Reset => {
889                    self.reset();
890                }
891                Command::ResetAssertions => {
892                    self.reset_assertions();
893                }
894                Command::Exit => {
895                    break;
896                }
897                Command::Echo(msg) => {
898                    output.push(msg);
899                }
900                Command::GetModel => {
901                    output.push(self.format_model());
902                }
903                Command::GetAssertions => {
904                    output.push(self.format_assertions());
905                }
906                Command::GetAssignment => {
907                    output.push(self.get_assignment());
908                }
909                Command::GetProof => {
910                    output.push(self.get_proof());
911                }
912                Command::GetOption(key) => {
913                    output.push(self.format_option(&key));
914                }
915                Command::SetOption(key, value) => {
916                    self.set_option(&key, &value);
917                }
918                Command::CheckSatAssuming(assumptions) => {
919                    // For now, we push, assert all assumptions, check, then pop
920                    self.push();
921                    for assumption in assumptions {
922                        self.assert(assumption);
923                    }
924                    let result = self.check_sat();
925                    self.pop();
926                    output.push(match result {
927                        SolverResult::Sat => "sat".to_string(),
928                        SolverResult::Unsat => "unsat".to_string(),
929                        SolverResult::Unknown => "unknown".to_string(),
930                    });
931                }
932                Command::Simplify(term) => {
933                    // Simplify and output the term
934                    let simplified = self.terms.simplify(term);
935                    let printer = oxiz_core::smtlib::Printer::new(&self.terms);
936                    output.push(printer.print_term(simplified));
937                }
938                Command::GetUnsatCore => {
939                    if let Some(core) = self.solver.get_unsat_core() {
940                        if core.names.is_empty() {
941                            output.push("()".to_string());
942                        } else {
943                            output.push(format!("({})", core.names.join(" ")));
944                        }
945                    } else {
946                        output.push("(error \"No unsat core available\")".to_string());
947                    }
948                }
949                Command::GetValue(terms) => {
950                    if self.last_result != Some(SolverResult::Sat) {
951                        output.push("(error \"No model available\")".to_string());
952                    } else if let Some(model) = self.solver.model() {
953                        let mut values = Vec::new();
954                        for term in terms {
955                            // Evaluate the term in the model first
956                            let value = model.eval(term, &mut self.terms);
957                            // Then create printer and format
958                            let printer = oxiz_core::smtlib::Printer::new(&self.terms);
959                            let term_str = printer.print_term(term);
960                            let value_str = printer.print_term(value);
961                            values.push(format!("({} {})", term_str, value_str));
962                        }
963                        output.push(format!("({})", values.join("\n ")));
964                    } else {
965                        output.push("(error \"No model available\")".to_string());
966                    }
967                }
968                Command::GetInfo(keyword) => {
969                    // Handle get-info command
970                    if keyword == ":all-statistics" {
971                        output.push(self.get_statistics());
972                    } else {
973                        output.push(format!("(error \"Unsupported info keyword: {}\")", keyword));
974                    }
975                }
976                Command::SetInfo(_, _)
977                | Command::DeclareSort(_, _)
978                | Command::DefineSort(_, _, _)
979                | Command::DefineFun(_, _, _, _)
980                | Command::DeclareDatatype { .. } => {
981                    // Ignore these commands for now
982                }
983            }
984        }
985
986        Ok(output)
987    }
988
989    /// Get solver statistics
990    #[must_use]
991    pub fn stats(&self) -> &oxiz_sat::SolverStats {
992        self.solver.stats()
993    }
994}
995
996#[cfg(test)]
997mod tests {
998    use super::*;
999
1000    #[test]
1001    fn test_context_basic() {
1002        let mut ctx = Context::new();
1003
1004        ctx.set_logic("QF_UF");
1005        assert_eq!(ctx.logic(), Some("QF_UF"));
1006
1007        let t = ctx.terms.mk_true();
1008        ctx.assert(t);
1009
1010        let result = ctx.check_sat();
1011        assert_eq!(result, SolverResult::Sat);
1012    }
1013
1014    #[test]
1015    fn test_context_push_pop() {
1016        let mut ctx = Context::new();
1017
1018        let t = ctx.terms.mk_true();
1019        ctx.assert(t);
1020        ctx.push();
1021
1022        let f = ctx.terms.mk_false();
1023        ctx.assert(f);
1024
1025        // Should be unsat with false asserted
1026        let result = ctx.check_sat();
1027        assert_eq!(result, SolverResult::Unsat);
1028
1029        ctx.pop();
1030
1031        // After pop, should be sat again
1032        let result = ctx.check_sat();
1033        assert_eq!(result, SolverResult::Sat);
1034    }
1035
1036    #[test]
1037    fn test_execute_script() {
1038        let mut ctx = Context::new();
1039
1040        let script = r#"
1041            (set-logic QF_UF)
1042            (declare-const p Bool)
1043            (assert p)
1044            (check-sat)
1045        "#;
1046
1047        let output = ctx
1048            .execute_script(script)
1049            .expect("test operation should succeed");
1050        assert_eq!(output, vec!["sat"]);
1051    }
1052
1053    #[test]
1054    fn test_declare_const() {
1055        let mut ctx = Context::new();
1056
1057        let bool_sort = ctx.terms.sorts.bool_sort;
1058        let int_sort = ctx.terms.sorts.int_sort;
1059
1060        ctx.declare_const("x", bool_sort);
1061        ctx.declare_const("y", int_sort);
1062
1063        let t = ctx.terms.mk_true();
1064        ctx.assert(t);
1065        let result = ctx.check_sat();
1066        assert_eq!(result, SolverResult::Sat);
1067
1068        // Model should include both constants
1069        let model = ctx.get_model();
1070        assert!(model.is_some());
1071        let model = model.expect("test operation should succeed");
1072        assert_eq!(model.len(), 2);
1073    }
1074
1075    #[test]
1076    fn test_format_model() {
1077        let mut ctx = Context::new();
1078
1079        let bool_sort = ctx.terms.sorts.bool_sort;
1080        ctx.declare_const("p", bool_sort);
1081
1082        let t = ctx.terms.mk_true();
1083        ctx.assert(t);
1084        let _ = ctx.check_sat();
1085
1086        let model_str = ctx.format_model();
1087        assert!(model_str.contains("(model"));
1088        assert!(model_str.contains("define-fun p () Bool"));
1089    }
1090
1091    #[test]
1092    fn test_get_model_script() {
1093        let mut ctx = Context::new();
1094
1095        let script = r#"
1096            (set-logic QF_LIA)
1097            (declare-const x Int)
1098            (declare-const y Bool)
1099            (assert true)
1100            (check-sat)
1101            (get-model)
1102        "#;
1103
1104        let output = ctx
1105            .execute_script(script)
1106            .expect("test operation should succeed");
1107        assert_eq!(output.len(), 2);
1108        assert_eq!(output[0], "sat");
1109        assert!(
1110            output[1].contains("(model"),
1111            "Expected '(model' in: {}",
1112            output[1]
1113        );
1114        // Note: Sorts may not always appear in model output if values are default
1115        // The model format is: (define-fun name () Sort value)
1116    }
1117
1118    #[test]
1119    fn test_push_pop_consts() {
1120        let mut ctx = Context::new();
1121
1122        let bool_sort = ctx.terms.sorts.bool_sort;
1123        ctx.declare_const("a", bool_sort);
1124        ctx.push();
1125        ctx.declare_const("b", bool_sort);
1126
1127        let t = ctx.terms.mk_true();
1128        ctx.assert(t);
1129        let _ = ctx.check_sat();
1130
1131        let model = ctx.get_model().expect("test operation should succeed");
1132        assert_eq!(model.len(), 2);
1133
1134        ctx.pop();
1135        let _ = ctx.check_sat();
1136
1137        let model = ctx.get_model().expect("test operation should succeed");
1138        assert_eq!(model.len(), 1);
1139        assert_eq!(model[0].0, "a");
1140    }
1141
1142    #[test]
1143    fn test_get_assertions() {
1144        let mut ctx = Context::new();
1145
1146        let script = r#"
1147            (set-logic QF_UF)
1148            (declare-const p Bool)
1149            (assert p)
1150            (assert (not p))
1151            (get-assertions)
1152        "#;
1153
1154        let output = ctx
1155            .execute_script(script)
1156            .expect("test operation should succeed");
1157        assert_eq!(output.len(), 1);
1158        assert!(output[0].starts_with('('));
1159        // Should contain both assertions
1160        assert!(output[0].contains("p"));
1161    }
1162
1163    #[test]
1164    fn test_check_sat_assuming_script() {
1165        let mut ctx = Context::new();
1166
1167        let script = r#"
1168            (set-logic QF_UF)
1169            (declare-const p Bool)
1170            (declare-const q Bool)
1171            (assert p)
1172            (check-sat-assuming (q))
1173        "#;
1174
1175        let output = ctx
1176            .execute_script(script)
1177            .expect("test operation should succeed");
1178        assert_eq!(output.len(), 1);
1179        assert_eq!(output[0], "sat");
1180    }
1181
1182    #[test]
1183    fn test_get_option_script() {
1184        let mut ctx = Context::new();
1185
1186        let script = r#"
1187            (set-option :produce-models true)
1188            (get-option :produce-models)
1189        "#;
1190
1191        let output = ctx
1192            .execute_script(script)
1193            .expect("test operation should succeed");
1194        assert_eq!(output.len(), 1);
1195        assert_eq!(output[0], "true");
1196    }
1197
1198    #[test]
1199    fn test_reset_assertions() {
1200        let mut ctx = Context::new();
1201
1202        let script = r#"
1203            (set-logic QF_UF)
1204            (declare-const p Bool)
1205            (assert p)
1206            (reset-assertions)
1207            (get-assertions)
1208            (check-sat)
1209        "#;
1210
1211        let output = ctx
1212            .execute_script(script)
1213            .expect("test operation should succeed");
1214        assert_eq!(output.len(), 2);
1215        assert_eq!(output[0], "()"); // No assertions after reset
1216        assert_eq!(output[1], "sat"); // Empty formula is SAT
1217    }
1218
1219    #[test]
1220    fn test_simplify_command() {
1221        let mut ctx = Context::new();
1222
1223        let script = r#"
1224            (simplify (+ 1 2))
1225        "#;
1226
1227        let output = ctx
1228            .execute_script(script)
1229            .expect("test operation should succeed");
1230        assert_eq!(output.len(), 1);
1231        // Should simplify to 3
1232        assert_eq!(output[0], "3");
1233    }
1234
1235    #[test]
1236    fn test_simplify_complex() {
1237        let mut ctx = Context::new();
1238
1239        let script = r#"
1240            (simplify (* 2 3 4))
1241        "#;
1242
1243        let output = ctx
1244            .execute_script(script)
1245            .expect("test operation should succeed");
1246        assert_eq!(output.len(), 1);
1247        // Should simplify to 24
1248        assert_eq!(output[0], "24");
1249    }
1250
1251    #[test]
1252    fn test_get_value() {
1253        let mut ctx = Context::new();
1254
1255        let script = r#"
1256            (set-logic QF_UF)
1257            (declare-const p Bool)
1258            (declare-const q Bool)
1259            (assert p)
1260            (assert (not q))
1261            (check-sat)
1262            (get-value (p q (and p q) (or p q)))
1263        "#;
1264
1265        let output = ctx
1266            .execute_script(script)
1267            .expect("test operation should succeed");
1268        assert_eq!(output.len(), 2);
1269        assert_eq!(output[0], "sat");
1270
1271        // Parse the get-value output
1272        let value_output = &output[1];
1273        assert!(value_output.contains("p"));
1274        assert!(value_output.contains("q"));
1275        // p should evaluate to true
1276        assert!(value_output.contains("true"));
1277        // q should evaluate to false
1278        assert!(value_output.contains("false"));
1279    }
1280
1281    #[test]
1282    fn test_get_value_no_model() {
1283        let mut ctx = Context::new();
1284
1285        let script = r#"
1286            (set-logic QF_UF)
1287            (declare-const p Bool)
1288            (get-value (p))
1289        "#;
1290
1291        let output = ctx
1292            .execute_script(script)
1293            .expect("test operation should succeed");
1294        assert_eq!(output.len(), 1);
1295        assert!(output[0].contains("error") || output[0].contains("No model"));
1296    }
1297
1298    #[test]
1299    fn test_get_value_after_unsat() {
1300        let mut ctx = Context::new();
1301
1302        let script = r#"
1303            (set-logic QF_UF)
1304            (declare-const p Bool)
1305            (assert p)
1306            (assert (not p))
1307            (check-sat)
1308            (get-value (p))
1309        "#;
1310
1311        let output = ctx
1312            .execute_script(script)
1313            .expect("test operation should succeed");
1314        assert_eq!(output.len(), 2);
1315        assert_eq!(output[0], "unsat");
1316        assert!(output[1].contains("error") || output[1].contains("No model"));
1317    }
1318}