use std::any::Any;
use std::collections::{HashMap, HashSet, VecDeque};
use std::time::Duration;
use rayon::prelude::*;
use crate::async_task::TaskStatus;
use crate::extensions::Extensions;
use crate::online_stats::SwarmStats;
use crate::types::{AgentId, TaskId, WorkerId};
#[derive(Debug, Clone, Default)]
pub struct LlmStats {
pub invocations: u64,
pub errors: u64,
pub total_duration: Duration,
}
impl LlmStats {
pub fn success_rate(&self) -> f64 {
if self.invocations == 0 {
1.0
} else {
(self.invocations - self.errors) as f64 / self.invocations as f64
}
}
pub fn record(&mut self, success: bool, duration: Duration) {
self.invocations += 1;
self.total_duration += duration;
if !success {
self.errors += 1;
}
}
}
pub struct SwarmState {
pub shared: SharedState,
pub workers: WorkerStates,
}
impl SwarmState {
pub fn new(worker_count: usize) -> Self {
Self {
shared: SharedState::default(),
workers: WorkerStates::new(worker_count),
}
}
pub fn advance_tick(&mut self) {
self.shared.tick += 1;
}
}
#[derive(Default)]
pub struct SharedState {
pub environment: Environment,
pub stats: SwarmStats,
pub tick: u64,
pub shared_data: SharedData,
pub extensions: Extensions,
pub avg_tick_duration_ns: u64,
pub done_workers: HashSet<WorkerId>,
pub environment_done: bool,
pub llm_stats: LlmStats,
}
impl SharedState {
pub fn mark_worker_done(&mut self, worker_id: WorkerId) {
self.done_workers.insert(worker_id);
self.environment_done = true;
}
pub fn is_worker_done(&self, worker_id: WorkerId) -> bool {
self.done_workers.contains(&worker_id)
}
pub fn is_environment_done(&self) -> bool {
self.environment_done
}
pub fn llm_invocations(&self) -> u64 {
self.llm_stats.invocations
}
pub fn llm_errors(&self) -> u64 {
self.llm_stats.errors
}
}
#[derive(Default)]
pub struct Environment {
pub variables: HashMap<String, String>,
pub flags: HashMap<String, bool>,
}
#[derive(Debug, Clone)]
pub struct TickSnapshot {
pub tick: u64,
pub duration: std::time::Duration,
pub manager_phase: Option<ManagerPhaseSnapshot>,
pub worker_results: Vec<WorkerResultSnapshot>,
}
#[derive(Debug, Clone)]
pub struct ManagerPhaseSnapshot {
pub batch_request: crate::agent::BatchDecisionRequest,
pub responses: Vec<(crate::types::WorkerId, crate::agent::DecisionResponse)>,
pub guidances: std::collections::HashMap<crate::types::WorkerId, crate::agent::Guidance>,
pub llm_errors: u64,
}
#[derive(Debug, Clone)]
pub struct WorkerResultSnapshot {
pub worker_id: crate::types::WorkerId,
pub guidance_received: Option<crate::agent::Guidance>,
pub result: WorkResultSnapshot,
}
#[derive(Debug, Clone)]
pub enum WorkResultSnapshot {
Acted {
action_result: ActionResultSnapshot,
state_delta: Option<crate::agent::WorkerStateDelta>,
},
Continuing { progress: f32 },
NeedsGuidance {
reason: String,
context: crate::agent::GuidanceContext,
},
Escalate {
reason: crate::agent::EscalationReason,
context: Option<String>,
},
Idle,
Done {
success: bool,
message: Option<String>,
},
}
#[derive(Debug, Clone)]
pub struct ActionResultSnapshot {
pub success: bool,
pub output_debug: Option<String>,
pub duration: std::time::Duration,
pub error: Option<String>,
}
impl ActionResultSnapshot {
pub fn from_action_result(result: &crate::types::ActionResult) -> Self {
Self {
success: result.success,
output_debug: result.output.as_ref().map(|o| o.as_text()),
duration: result.duration,
error: result.error.clone(),
}
}
}
const DEFAULT_MAX_ENV_ENTRIES: usize = 500;
pub struct SharedData {
pub kv: HashMap<String, Vec<u8>>,
pub completed_async_tasks: Vec<CompletedAsyncTask>,
max_env_entries: usize,
}
impl Default for SharedData {
fn default() -> Self {
Self {
kv: HashMap::new(),
completed_async_tasks: Vec::new(),
max_env_entries: DEFAULT_MAX_ENV_ENTRIES,
}
}
}
impl SharedData {
pub fn cleanup_env_entries(&mut self) {
let mut env_entries: Vec<(String, u64)> = self
.kv
.keys()
.filter(|k| k.starts_with("env:"))
.filter_map(|k| {
k.rsplit(':')
.next()?
.parse::<u64>()
.ok()
.map(|tick| (k.clone(), tick))
})
.collect();
if env_entries.len() <= self.max_env_entries {
return;
}
env_entries.sort_by_key(|(_, tick)| *tick);
let remove_count = env_entries.len() - self.max_env_entries;
for (key, _) in env_entries.into_iter().take(remove_count) {
self.kv.remove(&key);
}
}
pub fn set_max_env_entries(&mut self, max: usize) {
self.max_env_entries = max;
}
}
#[derive(Debug, Clone)]
pub struct CompletedAsyncTask {
pub task_id: TaskId,
pub worker_id: Option<WorkerId>,
pub task_type: String,
pub completed_at_tick: u64,
pub status: TaskStatus,
pub error: Option<String>,
}
pub struct WorkerStates {
states: Vec<WorkerState>,
}
impl WorkerStates {
pub fn new(count: usize) -> Self {
let states = (0..count).map(|i| WorkerState::new(AgentId(i))).collect();
Self { states }
}
pub fn get_mut(&mut self, id: AgentId) -> Option<&mut WorkerState> {
self.states.get_mut(id.0)
}
pub fn get(&self, id: AgentId) -> Option<&WorkerState> {
self.states.get(id.0)
}
pub fn len(&self) -> usize {
self.states.len()
}
pub fn is_empty(&self) -> bool {
self.states.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &WorkerState> {
self.states.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut WorkerState> {
self.states.iter_mut()
}
pub fn par_iter_mut(&mut self) -> impl ParallelIterator<Item = &mut WorkerState> {
self.states.par_iter_mut()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EscalationReason {
ConsecutiveFailures(u32),
ResourceExhausted,
Timeout,
AgentRequested(String),
Unknown(String),
}
#[derive(Debug, Clone)]
pub struct Escalation {
pub reason: EscalationReason,
pub raised_at_tick: u64,
pub context: Option<String>,
}
impl Escalation {
pub fn consecutive_failures(count: u32, tick: u64) -> Self {
Self {
reason: EscalationReason::ConsecutiveFailures(count),
raised_at_tick: tick,
context: None,
}
}
pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
self.context = Some(ctx.into());
self
}
}
pub struct WorkerState {
pub id: AgentId,
internal_state: Option<Box<dyn Any + Send + Sync>>,
pub history: ActionHistory,
pub cache: LocalCache,
pub pending_tasks: HashSet<TaskId>,
pub escalation: Option<Escalation>,
pub consecutive_failures: u32,
pub last_output: Option<String>,
}
impl WorkerState {
pub fn new(id: AgentId) -> Self {
Self {
id,
internal_state: None,
history: ActionHistory::default(),
cache: LocalCache::default(),
pending_tasks: HashSet::new(),
escalation: None,
consecutive_failures: 0,
last_output: None,
}
}
pub fn raise_escalation(&mut self, escalation: Escalation) {
self.escalation = Some(escalation);
}
pub fn clear_escalation(&mut self) {
self.escalation = None;
self.consecutive_failures = 0;
}
pub fn record_failure(&mut self, tick: u64, threshold: u32) -> bool {
self.consecutive_failures += 1;
if self.consecutive_failures >= threshold && self.escalation.is_none() {
self.raise_escalation(Escalation::consecutive_failures(
self.consecutive_failures,
tick,
));
true
} else {
false
}
}
pub fn record_success(&mut self) {
self.consecutive_failures = 0;
}
pub fn set_state<T: Any + Send + Sync + 'static>(&mut self, state: T) {
self.internal_state = Some(Box::new(state));
}
pub fn get_state<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
self.internal_state.as_ref()?.downcast_ref()
}
pub fn get_state_mut<T: Any + Send + Sync + 'static>(&mut self) -> Option<&mut T> {
self.internal_state.as_mut()?.downcast_mut()
}
pub fn add_pending_task(&mut self, task_id: TaskId) {
self.pending_tasks.insert(task_id);
}
pub fn complete_task(&mut self, task_id: TaskId) {
self.pending_tasks.remove(&task_id);
}
}
pub struct ActionHistory {
entries: VecDeque<HistoryEntry>,
max_entries: usize,
}
impl Default for ActionHistory {
fn default() -> Self {
Self::new(100) }
}
impl ActionHistory {
pub fn new(max_entries: usize) -> Self {
Self {
entries: VecDeque::with_capacity(max_entries),
max_entries,
}
}
pub fn push(&mut self, entry: HistoryEntry) {
if self.max_entries > 0 && self.entries.len() >= self.max_entries {
self.entries.pop_front(); }
self.entries.push_back(entry);
}
pub fn latest(&self) -> Option<&HistoryEntry> {
self.entries.back()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &HistoryEntry> {
self.entries.iter()
}
}
#[derive(Debug, Clone)]
pub struct HistoryEntry {
pub tick: u64,
pub action_name: String,
pub success: bool,
}
#[derive(Default)]
pub struct LocalCache {
data: HashMap<String, CacheEntry>,
}
impl LocalCache {
pub fn set(&mut self, key: impl Into<String>, value: Vec<u8>, ttl_ticks: u64) {
self.data.insert(
key.into(),
CacheEntry {
value,
expires_at_tick: ttl_ticks,
},
);
}
pub fn get(&self, key: &str, current_tick: u64) -> Option<&[u8]> {
let entry = self.data.get(key)?;
if entry.expires_at_tick > current_tick {
Some(&entry.value)
} else {
None
}
}
pub fn cleanup(&mut self, current_tick: u64) {
self.data.retain(|_, v| v.expires_at_tick > current_tick);
}
}
struct CacheEntry {
value: Vec<u8>,
expires_at_tick: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_swarm_state_creation() {
let state = SwarmState::new(3);
assert_eq!(state.workers.len(), 3);
assert_eq!(state.shared.tick, 0);
}
#[test]
fn test_swarm_state_advance_tick() {
let mut state = SwarmState::new(1);
assert_eq!(state.shared.tick, 0);
state.advance_tick();
assert_eq!(state.shared.tick, 1);
state.advance_tick();
assert_eq!(state.shared.tick, 2);
}
#[test]
fn test_worker_states_access() {
let mut states = WorkerStates::new(3);
assert_eq!(states.len(), 3);
assert!(!states.is_empty());
let ws = states.get_mut(AgentId(1)).unwrap();
assert_eq!(ws.id.0, 1);
assert!(states.get(AgentId(10)).is_none());
}
#[test]
fn test_worker_state_internal() {
let mut ws = WorkerState::new(AgentId(0));
assert!(ws.get_state::<i32>().is_none());
ws.set_state(42i32);
assert_eq!(ws.get_state::<i32>(), Some(&42));
if let Some(state) = ws.get_state_mut::<i32>() {
*state = 100;
}
assert_eq!(ws.get_state::<i32>(), Some(&100));
assert!(ws.get_state::<String>().is_none());
}
#[test]
fn test_worker_state_pending_tasks() {
let mut ws = WorkerState::new(AgentId(0));
assert!(ws.pending_tasks.is_empty());
ws.add_pending_task(TaskId(1));
ws.add_pending_task(TaskId(2));
assert_eq!(ws.pending_tasks.len(), 2);
assert!(ws.pending_tasks.contains(&TaskId(1)));
assert!(ws.pending_tasks.contains(&TaskId(2)));
ws.complete_task(TaskId(1));
assert_eq!(ws.pending_tasks.len(), 1);
assert!(!ws.pending_tasks.contains(&TaskId(1)));
assert!(ws.pending_tasks.contains(&TaskId(2)));
}
#[test]
fn test_action_history() {
let mut history = ActionHistory::new(3);
history.push(HistoryEntry {
tick: 0,
action_name: "action1".to_string(),
success: true,
});
history.push(HistoryEntry {
tick: 1,
action_name: "action2".to_string(),
success: false,
});
assert_eq!(history.len(), 2);
assert_eq!(history.latest().unwrap().action_name, "action2");
history.push(HistoryEntry {
tick: 2,
action_name: "action3".to_string(),
success: true,
});
history.push(HistoryEntry {
tick: 3,
action_name: "action4".to_string(),
success: true,
});
assert_eq!(history.len(), 3);
let entries: Vec<_> = history.iter().collect();
assert_eq!(entries[0].action_name, "action2");
}
#[test]
fn test_local_cache() {
let mut cache = LocalCache::default();
cache.set("key1", vec![1, 2, 3], 10);
cache.set("key2", vec![4, 5, 6], 5);
assert_eq!(cache.get("key1", 0), Some([1u8, 2, 3].as_slice()));
assert_eq!(cache.get("key2", 4), Some([4u8, 5, 6].as_slice()));
assert!(cache.get("key2", 5).is_none());
assert!(cache.get("key2", 10).is_none());
assert_eq!(cache.get("key1", 9), Some([1u8, 2, 3].as_slice()));
cache.cleanup(6);
assert!(cache.get("key1", 0).is_some()); cache.cleanup(11);
assert!(cache.get("key1", 0).is_none()); }
#[test]
fn test_environment() {
let mut env = Environment::default();
env.variables
.insert("PATH".to_string(), "/usr/bin".to_string());
env.flags.insert("debug".to_string(), true);
assert_eq!(env.variables.get("PATH"), Some(&"/usr/bin".to_string()));
assert_eq!(env.flags.get("debug"), Some(&true));
}
}