use async_trait::async_trait;
use crate::error::SpeculatorError;
use crate::types::{Draft, SearchResult, SpeculationDecision, SpeculationResult};
pub mod prompts {
pub const VERIFICATION_SYSTEM: &str = r"You are a verification assistant. Your task is to verify if a draft answer is accurate and well-supported by the provided context.
Analyze the draft for:
1. Factual accuracy - Is the information correct based on the context?
2. Completeness - Does it fully answer the question?
3. Consistency - Are there any contradictions?
4. Relevance - Is the answer relevant to the question?
Respond with your analysis and a decision: ACCEPT, REVISE, or REJECT.";
pub const VERIFICATION_TEMPLATE: &str = r"Question: {query}
Context:
{context}
Draft Answer:
{draft}
Please verify this draft answer and provide your decision.";
pub const REVISION_TEMPLATE: &str = r"Question: {query}
Context:
{context}
Original Draft:
{draft}
Issues Found:
{issues}
Please provide a revised answer that addresses the issues while staying faithful to the context.";
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SpeculatorConfig {
pub temperature: f32,
pub top_p: f32,
pub max_tokens: usize,
pub accept_threshold: f32,
pub reject_threshold: f32,
pub max_revisions: usize,
}
impl Default for SpeculatorConfig {
fn default() -> Self {
Self {
temperature: 0.3,
top_p: 0.9,
max_tokens: 512,
accept_threshold: 0.9,
reject_threshold: 0.3,
max_revisions: 2,
}
}
}
#[async_trait]
pub trait Speculator: Send + Sync {
async fn verify_draft(
&self,
draft: &Draft,
context: &[SearchResult],
) -> Result<SpeculationResult, SpeculatorError>;
async fn revise_draft(
&self,
draft: &Draft,
context: &[SearchResult],
speculation: &SpeculationResult,
) -> Result<Draft, SpeculatorError>;
fn config(&self) -> &SpeculatorConfig;
}
pub struct RuleBasedSpeculator {
config: SpeculatorConfig,
}
impl RuleBasedSpeculator {
#[must_use]
pub fn new(config: SpeculatorConfig) -> Self {
Self { config }
}
#[allow(clippy::unused_self)]
fn analyze_draft(&self, draft: &Draft, context: &[SearchResult]) -> Vec<String> {
let mut issues = Vec::new();
if draft.content.trim().is_empty() {
issues.push("Draft is empty".to_string());
return issues;
}
if draft.content.len() < 10 {
issues.push("Draft is too short".to_string());
}
if !context.is_empty() {
let context_words: std::collections::HashSet<&str> = context
.iter()
.flat_map(|r| r.document.content.split_whitespace())
.filter(|w| w.len() > 4)
.collect();
let draft_words: std::collections::HashSet<&str> = draft
.content
.split_whitespace()
.filter(|w| w.len() > 4)
.collect();
let overlap: usize = draft_words.intersection(&context_words).count();
#[allow(clippy::cast_precision_loss)]
let overlap_ratio = if draft_words.is_empty() {
0.0
} else {
overlap as f32 / draft_words.len() as f32
};
if overlap_ratio < 0.1 && !context.is_empty() {
issues.push("Draft does not appear to use information from context".to_string());
}
}
let uncertainty_markers = [
"maybe",
"might",
"possibly",
"i think",
"not sure",
"uncertain",
];
let draft_lower = draft.content.to_lowercase();
for marker in &uncertainty_markers {
if draft_lower.contains(marker) {
issues.push(format!("Draft contains uncertainty marker: '{marker}'"));
}
}
issues
}
#[allow(clippy::cast_precision_loss, clippy::unused_self)]
fn calculate_confidence(
&self,
draft: &Draft,
context: &[SearchResult],
issues: &[String],
) -> f32 {
let mut confidence = 1.0;
if draft.content.trim().is_empty() {
return 0.0;
}
confidence -= issues.len() as f32 * 0.15;
confidence += (draft.confidence - 0.5) * 0.2;
if !context.is_empty() {
let avg_context_score: f32 =
context.iter().map(|r| r.score).sum::<f32>() / context.len() as f32;
confidence += (avg_context_score - 0.5) * 0.3;
}
confidence.clamp(0.0, 1.0)
}
}
impl Default for RuleBasedSpeculator {
fn default() -> Self {
Self::new(SpeculatorConfig::default())
}
}
#[async_trait]
impl Speculator for RuleBasedSpeculator {
async fn verify_draft(
&self,
draft: &Draft,
context: &[SearchResult],
) -> Result<SpeculationResult, SpeculatorError> {
let issues = self.analyze_draft(draft, context);
let confidence = self.calculate_confidence(draft, context, &issues);
let decision = if confidence >= self.config.accept_threshold {
SpeculationDecision::Accept
} else if confidence <= self.config.reject_threshold {
SpeculationDecision::Reject
} else {
SpeculationDecision::Revise
};
let explanation = match &decision {
SpeculationDecision::Accept => "Draft appears accurate and well-supported.".to_string(),
SpeculationDecision::Reject => {
format!("Draft has significant issues: {}", issues.join("; "))
}
SpeculationDecision::Revise => {
format!("Draft needs revision to address: {}", issues.join("; "))
}
};
let mut result = SpeculationResult::new(decision, confidence).with_explanation(explanation);
for issue in issues {
result = result.with_issue(issue);
}
Ok(result)
}
async fn revise_draft(
&self,
draft: &Draft,
context: &[SearchResult],
speculation: &SpeculationResult,
) -> Result<Draft, SpeculatorError> {
let revision_note = if speculation.issues.is_empty() {
String::new()
} else {
format!("\n\n[Revision notes: {}]", speculation.issues.join("; "))
};
let revised_content = if context.is_empty() {
format!("{}{}", draft.content, revision_note)
} else {
let context_summary: String = context
.iter()
.take(3)
.map(|r| r.document.content.chars().take(100).collect::<String>())
.collect::<Vec<_>>()
.join(" ... ");
format!(
"Based on the available information: {}\n\n{}{}",
context_summary, draft.content, revision_note
)
};
Ok(Draft::new(revised_content, &draft.query).with_confidence(speculation.confidence + 0.1))
}
fn config(&self) -> &SpeculatorConfig {
&self.config
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
use crate::types::Document;
fn create_context() -> Vec<SearchResult> {
vec![
SearchResult::new(
Document::new("The capital of France is Paris. It is known for the Eiffel Tower."),
0.9,
0,
),
SearchResult::new(
Document::new("Paris has a population of about 2 million in the city proper."),
0.85,
1,
),
]
}
#[tokio::test]
async fn test_verify_good_draft() {
let speculator = RuleBasedSpeculator::default();
let draft = Draft::new(
"The capital of France is Paris, which is famous for the Eiffel Tower.",
"What is the capital of France?",
)
.with_confidence(0.9);
let context = create_context();
let result = speculator.verify_draft(&draft, &context).await.unwrap();
assert!(result.confidence > 0.5);
assert!(matches!(
result.decision,
SpeculationDecision::Accept | SpeculationDecision::Revise
));
}
#[tokio::test]
async fn test_verify_empty_draft() {
let speculator = RuleBasedSpeculator::default();
let draft = Draft::new("", "What is the capital of France?");
let context = create_context();
let result = speculator.verify_draft(&draft, &context).await.unwrap();
assert!(result.confidence < 0.5);
assert!(!result.issues.is_empty());
}
#[tokio::test]
async fn test_verify_uncertain_draft() {
let speculator = RuleBasedSpeculator::default();
let draft = Draft::new(
"I think maybe the capital might be Paris, but I'm not sure.",
"What is the capital of France?",
);
let context = create_context();
let result = speculator.verify_draft(&draft, &context).await.unwrap();
assert!(
result
.issues
.iter()
.any(|i| i.contains("uncertainty marker"))
);
}
#[tokio::test]
async fn test_revise_draft() {
let speculator = RuleBasedSpeculator::default();
let draft = Draft::new("Paris", "What is the capital of France?");
let context = create_context();
let speculation = speculator.verify_draft(&draft, &context).await.unwrap();
let revised = speculator
.revise_draft(&draft, &context, &speculation)
.await
.unwrap();
assert!(revised.content.len() > draft.content.len());
}
#[tokio::test]
async fn test_config() {
let config = SpeculatorConfig {
temperature: 0.5,
accept_threshold: 0.85,
..Default::default()
};
let speculator = RuleBasedSpeculator::new(config);
assert_eq!(speculator.config().temperature, 0.5);
assert_eq!(speculator.config().accept_threshold, 0.85);
}
}