use super::context_detector::ContextDetector;
use super::detected_context::{ConfidenceScores, DetectedContext};
use super::error::AgentError;
use super::payload::{Payload, PayloadContent};
use crate::context::TaskHealth;
use async_trait::async_trait;
#[derive(Debug, Clone, Default)]
pub struct RuleBasedDetector {
pub at_risk_threshold: usize,
pub failure_rate_threshold: f64,
}
impl RuleBasedDetector {
pub fn new() -> Self {
Self {
at_risk_threshold: 2,
failure_rate_threshold: 0.4, }
}
pub fn with_thresholds(at_risk_threshold: usize, failure_rate_threshold: f64) -> Self {
Self {
at_risk_threshold,
failure_rate_threshold,
}
}
fn detect_task_health(&self, payload: &Payload) -> (Option<TaskHealth>, f64) {
if let Some(env) = payload.latest_env_context() {
let mut confidence: f64 = 0.0;
let mut is_at_risk = false;
if env.redesign_count > self.at_risk_threshold {
is_at_risk = true;
confidence += 0.4;
}
if let Some(journal) = &env.journal_summary {
if journal.success_rate < (1.0 - self.failure_rate_threshold) {
is_at_risk = true;
confidence += 0.3;
}
if journal.consecutive_failures > 2 {
is_at_risk = true;
confidence += 0.3;
}
}
let health = if is_at_risk {
Some(TaskHealth::AtRisk)
} else {
Some(TaskHealth::OnTrack)
};
(health, confidence.min(1.0))
} else {
(None, 0.0)
}
}
fn detect_task_type(&self, payload: &Payload) -> (Option<String>, f64) {
let text = payload
.contents()
.iter()
.filter_map(|c| match c {
PayloadContent::Text(t) => Some(t.as_str()),
PayloadContent::Message { content, .. } => Some(content.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join(" ")
.to_lowercase();
let patterns = vec![
(
"security-review",
vec!["security", "vulnerability", "exploit", "auth"],
0.8,
),
(
"code-review",
vec!["review", "pr", "pull request", "refactor"],
0.7,
),
("debug", vec!["debug", "error", "bug", "fix", "crash"], 0.8),
(
"implementation",
vec!["implement", "feature", "add", "create"],
0.6,
),
("test", vec!["test", "spec", "coverage"], 0.7),
];
for (task_type, keywords, base_confidence) in patterns {
let matches = keywords.iter().filter(|kw| text.contains(*kw)).count();
if matches > 0 {
let confidence = (matches as f64 / keywords.len() as f64) * base_confidence;
return (Some(task_type.to_string()), confidence);
}
}
(None, 0.0)
}
}
#[async_trait]
impl ContextDetector for RuleBasedDetector {
async fn detect(&self, payload: &Payload) -> Result<DetectedContext, AgentError> {
let mut detected = DetectedContext::new();
let (health, health_confidence) = self.detect_task_health(payload);
let has_health = health.is_some();
if let Some(h) = health {
detected = detected.with_task_health(h);
}
let (task_type, type_confidence) = self.detect_task_type(payload);
let has_task_type = task_type.is_some();
if let Some(tt) = task_type {
detected = detected.with_task_type(tt);
}
let mut confidence = ConfidenceScores::new();
if has_health {
confidence = confidence.with_task_health(health_confidence);
}
if has_task_type {
confidence = confidence.with_task_type(type_confidence);
}
detected = detected.with_confidence(confidence);
detected = detected.detected_by("RuleBasedDetector");
Ok(detected)
}
fn name(&self) -> &str {
"RuleBasedDetector"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::{EnvContext, JournalSummary};
#[tokio::test]
async fn test_detect_task_health_at_risk() {
let detector = RuleBasedDetector::new();
let env = EnvContext::new().with_redesign_count(3);
let payload = Payload::text("Test").with_env_context(env);
let detected = detector.detect(&payload).await.unwrap();
assert_eq!(detected.task_health, Some(TaskHealth::AtRisk));
assert!(detected.confidence.unwrap().task_health.unwrap() > 0.0);
}
#[tokio::test]
async fn test_detect_task_health_from_journal() {
let detector = RuleBasedDetector::new();
let journal = JournalSummary::new(10, 5).with_consecutive_failures(3);
let env = EnvContext::new().with_journal_summary(journal);
let payload = Payload::text("Test").with_env_context(env);
let detected = detector.detect(&payload).await.unwrap();
assert_eq!(detected.task_health, Some(TaskHealth::AtRisk));
}
#[tokio::test]
async fn test_detect_task_type_security() {
let detector = RuleBasedDetector::new();
let payload = Payload::text("Review this security-critical authentication code");
let detected = detector.detect(&payload).await.unwrap();
assert_eq!(detected.task_type, Some("security-review".to_string()));
assert!(detected.confidence.unwrap().task_type.unwrap() > 0.0);
}
#[tokio::test]
async fn test_detect_task_type_debug() {
let detector = RuleBasedDetector::new();
let payload = Payload::text("Debug this error and fix the bug");
let detected = detector.detect(&payload).await.unwrap();
assert_eq!(detected.task_type, Some("debug".to_string()));
}
#[tokio::test]
async fn test_detect_combined() {
let detector = RuleBasedDetector::new();
let env = EnvContext::new().with_redesign_count(3);
let payload = Payload::text("Debug this security vulnerability").with_env_context(env);
let detected = detector.detect(&payload).await.unwrap();
assert_eq!(detected.task_health, Some(TaskHealth::AtRisk));
assert!(detected.task_type.is_some()); assert_eq!(detected.detected_by, vec!["RuleBasedDetector"]);
}
#[tokio::test]
async fn test_custom_thresholds() {
let detector = RuleBasedDetector::with_thresholds(5, 0.6);
let env = EnvContext::new().with_redesign_count(3);
let payload = Payload::text("Test").with_env_context(env);
let detected = detector.detect(&payload).await.unwrap();
assert_eq!(detected.task_health, Some(TaskHealth::OnTrack));
}
}