use super::detected_context::DetectedContext;
use super::error::AgentError;
use super::payload::Payload;
use async_trait::async_trait;
#[async_trait]
pub trait ContextDetector: Send + Sync {
async fn detect(&self, payload: &Payload) -> Result<DetectedContext, AgentError>;
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
}
#[async_trait]
pub trait DetectContextExt: Sized {
async fn detect_with<D: ContextDetector + ?Sized>(
self,
detector: &D,
) -> Result<Self, AgentError>;
}
#[async_trait]
impl DetectContextExt for Payload {
async fn detect_with<D: ContextDetector + ?Sized>(
self,
detector: &D,
) -> Result<Self, AgentError> {
let detected = detector.detect(&self).await?;
Ok(self.merge_detected_context(detected))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::EnvContext;
use crate::context::TaskHealth;
struct MockDetector {
should_detect_at_risk: bool,
}
#[async_trait]
impl ContextDetector for MockDetector {
async fn detect(&self, payload: &Payload) -> Result<DetectedContext, AgentError> {
let mut detected = DetectedContext::new();
if let Some(env_ctx) = payload.latest_env_context()
&& self.should_detect_at_risk
&& env_ctx.redesign_count > 2
{
detected = detected.with_task_health(TaskHealth::AtRisk);
}
Ok(detected.detected_by("MockDetector"))
}
fn name(&self) -> &str {
"MockDetector"
}
}
#[tokio::test]
async fn test_context_detector_basic() {
let env_ctx = EnvContext::new().with_redesign_count(3);
let payload = Payload::text("Test").with_env_context(env_ctx);
let detector = MockDetector {
should_detect_at_risk: true,
};
let detected = detector.detect(&payload).await.unwrap();
assert_eq!(detected.task_health, Some(TaskHealth::AtRisk));
assert_eq!(detected.detected_by, vec!["MockDetector"]);
}
#[tokio::test]
async fn test_detect_context_ext() {
let env_ctx = EnvContext::new().with_redesign_count(3);
let detector = MockDetector {
should_detect_at_risk: true,
};
let payload = Payload::text("Test")
.with_env_context(env_ctx)
.detect_with(&detector)
.await
.unwrap();
let detected = payload.detected_context().unwrap();
assert_eq!(detected.task_health, Some(TaskHealth::AtRisk));
}
#[tokio::test]
async fn test_layered_detection() {
let env_ctx = EnvContext::new().with_redesign_count(3);
let detector1 = MockDetector {
should_detect_at_risk: true,
};
let detector2 = MockDetector {
should_detect_at_risk: false, };
let payload = Payload::text("Test")
.with_env_context(env_ctx)
.detect_with(&detector1)
.await
.unwrap()
.detect_with(&detector2)
.await
.unwrap();
let detected = payload.detected_context().unwrap();
assert_eq!(detected.task_health, Some(TaskHealth::AtRisk));
assert_eq!(detected.detected_by.len(), 2);
}
}