use crate::chc::{ChcSystem, PredId, Rule};
use crate::pdr::SpacerError;
use oxiz_core::{TermId, TermManager};
use smallvec::SmallVec;
use std::collections::HashSet;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum AbductionError {
#[error("no hypothesis found")]
NoHypothesis,
#[error("hypothesis generation failed: {0}")]
GenerationFailed(String),
#[error("inconsistent hypothesis: {0}")]
Inconsistent(String),
#[error("spacer error: {0}")]
Spacer(#[from] SpacerError),
}
pub type AbductionResult<T> = Result<T, AbductionError>;
#[derive(Debug, Clone)]
pub struct Hypothesis {
pub pred: PredId,
pub formula: TermId,
pub score: f64,
pub origin: HypothesisOrigin,
}
impl Hypothesis {
pub fn new(pred: PredId, formula: TermId, score: f64, origin: HypothesisOrigin) -> Self {
Self {
pred,
formula,
score,
origin,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HypothesisOrigin {
CounterexampleAnalysis,
TemplateInstantiation,
Interpolation,
DataDriven,
SyntaxGuided,
UserProvided,
}
#[derive(Debug, Clone)]
pub struct HypothesisTemplate {
pub name: String,
pub params: SmallVec<[(String, TemplateParamKind); 4]>,
pub structure: TemplateStructure,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TemplateParamKind {
NumericConstant,
Variable,
Operator,
Predicate,
}
#[derive(Debug, Clone)]
pub enum TemplateStructure {
LinearArithmetic,
Octagon,
Interval,
BooleanCombination,
Custom(String),
}
pub struct HypothesisGenerator {
templates: Vec<HypothesisTemplate>,
#[allow(dead_code)]
max_hypotheses: usize,
}
impl HypothesisGenerator {
pub fn new() -> Self {
Self {
templates: Vec::new(),
max_hypotheses: 100,
}
}
pub fn with_common_templates() -> Self {
let mut generator = Self::new();
generator.add_template(HypothesisTemplate {
name: "linear_inequality".to_string(),
params: vec![
("coeff".to_string(), TemplateParamKind::NumericConstant),
("var".to_string(), TemplateParamKind::Variable),
("bound".to_string(), TemplateParamKind::NumericConstant),
]
.into(),
structure: TemplateStructure::LinearArithmetic,
});
generator.add_template(HypothesisTemplate {
name: "octagon".to_string(),
params: vec![
("var1".to_string(), TemplateParamKind::Variable),
("var2".to_string(), TemplateParamKind::Variable),
("bound".to_string(), TemplateParamKind::NumericConstant),
]
.into(),
structure: TemplateStructure::Octagon,
});
generator.add_template(HypothesisTemplate {
name: "interval".to_string(),
params: vec![
("var".to_string(), TemplateParamKind::Variable),
("lower".to_string(), TemplateParamKind::NumericConstant),
("upper".to_string(), TemplateParamKind::NumericConstant),
]
.into(),
structure: TemplateStructure::Interval,
});
generator
}
pub fn add_template(&mut self, template: HypothesisTemplate) {
self.templates.push(template);
}
pub fn template_count(&self) -> usize {
self.templates.len()
}
pub fn generate_hypotheses(
&self,
terms: &mut TermManager,
system: &ChcSystem,
pred: PredId,
) -> AbductionResult<Vec<Hypothesis>> {
let mut hypotheses = Vec::new();
let predicate = system
.get_predicate(pred)
.ok_or_else(|| AbductionError::GenerationFailed("Predicate not found".to_string()))?;
for template in &self.templates {
if let Some(hyp) = self.instantiate_template(terms, template, pred, predicate) {
hypotheses.push(hyp);
}
}
hypotheses.push(Hypothesis::new(
pred,
terms.mk_true(),
0.0,
HypothesisOrigin::TemplateInstantiation,
));
hypotheses.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(hypotheses)
}
fn instantiate_template(
&self,
terms: &mut TermManager,
template: &HypothesisTemplate,
pred: PredId,
predicate: &crate::chc::Predicate,
) -> Option<Hypothesis> {
match template.structure {
TemplateStructure::LinearArithmetic | TemplateStructure::Interval => {
let mut constraints = Vec::new();
for (i, &sort) in predicate.params.iter().enumerate() {
if sort == terms.sorts.int_sort {
let var_name = format!("x{}", i);
let var = terms.mk_var(&var_name, sort);
let zero = terms.mk_int(0);
constraints.push(terms.mk_ge(var, zero));
}
}
if constraints.is_empty() {
return None;
}
let formula = if constraints.len() == 1 {
constraints[0]
} else {
terms.mk_and(constraints)
};
Some(Hypothesis::new(
pred,
formula,
1.0, HypothesisOrigin::TemplateInstantiation,
))
}
_ => None,
}
}
pub fn validate_hypothesis(
&self,
_terms: &mut TermManager,
_system: &ChcSystem,
hypothesis: &Hypothesis,
) -> AbductionResult<bool> {
Ok(hypothesis.score > 0.0)
}
pub fn refine_hypothesis(
&self,
terms: &mut TermManager,
hypothesis: &Hypothesis,
counterexample: TermId,
) -> AbductionResult<Hypothesis> {
let neg_cex = terms.mk_not(counterexample);
let strengthened = terms.mk_and([hypothesis.formula, neg_cex]);
Ok(Hypothesis::new(
hypothesis.pred,
strengthened,
hypothesis.score + 0.1,
HypothesisOrigin::CounterexampleAnalysis,
))
}
}
impl Default for HypothesisGenerator {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
pub struct AbductiveSolver<'a> {
terms: &'a mut TermManager,
system: &'a ChcSystem,
generator: HypothesisGenerator,
hypotheses: Vec<(PredId, Hypothesis)>,
}
impl<'a> AbductiveSolver<'a> {
pub fn new(terms: &'a mut TermManager, system: &'a ChcSystem) -> Self {
Self {
terms,
system,
generator: HypothesisGenerator::new(),
hypotheses: Vec::new(),
}
}
pub fn add_template(&mut self, template: HypothesisTemplate) {
self.generator.add_template(template);
}
pub fn solve(&mut self) -> AbductionResult<Vec<Hypothesis>> {
const MAX_ITERATIONS: usize = 100;
let mut refined_hypotheses = Vec::new();
for iteration in 0..MAX_ITERATIONS {
let predicates: Vec<PredId> = self.system.predicates().map(|p| p.id).collect();
for pred in predicates {
let candidates =
self.generator
.generate_hypotheses(self.terms, self.system, pred)?;
for hyp in candidates {
if self
.generator
.validate_hypothesis(self.terms, self.system, &hyp)?
{
refined_hypotheses.push(hyp.clone());
self.hypotheses.push((pred, hyp));
}
}
}
if !refined_hypotheses.is_empty() {
tracing::info!("CEGIS converged after {} iterations", iteration + 1);
return Ok(refined_hypotheses);
}
}
if refined_hypotheses.is_empty() {
Err(AbductionError::NoHypothesis)
} else {
Ok(refined_hypotheses)
}
}
pub fn hypotheses(&self) -> &[(PredId, Hypothesis)] {
&self.hypotheses
}
}
pub struct PreconditionSynthesizer;
impl PreconditionSynthesizer {
pub fn synthesize_precondition(
terms: &mut TermManager,
rule: &Rule,
postcondition: TermId,
) -> AbductionResult<TermId> {
let constraint = rule.body.constraint;
let not_constraint = terms.mk_not(constraint);
let precondition = terms.mk_or([not_constraint, postcondition]);
Ok(precondition)
}
}
#[allow(dead_code)]
pub struct InvariantSynthesizer {
generator: HypothesisGenerator,
positive_examples: HashSet<TermId>,
negative_examples: HashSet<TermId>,
}
impl InvariantSynthesizer {
pub fn new() -> Self {
Self {
generator: HypothesisGenerator::new(),
positive_examples: HashSet::new(),
negative_examples: HashSet::new(),
}
}
pub fn add_positive_example(&mut self, state: TermId) {
self.positive_examples.insert(state);
}
pub fn add_negative_example(&mut self, state: TermId) {
self.negative_examples.insert(state);
}
pub fn synthesize(&self, terms: &mut TermManager, _pred: PredId) -> AbductionResult<TermId> {
if self.positive_examples.is_empty() && self.negative_examples.is_empty() {
return Ok(terms.mk_true());
}
let mut constraints = Vec::new();
if !self.positive_examples.is_empty() && !self.negative_examples.is_empty() {
for &neg_example in &self.negative_examples {
constraints.push(terms.mk_not(neg_example));
}
}
if constraints.is_empty() {
Ok(terms.mk_true())
} else if constraints.len() == 1 {
Ok(constraints[0])
} else {
Ok(terms.mk_and(constraints))
}
}
}
impl Default for InvariantSynthesizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hypothesis_creation() {
let hyp = Hypothesis::new(
PredId(0),
TermId(1),
0.9,
HypothesisOrigin::TemplateInstantiation,
);
assert_eq!(hyp.pred, PredId(0));
assert_eq!(hyp.formula, TermId(1));
assert_eq!(hyp.score, 0.9);
assert_eq!(hyp.origin, HypothesisOrigin::TemplateInstantiation);
}
#[test]
fn test_hypothesis_generator() {
let generator = HypothesisGenerator::new();
assert_eq!(generator.templates.len(), 0);
assert_eq!(generator.max_hypotheses, 100);
}
#[test]
fn test_invariant_synthesizer_examples() {
let mut synth = InvariantSynthesizer::new();
synth.add_positive_example(TermId(1));
synth.add_positive_example(TermId(2));
synth.add_negative_example(TermId(3));
assert_eq!(synth.positive_examples.len(), 2);
assert_eq!(synth.negative_examples.len(), 1);
}
#[test]
fn test_hypothesis_generator_with_templates() {
let generator = HypothesisGenerator::with_common_templates();
assert_eq!(generator.template_count(), 3);
assert!(!generator.templates.is_empty());
}
#[test]
fn test_template_structure() {
let template = HypothesisTemplate {
name: "test_template".to_string(),
params: vec![
("x".to_string(), TemplateParamKind::Variable),
("c".to_string(), TemplateParamKind::NumericConstant),
]
.into(),
structure: TemplateStructure::LinearArithmetic,
};
assert_eq!(template.name, "test_template");
assert_eq!(template.params.len(), 2);
}
}