use super::reflective_agent::ExecutionContext;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfidenceConfig {
pub threshold: f32,
pub revision_budget: u32,
pub min_improvement: f32,
pub factor_weights: ConfidenceFactorWeights,
pub use_structural_analysis: bool,
pub low_confidence_patterns: Vec<String>,
}
impl Default for ConfidenceConfig {
fn default() -> Self {
Self {
threshold: 0.7,
revision_budget: 3,
min_improvement: 0.05,
factor_weights: ConfidenceFactorWeights::default(),
use_structural_analysis: true,
low_confidence_patterns: vec![
"I'm not sure".to_string(),
"might be".to_string(),
"possibly".to_string(),
"could be wrong".to_string(),
"uncertain".to_string(),
"TODO".to_string(),
"FIXME".to_string(),
"not implemented".to_string(),
],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfidenceFactorWeights {
pub completeness: f32,
pub structure: f32,
pub certainty: f32,
pub relevance: f32,
pub code_validity: f32,
}
impl Default for ConfidenceFactorWeights {
fn default() -> Self {
Self {
completeness: 0.25,
structure: 0.20,
certainty: 0.20,
relevance: 0.20,
code_validity: 0.15,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ConfidenceLevel {
VeryHigh,
High,
Medium,
Low,
VeryLow,
}
impl ConfidenceLevel {
pub fn from_score(score: f32) -> Self {
match score {
s if s > 0.9 => Self::VeryHigh,
s if s > 0.7 => Self::High,
s if s > 0.5 => Self::Medium,
s if s > 0.3 => Self::Low,
_ => Self::VeryLow,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::VeryHigh => "very_high",
Self::High => "high",
Self::Medium => "medium",
Self::Low => "low",
Self::VeryLow => "very_low",
}
}
pub fn should_revise(&self) -> bool {
matches!(self, Self::Low | Self::VeryLow)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeakPoint {
pub location: String,
pub description: String,
pub severity: f32,
pub weakness_type: WeaknessType,
pub suggestion: String,
pub confidence: f32,
}
impl WeakPoint {
pub fn new(
location: impl Into<String>,
description: impl Into<String>,
severity: f32,
weakness_type: WeaknessType,
) -> Self {
Self {
location: location.into(),
description: description.into(),
severity: severity.clamp(0.0, 1.0),
weakness_type,
suggestion: String::new(),
confidence: 0.8,
}
}
pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
self.suggestion = suggestion.into();
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WeaknessType {
Incomplete,
Uncertainty,
MissingErrorHandling,
MissingValidation,
CodeSmell,
MissingTests,
DocumentationGap,
SecurityConcern,
PerformanceIssue,
LogicError,
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RevisionResult {
pub original_confidence: f32,
pub new_confidence: f32,
pub improvement: f32,
pub addressed_weak_points: Vec<WeakPoint>,
pub remaining_weak_points: Vec<WeakPoint>,
pub revision_count: u32,
pub successful: bool,
}
#[derive(Debug)]
pub struct ConfidenceChecker {
config: ConfidenceConfig,
check_history: Vec<ConfidenceCheckRecord>,
learned_patterns: HashMap<String, f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfidenceCheckRecord {
pub score: f32,
pub level: ConfidenceLevel,
pub weak_points: Vec<WeakPoint>,
pub factors: HashMap<String, f32>,
pub task_summary: String,
pub timestamp: u64,
}
impl ConfidenceChecker {
pub fn new(config: ConfidenceConfig) -> Self {
Self {
config,
check_history: Vec::new(),
learned_patterns: HashMap::new(),
}
}
pub fn should_revise(&self, output: &str, context: &ExecutionContext) -> bool {
let confidence = self.compute_confidence(output, context);
let attempts = context.previous_attempts.len() as u32;
confidence < self.config.threshold && attempts < self.config.revision_budget
}
pub fn compute_confidence(&self, output: &str, context: &ExecutionContext) -> f32 {
let weights = &self.config.factor_weights;
let mut score = 0.0f32;
let completeness = self.assess_completeness(output, context);
score += completeness * weights.completeness;
let structure = self.assess_structure(output);
score += structure * weights.structure;
let certainty = self.assess_certainty(output);
score += certainty * weights.certainty;
let relevance = self.assess_relevance(output, context);
score += relevance * weights.relevance;
let code_validity = self.assess_code_validity(output);
score += code_validity * weights.code_validity;
for (pattern, weight) in &self.learned_patterns {
if output.to_lowercase().contains(&pattern.to_lowercase()) {
score *= 1.0 - weight; }
}
score.clamp(0.0, 1.0)
}
fn assess_completeness(&self, output: &str, context: &ExecutionContext) -> f32 {
if output.is_empty() {
return 0.0;
}
let mut score = 0.5f32;
let task_words: Vec<&str> = context.task.split_whitespace().collect();
let output_lower = output.to_lowercase();
let addressed_count = task_words
.iter()
.filter(|w| output_lower.contains(&w.to_lowercase()))
.count();
let addressed_ratio = addressed_count as f32 / task_words.len().max(1) as f32;
score += addressed_ratio * 0.3;
let incomplete_markers = ["TODO", "FIXME", "...", "to be continued", "incomplete"];
let has_incomplete = incomplete_markers.iter().any(|m| output.contains(m));
if has_incomplete {
score -= 0.2;
}
if output.len() > 500 {
score += 0.1;
}
if output.len() > 1000 {
score += 0.1;
}
score.clamp(0.0, 1.0)
}
fn assess_structure(&self, output: &str) -> f32 {
if !self.config.use_structural_analysis {
return 0.8; }
let mut score = 0.5f32;
let has_code_blocks = output.contains("```");
if has_code_blocks {
score += 0.2;
}
let has_headers = output.contains("##") || output.contains("**");
if has_headers {
score += 0.1;
}
let has_lists =
output.contains("\n- ") || output.contains("\n* ") || output.contains("\n1.");
if has_lists {
score += 0.1;
}
if output.len() < 50 {
score -= 0.2;
}
let line_count = output.lines().count();
if line_count > 5 {
score += 0.1;
}
score.clamp(0.0, 1.0)
}
fn assess_certainty(&self, output: &str) -> f32 {
let output_lower = output.to_lowercase();
let mut uncertainty_count = 0;
for pattern in &self.config.low_confidence_patterns {
if output_lower.contains(&pattern.to_lowercase()) {
uncertainty_count += 1;
}
}
match uncertainty_count {
0 => 1.0,
1 => 0.8,
2 => 0.6,
3 => 0.4,
_ => 0.2,
}
}
fn assess_relevance(&self, output: &str, context: &ExecutionContext) -> f32 {
let task_lower = context.task.to_lowercase();
let output_lower = output.to_lowercase();
let key_terms: Vec<&str> = task_lower
.split_whitespace()
.filter(|w| w.len() > 3) .collect();
if key_terms.is_empty() {
return 0.5;
}
let matched = key_terms
.iter()
.filter(|term| output_lower.contains(*term))
.count();
let ratio = matched as f32 / key_terms.len() as f32;
(ratio * 0.5 + 0.5).clamp(0.0, 1.0) }
fn assess_code_validity(&self, output: &str) -> f32 {
let has_code = output.contains("```")
|| output.contains("fn ")
|| output.contains("def ")
|| output.contains("function ")
|| output.contains("class ");
if !has_code {
return 0.8; }
let mut score = 0.7f32;
let open_parens = output.matches('(').count();
let close_parens = output.matches(')').count();
let open_braces = output.matches('{').count();
let close_braces = output.matches('}').count();
let open_brackets = output.matches('[').count();
let close_brackets = output.matches(']').count();
if open_parens == close_parens {
score += 0.1;
} else {
score -= 0.2;
}
if open_braces == close_braces {
score += 0.1;
} else {
score -= 0.2;
}
if open_brackets == close_brackets {
score += 0.1;
} else {
score -= 0.1;
}
if output.contains("error[") || output.contains("Error:") {
score -= 0.3;
}
score.clamp(0.0, 1.0)
}
pub fn identify_weak_points(&self, output: &str, context: &ExecutionContext) -> Vec<WeakPoint> {
let mut weak_points = Vec::new();
for pattern in &self.config.low_confidence_patterns {
if let Some(pos) = output.to_lowercase().find(&pattern.to_lowercase()) {
let line_num = output[..pos].matches('\n').count() + 1;
weak_points.push(
WeakPoint::new(
format!("line {}", line_num),
format!("Uncertainty marker: '{}'", pattern),
0.6,
WeaknessType::Uncertainty,
)
.with_suggestion(format!(
"Remove or clarify the uncertain statement at '{}'",
pattern
)),
);
}
}
for marker in ["TODO", "FIXME", "XXX", "HACK"] {
if output.contains(marker) {
let count = output.matches(marker).count();
weak_points.push(
WeakPoint::new(
"multiple locations",
format!("Found {} {} markers", count, marker),
0.7,
WeaknessType::Incomplete,
)
.with_suggestion(format!("Address all {} items", marker)),
);
}
}
if output.contains("fn ") || output.contains("async fn ") {
if !output.contains("Result<") && !output.contains("Option<") && !output.contains("?") {
weak_points.push(
WeakPoint::new(
"function definitions",
"Functions may lack proper error handling",
0.5,
WeaknessType::MissingErrorHandling,
)
.with_suggestion("Add Result/Option return types and error propagation"),
);
}
}
if context.task.to_lowercase().contains("input")
|| context.task.to_lowercase().contains("parameter")
{
if !output.to_lowercase().contains("valid")
&& !output.to_lowercase().contains("check")
&& !output.to_lowercase().contains("assert")
{
weak_points.push(
WeakPoint::new(
"input handling",
"May be missing input validation",
0.4,
WeaknessType::MissingValidation,
)
.with_suggestion("Add input validation and bounds checking"),
);
}
}
if context.task.to_lowercase().contains("test") {
if !output.contains("#[test]") && !output.contains("fn test_") {
weak_points.push(
WeakPoint::new(
"test coverage",
"No test functions found",
0.6,
WeaknessType::MissingTests,
)
.with_suggestion("Add unit tests with #[test] attribute"),
);
}
}
weak_points
}
pub fn generate_targeted_revision(&self, output: &str, weak_points: &[WeakPoint]) -> String {
if weak_points.is_empty() {
return output.to_string();
}
let mut revision_prompt = String::from(
"Please revise the following output to address these specific issues:\n\n",
);
for (i, wp) in weak_points.iter().enumerate() {
revision_prompt.push_str(&format!(
"{}. [{:?}] At {}: {}\n Suggestion: {}\n\n",
i + 1,
wp.weakness_type,
wp.location,
wp.description,
wp.suggestion
));
}
revision_prompt.push_str("\nOriginal output:\n");
revision_prompt.push_str(output);
revision_prompt
}
pub fn record_check(
&mut self,
output: &str,
context: &ExecutionContext,
) -> ConfidenceCheckRecord {
let score = self.compute_confidence(output, context);
let level = ConfidenceLevel::from_score(score);
let weak_points = self.identify_weak_points(output, context);
let mut factors = HashMap::new();
factors.insert(
"completeness".to_string(),
self.assess_completeness(output, context),
);
factors.insert("structure".to_string(), self.assess_structure(output));
factors.insert("certainty".to_string(), self.assess_certainty(output));
factors.insert(
"relevance".to_string(),
self.assess_relevance(output, context),
);
factors.insert(
"code_validity".to_string(),
self.assess_code_validity(output),
);
let record = ConfidenceCheckRecord {
score,
level,
weak_points,
factors,
task_summary: context.task.chars().take(100).collect(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
};
self.check_history.push(record.clone());
record
}
pub fn learn_pattern(&mut self, pattern: String, weight: f32) {
self.learned_patterns
.insert(pattern, weight.clamp(0.0, 1.0));
}
pub fn history(&self) -> &[ConfidenceCheckRecord] {
&self.check_history
}
pub fn clear_history(&mut self) {
self.check_history.clear();
}
pub fn config(&self) -> &ConfidenceConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::claude_flow::AgentType;
#[test]
fn test_confidence_level_from_score() {
assert_eq!(ConfidenceLevel::from_score(0.95), ConfidenceLevel::VeryHigh);
assert_eq!(ConfidenceLevel::from_score(0.8), ConfidenceLevel::High);
assert_eq!(ConfidenceLevel::from_score(0.6), ConfidenceLevel::Medium);
assert_eq!(ConfidenceLevel::from_score(0.4), ConfidenceLevel::Low);
assert_eq!(ConfidenceLevel::from_score(0.2), ConfidenceLevel::VeryLow);
}
#[test]
fn test_should_revise_low_levels() {
assert!(ConfidenceLevel::Low.should_revise());
assert!(ConfidenceLevel::VeryLow.should_revise());
assert!(!ConfidenceLevel::Medium.should_revise());
assert!(!ConfidenceLevel::High.should_revise());
}
#[test]
fn test_confidence_checker_creation() {
let config = ConfidenceConfig::default();
let checker = ConfidenceChecker::new(config);
assert_eq!(checker.config().threshold, 0.7);
}
#[test]
fn test_compute_confidence_empty() {
let checker = ConfidenceChecker::new(ConfidenceConfig::default());
let context = ExecutionContext::new("test task", AgentType::Coder, "input");
let confidence = checker.compute_confidence("", &context);
assert!(confidence < 0.5);
}
#[test]
fn test_compute_confidence_with_uncertainty() {
let checker = ConfidenceChecker::new(ConfidenceConfig::default());
let context = ExecutionContext::new("implement function", AgentType::Coder, "input");
let confident_output = "Here is the implementation:\n```rust\nfn example() { }\n```";
let uncertain_output = "I'm not sure but possibly this might work...";
let conf1 = checker.compute_confidence(confident_output, &context);
let conf2 = checker.compute_confidence(uncertain_output, &context);
assert!(conf1 > conf2);
}
#[test]
fn test_identify_weak_points_todo() {
let checker = ConfidenceChecker::new(ConfidenceConfig::default());
let context = ExecutionContext::new("implement function", AgentType::Coder, "input");
let output = "fn example() {\n // TODO: implement this\n}";
let weak_points = checker.identify_weak_points(output, &context);
assert!(!weak_points.is_empty());
assert!(weak_points
.iter()
.any(|wp| matches!(wp.weakness_type, WeaknessType::Incomplete)));
}
#[test]
fn test_should_revise() {
let checker = ConfidenceChecker::new(ConfidenceConfig {
threshold: 0.7,
revision_budget: 3,
..Default::default()
});
let mut context = ExecutionContext::new("test", AgentType::Coder, "input");
let low_conf_output = "I'm not sure, maybe...";
assert!(checker.should_revise(low_conf_output, &context));
for _ in 0..3 {
context
.previous_attempts
.push(crate::reflection::reflective_agent::PreviousAttempt {
attempt_number: 1,
output: String::new(),
error: None,
quality_score: None,
duration_ms: 0,
reflection: None,
});
}
assert!(!checker.should_revise(low_conf_output, &context));
}
#[test]
fn test_weak_point_builder() {
let wp = WeakPoint::new(
"line 5",
"Missing error handling",
0.7,
WeaknessType::MissingErrorHandling,
)
.with_suggestion("Add Result return type");
assert_eq!(wp.location, "line 5");
assert!(!wp.suggestion.is_empty());
}
#[test]
fn test_generate_targeted_revision() {
let checker = ConfidenceChecker::new(ConfidenceConfig::default());
let weak_points = vec![
WeakPoint::new("line 1", "Issue 1", 0.5, WeaknessType::Incomplete)
.with_suggestion("Fix it"),
];
let revision = checker.generate_targeted_revision("original output", &weak_points);
assert!(revision.contains("Issue 1"));
assert!(revision.contains("Fix it"));
assert!(revision.contains("original output"));
}
#[test]
fn test_learn_pattern() {
let mut checker = ConfidenceChecker::new(ConfidenceConfig::default());
checker.learn_pattern("problematic pattern".to_string(), 0.3);
let context = ExecutionContext::new("test", AgentType::Coder, "input");
let output_with_pattern = "This has a problematic pattern in it";
let output_without = "This is clean code";
let conf1 = checker.compute_confidence(output_with_pattern, &context);
let conf2 = checker.compute_confidence(output_without, &context);
assert!(conf1 < conf2);
}
#[test]
fn test_record_check() {
let mut checker = ConfidenceChecker::new(ConfidenceConfig::default());
let context = ExecutionContext::new("test task", AgentType::Coder, "input");
let record = checker.record_check("test output", &context);
assert!(!checker.history().is_empty());
assert!(record.factors.contains_key("completeness"));
}
}