use super::executor::{ProtocolInput, ProtocolOutput};
use super::llm::LlmProvider;
use super::validation::ValidationVerdict;
use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiModelValidatorConfig {
pub model_count: usize,
pub majority_threshold: f64,
pub query_timeout_secs: u64,
pub max_concurrent_queries: usize,
pub enable_factual_check: bool,
pub enable_logical_check: bool,
pub min_consensus_confidence: f64,
}
impl Default for MultiModelValidatorConfig {
fn default() -> Self {
Self {
model_count: 3,
majority_threshold: 0.67,
query_timeout_secs: 30,
max_concurrent_queries: 5,
enable_factual_check: true,
enable_logical_check: true,
min_consensus_confidence: 0.70,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelValidationResult {
pub model_name: String,
pub provider: LlmProvider,
pub verdict: ValidationVerdict,
pub confidence: f64,
pub reasoning: String,
pub factual_accuracy: Option<f64>,
pub logical_consistency: Option<f64>,
pub duration_ms: u64,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiModelValidationResult {
pub consensus_verdict: ValidationVerdict,
pub consensus_confidence: f64,
pub successful_validations: usize,
pub failed_validations: usize,
pub consensus_reached: bool,
pub majority_threshold_met: bool,
pub total_duration_ms: u64,
pub model_results: Vec<ModelValidationResult>,
pub performance: MultiModelValidationPerformance,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct MultiModelValidationPerformance {
pub avg_response_time_ms: f64,
pub min_response_time_ms: u64,
pub max_response_time_ms: u64,
pub total_tokens: u64,
pub estimated_cost_usd: f64,
}
pub struct MultiModelValidator {
pub config: MultiModelValidatorConfig,
}
impl MultiModelValidator {
pub fn new() -> Result<Self> {
Self::with_config(MultiModelValidatorConfig::default())
}
pub fn with_config(config: MultiModelValidatorConfig) -> Result<Self> {
if config.model_count < 3 {
return Err(Error::validation("model_count must be at least 3"));
}
if config.majority_threshold < 0.5 || config.majority_threshold > 1.0 {
return Err(Error::validation(
"majority_threshold must be between 0.5 and 1.0",
));
}
Ok(Self { config })
}
pub fn aggregate_validation_results(
&self,
results: Vec<ModelValidationResult>,
total_duration_ms: u64,
) -> MultiModelValidationResult {
let successful: Vec<_> = results.iter().filter(|r| r.error.is_none()).collect();
let failed_count = results.len() - successful.len();
let validated_count = successful
.iter()
.filter(|r| r.verdict == ValidationVerdict::Validated)
.count();
let total_successful = successful.len();
let majority_ratio = if total_successful > 0 {
validated_count as f64 / total_successful as f64
} else {
0.0
};
let majority_threshold_met = majority_ratio >= self.config.majority_threshold;
let consensus_reached = majority_threshold_met && total_successful >= 2;
let consensus_verdict = if consensus_reached && validated_count > total_successful / 2 {
ValidationVerdict::Validated
} else if !consensus_reached {
ValidationVerdict::NeedsImprovement
} else {
let invalid_count = successful
.iter()
.filter(|r| r.verdict == ValidationVerdict::Invalid)
.count();
if invalid_count > validated_count {
ValidationVerdict::Invalid
} else {
ValidationVerdict::NeedsImprovement
}
};
let consensus_confidence = if consensus_reached {
successful.iter().map(|r| r.confidence).sum::<f64>() / total_successful as f64
} else {
0.0
};
let durations: Vec<u64> = successful.iter().map(|r| r.duration_ms).collect();
let performance = MultiModelValidationPerformance {
avg_response_time_ms: if durations.is_empty() {
0.0
} else {
durations.iter().sum::<u64>() as f64 / durations.len() as f64
},
min_response_time_ms: durations.iter().copied().min().unwrap_or(0),
max_response_time_ms: durations.iter().copied().max().unwrap_or(0),
total_tokens: 0,
estimated_cost_usd: 0.0,
};
MultiModelValidationResult {
consensus_verdict,
consensus_confidence,
successful_validations: total_successful,
failed_validations: failed_count,
consensus_reached,
majority_threshold_met,
total_duration_ms,
model_results: results,
performance,
}
}
pub fn build_validation_prompt(
&self,
output: &ProtocolOutput,
input: &ProtocolInput,
) -> String {
let query = input
.fields
.get("query")
.and_then(|v| v.as_str())
.unwrap_or("No query provided");
let mut prompt = format!(
r#"## REASONING OUTPUT VALIDATION
### Original Query
{query}
### Protocol Used
{protocol_id}
### Confidence Score
{confidence:.1}%
### VALIDATION TASK
Evaluate the reasoning output for:"#,
protocol_id = output.protocol_id,
confidence = output.confidence * 100.0
);
if self.config.enable_logical_check {
prompt.push_str("\n- Logical Consistency: Are the reasoning steps coherent?");
}
if self.config.enable_factual_check {
prompt.push_str("\n- Factual Accuracy: Are claims verifiable?");
}
prompt
}
pub fn parse_validation_from_text(
&self,
text: &str,
model_name: &str,
) -> Result<ModelValidationResult> {
let lower = text.to_lowercase();
let (verdict, confidence) =
if lower.contains("validated") && lower.contains("high confidence") {
(ValidationVerdict::Validated, 0.9)
} else if lower.contains("validated") {
(ValidationVerdict::Validated, 0.85)
} else if lower.contains("invalid") {
(ValidationVerdict::Invalid, 0.3)
} else {
(ValidationVerdict::NeedsImprovement, 0.6)
};
Ok(ModelValidationResult {
model_name: model_name.to_string(),
provider: LlmProvider::Anthropic, verdict,
confidence,
reasoning: text.to_string(),
factual_accuracy: None,
logical_consistency: None,
duration_ms: 0,
error: None,
})
}
pub fn parse_validation_json(
&self,
json: &serde_json::Value,
model_name: &str,
) -> Result<ModelValidationResult> {
let verdict_str = json
.get("verdict")
.and_then(|v| v.as_str())
.unwrap_or("needs_improvement");
let verdict = match verdict_str {
"validated" => ValidationVerdict::Validated,
"invalid" => ValidationVerdict::Invalid,
"critical_issues" => ValidationVerdict::CriticalIssues,
"partially_validated" => ValidationVerdict::PartiallyValidated,
_ => ValidationVerdict::NeedsImprovement,
};
let confidence = json
.get("confidence")
.and_then(|v| v.as_f64())
.unwrap_or(0.5);
let reasoning = json
.get("reasoning")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let factual_accuracy = json.get("factual_accuracy").and_then(|v| v.as_f64());
let logical_consistency = json.get("logical_consistency").and_then(|v| v.as_f64());
Ok(ModelValidationResult {
model_name: model_name.to_string(),
provider: LlmProvider::Anthropic,
verdict,
confidence,
reasoning,
factual_accuracy,
logical_consistency,
duration_ms: 0,
error: None,
})
}
}
#[allow(dead_code)]
fn provider_to_string(provider: &LlmProvider) -> &'static str {
match provider {
LlmProvider::Anthropic => "anthropic",
LlmProvider::OpenAI => "openai",
LlmProvider::GoogleGemini => "google-gemini",
LlmProvider::GoogleVertex => "google-vertex",
LlmProvider::AzureOpenAI => "azure-openai",
LlmProvider::AWSBedrock => "aws-bedrock",
LlmProvider::Ollama => "ollama",
LlmProvider::XAI => "xai",
LlmProvider::Groq => "groq",
LlmProvider::Mistral => "mistral",
LlmProvider::DeepSeek => "deepseek",
LlmProvider::Cohere => "cohere",
LlmProvider::Perplexity => "perplexity",
LlmProvider::Cerebras => "cerebras",
LlmProvider::TogetherAI => "together-ai",
LlmProvider::FireworksAI => "fireworks-ai",
LlmProvider::AlibabaQwen => "alibaba-qwen",
LlmProvider::OpenRouter => "openrouter",
LlmProvider::CloudflareAI => "cloudflare-ai",
LlmProvider::Opencode => "opencode",
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::thinktool::executor::{ProtocolInput, ProtocolOutput};
use crate::thinktool::validation::ValidationVerdict;
use std::collections::HashMap;
fn create_mock_output() -> ProtocolOutput {
ProtocolOutput {
protocol_id: "test_protocol".to_string(),
success: true,
data: HashMap::new(),
confidence: 0.85,
steps: vec![],
tokens: crate::thinktool::step::TokenUsage::default(),
duration_ms: 100,
error: None,
trace_id: None,
budget_summary: None,
}
}
fn create_mock_input() -> ProtocolInput {
ProtocolInput::query("What is the capital of France?")
}
#[test]
fn test_multi_model_validator_config_defaults() {
let config = MultiModelValidatorConfig::default();
assert_eq!(config.model_count, 3);
assert!((config.majority_threshold - 0.67).abs() < 0.001);
assert_eq!(config.query_timeout_secs, 30);
assert_eq!(config.max_concurrent_queries, 5);
assert!(config.enable_factual_check);
assert!(config.enable_logical_check);
assert!((config.min_consensus_confidence - 0.70).abs() < 0.001);
}
#[test]
fn test_multi_model_validator_config_validation() {
let result = MultiModelValidator::with_config(MultiModelValidatorConfig {
model_count: 2, ..Default::default()
});
assert!(result.is_err());
let result = MultiModelValidator::with_config(MultiModelValidatorConfig {
majority_threshold: 0.4, ..Default::default()
});
assert!(result.is_err());
let result = MultiModelValidator::with_config(MultiModelValidatorConfig {
majority_threshold: 1.1, ..Default::default()
});
assert!(result.is_err());
}
#[test]
fn test_model_validation_result_creation() {
let result = ModelValidationResult {
model_name: "test_model".to_string(),
provider: crate::thinktool::llm::LlmProvider::Anthropic,
verdict: ValidationVerdict::Validated,
confidence: 0.9,
reasoning: "Test reasoning".to_string(),
factual_accuracy: Some(0.95),
logical_consistency: Some(0.85),
duration_ms: 1000,
error: None,
};
assert_eq!(result.model_name, "test_model");
assert_eq!(result.verdict, ValidationVerdict::Validated);
assert!((result.confidence - 0.9).abs() < 0.001);
}
#[test]
fn test_validation_performance_creation() {
let perf = MultiModelValidationPerformance {
avg_response_time_ms: 1500.0,
min_response_time_ms: 1000,
max_response_time_ms: 2000,
total_tokens: 3000,
estimated_cost_usd: 0.15,
};
assert!((perf.avg_response_time_ms - 1500.0).abs() < 0.001);
assert_eq!(perf.min_response_time_ms, 1000);
assert_eq!(perf.max_response_time_ms, 2000);
}
#[test]
fn test_multi_model_validation_result_aggregation() {
let validator = MultiModelValidator::new().unwrap();
let results = vec![
ModelValidationResult {
model_name: "model1".to_string(),
provider: crate::thinktool::llm::LlmProvider::Anthropic,
verdict: ValidationVerdict::Validated,
confidence: 0.9,
reasoning: "Good reasoning".to_string(),
factual_accuracy: Some(0.95),
logical_consistency: Some(0.85),
duration_ms: 1000,
error: None,
},
ModelValidationResult {
model_name: "model2".to_string(),
provider: crate::thinktool::llm::LlmProvider::OpenAI,
verdict: ValidationVerdict::Validated,
confidence: 0.85,
reasoning: "Also good".to_string(),
factual_accuracy: Some(0.90),
logical_consistency: Some(0.80),
duration_ms: 1200,
error: None,
},
ModelValidationResult {
model_name: "model3".to_string(),
provider: crate::thinktool::llm::LlmProvider::GoogleGemini,
verdict: ValidationVerdict::Validated,
confidence: 0.95,
reasoning: "Excellent".to_string(),
factual_accuracy: Some(0.98),
logical_consistency: Some(0.90),
duration_ms: 1100,
error: None,
},
];
let aggregated = validator.aggregate_validation_results(results, 3300);
assert_eq!(aggregated.consensus_verdict, ValidationVerdict::Validated);
assert!(aggregated.consensus_confidence > 0.8);
assert_eq!(aggregated.successful_validations, 3);
assert_eq!(aggregated.failed_validations, 0);
assert!(aggregated.consensus_reached);
assert!(aggregated.majority_threshold_met);
assert_eq!(aggregated.total_duration_ms, 3300);
}
#[test]
fn test_aggregation_with_failures() {
let validator = MultiModelValidator::new().unwrap();
let results = vec![
ModelValidationResult {
model_name: "model1".to_string(),
provider: crate::thinktool::llm::LlmProvider::Anthropic,
verdict: ValidationVerdict::Validated,
confidence: 0.9,
reasoning: "Good".to_string(),
factual_accuracy: None,
logical_consistency: None,
duration_ms: 1000,
error: None,
},
ModelValidationResult {
model_name: "model2".to_string(),
provider: crate::thinktool::llm::LlmProvider::OpenAI,
verdict: ValidationVerdict::Invalid,
confidence: 0.0,
reasoning: "Failed".to_string(),
factual_accuracy: None,
logical_consistency: None,
duration_ms: 0,
error: Some("Timeout".to_string()),
},
ModelValidationResult {
model_name: "model3".to_string(),
provider: crate::thinktool::llm::LlmProvider::GoogleGemini,
verdict: ValidationVerdict::Validated,
confidence: 0.8,
reasoning: "Good".to_string(),
factual_accuracy: None,
logical_consistency: None,
duration_ms: 1100,
error: None,
},
];
let aggregated = validator.aggregate_validation_results(results, 2100);
assert_eq!(aggregated.successful_validations, 2);
assert_eq!(aggregated.failed_validations, 1);
assert_eq!(aggregated.consensus_verdict, ValidationVerdict::Validated);
assert!(aggregated.consensus_reached);
}
#[test]
fn test_aggregation_no_consensus() {
let validator = MultiModelValidator::with_config(MultiModelValidatorConfig {
majority_threshold: 0.8, ..Default::default()
})
.unwrap();
let results = vec![
ModelValidationResult {
model_name: "model1".to_string(),
provider: crate::thinktool::llm::LlmProvider::Anthropic,
verdict: ValidationVerdict::Validated,
confidence: 0.9,
reasoning: "Good".to_string(),
factual_accuracy: None,
logical_consistency: None,
duration_ms: 1000,
error: None,
},
ModelValidationResult {
model_name: "model2".to_string(),
provider: crate::thinktool::llm::LlmProvider::OpenAI,
verdict: ValidationVerdict::Invalid,
confidence: 0.3,
reasoning: "Bad".to_string(),
factual_accuracy: None,
logical_consistency: None,
duration_ms: 1000,
error: None,
},
ModelValidationResult {
model_name: "model3".to_string(),
provider: crate::thinktool::llm::LlmProvider::GoogleGemini,
verdict: ValidationVerdict::NeedsImprovement,
confidence: 0.6,
reasoning: "Mixed".to_string(),
factual_accuracy: None,
logical_consistency: None,
duration_ms: 1000,
error: None,
},
];
let aggregated = validator.aggregate_validation_results(results, 3000);
assert_eq!(
aggregated.consensus_verdict,
ValidationVerdict::NeedsImprovement
);
assert!(!aggregated.consensus_reached);
assert!(!aggregated.majority_threshold_met);
assert!((aggregated.consensus_confidence - 0.0).abs() < 0.001);
}
#[test]
fn test_validation_prompt_building() {
let validator = MultiModelValidator::new().unwrap();
let output = create_mock_output();
let input = create_mock_input();
let prompt = validator.build_validation_prompt(&output, &input);
assert!(prompt.contains("REASONING OUTPUT VALIDATION"));
assert!(prompt.contains("What is the capital of France"));
assert!(prompt.contains("test_protocol"));
assert!(prompt.contains("85.0%"));
assert!(prompt.contains("VALIDATION TASK"));
assert!(prompt.contains("Logical Consistency"));
assert!(prompt.contains("Factual Accuracy"));
}
#[test]
fn test_validation_prompt_with_disabled_checks() {
let validator = MultiModelValidator::with_config(MultiModelValidatorConfig {
enable_factual_check: false,
enable_logical_check: true,
..Default::default()
})
.unwrap();
let output = create_mock_output();
let input = create_mock_input();
let prompt = validator.build_validation_prompt(&output, &input);
assert!(prompt.contains("Logical Consistency"));
assert!(!prompt.contains("Factual Accuracy"));
}
#[test]
fn test_parse_validation_from_text() {
let validator = MultiModelValidator::new().unwrap();
let result = validator
.parse_validation_from_text(
"This reasoning appears to be validated with high confidence. The logic is sound.",
"test-model",
)
.unwrap();
assert_eq!(result.verdict, ValidationVerdict::Validated);
assert!(result.confidence > 0.8);
assert_eq!(result.model_name, "test-model");
}
#[test]
fn test_parse_validation_json() {
let validator = MultiModelValidator::new().unwrap();
let json = serde_json::json!({
"verdict": "validated",
"confidence": 0.9,
"reasoning": "Excellent reasoning",
"factual_accuracy": 0.95,
"logical_consistency": 0.85
});
let result = validator
.parse_validation_json(&json, "test-model")
.unwrap();
assert_eq!(result.verdict, ValidationVerdict::Validated);
assert!((result.confidence - 0.9).abs() < 0.001);
assert_eq!(result.reasoning, "Excellent reasoning");
assert_eq!(result.factual_accuracy, Some(0.95));
assert_eq!(result.logical_consistency, Some(0.85));
}
#[test]
fn test_model_provider_string_conversion() {
use crate::thinktool::llm::LlmProvider;
assert_eq!(provider_to_string(&LlmProvider::Anthropic), "anthropic");
assert_eq!(provider_to_string(&LlmProvider::OpenAI), "openai");
assert_eq!(provider_to_string(&LlmProvider::DeepSeek), "deepseek");
assert_eq!(provider_to_string(&LlmProvider::XAI), "xai");
}
#[tokio::test]
async fn test_validator_creation_with_custom_config() {
let config = MultiModelValidatorConfig {
model_count: 4,
majority_threshold: 0.75,
query_timeout_secs: 60,
enable_factual_check: false,
..Default::default()
};
let validator = MultiModelValidator::with_config(config).unwrap();
assert_eq!(validator.config.model_count, 4);
assert!((validator.config.majority_threshold - 0.75).abs() < 0.001);
assert_eq!(validator.config.query_timeout_secs, 60);
assert!(!validator.config.enable_factual_check);
}
}