use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use serde_json::Value;
use crate::actions::ActionsConfig;
use crate::agent::Guidance;
use crate::state::Escalation;
use crate::types::WorkerId;
#[derive(Debug, Clone)]
pub struct WorkerSummary {
pub id: WorkerId,
pub consecutive_failures: u32,
pub last_action: Option<String>,
pub last_success: Option<bool>,
pub last_output: Option<String>,
pub history_len: usize,
pub has_escalation: bool,
}
impl WorkerSummary {
pub fn new(id: WorkerId) -> Self {
Self {
id,
consecutive_failures: 0,
last_action: None,
last_success: None,
last_output: None,
history_len: 0,
has_escalation: false,
}
}
pub fn with_failures(mut self, count: u32) -> Self {
self.consecutive_failures = count;
self
}
pub fn with_last_action(mut self, action: impl Into<String>, success: bool) -> Self {
self.last_action = Some(action.into());
self.last_success = Some(success);
self
}
pub fn with_history_len(mut self, len: usize) -> Self {
self.history_len = len;
self
}
pub fn with_escalation(mut self, has_escalation: bool) -> Self {
self.has_escalation = has_escalation;
self
}
}
#[derive(Debug, Clone)]
pub struct TaskContext {
pub tick: u64,
pub workers: HashMap<WorkerId, WorkerSummary>,
pub success_rate: f64,
pub progress: f64,
pub escalations: Vec<(WorkerId, Escalation)>,
pub available_actions: Option<ActionsConfig>,
pub v2_guidances: Option<Vec<Guidance>>,
pub excluded_actions: Vec<String>,
pub previous_guidances: HashMap<WorkerId, Arc<Guidance>>,
pub done_workers: HashSet<WorkerId>,
pub metadata: HashMap<String, Value>,
}
impl TaskContext {
pub fn new(tick: u64) -> Self {
Self {
tick,
workers: HashMap::new(),
success_rate: 0.0,
progress: 0.0,
escalations: Vec::new(),
available_actions: None,
v2_guidances: None,
excluded_actions: Vec::new(),
previous_guidances: HashMap::new(),
done_workers: HashSet::new(),
metadata: HashMap::new(),
}
}
pub fn with_worker(mut self, summary: WorkerSummary) -> Self {
self.workers.insert(summary.id, summary);
self
}
pub fn with_success_rate(mut self, rate: f64) -> Self {
self.success_rate = rate;
self
}
pub fn with_progress(mut self, progress: f64) -> Self {
self.progress = progress;
self
}
pub fn with_escalation(mut self, worker_id: WorkerId, escalation: Escalation) -> Self {
self.escalations.push((worker_id, escalation));
self
}
pub fn with_actions(mut self, actions: ActionsConfig) -> Self {
self.available_actions = Some(actions);
self
}
pub fn with_previous_guidances(mut self, guidances: HashMap<WorkerId, Arc<Guidance>>) -> Self {
self.previous_guidances = guidances;
self
}
pub fn with_previous_guidance(mut self, worker_id: WorkerId, guidance: Arc<Guidance>) -> Self {
self.previous_guidances.insert(worker_id, guidance);
self
}
pub fn insert<V: Into<Value>>(mut self, key: impl Into<String>, value: V) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn set<V: Into<Value>>(&mut self, key: impl Into<String>, value: V) {
self.metadata.insert(key.into(), value.into());
}
pub fn get(&self, key: &str) -> Option<&Value> {
self.metadata.get(key)
}
pub fn get_str(&self, key: &str) -> Option<&str> {
self.metadata.get(key).and_then(|v| v.as_str())
}
pub fn get_f64(&self, key: &str) -> Option<f64> {
self.metadata.get(key).and_then(|v| v.as_f64())
}
pub fn get_i64(&self, key: &str) -> Option<i64> {
self.metadata.get(key).and_then(|v| v.as_i64())
}
pub fn get_bool(&self, key: &str) -> Option<bool> {
self.metadata.get(key).and_then(|v| v.as_bool())
}
pub fn has_escalations(&self) -> bool {
!self.escalations.is_empty()
}
pub fn has_escalation_for(&self, worker_id: WorkerId) -> bool {
self.escalations.iter().any(|(id, _)| *id == worker_id)
}
pub fn worker(&self, id: WorkerId) -> Option<&WorkerSummary> {
self.workers.get(&id)
}
pub fn escalated_worker_count(&self) -> usize {
self.workers.values().filter(|w| w.has_escalation).count()
}
pub fn worker_ids(&self) -> Vec<WorkerId> {
self.workers.keys().copied().collect()
}
}
impl TaskContext {
pub fn has_exploration(&self) -> bool {
self.v2_guidances.is_some()
}
pub fn filter_for_workers(&self, worker_ids: &[WorkerId]) -> TaskContext {
use std::collections::HashSet;
let worker_set: HashSet<WorkerId> = worker_ids.iter().copied().collect();
let filtered_workers: HashMap<WorkerId, WorkerSummary> = self
.workers
.iter()
.filter(|(id, _)| worker_set.contains(id))
.map(|(id, summary)| (*id, summary.clone()))
.collect();
let filtered_escalations: Vec<(WorkerId, Escalation)> = self
.escalations
.iter()
.filter(|(id, _)| worker_set.contains(id))
.cloned()
.collect();
let filtered_guidances: HashMap<WorkerId, Arc<Guidance>> = self
.previous_guidances
.iter()
.filter(|(id, _)| worker_set.contains(id))
.map(|(id, g)| (*id, Arc::clone(g)))
.collect();
let filtered_done_workers: HashSet<WorkerId> = self
.done_workers
.iter()
.filter(|id| worker_set.contains(id))
.copied()
.collect();
TaskContext {
tick: self.tick,
workers: filtered_workers,
success_rate: self.success_rate,
progress: self.progress,
escalations: filtered_escalations,
available_actions: self.available_actions.clone(),
v2_guidances: self.v2_guidances.clone(),
excluded_actions: self.excluded_actions.clone(),
previous_guidances: filtered_guidances,
done_workers: filtered_done_workers,
metadata: self.metadata.clone(),
}
}
}
impl Default for TaskContext {
fn default() -> Self {
Self::new(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_context_new() {
let ctx = TaskContext::new(10);
assert_eq!(ctx.tick, 10);
assert!(ctx.workers.is_empty());
assert_eq!(ctx.success_rate, 0.0);
assert_eq!(ctx.progress, 0.0);
}
#[test]
fn test_task_context_builder() {
let ctx = TaskContext::new(5)
.with_worker(WorkerSummary::new(WorkerId(0)))
.with_worker(WorkerSummary::new(WorkerId(1)).with_escalation(true))
.with_success_rate(0.8)
.with_progress(0.5)
.insert("key1", "value1")
.insert("count", 42);
assert_eq!(ctx.tick, 5);
assert_eq!(ctx.workers.len(), 2);
assert_eq!(ctx.success_rate, 0.8);
assert_eq!(ctx.progress, 0.5);
assert_eq!(ctx.get_str("key1"), Some("value1"));
assert_eq!(ctx.get_i64("count"), Some(42));
}
#[test]
fn test_worker_summary() {
let summary = WorkerSummary::new(WorkerId(0))
.with_failures(2)
.with_last_action("read:/path", true)
.with_history_len(10)
.with_escalation(true);
assert_eq!(summary.id, WorkerId(0));
assert_eq!(summary.consecutive_failures, 2);
assert_eq!(summary.last_action, Some("read:/path".to_string()));
assert_eq!(summary.last_success, Some(true));
assert_eq!(summary.history_len, 10);
assert!(summary.has_escalation);
}
#[test]
fn test_query_methods() {
let ctx = TaskContext::new(0)
.with_worker(WorkerSummary::new(WorkerId(0)))
.with_worker(WorkerSummary::new(WorkerId(1)).with_escalation(true))
.with_worker(WorkerSummary::new(WorkerId(2)));
assert_eq!(ctx.escalated_worker_count(), 1);
assert_eq!(ctx.worker_ids().len(), 3);
}
#[test]
fn test_filter_for_workers() {
let ctx = TaskContext::new(10)
.with_worker(WorkerSummary::new(WorkerId(0)).with_failures(1))
.with_worker(WorkerSummary::new(WorkerId(1)).with_escalation(true))
.with_worker(WorkerSummary::new(WorkerId(2)).with_history_len(5))
.with_worker(WorkerSummary::new(WorkerId(3)).with_last_action("read", true))
.with_escalation(WorkerId(1), Escalation::consecutive_failures(3, 5))
.with_success_rate(0.75)
.with_progress(0.5)
.insert("meta_key", "meta_value");
let filtered = ctx.filter_for_workers(&[WorkerId(0), WorkerId(2)]);
assert_eq!(filtered.tick, 10);
assert_eq!(filtered.workers.len(), 2);
assert!(filtered.workers.contains_key(&WorkerId(0)));
assert!(filtered.workers.contains_key(&WorkerId(2)));
assert!(!filtered.workers.contains_key(&WorkerId(1)));
assert!(!filtered.workers.contains_key(&WorkerId(3)));
assert_eq!(
filtered
.workers
.get(&WorkerId(0))
.unwrap()
.consecutive_failures,
1
);
assert_eq!(filtered.workers.get(&WorkerId(2)).unwrap().history_len, 5);
assert!(filtered.escalations.is_empty());
assert_eq!(filtered.success_rate, 0.75);
assert_eq!(filtered.progress, 0.5);
assert_eq!(filtered.get_str("meta_key"), Some("meta_value"));
}
#[test]
fn test_filter_for_workers_empty() {
let ctx = TaskContext::new(5)
.with_worker(WorkerSummary::new(WorkerId(0)))
.with_worker(WorkerSummary::new(WorkerId(1)));
let filtered = ctx.filter_for_workers(&[]);
assert_eq!(filtered.tick, 5);
assert!(filtered.workers.is_empty());
}
}