use std::collections::HashMap;
use serde_json::Value;
use crate::actions::ActionsConfig;
use crate::state::Escalation;
use crate::types::WorkerId;
use crate::agent::{ManagerId, ManagerInstruction};
#[derive(Debug, Clone)]
pub struct ContextStore {
pub global: GlobalContext,
pub workers: HashMap<WorkerId, WorkerContext>,
pub managers: HashMap<ManagerId, ManagerContext>,
pub escalations: Vec<(WorkerId, Escalation)>,
pub actions: Option<ActionsConfig>,
pub metadata: HashMap<String, Value>,
}
impl ContextStore {
pub fn new(tick: u64) -> Self {
Self {
global: GlobalContext::new(tick),
workers: HashMap::new(),
managers: HashMap::new(),
escalations: Vec::new(),
actions: None,
metadata: HashMap::new(),
}
}
pub fn with_worker(mut self, ctx: WorkerContext) -> Self {
self.workers.insert(ctx.id, ctx);
self
}
pub fn with_manager(mut self, ctx: ManagerContext) -> Self {
self.managers.insert(ctx.id, ctx);
self
}
pub fn with_escalation(mut self, worker_id: WorkerId, escalation: Escalation) -> Self {
self.escalations.push((worker_id, escalation));
self
}
pub fn with_actions(mut self, actions: ActionsConfig) -> Self {
self.actions = Some(actions);
self
}
pub fn insert<V: Into<Value>>(mut self, key: impl Into<String>, value: V) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn get(&self, key: &str) -> Option<&Value> {
self.metadata.get(key)
}
pub fn get_str(&self, key: &str) -> Option<&str> {
self.metadata.get(key).and_then(|v| v.as_str())
}
}
#[derive(Debug, Clone, Default)]
pub struct GlobalContext {
pub tick: u64,
pub max_ticks: u64,
pub progress: f64,
pub success_rate: f64,
pub task_description: Option<String>,
pub hint: Option<String>,
}
impl GlobalContext {
pub fn new(tick: u64) -> Self {
Self {
tick,
..Default::default()
}
}
pub fn with_max_ticks(mut self, max: u64) -> Self {
self.max_ticks = max;
self
}
pub fn with_progress(mut self, progress: f64) -> Self {
self.progress = progress;
self
}
pub fn with_success_rate(mut self, rate: f64) -> Self {
self.success_rate = rate;
self
}
pub fn with_task(mut self, description: impl Into<String>) -> Self {
self.task_description = Some(description.into());
self
}
pub fn with_hint(mut self, hint: impl Into<String>) -> Self {
self.hint = Some(hint.into());
self
}
}
#[derive(Debug, Clone)]
pub struct WorkerContext {
pub id: WorkerId,
pub consecutive_failures: u32,
pub last_action: Option<String>,
pub last_success: Option<bool>,
pub history_len: usize,
pub has_escalation: bool,
pub candidates: Vec<String>,
pub metadata: HashMap<String, Value>,
}
impl WorkerContext {
pub fn new(id: WorkerId) -> Self {
Self {
id,
consecutive_failures: 0,
last_action: None,
last_success: None,
history_len: 0,
has_escalation: false,
candidates: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn with_failures(mut self, count: u32) -> Self {
self.consecutive_failures = count;
self
}
pub fn with_last_action(mut self, action: impl Into<String>, success: bool) -> Self {
self.last_action = Some(action.into());
self.last_success = Some(success);
self
}
pub fn with_history_len(mut self, len: usize) -> Self {
self.history_len = len;
self
}
pub fn with_escalation(mut self, has: bool) -> Self {
self.has_escalation = has;
self
}
pub fn with_candidates(mut self, candidates: Vec<String>) -> Self {
self.candidates = candidates;
self
}
}
#[derive(Debug, Clone)]
pub struct ManagerContext {
pub id: ManagerId,
pub name: String,
pub last_tick: u64,
pub metadata: HashMap<String, Value>,
}
impl ManagerContext {
pub fn new(id: ManagerId) -> Self {
Self {
id,
name: format!("Manager_{}", id.0),
last_tick: 0,
metadata: HashMap::new(),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn with_last_tick(mut self, tick: u64) -> Self {
self.last_tick = tick;
self
}
}
#[derive(Debug, Clone)]
pub enum ContextView {
Global { manager_id: ManagerId },
Local {
worker_id: WorkerId,
neighbor_ids: Vec<WorkerId>,
},
Custom {
name: String,
visible_worker_ids: Vec<WorkerId>,
visible_manager_ids: Vec<ManagerId>,
},
}
impl ContextView {
pub fn global(manager_id: ManagerId) -> Self {
Self::Global { manager_id }
}
pub fn local(worker_id: WorkerId) -> Self {
Self::Local {
worker_id,
neighbor_ids: Vec::new(),
}
}
pub fn local_with_neighbors(worker_id: WorkerId, neighbor_ids: Vec<WorkerId>) -> Self {
Self::Local {
worker_id,
neighbor_ids,
}
}
pub fn custom(
name: impl Into<String>,
visible_workers: Vec<WorkerId>,
visible_managers: Vec<ManagerId>,
) -> Self {
Self::Custom {
name: name.into(),
visible_worker_ids: visible_workers,
visible_manager_ids: visible_managers,
}
}
}
#[derive(Debug, Clone)]
pub struct ActionParam {
pub name: String,
pub description: String,
pub required: bool,
}
#[derive(Debug, Clone)]
pub struct ActionCandidate {
pub name: String,
pub description: String,
pub params: Vec<ActionParam>,
pub example: Option<String>,
}
impl ActionCandidate {
pub fn from_config(config: &ActionsConfig) -> Vec<Self> {
config
.all_actions()
.map(|def| ActionCandidate {
name: def.name.clone(),
description: def.description.clone(),
params: def
.params
.iter()
.map(|p| ActionParam {
name: p.name.clone(),
description: p.description.clone(),
required: p.required,
})
.collect(),
example: def.example.clone(),
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct ResolvedContext {
pub global: GlobalContext,
pub visible_workers: Vec<WorkerContext>,
pub escalations: Vec<(WorkerId, Escalation)>,
pub candidates: Vec<ActionCandidate>,
pub metadata: HashMap<String, Value>,
pub target: ContextTarget,
pub self_last_output: Option<String>,
pub manager_instruction: Option<ManagerInstruction>,
}
#[derive(Debug, Clone)]
pub enum ContextTarget {
Manager(ManagerId),
Worker(WorkerId),
}
impl ResolvedContext {
pub fn new(global: GlobalContext, target: ContextTarget) -> Self {
Self {
global,
visible_workers: Vec::new(),
escalations: Vec::new(),
candidates: Vec::new(),
metadata: HashMap::new(),
target,
self_last_output: None,
manager_instruction: None,
}
}
pub fn with_workers(mut self, workers: Vec<WorkerContext>) -> Self {
self.visible_workers = workers;
self
}
pub fn with_escalations(mut self, escalations: Vec<(WorkerId, Escalation)>) -> Self {
self.escalations = escalations;
self
}
pub fn with_candidates(mut self, candidates: Vec<ActionCandidate>) -> Self {
self.candidates = candidates;
self
}
pub fn with_actions_config(mut self, config: &ActionsConfig) -> Self {
self.candidates = ActionCandidate::from_config(config);
self
}
pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
self.metadata = metadata;
self
}
pub fn with_self_last_output(mut self, output: Option<String>) -> Self {
self.self_last_output = output;
self
}
pub fn with_manager_instruction(mut self, instruction: ManagerInstruction) -> Self {
self.manager_instruction = Some(instruction);
self
}
pub fn has_escalations(&self) -> bool {
!self.escalations.is_empty()
}
pub fn is_manager(&self) -> bool {
matches!(self.target, ContextTarget::Manager(_))
}
pub fn is_worker(&self) -> bool {
matches!(self.target, ContextTarget::Worker(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_store_builder() {
let store = ContextStore::new(10)
.with_worker(WorkerContext::new(WorkerId(0)))
.with_worker(WorkerContext::new(WorkerId(1)))
.with_manager(ManagerContext::new(ManagerId(0)))
.insert("task", "Find the bug");
assert_eq!(store.global.tick, 10);
assert_eq!(store.workers.len(), 2);
assert_eq!(store.managers.len(), 1);
assert_eq!(store.get_str("task"), Some("Find the bug"));
}
#[test]
fn test_context_view_creation() {
let global = ContextView::global(ManagerId(0));
assert!(matches!(global, ContextView::Global { .. }));
let local = ContextView::local(WorkerId(0));
assert!(matches!(local, ContextView::Local { .. }));
let local_with_neighbors =
ContextView::local_with_neighbors(WorkerId(0), vec![WorkerId(1), WorkerId(2)]);
if let ContextView::Local { neighbor_ids, .. } = local_with_neighbors {
assert_eq!(neighbor_ids.len(), 2);
}
}
#[test]
fn test_worker_context_builder() {
let ctx = WorkerContext::new(WorkerId(0))
.with_failures(2)
.with_last_action("read:/path", true)
.with_history_len(10)
.with_escalation(true)
.with_candidates(vec!["read".into(), "grep".into()]);
assert_eq!(ctx.id, WorkerId(0));
assert_eq!(ctx.consecutive_failures, 2);
assert_eq!(ctx.last_action, Some("read:/path".to_string()));
assert!(ctx.has_escalation);
assert_eq!(ctx.candidates.len(), 2);
}
#[test]
fn test_resolved_context() {
let global = GlobalContext::new(5)
.with_progress(0.5)
.with_task("Test task");
let candidates = vec![ActionCandidate {
name: "action1".to_string(),
description: "Test action".to_string(),
params: vec![],
example: None,
}];
let resolved = ResolvedContext::new(global, ContextTarget::Worker(WorkerId(0)))
.with_workers(vec![WorkerContext::new(WorkerId(0))])
.with_candidates(candidates);
assert!(resolved.is_worker());
assert!(!resolved.is_manager());
assert_eq!(resolved.visible_workers.len(), 1);
assert_eq!(resolved.candidates.len(), 1);
}
}