use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::sync::RwLock;
use std::time::{SystemTime, UNIX_EPOCH};
const MAX_ENTRIES: usize = 10_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Outcome {
Success,
Partial,
Failure,
Abandoned,
}
impl Outcome {
pub fn score(&self) -> f32 {
match self {
Self::Success => 1.0,
Self::Partial => 0.5,
Self::Failure => 0.0,
Self::Abandoned => 0.0,
}
}
pub fn is_positive(&self) -> bool {
matches!(self, Self::Success | Self::Partial)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptRecord {
pub prompt: String,
pub task_type: String,
pub outcome: Outcome,
pub quality_score: f32,
pub tokens_used: usize,
pub response_time_ms: u64,
pub timestamp: u64,
}
impl PromptRecord {
pub fn new(prompt: String, task_type: String, outcome: Outcome) -> Self {
Self {
prompt,
task_type,
outcome,
quality_score: outcome.score(),
tokens_used: 0,
response_time_ms: 0,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
pub fn with_quality(mut self, score: f32) -> Self {
self.quality_score = score.clamp(0.0, 1.0);
self
}
pub fn with_tokens(mut self, tokens: usize) -> Self {
self.tokens_used = tokens;
self
}
pub fn with_response_time(mut self, time_ms: u64) -> Self {
self.response_time_ms = time_ms;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptPattern {
pub id: String,
pub template: String,
pub effective_for: Vec<String>,
pub avg_quality: f32,
pub usage_count: usize,
pub success_rate: f32,
}
impl PromptPattern {
pub fn new(id: &str, template: &str) -> Self {
Self {
id: id.to_string(),
template: template.to_string(),
effective_for: Vec::new(),
avg_quality: 0.0,
usage_count: 0,
success_rate: 0.0,
}
}
pub fn update(&mut self, outcome: Outcome, quality: f32) {
let old_total = self.usage_count as f32 * self.avg_quality;
self.usage_count += 1;
self.avg_quality = (old_total + quality) / self.usage_count as f32;
let previous_count = self.usage_count.saturating_sub(1);
let old_success = if self.success_rate.is_finite() {
(self.success_rate.clamp(0.0, 1.0) * previous_count as f32).round() as usize
} else {
0
};
let new_success = if outcome.is_positive() {
old_success + 1
} else {
old_success
};
self.success_rate = new_success as f32 / self.usage_count as f32;
}
}
pub struct PromptOptimizer {
records: Vec<PromptRecord>,
patterns: HashMap<String, PromptPattern>,
task_stats: HashMap<String, TaskPromptStats>,
max_records: usize,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TaskPromptStats {
pub total_attempts: usize,
pub successful: usize,
pub avg_quality: f32,
pub avg_tokens: f32,
pub best_patterns: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptVariantScore {
pub variant_id: String,
pub strategy: String,
pub prompt: String,
pub predicted_quality: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptTournamentResult {
pub task_type: String,
pub winner_prompt: String,
pub winner_strategy: String,
pub winner_score: f32,
pub variants: Vec<PromptVariantScore>,
}
impl PromptOptimizer {
pub fn new() -> Self {
Self {
records: Vec::new(),
patterns: HashMap::new(),
task_stats: HashMap::new(),
max_records: MAX_ENTRIES,
}
}
pub fn record(&mut self, record: PromptRecord) {
let stats = self.task_stats.entry(record.task_type.clone()).or_default();
let old_total_quality = stats.avg_quality * stats.total_attempts as f32;
let old_total_tokens = stats.avg_tokens * stats.total_attempts as f32;
stats.total_attempts += 1;
if record.outcome.is_positive() {
stats.successful += 1;
}
stats.avg_quality =
(old_total_quality + record.quality_score) / stats.total_attempts as f32;
stats.avg_tokens =
(old_total_tokens + record.tokens_used as f32) / stats.total_attempts as f32;
if self.task_stats.len() > MAX_ENTRIES / 10 {
let mut entries: Vec<_> = self
.task_stats
.iter()
.map(|(k, v)| (k.clone(), v.total_attempts))
.collect();
entries.sort_by_key(|(_, count)| *count);
for (key, _) in entries.iter().take(entries.len() / 2) {
self.task_stats.remove(key);
}
}
self.records.push(record);
if self.records.len() > self.max_records {
self.records.drain(0..self.max_records / 2);
}
}
pub fn register_pattern(&mut self, pattern: PromptPattern) {
self.patterns.insert(pattern.id.clone(), pattern);
}
pub fn update_pattern(&mut self, pattern_id: &str, outcome: Outcome, quality: f32) {
if let Some(pattern) = self.patterns.get_mut(pattern_id) {
pattern.update(outcome, quality);
}
}
pub fn best_patterns_for(&self, task_type: &str) -> Vec<&PromptPattern> {
let mut patterns: Vec<_> = self
.patterns
.values()
.filter(|p| {
p.effective_for.contains(&task_type.to_string()) || p.effective_for.is_empty()
})
.filter(|p| p.usage_count >= 5) .collect();
patterns.sort_by(|a, b| {
b.avg_quality
.partial_cmp(&a.avg_quality)
.unwrap_or(std::cmp::Ordering::Equal)
});
patterns.into_iter().take(5).collect()
}
pub fn get_task_stats(&self, task_type: &str) -> Option<&TaskPromptStats> {
self.task_stats.get(task_type)
}
pub fn suggest_improvements(&self, prompt: &str, task_type: &str) -> Vec<PromptSuggestion> {
let mut suggestions = Vec::new();
if prompt.len() < 20 {
suggestions.push(PromptSuggestion {
suggestion_type: SuggestionType::AddContext,
description: "Prompt may be too short. Consider adding more context.".to_string(),
example: None,
});
}
if !prompt.contains("please") && !prompt.contains("should") && !prompt.contains("must") {
suggestions.push(PromptSuggestion {
suggestion_type: SuggestionType::ClarifyIntent,
description: "Consider using clearer directive words.".to_string(),
example: Some("Please implement...".to_string()),
});
}
let best_patterns = self.best_patterns_for(task_type);
for pattern in best_patterns.iter().take(2) {
suggestions.push(PromptSuggestion {
suggestion_type: SuggestionType::UsePattern,
description: format!(
"Pattern '{}' has {:.0}% success rate for this task type",
pattern.id,
pattern.success_rate * 100.0
),
example: Some(pattern.template.clone()),
});
}
suggestions
}
pub fn evolve_prompt(
&mut self,
task_type: &str,
baseline_prompt: &str,
) -> PromptTournamentResult {
let variants = self.generate_prompt_variants(task_type, baseline_prompt);
let mut scored: Vec<PromptVariantScore> = variants
.into_iter()
.map(|(id, strategy, prompt)| PromptVariantScore {
variant_id: id,
strategy,
predicted_quality: self.estimate_prompt_quality(task_type, &prompt),
prompt,
})
.collect();
scored.sort_by(|a, b| {
b.predicted_quality
.partial_cmp(&a.predicted_quality)
.unwrap_or(std::cmp::Ordering::Equal)
});
let winner = scored.first().cloned().unwrap_or(PromptVariantScore {
variant_id: "baseline".to_string(),
strategy: "baseline".to_string(),
prompt: baseline_prompt.to_string(),
predicted_quality: self.estimate_prompt_quality(task_type, baseline_prompt),
});
let pattern_id = format!("evo-{}-{}", task_type, winner.variant_id);
let mut pattern = PromptPattern::new(&pattern_id, &winner.prompt);
pattern.effective_for.push(task_type.to_string());
pattern.usage_count = 1;
pattern.avg_quality = winner.predicted_quality;
pattern.success_rate = winner.predicted_quality;
self.register_pattern(pattern);
PromptTournamentResult {
task_type: task_type.to_string(),
winner_prompt: winner.prompt,
winner_strategy: winner.strategy,
winner_score: winner.predicted_quality,
variants: scored,
}
}
fn generate_prompt_variants(
&self,
task_type: &str,
baseline_prompt: &str,
) -> Vec<(String, String, String)> {
let mut variants = vec![
(
"baseline".to_string(),
"baseline".to_string(),
baseline_prompt.to_string(),
),
(
"verify_first".to_string(),
"add_verification_clause".to_string(),
format!(
"{}\n\nBefore finalizing, explicitly verify assumptions and list validation checks.",
baseline_prompt
),
),
(
"structured_output".to_string(),
"force_structure".to_string(),
format!(
"{}\n\nRespond in sections: Plan, Changes, Verification, Risks.",
baseline_prompt
),
),
(
"concise_steps".to_string(),
"concise_step_plan".to_string(),
format!(
"{}\n\nKeep each step concise and actionable; avoid redundant narration.",
baseline_prompt
),
),
];
for (idx, pattern) in self.best_patterns_for(task_type).iter().take(2).enumerate() {
let rendered = Self::render_pattern_template(&pattern.template);
variants.push((
format!("pattern_{}", idx + 1),
format!("pattern_{}", pattern.id),
rendered,
));
}
let mut seen = std::collections::HashSet::new();
variants
.into_iter()
.filter(|(_, _, prompt)| seen.insert(Self::normalize_prompt(prompt)))
.collect()
}
fn render_pattern_template(template: &str) -> String {
template
.replace("{action}", "complete the requested task")
.replace("{target}", "codebase")
}
fn estimate_prompt_quality(&self, task_type: &str, prompt: &str) -> f32 {
let norm = Self::normalize_prompt(prompt);
let matching: Vec<&PromptRecord> = self
.records
.iter()
.filter(|r| r.task_type == task_type && Self::normalize_prompt(&r.prompt) == norm)
.collect();
let historical = if matching.is_empty() {
self.task_stats
.get(task_type)
.map(|s| s.avg_quality)
.unwrap_or(0.5)
} else {
matching.iter().map(|r| r.quality_score).sum::<f32>() / matching.len() as f32
};
let mut score = historical;
let lower = prompt.to_lowercase();
if lower.contains("verify") || lower.contains("validation") {
score += 0.05;
}
if lower.contains("step") || lower.contains("plan") {
score += 0.03;
}
if (200..=1800).contains(&prompt.len()) {
score += 0.02;
}
score.clamp(0.0, 1.0)
}
fn normalize_prompt(prompt: &str) -> String {
prompt
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
.to_lowercase()
}
pub fn get_stats(&self) -> PromptOptimizerStats {
let total_records = self.records.len();
let successful = self
.records
.iter()
.filter(|r| r.outcome.is_positive())
.count();
let avg_quality = if total_records > 0 {
self.records.iter().map(|r| r.quality_score).sum::<f32>() / total_records as f32
} else {
0.0
};
PromptOptimizerStats {
total_records,
successful_records: successful,
pattern_count: self.patterns.len(),
task_types_tracked: self.task_stats.len(),
avg_quality,
}
}
}
impl PromptOptimizer {
fn to_snapshot(&self) -> PromptOptimizerSnapshot {
PromptOptimizerSnapshot {
records: self.records.clone(),
patterns: self.patterns.clone(),
task_stats: self.task_stats.clone(),
}
}
fn from_snapshot(snapshot: PromptOptimizerSnapshot) -> Self {
Self {
records: snapshot.records,
patterns: snapshot.patterns,
task_stats: snapshot.task_stats,
max_records: MAX_ENTRIES,
}
}
}
impl Default for PromptOptimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptSuggestion {
pub suggestion_type: SuggestionType,
pub description: String,
pub example: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SuggestionType {
AddContext,
ClarifyIntent,
UsePattern,
SimplifyPrompt,
AddExamples,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PromptOptimizerSnapshot {
records: Vec<PromptRecord>,
patterns: HashMap<String, PromptPattern>,
task_stats: HashMap<String, TaskPromptStats>,
}
#[derive(Debug, Clone)]
pub struct PromptOptimizerStats {
pub total_records: usize,
pub successful_records: usize,
pub pattern_count: usize,
pub task_types_tracked: usize,
pub avg_quality: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolUsageRecord {
pub tool: String,
pub task_context: String,
pub outcome: Outcome,
pub execution_time_ms: u64,
pub error: Option<String>,
pub timestamp: u64,
}
impl ToolUsageRecord {
pub fn new(tool: String, task_context: String, outcome: Outcome) -> Self {
Self {
tool,
task_context,
outcome,
execution_time_ms: 0,
error: None,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
pub fn with_execution_time(mut self, time_ms: u64) -> Self {
self.execution_time_ms = time_ms;
self
}
pub fn with_error(mut self, error: String) -> Self {
self.error = Some(error);
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ToolStats {
pub usage_count: usize,
pub success_count: usize,
pub failure_count: usize,
pub avg_execution_time_ms: f64,
pub effective_contexts: Vec<String>,
pub common_errors: HashMap<String, usize>,
}
impl ToolStats {
pub fn success_rate(&self) -> f32 {
if self.usage_count == 0 {
0.0
} else {
self.success_count as f32 / self.usage_count as f32
}
}
}
pub struct ToolSelectionLearner {
records: Vec<ToolUsageRecord>,
tool_stats: HashMap<String, ToolStats>,
context_tools: HashMap<String, Vec<(String, f32)>>,
max_records: usize,
}
impl ToolSelectionLearner {
pub fn new() -> Self {
Self {
records: Vec::new(),
tool_stats: HashMap::new(),
context_tools: HashMap::new(),
max_records: MAX_ENTRIES,
}
}
pub fn record(&mut self, record: ToolUsageRecord) {
let stats = self.tool_stats.entry(record.tool.clone()).or_default();
let old_total_time = stats.avg_execution_time_ms * stats.usage_count as f64;
stats.usage_count += 1;
stats.avg_execution_time_ms =
(old_total_time + record.execution_time_ms as f64) / stats.usage_count as f64;
if record.outcome.is_positive() {
stats.success_count += 1;
if !stats.effective_contexts.contains(&record.task_context) {
stats.effective_contexts.push(record.task_context.clone());
if stats.effective_contexts.len() > 20 {
stats.effective_contexts.remove(0);
}
}
} else {
stats.failure_count += 1;
if let Some(error) = &record.error {
let error_key = Self::normalize_error(error);
*stats.common_errors.entry(error_key).or_insert(0) += 1;
}
}
let context_key = Self::normalize_context(&record.task_context);
let tool_scores = self.context_tools.entry(context_key).or_default();
if let Some((_, score)) = tool_scores.iter_mut().find(|(t, _)| t == &record.tool) {
*score = 0.8 * *score + 0.2 * record.outcome.score();
} else {
tool_scores.push((record.tool.clone(), record.outcome.score()));
}
if self.context_tools.len() > MAX_ENTRIES / 10 {
let mut entries: Vec<_> = self.context_tools.keys().cloned().collect();
entries.truncate(entries.len() / 2);
for key in entries {
self.context_tools.remove(&key);
}
}
self.records.push(record);
if self.records.len() > self.max_records {
self.records.drain(0..self.max_records / 2);
}
}
fn normalize_error(error: &str) -> String {
error
.lines()
.next()
.unwrap_or(error)
.chars()
.take(100)
.collect::<String>()
.replace(char::is_numeric, "#")
}
fn normalize_context(context: &str) -> String {
context
.to_lowercase()
.split_whitespace()
.take(5)
.collect::<Vec<_>>()
.join(" ")
}
pub fn best_tools_for(&self, context: &str) -> Vec<(String, f32)> {
let context_key = Self::normalize_context(context);
if let Some(tools) = self.context_tools.get(&context_key) {
let mut tools = tools.clone();
tools.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
return tools.into_iter().take(5).collect();
}
let mut tool_rates: Vec<_> = self
.tool_stats
.iter()
.map(|(tool, stats)| (tool.clone(), stats.success_rate()))
.collect();
tool_rates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
tool_rates.into_iter().take(5).collect()
}
pub fn get_tool_stats(&self, tool: &str) -> Option<&ToolStats> {
self.tool_stats.get(tool)
}
pub fn common_errors_for(&self, tool: &str) -> Vec<(String, usize)> {
if let Some(stats) = self.tool_stats.get(tool) {
let mut errors: Vec<_> = stats
.common_errors
.iter()
.map(|(e, c)| (e.clone(), *c))
.collect();
errors.sort_by(|a, b| b.1.cmp(&a.1));
errors.into_iter().take(5).collect()
} else {
Vec::new()
}
}
pub fn get_stats(&self) -> ToolLearnerStats {
let total_records = self.records.len();
let successful = self
.records
.iter()
.filter(|r| r.outcome.is_positive())
.count();
let unique_tools = self.tool_stats.len();
ToolLearnerStats {
total_records,
successful_records: successful,
unique_tools,
contexts_tracked: self.context_tools.len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ToolLearnerSnapshot {
records: Vec<ToolUsageRecord>,
tool_stats: HashMap<String, ToolStats>,
context_tools: HashMap<String, Vec<(String, f32)>>,
}
impl ToolSelectionLearner {
fn to_snapshot(&self) -> ToolLearnerSnapshot {
ToolLearnerSnapshot {
records: self.records.clone(),
tool_stats: self.tool_stats.clone(),
context_tools: self.context_tools.clone(),
}
}
fn from_snapshot(snapshot: ToolLearnerSnapshot) -> Self {
Self {
records: snapshot.records,
tool_stats: snapshot.tool_stats,
context_tools: snapshot.context_tools,
max_records: MAX_ENTRIES,
}
}
}
impl Default for ToolSelectionLearner {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ToolLearnerStats {
pub total_records: usize,
pub successful_records: usize,
pub unique_tools: usize,
pub contexts_tracked: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorRecord {
pub message: String,
pub error_type: String,
pub context: String,
pub action: String,
pub recovered: bool,
pub recovery_action: Option<String>,
pub timestamp: u64,
}
impl ErrorRecord {
pub fn new(message: String, error_type: String, context: String, action: String) -> Self {
Self {
message,
error_type,
context,
action,
recovered: false,
recovery_action: None,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
pub fn with_recovery(mut self, action: String) -> Self {
self.recovered = true;
self.recovery_action = Some(action);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorPattern {
pub id: String,
pub error_type: String,
pub contexts: Vec<String>,
pub triggering_actions: Vec<String>,
pub recovery_strategies: Vec<String>,
pub count: usize,
pub prevention: Vec<String>,
}
impl ErrorPattern {
pub fn new(id: &str, error_type: &str) -> Self {
Self {
id: id.to_string(),
error_type: error_type.to_string(),
contexts: Vec::new(),
triggering_actions: Vec::new(),
recovery_strategies: Vec::new(),
count: 0,
prevention: Vec::new(),
}
}
pub fn update(&mut self, record: &ErrorRecord) {
self.count += 1;
if !self.contexts.contains(&record.context) {
self.contexts.push(record.context.clone());
if self.contexts.len() > 10 {
self.contexts.remove(0);
}
}
if !self.triggering_actions.contains(&record.action) {
self.triggering_actions.push(record.action.clone());
if self.triggering_actions.len() > 10 {
self.triggering_actions.remove(0);
}
}
if let Some(recovery) = &record.recovery_action {
if !self.recovery_strategies.contains(recovery) {
self.recovery_strategies.push(recovery.clone());
}
}
}
pub fn add_prevention(&mut self, suggestion: String) {
if !self.prevention.contains(&suggestion) {
self.prevention.push(suggestion);
}
}
}
pub struct ErrorPatternLearner {
records: Vec<ErrorRecord>,
patterns: HashMap<String, ErrorPattern>,
type_counts: HashMap<String, usize>,
max_records: usize,
}
impl ErrorPatternLearner {
pub fn new() -> Self {
Self {
records: Vec::new(),
patterns: HashMap::new(),
type_counts: HashMap::new(),
max_records: MAX_ENTRIES,
}
}
pub fn record(&mut self, record: ErrorRecord) {
*self
.type_counts
.entry(record.error_type.clone())
.or_insert(0) += 1;
let pattern_id = Self::compute_pattern_id(&record);
let pattern = self
.patterns
.entry(pattern_id.clone())
.or_insert_with(|| ErrorPattern::new(&pattern_id, &record.error_type));
pattern.update(&record);
if self.patterns.len() > MAX_ENTRIES / 10 {
let mut entries: Vec<_> = self
.patterns
.iter()
.map(|(k, v)| (k.clone(), v.count))
.collect();
entries.sort_by_key(|(_, count)| *count);
for (key, _) in entries.iter().take(entries.len() / 2) {
self.patterns.remove(key);
}
}
self.records.push(record);
if self.records.len() > self.max_records {
self.records.drain(0..self.max_records / 2);
}
}
fn compute_pattern_id(record: &ErrorRecord) -> String {
let msg_prefix: String = record
.message
.chars()
.filter(|c| c.is_alphanumeric() || c.is_whitespace())
.take(30)
.collect();
format!("{}:{}", record.error_type, msg_prefix.trim())
}
pub fn get_pattern(&self, pattern_id: &str) -> Option<&ErrorPattern> {
self.patterns.get(pattern_id)
}
pub fn most_common_patterns(&self, limit: usize) -> Vec<&ErrorPattern> {
let mut patterns: Vec<_> = self.patterns.values().collect();
patterns.sort_by(|a, b| b.count.cmp(&a.count));
patterns.into_iter().take(limit).collect()
}
pub fn might_trigger_error(&self, action: &str, context: &str) -> Vec<ErrorWarning> {
let mut warnings = Vec::new();
for pattern in self.patterns.values() {
let action_match = pattern
.triggering_actions
.iter()
.any(|a| action.contains(a) || a.contains(action));
let context_match = pattern
.contexts
.iter()
.any(|c| context.contains(c) || c.contains(context));
if action_match || context_match {
warnings.push(ErrorWarning {
pattern_id: pattern.id.clone(),
error_type: pattern.error_type.clone(),
likelihood: if action_match && context_match {
0.8
} else {
0.4
},
prevention: pattern.prevention.clone(),
recovery: pattern.recovery_strategies.clone(),
});
}
}
warnings.sort_by(|a, b| {
b.likelihood
.partial_cmp(&a.likelihood)
.unwrap_or(std::cmp::Ordering::Equal)
});
warnings
}
pub fn get_stats(&self) -> ErrorLearnerStats {
let total_errors = self.records.len();
let recovered = self.records.iter().filter(|r| r.recovered).count();
let pattern_count = self.patterns.len();
ErrorLearnerStats {
total_errors,
recovered_count: recovered,
pattern_count,
top_error_types: self
.type_counts
.iter()
.map(|(t, c)| (t.clone(), *c))
.collect(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ErrorLearnerSnapshot {
records: Vec<ErrorRecord>,
patterns: HashMap<String, ErrorPattern>,
type_counts: HashMap<String, usize>,
}
impl ErrorPatternLearner {
fn to_snapshot(&self) -> ErrorLearnerSnapshot {
ErrorLearnerSnapshot {
records: self.records.clone(),
patterns: self.patterns.clone(),
type_counts: self.type_counts.clone(),
}
}
fn from_snapshot(snapshot: ErrorLearnerSnapshot) -> Self {
Self {
records: snapshot.records,
patterns: snapshot.patterns,
type_counts: snapshot.type_counts,
max_records: MAX_ENTRIES,
}
}
}
impl Default for ErrorPatternLearner {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorWarning {
pub pattern_id: String,
pub error_type: String,
pub likelihood: f32,
pub prevention: Vec<String>,
pub recovery: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct ErrorLearnerStats {
pub total_errors: usize,
pub recovered_count: usize,
pub pattern_count: usize,
pub top_error_types: Vec<(String, usize)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageSession {
pub id: String,
pub start_time: u64,
pub end_time: Option<u64>,
pub tasks_attempted: usize,
pub tasks_completed: usize,
pub tools_used: Vec<String>,
pub errors: usize,
pub satisfaction: Option<f32>,
}
impl UsageSession {
pub fn new(id: &str) -> Self {
Self {
id: id.to_string(),
start_time: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
end_time: None,
tasks_attempted: 0,
tasks_completed: 0,
tools_used: Vec::new(),
errors: 0,
satisfaction: None,
}
}
pub fn end(&mut self) {
self.end_time = Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
);
}
pub fn duration_secs(&self) -> u64 {
let end = self.end_time.unwrap_or_else(|| {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
});
end.saturating_sub(self.start_time)
}
pub fn completion_rate(&self) -> f32 {
if self.tasks_attempted == 0 {
0.0
} else {
self.tasks_completed as f32 / self.tasks_attempted as f32
}
}
}
pub struct UsageAnalyzer {
sessions: Vec<UsageSession>,
current_session: Option<UsageSession>,
daily_stats: HashMap<String, DailyStats>,
tool_frequency: HashMap<String, usize>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DailyStats {
pub sessions: usize,
pub tasks_attempted: usize,
pub tasks_completed: usize,
pub errors: usize,
pub avg_satisfaction: f32,
pub total_duration_secs: u64,
}
impl UsageAnalyzer {
pub fn new() -> Self {
Self {
sessions: Vec::new(),
current_session: None,
daily_stats: HashMap::new(),
tool_frequency: HashMap::new(),
}
}
pub fn start_session(&mut self, session_id: &str) {
if let Some(mut old_session) = self.current_session.take() {
old_session.end();
self.record_session(old_session);
}
self.current_session = Some(UsageSession::new(session_id));
}
pub fn record_task_attempt(&mut self, completed: bool) {
if let Some(ref mut session) = self.current_session {
session.tasks_attempted += 1;
if completed {
session.tasks_completed += 1;
}
}
}
pub fn record_tool_usage(&mut self, tool: &str) {
*self.tool_frequency.entry(tool.to_string()).or_insert(0) += 1;
if let Some(ref mut session) = self.current_session {
if !session.tools_used.contains(&tool.to_string()) {
session.tools_used.push(tool.to_string());
}
}
}
pub fn record_error(&mut self) {
if let Some(ref mut session) = self.current_session {
session.errors += 1;
}
}
pub fn end_session(&mut self, satisfaction: Option<f32>) {
if let Some(mut session) = self.current_session.take() {
session.satisfaction = satisfaction;
session.end();
self.record_session(session);
}
}
fn record_session(&mut self, session: UsageSession) {
let date = Self::timestamp_to_date(session.start_time);
let daily = self.daily_stats.entry(date).or_default();
daily.sessions += 1;
daily.tasks_attempted += session.tasks_attempted;
daily.tasks_completed += session.tasks_completed;
daily.errors += session.errors;
daily.total_duration_secs += session.duration_secs();
if let Some(sat) = session.satisfaction {
let old_total = daily.avg_satisfaction * (daily.sessions - 1) as f32;
daily.avg_satisfaction = (old_total + sat) / daily.sessions as f32;
}
self.sessions.push(session);
if self.sessions.len() > 1000 {
self.sessions.drain(0..500);
}
}
fn timestamp_to_date(timestamp: u64) -> String {
let days = timestamp / 86400;
format!("day_{}", days)
}
pub fn most_used_tools(&self, limit: usize) -> Vec<(String, usize)> {
let mut tools: Vec<_> = self
.tool_frequency
.iter()
.map(|(t, c)| (t.clone(), *c))
.collect();
tools.sort_by(|a, b| b.1.cmp(&a.1));
tools.into_iter().take(limit).collect()
}
pub fn get_stats(&self) -> UsageStats {
let total_sessions = self.sessions.len();
let total_tasks: usize = self.sessions.iter().map(|s| s.tasks_attempted).sum();
let completed_tasks: usize = self.sessions.iter().map(|s| s.tasks_completed).sum();
let total_errors: usize = self.sessions.iter().map(|s| s.errors).sum();
let avg_satisfaction = {
let rated: Vec<_> = self
.sessions
.iter()
.filter_map(|s| s.satisfaction)
.collect();
if rated.is_empty() {
0.0
} else {
rated.iter().sum::<f32>() / rated.len() as f32
}
};
UsageStats {
total_sessions,
total_tasks,
completed_tasks,
total_errors,
avg_satisfaction,
unique_tools: self.tool_frequency.len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct UsageAnalyzerSnapshot {
sessions: Vec<UsageSession>,
current_session: Option<UsageSession>,
daily_stats: HashMap<String, DailyStats>,
tool_frequency: HashMap<String, usize>,
}
impl UsageAnalyzer {
fn to_snapshot(&self) -> UsageAnalyzerSnapshot {
UsageAnalyzerSnapshot {
sessions: self.sessions.clone(),
current_session: self.current_session.clone(),
daily_stats: self.daily_stats.clone(),
tool_frequency: self.tool_frequency.clone(),
}
}
fn from_snapshot(snapshot: UsageAnalyzerSnapshot) -> Self {
Self {
sessions: snapshot.sessions,
current_session: snapshot.current_session,
daily_stats: snapshot.daily_stats,
tool_frequency: snapshot.tool_frequency,
}
}
}
impl Default for UsageAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct UsageStats {
pub total_sessions: usize,
pub total_tasks: usize,
pub completed_tasks: usize,
pub total_errors: usize,
pub avg_satisfaction: f32,
pub unique_tools: usize,
}
pub struct SelfImprovementEngine {
prompt_optimizer: RwLock<PromptOptimizer>,
tool_learner: RwLock<ToolSelectionLearner>,
error_learner: RwLock<ErrorPatternLearner>,
usage_analyzer: RwLock<UsageAnalyzer>,
learning_enabled: bool,
}
impl SelfImprovementEngine {
pub fn new() -> Self {
Self {
prompt_optimizer: RwLock::new(PromptOptimizer::new()),
tool_learner: RwLock::new(ToolSelectionLearner::new()),
error_learner: RwLock::new(ErrorPatternLearner::new()),
usage_analyzer: RwLock::new(UsageAnalyzer::new()),
learning_enabled: true,
}
}
pub fn set_learning_enabled(&mut self, enabled: bool) {
self.learning_enabled = enabled;
}
pub fn record_prompt(&self, prompt: &str, task_type: &str, outcome: Outcome, quality: f32) {
if !self.learning_enabled {
return;
}
if let Ok(mut optimizer) = self.prompt_optimizer.write() {
optimizer.record(
PromptRecord::new(prompt.to_string(), task_type.to_string(), outcome)
.with_quality(quality),
);
}
}
pub fn record_tool(
&self,
tool: &str,
context: &str,
outcome: Outcome,
time_ms: u64,
error: Option<String>,
) {
if !self.learning_enabled {
return;
}
if let Ok(mut learner) = self.tool_learner.write() {
let mut record = ToolUsageRecord::new(tool.to_string(), context.to_string(), outcome)
.with_execution_time(time_ms);
if let Some(err) = error {
record = record.with_error(err);
}
learner.record(record);
}
if let Ok(mut analyzer) = self.usage_analyzer.write() {
analyzer.record_tool_usage(tool);
}
}
pub fn record_error(
&self,
message: &str,
error_type: &str,
context: &str,
action: &str,
recovery: Option<String>,
) {
if !self.learning_enabled {
return;
}
if let Ok(mut learner) = self.error_learner.write() {
let mut record = ErrorRecord::new(
message.to_string(),
error_type.to_string(),
context.to_string(),
action.to_string(),
);
if let Some(rec) = recovery {
record = record.with_recovery(rec);
}
learner.record(record);
}
if let Ok(mut analyzer) = self.usage_analyzer.write() {
analyzer.record_error();
}
}
pub fn best_tools_for(&self, context: &str) -> Vec<(String, f32)> {
if let Ok(learner) = self.tool_learner.read() {
learner.best_tools_for(context)
} else {
Vec::new()
}
}
pub fn check_for_errors(&self, action: &str, context: &str) -> Vec<ErrorWarning> {
if let Ok(learner) = self.error_learner.read() {
learner.might_trigger_error(action, context)
} else {
Vec::new()
}
}
pub fn suggest_prompt_improvements(
&self,
prompt: &str,
task_type: &str,
) -> Vec<PromptSuggestion> {
if let Ok(optimizer) = self.prompt_optimizer.read() {
optimizer.suggest_improvements(prompt, task_type)
} else {
Vec::new()
}
}
pub fn evolve_prompt(&self, prompt: &str, task_type: &str) -> Option<PromptTournamentResult> {
if !self.learning_enabled {
return None;
}
if let Ok(mut optimizer) = self.prompt_optimizer.write() {
Some(optimizer.evolve_prompt(task_type, prompt))
} else {
None
}
}
pub fn start_session(&self, session_id: &str) {
if let Ok(mut analyzer) = self.usage_analyzer.write() {
analyzer.start_session(session_id);
}
}
pub fn record_task(&self, completed: bool) {
if let Ok(mut analyzer) = self.usage_analyzer.write() {
analyzer.record_task_attempt(completed);
}
}
pub fn end_session(&self, satisfaction: Option<f32>) {
if let Ok(mut analyzer) = self.usage_analyzer.write() {
analyzer.end_session(satisfaction);
}
}
pub fn get_stats(&self) -> ImprovementStats {
let prompt_stats = self.prompt_optimizer.read().ok().map(|o| o.get_stats());
let tool_stats = self.tool_learner.read().ok().map(|l| l.get_stats());
let error_stats = self.error_learner.read().ok().map(|l| l.get_stats());
let usage_stats = self.usage_analyzer.read().ok().map(|a| a.get_stats());
ImprovementStats {
prompt_stats,
tool_stats,
error_stats,
usage_stats,
learning_enabled: self.learning_enabled,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct EngineSnapshot {
prompt_optimizer: PromptOptimizerSnapshot,
tool_learner: ToolLearnerSnapshot,
error_learner: ErrorLearnerSnapshot,
usage_analyzer: UsageAnalyzerSnapshot,
learning_enabled: bool,
}
impl SelfImprovementEngine {
pub fn save(&self, path: &Path) -> Result<()> {
let snapshot = EngineSnapshot {
prompt_optimizer: self
.prompt_optimizer
.read()
.map_err(|e| anyhow::anyhow!("Lock poisoned: {}", e))?
.to_snapshot(),
tool_learner: self
.tool_learner
.read()
.map_err(|e| anyhow::anyhow!("Lock poisoned: {}", e))?
.to_snapshot(),
error_learner: self
.error_learner
.read()
.map_err(|e| anyhow::anyhow!("Lock poisoned: {}", e))?
.to_snapshot(),
usage_analyzer: self
.usage_analyzer
.read()
.map_err(|e| anyhow::anyhow!("Lock poisoned: {}", e))?
.to_snapshot(),
learning_enabled: self.learning_enabled,
};
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let content = serde_json::to_string_pretty(&snapshot)?;
std::fs::write(path, content)?;
Ok(())
}
pub fn load(path: &Path) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
let snapshot: EngineSnapshot = serde_json::from_str(&content)?;
Ok(Self {
prompt_optimizer: RwLock::new(PromptOptimizer::from_snapshot(
snapshot.prompt_optimizer,
)),
tool_learner: RwLock::new(ToolSelectionLearner::from_snapshot(snapshot.tool_learner)),
error_learner: RwLock::new(ErrorPatternLearner::from_snapshot(snapshot.error_learner)),
usage_analyzer: RwLock::new(UsageAnalyzer::from_snapshot(snapshot.usage_analyzer)),
learning_enabled: snapshot.learning_enabled,
})
}
}
impl Default for SelfImprovementEngine {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ImprovementStats {
pub prompt_stats: Option<PromptOptimizerStats>,
pub tool_stats: Option<ToolLearnerStats>,
pub error_stats: Option<ErrorLearnerStats>,
pub usage_stats: Option<UsageStats>,
pub learning_enabled: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_outcome_score() {
assert_eq!(Outcome::Success.score(), 1.0);
assert_eq!(Outcome::Partial.score(), 0.5);
assert_eq!(Outcome::Failure.score(), 0.0);
}
#[test]
fn test_outcome_is_positive() {
assert!(Outcome::Success.is_positive());
assert!(Outcome::Partial.is_positive());
assert!(!Outcome::Failure.is_positive());
assert!(!Outcome::Abandoned.is_positive());
}
#[test]
fn test_prompt_record_new() {
let record = PromptRecord::new(
"test prompt".to_string(),
"code_write".to_string(),
Outcome::Success,
);
assert_eq!(record.quality_score, 1.0);
assert!(record.timestamp > 0);
}
#[test]
fn test_prompt_record_with_quality() {
let record = PromptRecord::new("test".to_string(), "code".to_string(), Outcome::Partial)
.with_quality(0.8);
assert_eq!(record.quality_score, 0.8);
}
#[test]
fn test_prompt_pattern_new() {
let pattern = PromptPattern::new("p1", "Please {action} the {target}");
assert_eq!(pattern.id, "p1");
assert_eq!(pattern.usage_count, 0);
}
#[test]
fn test_prompt_pattern_update() {
let mut pattern = PromptPattern::new("p1", "template");
pattern.update(Outcome::Success, 0.9);
pattern.update(Outcome::Failure, 0.2);
assert_eq!(pattern.usage_count, 2);
assert_eq!(pattern.success_rate, 0.5);
}
#[test]
fn test_prompt_optimizer_new() {
let optimizer = PromptOptimizer::new();
assert_eq!(optimizer.get_stats().total_records, 0);
}
#[test]
fn test_prompt_optimizer_record() {
let mut optimizer = PromptOptimizer::new();
optimizer.record(PromptRecord::new(
"test".to_string(),
"code".to_string(),
Outcome::Success,
));
assert_eq!(optimizer.get_stats().total_records, 1);
}
#[test]
fn test_prompt_optimizer_suggest_improvements() {
let optimizer = PromptOptimizer::new();
let suggestions = optimizer.suggest_improvements("x", "code");
assert!(!suggestions.is_empty()); }
#[test]
fn test_prompt_optimizer_evolve_prompt() {
let mut optimizer = PromptOptimizer::new();
for _ in 0..4 {
optimizer.record(
PromptRecord::new(
"Use a step-by-step plan and verify output".to_string(),
"system_prompt".to_string(),
Outcome::Success,
)
.with_quality(0.9),
);
}
let result = optimizer.evolve_prompt("system_prompt", "You are a coding agent.");
assert!(!result.variants.is_empty());
assert!(result.winner_score >= 0.0);
assert!(!result.winner_prompt.is_empty());
}
#[test]
fn test_tool_usage_record_new() {
let record = ToolUsageRecord::new(
"file_read".to_string(),
"reading config".to_string(),
Outcome::Success,
);
assert_eq!(record.tool, "file_read");
assert!(record.error.is_none());
}
#[test]
fn test_tool_stats_success_rate() {
let stats = ToolStats {
usage_count: 10,
success_count: 8,
..Default::default()
};
assert_eq!(stats.success_rate(), 0.8);
}
#[test]
fn test_tool_selection_learner_new() {
let learner = ToolSelectionLearner::new();
assert_eq!(learner.get_stats().total_records, 0);
}
#[test]
fn test_tool_selection_learner_record() {
let mut learner = ToolSelectionLearner::new();
learner.record(ToolUsageRecord::new(
"file_read".to_string(),
"reading file".to_string(),
Outcome::Success,
));
assert_eq!(learner.get_stats().total_records, 1);
assert_eq!(learner.get_stats().unique_tools, 1);
}
#[test]
fn test_tool_selection_learner_best_tools() {
let mut learner = ToolSelectionLearner::new();
for _ in 0..5 {
learner.record(ToolUsageRecord::new(
"file_read".to_string(),
"reading".to_string(),
Outcome::Success,
));
}
for _ in 0..3 {
learner.record(ToolUsageRecord::new(
"file_write".to_string(),
"writing".to_string(),
Outcome::Failure,
));
}
let best = learner.best_tools_for("reading");
assert!(!best.is_empty());
}
#[test]
fn test_error_record_new() {
let record = ErrorRecord::new(
"file not found".to_string(),
"io_error".to_string(),
"loading config".to_string(),
"file_read".to_string(),
);
assert!(!record.recovered);
}
#[test]
fn test_error_record_with_recovery() {
let record = ErrorRecord::new(
"error".to_string(),
"type".to_string(),
"ctx".to_string(),
"action".to_string(),
)
.with_recovery("retry".to_string());
assert!(record.recovered);
assert_eq!(record.recovery_action, Some("retry".to_string()));
}
#[test]
fn test_error_pattern_new() {
let pattern = ErrorPattern::new("p1", "io_error");
assert_eq!(pattern.count, 0);
}
#[test]
fn test_error_pattern_update() {
let mut pattern = ErrorPattern::new("p1", "io_error");
let record = ErrorRecord::new(
"error".to_string(),
"io_error".to_string(),
"context".to_string(),
"action".to_string(),
);
pattern.update(&record);
assert_eq!(pattern.count, 1);
assert!(pattern.contexts.contains(&"context".to_string()));
}
#[test]
fn test_error_pattern_learner_new() {
let learner = ErrorPatternLearner::new();
assert_eq!(learner.get_stats().total_errors, 0);
}
#[test]
fn test_error_pattern_learner_record() {
let mut learner = ErrorPatternLearner::new();
learner.record(ErrorRecord::new(
"error".to_string(),
"type".to_string(),
"ctx".to_string(),
"action".to_string(),
));
assert_eq!(learner.get_stats().total_errors, 1);
}
#[test]
fn test_error_pattern_learner_might_trigger() {
let mut learner = ErrorPatternLearner::new();
learner.record(ErrorRecord::new(
"file not found".to_string(),
"io_error".to_string(),
"loading config".to_string(),
"file_read".to_string(),
));
let warnings = learner.might_trigger_error("file_read", "loading");
assert!(!warnings.is_empty());
}
#[test]
fn test_usage_session_new() {
let session = UsageSession::new("s1");
assert_eq!(session.id, "s1");
assert!(session.end_time.is_none());
}
#[test]
fn test_usage_session_end() {
let mut session = UsageSession::new("s1");
session.end();
assert!(session.end_time.is_some());
}
#[test]
fn test_usage_session_completion_rate() {
let mut session = UsageSession::new("s1");
session.tasks_attempted = 10;
session.tasks_completed = 8;
assert_eq!(session.completion_rate(), 0.8);
}
#[test]
fn test_usage_analyzer_new() {
let analyzer = UsageAnalyzer::new();
assert_eq!(analyzer.get_stats().total_sessions, 0);
}
#[test]
fn test_usage_analyzer_session() {
let mut analyzer = UsageAnalyzer::new();
analyzer.start_session("s1");
analyzer.record_task_attempt(true);
analyzer.record_tool_usage("file_read");
analyzer.end_session(Some(0.9));
let stats = analyzer.get_stats();
assert_eq!(stats.total_sessions, 1);
assert_eq!(stats.completed_tasks, 1);
}
#[test]
fn test_usage_analyzer_most_used_tools() {
let mut analyzer = UsageAnalyzer::new();
for _ in 0..5 {
analyzer.record_tool_usage("file_read");
}
for _ in 0..3 {
analyzer.record_tool_usage("file_write");
}
let tools = analyzer.most_used_tools(2);
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].0, "file_read");
}
#[test]
fn test_self_improvement_engine_new() {
let engine = SelfImprovementEngine::new();
assert!(engine.learning_enabled);
}
#[test]
fn test_self_improvement_engine_record_prompt() {
let engine = SelfImprovementEngine::new();
engine.record_prompt("test prompt", "code", Outcome::Success, 0.9);
let stats = engine.get_stats();
assert!(stats.prompt_stats.is_some());
}
#[test]
fn test_self_improvement_engine_record_tool() {
let engine = SelfImprovementEngine::new();
engine.record_tool("file_read", "reading config", Outcome::Success, 100, None);
let stats = engine.get_stats();
assert!(stats.tool_stats.is_some());
}
#[test]
fn test_self_improvement_engine_evolve_prompt() {
let engine = SelfImprovementEngine::new();
let result = engine.evolve_prompt("You are Selfware.", "system_prompt");
assert!(result.is_some());
let result = result.unwrap();
assert!(!result.winner_prompt.is_empty());
assert!(!result.variants.is_empty());
}
#[test]
fn test_self_improvement_engine_record_error() {
let engine = SelfImprovementEngine::new();
engine.record_error(
"error msg",
"io_error",
"context",
"action",
Some("retry".to_string()),
);
let stats = engine.get_stats();
assert!(stats.error_stats.is_some());
}
#[test]
fn test_self_improvement_engine_best_tools() {
let engine = SelfImprovementEngine::new();
for _ in 0..5 {
engine.record_tool("file_read", "reading", Outcome::Success, 100, None);
}
let best = engine.best_tools_for("reading");
assert!(!best.is_empty());
}
#[test]
fn test_self_improvement_engine_check_errors() {
let engine = SelfImprovementEngine::new();
engine.record_error("file not found", "io_error", "loading", "file_read", None);
let warnings = engine.check_for_errors("file_read", "loading");
assert!(!warnings.is_empty());
}
#[test]
fn test_self_improvement_engine_session() {
let engine = SelfImprovementEngine::new();
engine.start_session("s1");
engine.record_task(true);
engine.end_session(Some(0.9));
let stats = engine.get_stats();
assert!(stats.usage_stats.is_some());
}
#[test]
fn test_self_improvement_engine_disable_learning() {
let mut engine = SelfImprovementEngine::new();
engine.set_learning_enabled(false);
engine.record_prompt("test", "code", Outcome::Success, 1.0);
let stats = engine.get_stats();
assert!(stats.prompt_stats.unwrap().total_records == 0);
}
#[test]
fn test_self_improvement_engine_save_load_roundtrip() {
let engine = SelfImprovementEngine::new();
engine.record_prompt("test prompt", "code", Outcome::Success, 0.9);
engine.record_tool("file_read", "reading config", Outcome::Success, 100, None);
engine.record_error("file not found", "io_error", "loading", "file_read", None);
engine.start_session("s1");
engine.record_task(true);
let tmp = std::env::temp_dir().join("selfware_test_engine.json");
engine.save(&tmp).unwrap();
let loaded = SelfImprovementEngine::load(&tmp).unwrap();
let stats = loaded.get_stats();
assert_eq!(stats.prompt_stats.unwrap().total_records, 1);
assert_eq!(stats.tool_stats.unwrap().total_records, 1);
assert_eq!(stats.error_stats.unwrap().total_errors, 1);
std::fs::remove_file(&tmp).ok();
}
#[test]
fn test_save_load_preserves_tool_stats() {
let engine = SelfImprovementEngine::new();
for _ in 0..5 {
engine.record_tool("file_read", "reading", Outcome::Success, 50, None);
}
for _ in 0..3 {
engine.record_tool(
"file_write",
"writing",
Outcome::Failure,
100,
Some("permission denied".to_string()),
);
}
let tmp = std::env::temp_dir().join("selfware_test_engine_tools.json");
engine.save(&tmp).unwrap();
let loaded = SelfImprovementEngine::load(&tmp).unwrap();
let best = loaded.best_tools_for("reading");
assert!(!best.is_empty());
let file_read_score = best.iter().find(|(t, _)| t == "file_read").map(|(_, s)| *s);
assert!(file_read_score.is_some());
std::fs::remove_file(&tmp).ok();
}
#[test]
fn test_save_load_preserves_error_patterns() {
let engine = SelfImprovementEngine::new();
engine.record_error("timeout waiting", "timeout", "api_call", "shell_exec", None);
engine.record_error(
"timeout waiting",
"timeout",
"api_call",
"shell_exec",
Some("retry".to_string()),
);
let tmp = std::env::temp_dir().join("selfware_test_engine_errors.json");
engine.save(&tmp).unwrap();
let loaded = SelfImprovementEngine::load(&tmp).unwrap();
let warnings = loaded.check_for_errors("shell_exec", "api_call");
assert!(!warnings.is_empty());
std::fs::remove_file(&tmp).ok();
}
#[test]
fn test_save_load_preserves_usage_sessions() {
let engine = SelfImprovementEngine::new();
engine.start_session("s1");
engine.record_task(true);
engine.record_task(false);
engine.end_session(Some(0.7));
let tmp = std::env::temp_dir().join("selfware_test_engine_sessions.json");
engine.save(&tmp).unwrap();
let loaded = SelfImprovementEngine::load(&tmp).unwrap();
let stats = loaded.get_stats();
let usage = stats.usage_stats.unwrap();
assert_eq!(usage.total_sessions, 1);
assert_eq!(usage.total_tasks, 2);
assert_eq!(usage.completed_tasks, 1);
std::fs::remove_file(&tmp).ok();
}
#[test]
fn test_load_nonexistent_file_errors() {
let result = SelfImprovementEngine::load(std::path::Path::new(
"/tmp/selfware_nonexistent_engine_12345.json",
));
assert!(result.is_err());
}
#[test]
fn test_save_creates_parent_dirs() {
let tmp = std::env::temp_dir().join("selfware_test_nested/deep/dir/engine.json");
std::fs::remove_dir_all(std::env::temp_dir().join("selfware_test_nested")).ok();
let engine = SelfImprovementEngine::new();
engine.save(&tmp).unwrap();
assert!(tmp.exists());
std::fs::remove_dir_all(std::env::temp_dir().join("selfware_test_nested")).ok();
}
#[test]
fn test_outcome_serialization_roundtrip() {
for outcome in [
Outcome::Success,
Outcome::Partial,
Outcome::Failure,
Outcome::Abandoned,
] {
let json = serde_json::to_string(&outcome).unwrap();
let deserialized: Outcome = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, outcome);
}
}
#[test]
fn test_prompt_record_serialization_roundtrip() {
let record = PromptRecord::new(
"test prompt".to_string(),
"code".to_string(),
Outcome::Success,
)
.with_quality(0.85)
.with_tokens(1500)
.with_response_time(2000);
let json = serde_json::to_string(&record).unwrap();
let deserialized: PromptRecord = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.prompt, "test prompt");
assert_eq!(deserialized.quality_score, 0.85);
assert_eq!(deserialized.tokens_used, 1500);
assert_eq!(deserialized.response_time_ms, 2000);
}
#[test]
fn test_tool_usage_record_serialization_roundtrip() {
let record = ToolUsageRecord::new(
"cargo_check".to_string(),
"building".to_string(),
Outcome::Failure,
)
.with_execution_time(5000)
.with_error("compilation error".to_string());
let json = serde_json::to_string(&record).unwrap();
let deserialized: ToolUsageRecord = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.tool, "cargo_check");
assert_eq!(deserialized.outcome, Outcome::Failure);
assert_eq!(deserialized.execution_time_ms, 5000);
assert_eq!(deserialized.error, Some("compilation error".to_string()));
}
#[test]
fn test_error_record_serialization_roundtrip() {
let record = ErrorRecord::new(
"file not found".to_string(),
"io_error".to_string(),
"loading config".to_string(),
"file_read".to_string(),
)
.with_recovery("use default".to_string());
let json = serde_json::to_string(&record).unwrap();
let deserialized: ErrorRecord = serde_json::from_str(&json).unwrap();
assert!(deserialized.recovered);
assert_eq!(
deserialized.recovery_action,
Some("use default".to_string())
);
}
#[test]
fn test_usage_session_zero_tasks_completion_rate() {
let session = UsageSession::new("s1");
assert_eq!(session.completion_rate(), 0.0);
}
#[test]
fn test_usage_session_serialization_roundtrip() {
let mut session = UsageSession::new("s1");
session.tasks_attempted = 5;
session.tasks_completed = 3;
session.tools_used = vec!["file_read".to_string(), "shell_exec".to_string()];
session.end();
let json = serde_json::to_string(&session).unwrap();
let deserialized: UsageSession = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, "s1");
assert!(deserialized.end_time.is_some());
assert_eq!(deserialized.tools_used.len(), 2);
}
#[test]
fn test_tool_stats_serialization_roundtrip() {
let stats = ToolStats {
usage_count: 10,
success_count: 8,
failure_count: 2,
avg_execution_time_ms: 150.0,
effective_contexts: vec!["reading files".to_string()],
common_errors: HashMap::from([("permission denied".to_string(), 2)]),
};
let json = serde_json::to_string(&stats).unwrap();
let deserialized: ToolStats = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.usage_count, 10);
assert_eq!(deserialized.success_rate(), 0.8);
assert_eq!(deserialized.common_errors.len(), 1);
}
#[test]
fn test_error_warning_serialization_roundtrip() {
let warning = ErrorWarning {
pattern_id: "p1".to_string(),
error_type: "timeout".to_string(),
likelihood: 0.8,
prevention: vec!["set longer timeout".to_string()],
recovery: vec!["retry".to_string()],
};
let json = serde_json::to_string(&warning).unwrap();
let deserialized: ErrorWarning = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.pattern_id, "p1");
assert_eq!(deserialized.likelihood, 0.8);
}
#[test]
fn test_error_pattern_serialization_roundtrip() {
let mut pattern = ErrorPattern::new("p1", "io_error");
let record = ErrorRecord::new(
"not found".to_string(),
"io_error".to_string(),
"context".to_string(),
"action".to_string(),
);
pattern.update(&record);
pattern.add_prevention("check existence first".to_string());
let json = serde_json::to_string(&pattern).unwrap();
let deserialized: ErrorPattern = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.count, 1);
assert_eq!(deserialized.prevention.len(), 1);
}
#[test]
fn test_suggest_improvements_short_prompt() {
let engine = SelfImprovementEngine::new();
let suggestions = engine.suggest_prompt_improvements("x", "code");
assert!(!suggestions.is_empty());
assert!(suggestions
.iter()
.any(|s| s.suggestion_type == SuggestionType::AddContext));
}
#[test]
fn test_tool_selection_learner_common_errors() {
let mut learner = ToolSelectionLearner::new();
learner.record(
ToolUsageRecord::new(
"shell_exec".to_string(),
"running".to_string(),
Outcome::Failure,
)
.with_error("permission denied".to_string()),
);
learner.record(
ToolUsageRecord::new(
"shell_exec".to_string(),
"running".to_string(),
Outcome::Failure,
)
.with_error("permission denied".to_string()),
);
let errors = learner.common_errors_for("shell_exec");
assert!(!errors.is_empty());
assert!(errors[0].1 >= 2);
}
#[test]
fn test_tool_selection_learner_no_stats() {
let learner = ToolSelectionLearner::new();
assert!(learner.get_tool_stats("nonexistent").is_none());
assert!(learner.common_errors_for("nonexistent").is_empty());
}
#[test]
fn test_usage_analyzer_multiple_sessions() {
let mut analyzer = UsageAnalyzer::new();
analyzer.start_session("s1");
analyzer.record_task_attempt(true);
analyzer.record_tool_usage("file_read");
analyzer.end_session(Some(0.8));
analyzer.start_session("s2");
analyzer.record_task_attempt(false);
analyzer.record_error();
analyzer.end_session(Some(0.5));
let stats = analyzer.get_stats();
assert_eq!(stats.total_sessions, 2);
assert_eq!(stats.total_tasks, 2);
assert_eq!(stats.completed_tasks, 1);
assert_eq!(stats.total_errors, 1);
}
#[test]
fn test_prompt_optimizer_best_patterns() {
let mut optimizer = PromptOptimizer::new();
let mut pattern = PromptPattern::new("p1", "Step by step: {action}");
pattern.effective_for = vec!["code".to_string()];
for _ in 0..6 {
pattern.update(Outcome::Success, 0.9);
}
optimizer.register_pattern(pattern);
let best = optimizer.best_patterns_for("code");
assert_eq!(best.len(), 1);
assert_eq!(best[0].id, "p1");
}
}