use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
use super::{
executor::{ExecutorConfig, ProtocolExecutor, ProtocolInput, ProtocolOutput},
protocol::Protocol,
validation::{DeepSeekValidationEngine, DeepSeekValidationResult, ValidationVerdict},
};
pub struct ValidatingProtocolExecutor {
base_executor: ProtocolExecutor,
validation_engine: Option<DeepSeekValidationEngine>,
validation_config: ValidationExecutorConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationExecutorConfig {
#[serde(default)]
pub enable_validation: bool,
#[serde(default)]
pub validation_level: ValidationLevel,
#[serde(default = "default_min_confidence_threshold")]
pub min_confidence_threshold: f64,
#[serde(default)]
pub validate_protocols: Vec<String>,
#[serde(default)]
pub skip_protocols: Vec<String>,
}
fn default_min_confidence_threshold() -> f64 {
0.70
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ValidationLevel {
None,
Quick,
#[default]
Standard,
Rigorous,
Paranoid,
}
impl Default for ValidationExecutorConfig {
fn default() -> Self {
Self {
enable_validation: true,
validation_level: ValidationLevel::Standard,
min_confidence_threshold: default_min_confidence_threshold(),
validate_protocols: Vec::new(),
skip_protocols: Vec::new(),
}
}
}
impl ValidatingProtocolExecutor {
pub fn new() -> Result<Self> {
Self::with_configs(
ExecutorConfig::default(),
ValidationExecutorConfig::default(),
)
}
pub fn with_configs(
executor_config: ExecutorConfig,
validation_config: ValidationExecutorConfig,
) -> Result<Self> {
let base_executor = ProtocolExecutor::with_config(executor_config)?;
let validation_engine = if validation_config.enable_validation {
Some(DeepSeekValidationEngine::new()?)
} else {
None
};
Ok(Self {
base_executor,
validation_engine,
validation_config,
})
}
pub async fn execute_with_validation(
&self,
protocol_id: &str,
input: ProtocolInput,
) -> Result<ProtocolOutput> {
let base_result = self
.base_executor
.execute(protocol_id, input.clone())
.await?;
if !self.should_validate(protocol_id, &base_result) {
return Ok(base_result);
}
let validation_result = self
.apply_validation(protocol_id, &input, &base_result)
.await?;
self.merge_validation_results(base_result, validation_result)
}
pub async fn execute_profile_with_validation(
&self,
profile_id: &str,
input: ProtocolInput,
) -> Result<ProtocolOutput> {
let base_result = self
.base_executor
.execute_profile(profile_id, input.clone())
.await?;
if !self.should_validate_profile(profile_id, &base_result) {
return Ok(base_result);
}
let validation_result = self
.apply_validation(profile_id, &input, &base_result)
.await?;
self.merge_validation_results(base_result, validation_result)
}
fn should_validate(&self, protocol_id: &str, result: &ProtocolOutput) -> bool {
if !self.validation_config.enable_validation {
return false;
}
if self.validation_config.validation_level == ValidationLevel::None {
return false;
}
if !self.validation_config.validate_protocols.is_empty()
&& !self
.validation_config
.validate_protocols
.contains(&protocol_id.to_string())
{
return false;
}
if self
.validation_config
.skip_protocols
.contains(&protocol_id.to_string())
{
return false;
}
if result.confidence < self.validation_config.min_confidence_threshold {
return false; }
true
}
fn should_validate_profile(&self, profile_id: &str, result: &ProtocolOutput) -> bool {
if !self.validation_config.enable_validation {
return false;
}
match profile_id {
"paranoid" | "deep" => true, "quick" => self.validation_config.validation_level != ValidationLevel::None,
_ => result.confidence >= self.validation_config.min_confidence_threshold,
}
}
async fn apply_validation(
&self,
_target_id: &str,
original_input: &ProtocolInput,
protocol_result: &ProtocolOutput,
) -> Result<DeepSeekValidationResult> {
let validation_engine = self
.validation_engine
.as_ref()
.ok_or_else(|| Error::Config("Validation engine not available".into()))?;
match self.validation_config.validation_level {
ValidationLevel::Quick => {
validation_engine
.validate_quick(protocol_result, original_input)
.await
}
ValidationLevel::Standard => {
validation_engine
.validate_chain(protocol_result, original_input, &Default::default())
.await
}
ValidationLevel::Rigorous => {
validation_engine
.validate_rigorous(protocol_result, original_input, &Default::default())
.await
}
ValidationLevel::Paranoid => {
validation_engine
.validate_with_statistical_significance(
protocol_result,
original_input,
&Default::default(),
)
.await
}
ValidationLevel::None => {
unreachable!("ValidationLevel::None should be filtered earlier")
}
}
}
fn merge_validation_results(
&self,
mut base_result: ProtocolOutput,
validation_result: DeepSeekValidationResult,
) -> Result<ProtocolOutput> {
base_result.data.insert(
"deepseek_validation".into(),
serde_json::to_value(&validation_result)?,
);
let validation_impact = match validation_result.verdict {
ValidationVerdict::Validated => 1.10, ValidationVerdict::PartiallyValidated => 1.00, ValidationVerdict::NeedsImprovement => 0.85, ValidationVerdict::Invalid => 0.60, ValidationVerdict::CriticalIssues => 0.30, };
base_result.confidence = (base_result.confidence * validation_impact).clamp(0.0, 1.0);
let tokens_used = super::step::TokenUsage {
input_tokens: validation_result.tokens_used.input_tokens,
output_tokens: validation_result.tokens_used.output_tokens,
total_tokens: validation_result.tokens_used.total_tokens,
cost_usd: validation_result.tokens_used.cost_usd,
};
base_result.tokens.add(&tokens_used);
base_result.duration_ms += validation_result.performance.duration_ms;
Ok(base_result)
}
pub fn list_protocols(&self) -> Vec<&str> {
self.base_executor.list_protocols()
}
pub fn list_profiles(&self) -> Vec<&str> {
self.base_executor.list_profiles()
}
pub fn get_protocol(&self, id: &str) -> Option<&Protocol> {
self.base_executor.get_protocol(id)
}
pub fn get_profile(&self, id: &str) -> Option<super::profiles::ReasoningProfile> {
self.base_executor.get_profile(id).cloned()
}
}
impl Default for ValidatingProtocolExecutor {
fn default() -> Self {
Self::new().expect("Failed to create default validating executor")
}
}
impl ValidationExecutorConfig {
pub fn enterprise() -> Self {
Self {
enable_validation: true,
validation_level: ValidationLevel::Rigorous,
min_confidence_threshold: 0.80,
validate_protocols: vec![
"proofguard".into(),
"brutalhonesty".into(),
"laserlogic".into(),
],
skip_protocols: Vec::new(),
}
}
pub fn research() -> Self {
Self {
enable_validation: true,
validation_level: ValidationLevel::Paranoid,
min_confidence_threshold: 0.90,
validate_protocols: vec!["gigathink".into(), "scientific".into()],
skip_protocols: Vec::new(),
}
}
pub fn production() -> Self {
Self {
enable_validation: true,
validation_level: ValidationLevel::Standard,
min_confidence_threshold: 0.70,
validate_protocols: Vec::new(),
skip_protocols: vec!["quick".into()],
}
}
pub fn compliance() -> Self {
Self {
enable_validation: true,
validation_level: ValidationLevel::Rigorous,
min_confidence_threshold: 0.85,
validate_protocols: vec!["proofguard".into(), "brutalhonesty".into()],
skip_protocols: vec!["gigathink".into(), "quick".into()],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::thinktool::step::TokenUsage;
use crate::thinktool::validation::{
ChainIntegrityResult, DependencyStatus, LogicalFlowStatus, ProgressionStatus,
ValidationPerformance,
};
use std::collections::HashMap;
#[test]
fn test_default_validation_executor_config() {
let config = ValidationExecutorConfig::default();
assert!(config.enable_validation);
assert_eq!(config.validation_level, ValidationLevel::Standard);
assert!((config.min_confidence_threshold - 0.70).abs() < f64::EPSILON);
assert!(config.validate_protocols.is_empty());
assert!(config.skip_protocols.is_empty());
}
#[test]
fn test_enterprise_configuration() {
let config = ValidationExecutorConfig::enterprise();
assert!(config.enable_validation);
assert_eq!(config.validation_level, ValidationLevel::Rigorous);
assert!((config.min_confidence_threshold - 0.80).abs() < f64::EPSILON);
assert!(config
.validate_protocols
.contains(&"proofguard".to_string()));
assert!(config
.validate_protocols
.contains(&"brutalhonesty".to_string()));
assert!(config
.validate_protocols
.contains(&"laserlogic".to_string()));
assert!(config.skip_protocols.is_empty());
}
#[test]
fn test_research_configuration() {
let config = ValidationExecutorConfig::research();
assert!(config.enable_validation);
assert_eq!(config.validation_level, ValidationLevel::Paranoid);
assert!((config.min_confidence_threshold - 0.90).abs() < f64::EPSILON);
assert!(config.validate_protocols.contains(&"gigathink".to_string()));
assert!(config
.validate_protocols
.contains(&"scientific".to_string()));
}
#[test]
fn test_production_configuration() {
let config = ValidationExecutorConfig::production();
assert!(config.enable_validation);
assert_eq!(config.validation_level, ValidationLevel::Standard);
assert!((config.min_confidence_threshold - 0.70).abs() < f64::EPSILON);
assert!(config.validate_protocols.is_empty());
assert!(config.skip_protocols.contains(&"quick".to_string()));
}
#[test]
fn test_compliance_configuration() {
let config = ValidationExecutorConfig::compliance();
assert!(config.enable_validation);
assert_eq!(config.validation_level, ValidationLevel::Rigorous);
assert!((config.min_confidence_threshold - 0.85).abs() < f64::EPSILON);
assert!(config
.validate_protocols
.contains(&"proofguard".to_string()));
assert!(config
.validate_protocols
.contains(&"brutalhonesty".to_string()));
assert!(config.skip_protocols.contains(&"gigathink".to_string()));
assert!(config.skip_protocols.contains(&"quick".to_string()));
}
#[test]
fn test_validation_level_default() {
let level = ValidationLevel::default();
assert_eq!(level, ValidationLevel::Standard);
}
#[test]
fn test_validation_level_equality() {
assert_eq!(ValidationLevel::None, ValidationLevel::None);
assert_eq!(ValidationLevel::Quick, ValidationLevel::Quick);
assert_eq!(ValidationLevel::Standard, ValidationLevel::Standard);
assert_eq!(ValidationLevel::Rigorous, ValidationLevel::Rigorous);
assert_eq!(ValidationLevel::Paranoid, ValidationLevel::Paranoid);
assert_ne!(ValidationLevel::None, ValidationLevel::Quick);
assert_ne!(ValidationLevel::Quick, ValidationLevel::Standard);
assert_ne!(ValidationLevel::Standard, ValidationLevel::Rigorous);
assert_ne!(ValidationLevel::Rigorous, ValidationLevel::Paranoid);
}
#[test]
fn test_config_serialization() {
let config = ValidationExecutorConfig::enterprise();
let json = serde_json::to_string(&config).expect("Failed to serialize config");
assert!(json.contains("enable_validation"));
assert!(json.contains("rigorous"));
assert!(json.contains("proofguard"));
let deserialized: ValidationExecutorConfig =
serde_json::from_str(&json).expect("Failed to deserialize config");
assert_eq!(deserialized.validation_level, config.validation_level);
assert_eq!(
deserialized.min_confidence_threshold,
config.min_confidence_threshold
);
}
#[tokio::test]
async fn test_executor_creation_with_default_config() {
let executor = ValidatingProtocolExecutor::new().unwrap();
let protocols = executor.list_protocols();
assert!(protocols.contains(&"gigathink"));
assert!(protocols.contains(&"laserlogic"));
assert!(protocols.contains(&"bedrock"));
}
#[tokio::test]
async fn test_executor_creation_with_mock_config() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
assert!(!executor.list_protocols().is_empty());
assert!(!executor.list_profiles().is_empty());
}
#[tokio::test]
async fn test_executor_creation_with_disabled_validation() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: false,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
assert!(!executor.list_protocols().is_empty());
}
#[test]
fn test_executor_list_protocols() {
let executor = ValidatingProtocolExecutor::new().unwrap();
let protocols = executor.list_protocols();
assert!(protocols.contains(&"gigathink"));
assert!(protocols.contains(&"laserlogic"));
assert!(protocols.contains(&"bedrock"));
assert!(protocols.contains(&"proofguard"));
assert!(protocols.contains(&"brutalhonesty"));
}
#[test]
fn test_executor_list_profiles() {
let executor = ValidatingProtocolExecutor::new().unwrap();
let profiles = executor.list_profiles();
assert!(profiles.contains(&"quick"));
assert!(profiles.contains(&"balanced"));
assert!(profiles.contains(&"deep"));
assert!(profiles.contains(&"paranoid"));
}
#[test]
fn test_executor_get_protocol() {
let executor = ValidatingProtocolExecutor::new().unwrap();
let protocol = executor.get_protocol("gigathink");
assert!(protocol.is_some());
assert_eq!(protocol.unwrap().id, "gigathink");
let missing = executor.get_protocol("nonexistent");
assert!(missing.is_none());
}
#[test]
fn test_executor_get_profile() {
let executor = ValidatingProtocolExecutor::new().unwrap();
let profile = executor.get_profile("balanced");
assert!(profile.is_some());
let missing = executor.get_profile("nonexistent");
assert!(missing.is_none());
}
fn create_mock_output(protocol_id: &str, confidence: f64) -> ProtocolOutput {
ProtocolOutput {
protocol_id: protocol_id.to_string(),
success: true,
data: HashMap::new(),
confidence,
steps: vec![],
tokens: TokenUsage::default(),
duration_ms: 100,
error: None,
trace_id: None,
budget_summary: None,
}
}
#[test]
fn test_should_validate_disabled_validation() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: false,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("gigathink", 0.85);
assert!(!executor.should_validate("gigathink", &output));
}
#[test]
fn test_should_validate_with_none_level() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::None,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("gigathink", 0.85);
assert!(!executor.should_validate("gigathink", &output));
}
#[test]
fn test_should_validate_with_specific_protocols() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Standard,
validate_protocols: vec!["proofguard".to_string(), "laserlogic".to_string()],
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output_proofguard = create_mock_output("proofguard", 0.85);
assert!(executor.should_validate("proofguard", &output_proofguard));
let output_gigathink = create_mock_output("gigathink", 0.85);
assert!(!executor.should_validate("gigathink", &output_gigathink));
}
#[test]
fn test_should_validate_with_skip_protocols() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Standard,
skip_protocols: vec!["quick".to_string()],
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output_quick = create_mock_output("quick", 0.85);
assert!(!executor.should_validate("quick", &output_quick));
let output_gigathink = create_mock_output("gigathink", 0.85);
assert!(executor.should_validate("gigathink", &output_gigathink));
}
#[test]
fn test_should_validate_below_confidence_threshold() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Standard,
min_confidence_threshold: 0.80,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output_low_confidence = create_mock_output("gigathink", 0.75);
assert!(!executor.should_validate("gigathink", &output_low_confidence));
let output_at_threshold = create_mock_output("gigathink", 0.80);
assert!(executor.should_validate("gigathink", &output_at_threshold));
let output_high_confidence = create_mock_output("gigathink", 0.90);
assert!(executor.should_validate("gigathink", &output_high_confidence));
}
#[test]
fn test_should_validate_all_conditions_met() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Standard,
min_confidence_threshold: 0.70,
validate_protocols: Vec::new(), skip_protocols: Vec::new(),
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("gigathink", 0.85);
assert!(executor.should_validate("gigathink", &output));
}
#[test]
fn test_should_validate_profile_disabled() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: false,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("paranoid", 0.95);
assert!(!executor.should_validate_profile("paranoid", &output));
}
#[test]
fn test_should_validate_profile_paranoid_always() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Standard,
min_confidence_threshold: 0.99, ..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("paranoid", 0.50);
assert!(executor.should_validate_profile("paranoid", &output));
}
#[test]
fn test_should_validate_profile_deep_always() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Standard,
min_confidence_threshold: 0.99,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("deep", 0.50);
assert!(executor.should_validate_profile("deep", &output));
}
#[test]
fn test_should_validate_profile_quick_with_level() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Standard,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("quick", 0.50);
assert!(executor.should_validate_profile("quick", &output));
}
#[test]
fn test_should_validate_profile_quick_with_none_level() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::None,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("quick", 0.90);
assert!(!executor.should_validate_profile("quick", &output));
}
#[test]
fn test_should_validate_profile_other_profiles_use_threshold() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Standard,
min_confidence_threshold: 0.80,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output_low = create_mock_output("balanced", 0.75);
assert!(!executor.should_validate_profile("balanced", &output_low));
let output_high = create_mock_output("balanced", 0.85);
assert!(executor.should_validate_profile("balanced", &output_high));
}
fn create_mock_validation_result(verdict: ValidationVerdict) -> DeepSeekValidationResult {
use crate::thinktool::validation::TokenUsage as ValidationTokenUsage;
DeepSeekValidationResult {
verdict,
chain_integrity: ChainIntegrityResult {
logical_flow: LogicalFlowStatus::Good,
step_dependencies: DependencyStatus::FullySatisfied,
confidence_progression: ProgressionStatus::Monotonic,
gaps_detected: vec![],
continuity_score: 0.85,
},
statistical_results: None,
compliance_results: None,
meta_cognitive_results: None,
validation_confidence: 0.90,
findings: vec![],
tokens_used: ValidationTokenUsage {
input_tokens: 100,
output_tokens: 50,
total_tokens: 150,
cost_usd: 0.002,
},
performance: ValidationPerformance {
duration_ms: 500,
tokens_per_second: 300.0,
memory_usage_mb: 50.0,
},
}
}
#[test]
fn test_merge_validation_results_validated_boosts_confidence() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let base_result = create_mock_output("gigathink", 0.80);
let validation_result = create_mock_validation_result(ValidationVerdict::Validated);
let merged = executor
.merge_validation_results(base_result, validation_result)
.unwrap();
assert!((merged.confidence - 0.88).abs() < 0.001);
}
#[test]
fn test_merge_validation_results_validated_clamps_to_one() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let base_result = create_mock_output("gigathink", 0.95);
let validation_result = create_mock_validation_result(ValidationVerdict::Validated);
let merged = executor
.merge_validation_results(base_result, validation_result)
.unwrap();
assert!((merged.confidence - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_merge_validation_results_partially_validated_neutral() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let base_result = create_mock_output("gigathink", 0.80);
let validation_result =
create_mock_validation_result(ValidationVerdict::PartiallyValidated);
let merged = executor
.merge_validation_results(base_result, validation_result)
.unwrap();
assert!((merged.confidence - 0.80).abs() < f64::EPSILON);
}
#[test]
fn test_merge_validation_results_needs_improvement_reduces() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let base_result = create_mock_output("gigathink", 0.80);
let validation_result = create_mock_validation_result(ValidationVerdict::NeedsImprovement);
let merged = executor
.merge_validation_results(base_result, validation_result)
.unwrap();
assert!((merged.confidence - 0.68).abs() < 0.001);
}
#[test]
fn test_merge_validation_results_invalid_significantly_reduces() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let base_result = create_mock_output("gigathink", 0.80);
let validation_result = create_mock_validation_result(ValidationVerdict::Invalid);
let merged = executor
.merge_validation_results(base_result, validation_result)
.unwrap();
assert!((merged.confidence - 0.48).abs() < 0.001);
}
#[test]
fn test_merge_validation_results_critical_issues_severely_reduces() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let base_result = create_mock_output("gigathink", 0.80);
let validation_result = create_mock_validation_result(ValidationVerdict::CriticalIssues);
let merged = executor
.merge_validation_results(base_result, validation_result)
.unwrap();
assert!((merged.confidence - 0.24).abs() < 0.001);
}
#[test]
fn test_merge_validation_results_adds_metadata() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let base_result = create_mock_output("gigathink", 0.80);
let validation_result = create_mock_validation_result(ValidationVerdict::Validated);
let merged = executor
.merge_validation_results(base_result, validation_result)
.unwrap();
assert!(merged.data.contains_key("deepseek_validation"));
}
#[test]
fn test_merge_validation_results_adds_tokens() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let mut base_result = create_mock_output("gigathink", 0.80);
base_result.tokens = TokenUsage::new(200, 100, 0.003);
let validation_result = create_mock_validation_result(ValidationVerdict::Validated);
let merged = executor
.merge_validation_results(base_result, validation_result)
.unwrap();
assert_eq!(merged.tokens.input_tokens, 300);
assert_eq!(merged.tokens.output_tokens, 150);
assert_eq!(merged.tokens.total_tokens, 450);
}
#[test]
fn test_merge_validation_results_adds_duration() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let mut base_result = create_mock_output("gigathink", 0.80);
base_result.duration_ms = 1000;
let validation_result = create_mock_validation_result(ValidationVerdict::Validated);
let merged = executor
.merge_validation_results(base_result, validation_result)
.unwrap();
assert_eq!(merged.duration_ms, 1500);
}
#[tokio::test]
async fn test_execute_with_mock_returns_success() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: false, ..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let input = ProtocolInput::query("What is machine learning?");
let result = executor
.execute_with_validation("gigathink", input)
.await
.unwrap();
assert!(result.success);
assert!(result.confidence > 0.0);
assert!(!result.steps.is_empty());
}
#[tokio::test]
async fn test_execute_profile_with_mock_returns_success() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: false, ..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let input = ProtocolInput::query("Should we use microservices?");
let result = executor
.execute_profile_with_validation("quick", input)
.await
.unwrap();
assert!(result.success);
assert!(result.confidence > 0.0);
}
#[tokio::test]
async fn test_execute_nonexistent_protocol_returns_error() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: false,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let input = ProtocolInput::query("Test query");
let result = executor
.execute_with_validation("nonexistent_protocol", input)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_execute_nonexistent_profile_returns_error() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: false,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let input = ProtocolInput::query("Test query");
let result = executor
.execute_profile_with_validation("nonexistent_profile", input)
.await;
assert!(result.is_err());
}
#[test]
fn test_validation_engine_not_available_error() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: false,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
assert!(executor.validation_engine.is_none());
}
#[test]
fn test_zero_confidence_output() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig::default();
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("gigathink", 0.0);
assert!(!executor.should_validate("gigathink", &output));
}
#[test]
fn test_exact_threshold_confidence() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
min_confidence_threshold: 0.75,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("gigathink", 0.75);
assert!(executor.should_validate("gigathink", &output));
}
#[test]
fn test_just_below_threshold() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
min_confidence_threshold: 0.75,
..Default::default()
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("gigathink", 0.7499999);
assert!(!executor.should_validate("gigathink", &output));
}
#[test]
fn test_empty_validate_protocols_means_all() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Standard,
validate_protocols: Vec::new(), skip_protocols: Vec::new(),
min_confidence_threshold: 0.0,
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("any_protocol", 0.50);
assert!(executor.should_validate("any_protocol", &output));
}
#[test]
fn test_skip_takes_precedence_over_validate() {
let executor_config = ExecutorConfig::mock();
let validation_config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Standard,
validate_protocols: vec!["gigathink".to_string()],
skip_protocols: vec!["gigathink".to_string()], min_confidence_threshold: 0.0,
};
let executor =
ValidatingProtocolExecutor::with_configs(executor_config, validation_config).unwrap();
let output = create_mock_output("gigathink", 0.90);
assert!(!executor.should_validate("gigathink", &output));
}
#[test]
fn test_default_executor_via_default_trait() {
let executor = ValidatingProtocolExecutor::default();
assert!(!executor.list_protocols().is_empty());
}
#[test]
fn test_validation_level_serde_roundtrip() {
for level in [
ValidationLevel::None,
ValidationLevel::Quick,
ValidationLevel::Standard,
ValidationLevel::Rigorous,
ValidationLevel::Paranoid,
] {
let json = serde_json::to_string(&level).unwrap();
let deserialized: ValidationLevel = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, level);
}
}
#[test]
fn test_config_with_all_fields_set() {
let config = ValidationExecutorConfig {
enable_validation: true,
validation_level: ValidationLevel::Rigorous,
min_confidence_threshold: 0.95,
validate_protocols: vec!["a".to_string(), "b".to_string()],
skip_protocols: vec!["c".to_string()],
};
assert!(config.enable_validation);
assert_eq!(config.validation_level, ValidationLevel::Rigorous);
assert!((config.min_confidence_threshold - 0.95).abs() < f64::EPSILON);
assert_eq!(config.validate_protocols.len(), 2);
assert_eq!(config.skip_protocols.len(), 1);
}
}