use async_trait::async_trait;
use crate::error::JudgeError;
use crate::types::{
ClaimStructure, ClaimVerificationResult, Draft, LogicalClaim, SearchResult, VerificationResult,
};
#[async_trait]
pub trait ClaimExtractor: Send + Sync {
async fn extract_claims(
&self,
text: &str,
max_claims: usize,
) -> Result<Vec<LogicalClaim>, JudgeError>;
fn to_smtlib(&self, claim: &LogicalClaim) -> Result<String, JudgeError>;
}
#[async_trait]
pub trait SmtVerifier: Send + Sync {
async fn verify_claim(
&self,
claim: &LogicalClaim,
) -> Result<ClaimVerificationResult, JudgeError>;
async fn verify_claims(
&self,
claims: &[LogicalClaim],
) -> Result<Vec<ClaimVerificationResult>, JudgeError>;
async fn check_consistency(&self, claims: &[LogicalClaim]) -> Result<bool, JudgeError>;
}
#[derive(Debug, Clone)]
pub struct JudgeConfig {
pub timeout_ms: u64,
pub max_claims: usize,
pub check_consistency: bool,
pub min_claim_confidence: f32,
pub generate_counterexamples: bool,
}
impl Default for JudgeConfig {
fn default() -> Self {
Self {
timeout_ms: 5000,
max_claims: 10,
check_consistency: true,
min_claim_confidence: 0.5,
generate_counterexamples: true,
}
}
}
#[async_trait]
pub trait Judge: Send + Sync {
async fn judge(
&self,
draft: &Draft,
context: &[SearchResult],
) -> Result<VerificationResult, JudgeError>;
async fn quick_judge(&self, draft: &Draft) -> Result<VerificationResult, JudgeError>;
fn config(&self) -> &JudgeConfig;
}
pub struct PatternClaimExtractor {
patterns: Vec<ClaimPattern>,
}
#[derive(Debug, Clone)]
pub struct ClaimPattern {
pub name: String,
pub keywords: Vec<String>,
pub structure_type: PatternStructureType,
}
#[derive(Debug, Clone)]
pub enum PatternStructureType {
Predicate,
Comparison,
Implication,
}
impl Default for PatternClaimExtractor {
fn default() -> Self {
Self {
patterns: vec![
ClaimPattern {
name: "is-a".to_string(),
keywords: vec![
"is".to_string(),
"are".to_string(),
"was".to_string(),
"were".to_string(),
],
structure_type: PatternStructureType::Predicate,
},
ClaimPattern {
name: "comparison".to_string(),
keywords: vec![
"greater".to_string(),
"less".to_string(),
"more".to_string(),
"fewer".to_string(),
"equal".to_string(),
],
structure_type: PatternStructureType::Comparison,
},
ClaimPattern {
name: "conditional".to_string(),
keywords: vec!["if".to_string(), "when".to_string(), "then".to_string()],
structure_type: PatternStructureType::Implication,
},
],
}
}
}
impl PatternClaimExtractor {
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_pattern(mut self, pattern: ClaimPattern) -> Self {
self.patterns.push(pattern);
self
}
fn extract_sentence_claims(&self, sentence: &str) -> Vec<(String, ClaimStructure, f32)> {
let mut claims = Vec::new();
let sentence_lower = sentence.to_lowercase();
for pattern in &self.patterns {
for keyword in &pattern.keywords {
if sentence_lower.contains(keyword) {
let structure = match pattern.structure_type {
PatternStructureType::Predicate => ClaimStructure::Predicate {
subject: sentence.to_string(),
predicate: keyword.clone(),
object: None,
},
PatternStructureType::Comparison | PatternStructureType::Implication => {
ClaimStructure::Raw(sentence.to_string())
}
};
claims.push((sentence.to_string(), structure, 0.6));
break;
}
}
}
claims
}
}
#[async_trait]
impl ClaimExtractor for PatternClaimExtractor {
async fn extract_claims(
&self,
text: &str,
max_claims: usize,
) -> Result<Vec<LogicalClaim>, JudgeError> {
let sentences: Vec<&str> = text
.split(['.', '!', '?'])
.map(str::trim)
.filter(|s| !s.is_empty())
.collect();
let mut claims = Vec::new();
for sentence in sentences {
for (text, structure, confidence) in self.extract_sentence_claims(sentence) {
if claims.len() >= max_claims {
break;
}
claims.push(LogicalClaim::new(text, structure).with_confidence(confidence));
}
}
Ok(claims)
}
fn to_smtlib(&self, claim: &LogicalClaim) -> Result<String, JudgeError> {
let smtlib = match &claim.structure {
ClaimStructure::Predicate {
subject,
predicate,
object,
} => {
let obj_str = object.as_deref().unwrap_or("true");
format!("(assert ({predicate} {subject} {obj_str}))")
}
ClaimStructure::Comparison {
left,
operator,
right,
} => {
format!("(assert ({} {} {}))", operator.to_smtlib(), left, right)
}
ClaimStructure::And(claims) => {
let inner: Result<Vec<String>, _> = claims
.iter()
.map(|c| self.to_smtlib(&LogicalClaim::new("", c.clone())))
.collect();
format!("(assert (and {}))", inner?.join(" "))
}
ClaimStructure::Or(claims) => {
let inner: Result<Vec<String>, _> = claims
.iter()
.map(|c| self.to_smtlib(&LogicalClaim::new("", c.clone())))
.collect();
format!("(assert (or {}))", inner?.join(" "))
}
ClaimStructure::Not(inner) => {
let inner_smt = self.to_smtlib(&LogicalClaim::new("", *inner.clone()))?;
format!("(assert (not {inner_smt}))")
}
ClaimStructure::Implies {
premise,
conclusion,
} => {
let p = self.to_smtlib(&LogicalClaim::new("", *premise.clone()))?;
let c = self.to_smtlib(&LogicalClaim::new("", *conclusion.clone()))?;
format!("(assert (=> {p} {c}))")
}
ClaimStructure::Quantified {
quantifier,
variable,
domain,
body,
} => {
let body_smt = self.to_smtlib(&LogicalClaim::new("", *body.clone()))?;
format!(
"(assert ({} (({} {})) {}))",
quantifier.to_smtlib(),
variable,
domain,
body_smt
)
}
ClaimStructure::Raw(raw) => format!("(assert {raw})"),
ClaimStructure::Temporal {
event,
time_relation,
reference,
} => {
format!(
"(assert ({} {} {}))",
time_relation.to_smtlib(),
event,
reference
)
}
ClaimStructure::Causal {
cause,
effect,
strength,
} => {
let cause_smt = self.to_smtlib(&LogicalClaim::new("", *cause.clone()))?;
let effect_smt = self.to_smtlib(&LogicalClaim::new("", *effect.clone()))?;
format!(
"(assert ({} {} {}))",
strength.to_smtlib(),
cause_smt,
effect_smt
)
}
ClaimStructure::Modal { claim, modality } => {
let claim_smt = self.to_smtlib(&LogicalClaim::new("", *claim.clone()))?;
format!("(assert ({} {}))", modality.to_smtlib(), claim_smt)
}
};
Ok(smtlib)
}
}
#[cfg(test)]
#[allow(clippy::single_char_pattern)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pattern_extractor_basic() {
let extractor = PatternClaimExtractor::new();
let claims = extractor
.extract_claims("The sky is blue. Water flows downhill.", 10)
.await
.unwrap();
assert!(!claims.is_empty());
}
#[tokio::test]
async fn test_pattern_extractor_max_claims() {
let extractor = PatternClaimExtractor::new();
let text = "A is B. C is D. E is F. G is H.";
let claims = extractor.extract_claims(text, 2).await.unwrap();
assert!(claims.len() <= 2);
}
#[tokio::test]
async fn test_pattern_extractor_empty() {
let extractor = PatternClaimExtractor::new();
let claims = extractor.extract_claims("", 10).await.unwrap();
assert!(claims.is_empty());
}
#[test]
fn test_to_smtlib_predicate() {
let extractor = PatternClaimExtractor::new();
let claim = LogicalClaim::new(
"test",
ClaimStructure::Predicate {
subject: "x".to_string(),
predicate: "positive".to_string(),
object: None,
},
);
let smt = extractor.to_smtlib(&claim).unwrap();
assert!(smt.contains("assert"));
assert!(smt.contains("positive"));
}
#[test]
fn test_to_smtlib_comparison() {
use crate::types::ComparisonOp;
let extractor = PatternClaimExtractor::new();
let claim = LogicalClaim::new(
"test",
ClaimStructure::Comparison {
left: "a".to_string(),
operator: ComparisonOp::GreaterThan,
right: "b".to_string(),
},
);
let smt = extractor.to_smtlib(&claim).unwrap();
assert!(smt.contains(">"));
}
}