use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::context::{ActionRecord, JobContext};
use crate::error::EvaluationError;
use crate::llm::LlmProvider;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvaluationResult {
pub success: bool,
pub confidence: f64,
pub reasoning: String,
pub issues: Vec<String>,
pub suggestions: Vec<String>,
pub quality_score: u32,
}
impl EvaluationResult {
pub fn success(reasoning: impl Into<String>, quality_score: u32) -> Self {
Self {
success: true,
confidence: 0.9,
reasoning: reasoning.into(),
issues: vec![],
suggestions: vec![],
quality_score,
}
}
pub fn failure(reasoning: impl Into<String>, issues: Vec<String>) -> Self {
Self {
success: false,
confidence: 0.9,
reasoning: reasoning.into(),
issues,
suggestions: vec![],
quality_score: 0,
}
}
}
#[async_trait]
pub trait SuccessEvaluator: Send + Sync {
async fn evaluate(
&self,
job: &JobContext,
actions: &[ActionRecord],
output: Option<&str>,
) -> Result<EvaluationResult, EvaluationError>;
}
pub struct RuleBasedEvaluator {
min_action_success_rate: f64,
max_failures: u32,
}
impl RuleBasedEvaluator {
pub fn new() -> Self {
Self {
min_action_success_rate: 0.8,
max_failures: 3,
}
}
#[allow(dead_code)] pub fn with_min_success_rate(mut self, rate: f64) -> Self {
self.min_action_success_rate = rate;
self
}
#[allow(dead_code)] pub fn with_max_failures(mut self, max: u32) -> Self {
self.max_failures = max;
self
}
}
impl Default for RuleBasedEvaluator {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SuccessEvaluator for RuleBasedEvaluator {
async fn evaluate(
&self,
job: &JobContext,
actions: &[ActionRecord],
_output: Option<&str>,
) -> Result<EvaluationResult, EvaluationError> {
let mut issues = Vec::new();
if actions.is_empty() {
return Ok(EvaluationResult::failure(
"No actions were taken",
vec!["No actions recorded".to_string()],
));
}
let successful = actions.iter().filter(|a| a.success).count();
let total = actions.len();
let success_rate = successful as f64 / total as f64;
if success_rate < self.min_action_success_rate {
issues.push(format!(
"Action success rate {:.1}% below threshold {:.1}%",
success_rate * 100.0,
self.min_action_success_rate * 100.0
));
}
let failures = actions.iter().filter(|a| !a.success).count() as u32;
if failures > self.max_failures {
issues.push(format!(
"Too many failures: {} (max {})",
failures, self.max_failures
));
}
for action in actions.iter().filter(|a| !a.success) {
if let Some(ref error) = action.error
&& (error.to_lowercase().contains("critical")
|| error.to_lowercase().contains("fatal"))
{
issues.push(format!("Critical error in {}: {}", action.tool_name, error));
}
}
if job.state != crate::context::JobState::Completed
&& job.state != crate::context::JobState::Submitted
{
issues.push(format!("Job not in completed state: {:?}", job.state));
}
let quality_score = if issues.is_empty() {
let base_score = (success_rate * 80.0) as u32;
let completion_bonus = if job.state == crate::context::JobState::Completed {
20
} else {
0
};
(base_score + completion_bonus).min(100)
} else {
((success_rate * 50.0) as u32).min(50)
};
if issues.is_empty() {
Ok(EvaluationResult::success(
format!(
"Job completed successfully with {}/{} actions succeeding ({:.1}%)",
successful,
total,
success_rate * 100.0
),
quality_score,
))
} else {
Ok(EvaluationResult {
success: false,
confidence: 0.85,
reasoning: format!("Job had {} issues", issues.len()),
issues,
suggestions: vec![
"Review failed actions for common patterns".to_string(),
"Consider adjusting retry logic".to_string(),
],
quality_score,
})
}
}
}
pub struct LlmEvaluator {
llm: Arc<dyn LlmProvider>,
}
impl LlmEvaluator {
#[allow(dead_code)] pub fn new(llm: Arc<dyn LlmProvider>) -> Self {
Self { llm }
}
}
#[async_trait]
impl SuccessEvaluator for LlmEvaluator {
async fn evaluate(
&self,
job: &JobContext,
actions: &[ActionRecord],
output: Option<&str>,
) -> Result<EvaluationResult, EvaluationError> {
let actions_summary: Vec<String> = actions
.iter()
.map(|a| {
format!(
"- {}: {} ({})",
a.tool_name,
if a.success { "success" } else { "failed" },
a.error.as_deref().unwrap_or("ok")
)
})
.collect();
let prompt = format!(
r#"Evaluate if this job was completed successfully.
Job: {}
Description: {}
State: {:?}
Actions taken:
{}
{}
Respond in JSON format:
{{
"success": true/false,
"confidence": 0.0-1.0,
"reasoning": "...",
"issues": ["..."],
"suggestions": ["..."],
"quality_score": 0-100
}}"#,
job.title,
job.description,
job.state,
actions_summary.join("\n"),
output
.map(|o| format!("Output:\n{}", o))
.unwrap_or_default()
);
let request =
crate::llm::CompletionRequest::new(vec![crate::llm::ChatMessage::user(prompt)])
.with_max_tokens(1024)
.with_temperature(0.1);
let response = self
.llm
.complete(request)
.await
.map_err(|e| EvaluationError::Failed {
job_id: job.job_id,
reason: e.to_string(),
})?;
let result: EvaluationResult =
serde_json::from_str(&response.content).map_err(|e| EvaluationError::Failed {
job_id: job.job_id,
reason: format!("Failed to parse LLM evaluation: {}", e),
})?;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::JobContext;
#[tokio::test]
async fn test_rule_based_evaluator_success() {
let evaluator = RuleBasedEvaluator::new();
let mut job = JobContext::new("Test", "Test job");
job.transition_to(crate::context::JobState::InProgress, None)
.unwrap();
job.transition_to(crate::context::JobState::Completed, None)
.unwrap();
let actions = vec![
create_action(true),
create_action(true),
create_action(true),
];
let result = evaluator.evaluate(&job, &actions, None).await.unwrap();
assert!(result.success);
assert!(result.quality_score > 80);
}
#[tokio::test]
async fn test_rule_based_evaluator_failure() {
let evaluator = RuleBasedEvaluator::new().with_max_failures(1);
let job = JobContext::new("Test", "Test job");
let actions = vec![
create_action(true),
create_action(false),
create_action(false),
];
let result = evaluator.evaluate(&job, &actions, None).await.unwrap();
assert!(!result.success);
assert!(!result.issues.is_empty());
}
fn create_action(success: bool) -> ActionRecord {
let mut action = ActionRecord::new(0, "test", serde_json::json!({}));
if success {
action = action.succeed(
None,
serde_json::json!({}),
std::time::Duration::from_secs(1),
);
} else {
action = action.fail("Test error", std::time::Duration::from_secs(1));
}
action
}
}