use crate::error::{SdkError, SdkResult};
use crate::lifecycle::snapshot::SnapshotStore;
use crate::lifecycle::{AgentLifecycleEvent, AgentSnapshot, AgentStatus, MetricsSnapshot};
use crate::routing::RoutingControl;
use oxi_agent::{AgentConfig, AgentTool, ProviderResolver, ToolRegistry};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use tokio::sync::broadcast;
const STATUS_CREATED: u8 = 0;
const STATUS_RUNNING: u8 = 1;
const STATUS_SUSPENDED: u8 = 2;
const STATUS_TERMINATED: u8 = 3;
const STATUS_FAILED: u8 = 4;
fn u8_to_status(v: u8) -> AgentStatus {
match v {
STATUS_CREATED => AgentStatus::Created,
STATUS_RUNNING => AgentStatus::Running,
STATUS_SUSPENDED => AgentStatus::Suspended,
STATUS_TERMINATED => AgentStatus::Terminated,
_ => AgentStatus::Failed,
}
}
#[derive(Debug, Clone)]
pub struct SupervisorPolicy {
pub max_restarts: usize,
pub restart_window_secs: u64,
pub backoff: RestartBackoff,
}
impl Default for SupervisorPolicy {
fn default() -> Self {
Self {
max_restarts: 3,
restart_window_secs: 60,
backoff: RestartBackoff::Exponential {
base_ms: 1000,
max_ms: 30_000,
},
}
}
}
impl SupervisorPolicy {
pub fn no_restart() -> Self {
Self {
max_restarts: 0,
restart_window_secs: 0,
backoff: RestartBackoff::None,
}
}
}
#[derive(Debug, Clone)]
pub enum RestartBackoff {
None,
Fixed { delay_ms: u64 },
Exponential { base_ms: u64, max_ms: u64 },
}
#[derive(Clone)]
pub struct AgentHandle {
agent_id: String,
status: Arc<AtomicU8>,
agent: Arc<oxi_agent::Agent>,
config: Arc<RwLock<AgentConfig>>,
metrics: Arc<crate::metrics::AgentMetrics>,
lifecycle_tx: broadcast::Sender<AgentLifecycleEvent>,
created_at_ms: u64,
parent_id: Option<String>,
routing: RoutingControl,
}
impl AgentHandle {
pub(crate) fn new(
agent: oxi_agent::Agent,
config: AgentConfig,
parent_id: Option<String>,
lifecycle_tx: broadcast::Sender<AgentLifecycleEvent>,
) -> Self {
let routing = RoutingControl::new(crate::routing::RoutingConfig::default());
Self {
agent_id: if config.name.is_empty() {
uuid::Uuid::new_v4().to_string()
} else {
config.name.clone()
},
status: Arc::new(AtomicU8::new(STATUS_CREATED)),
agent: Arc::new(agent),
config: Arc::new(RwLock::new(config)),
metrics: Arc::new(crate::metrics::AgentMetrics::new()),
lifecycle_tx,
created_at_ms: AgentLifecycleEvent::now_ms(),
parent_id,
routing,
}
}
pub fn agent_id(&self) -> &str {
&self.agent_id
}
pub fn parent_id(&self) -> Option<&str> {
self.parent_id.as_deref()
}
pub fn created_at_ms(&self) -> u64 {
self.created_at_ms
}
pub fn metrics(&self) -> MetricsSnapshot {
self.metrics.snapshot()
}
pub fn is_running(&self) -> bool {
self.status() == AgentStatus::Running
}
pub fn status(&self) -> AgentStatus {
u8_to_status(self.status.load(Ordering::SeqCst))
}
pub async fn run(
&self,
prompt: String,
) -> SdkResult<(oxi_agent::types::Response, Vec<oxi_agent::AgentEvent>)> {
let prev = self
.status
.compare_exchange(
STATUS_CREATED,
STATUS_RUNNING,
Ordering::SeqCst,
Ordering::SeqCst,
)
.or_else(|_| {
self.status.compare_exchange(
STATUS_SUSPENDED,
STATUS_RUNNING,
Ordering::SeqCst,
Ordering::SeqCst,
)
});
if prev.is_err() {
return Err(SdkError::AgentNotRunnable {
agent_id: self.agent_id.clone(),
status: self.status().to_string(),
});
}
self.emit(AgentLifecycleEvent::RunStart {
agent_id: self.agent_id.clone(),
timestamp_ms: AgentLifecycleEvent::now_ms(),
});
let start = std::time::Instant::now();
let result = self.agent.run(prompt).await;
let elapsed = start.elapsed();
match result {
Ok((response, events)) => {
let agent_state = self.agent.state();
let input_tokens = agent_state.input_tokens as u64;
let output_tokens = agent_state.output_tokens as u64;
let tool_count = events
.iter()
.filter(|e| matches!(e, oxi_agent::AgentEvent::ToolExecutionStart { .. }))
.count() as u64;
self.metrics.record_success(
elapsed.as_millis() as u64,
input_tokens,
output_tokens,
tool_count,
);
self.transition(STATUS_CREATED);
self.emit(AgentLifecycleEvent::RunEnd {
agent_id: self.agent_id.clone(),
timestamp_ms: AgentLifecycleEvent::now_ms(),
success: true,
});
Ok((response, events))
}
Err(e) => {
self.transition(STATUS_FAILED);
self.emit(AgentLifecycleEvent::RunEnd {
agent_id: self.agent_id.clone(),
timestamp_ms: AgentLifecycleEvent::now_ms(),
success: false,
});
Err(SdkError::ExecutionFailed {
reason: e.to_string(),
})
}
}
}
pub async fn continue_with(
&self,
prompt: String,
) -> SdkResult<(oxi_agent::types::Response, Vec<oxi_agent::AgentEvent>)> {
self.run(prompt).await
}
pub fn cancel(&self) {
self.agent.cancel();
}
pub async fn suspend(&self) -> SdkResult<AgentSnapshot> {
let cur = self.status();
if !cur.is_runnable() && cur != AgentStatus::Running {
return Err(SdkError::AgentNotRunnable {
agent_id: self.agent_id.clone(),
status: cur.to_string(),
});
}
if cur == AgentStatus::Running {
self.cancel();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
let snapshot = AgentSnapshot::from_agent(
self.agent_id.clone(),
&self.config.read(),
&self.agent.state(),
&self.agent.tools(),
self.parent_id.clone(),
HashMap::new(),
);
self.transition(STATUS_SUSPENDED);
self.emit(AgentLifecycleEvent::Suspended {
agent_id: self.agent_id.clone(),
snapshot: Box::new(snapshot.clone()),
timestamp_ms: AgentLifecycleEvent::now_ms(),
});
Ok(snapshot)
}
pub fn terminate(&self) -> SdkResult<()> {
if self.status().is_terminal() {
return Err(SdkError::AgentNotRunnable {
agent_id: self.agent_id.clone(),
status: self.status().to_string(),
});
}
self.transition(STATUS_TERMINATED);
self.emit(AgentLifecycleEvent::Terminated {
agent_id: self.agent_id.clone(),
timestamp_ms: AgentLifecycleEvent::now_ms(),
});
Ok(())
}
pub fn snapshot(&self) -> SdkResult<AgentSnapshot> {
Ok(AgentSnapshot::from_agent(
self.agent_id.clone(),
&self.config.read(),
&self.agent.state(),
&self.agent.tools(),
self.parent_id.clone(),
HashMap::new(),
))
}
pub fn switch_model(&self, model_id: &str, api_key: Option<String>) -> anyhow::Result<()> {
let old = self.config.read().model_id.clone();
self.agent.switch_model(model_id, api_key)?;
self.config.write().model_id = model_id.to_string();
self.emit(AgentLifecycleEvent::ModelSwitched {
agent_id: self.agent_id.clone(),
from_model: old,
to_model: model_id.to_string(),
timestamp_ms: AgentLifecycleEvent::now_ms(),
});
Ok(())
}
pub fn set_system_prompt(&self, prompt: String) {
self.config.write().system_prompt = Some(prompt.clone());
self.agent.set_system_prompt(prompt);
}
pub fn add_tool(&self, tool: impl AgentTool + 'static) {
self.agent.add_tool(tool);
}
pub fn routing(&self) -> &RoutingControl {
&self.routing
}
pub fn disable_routing(&self) {
self.routing.set_enabled(false);
}
pub fn enable_routing(&self) {
self.routing.set_enabled(true);
}
pub fn exclude_route_model(&self, model_id: &str) {
self.routing.exclude_model(model_id);
}
fn transition(&self, new_status: u8) {
self.status.store(new_status, Ordering::SeqCst);
}
fn emit(&self, event: AgentLifecycleEvent) {
let _ = self.lifecycle_tx.send(event);
}
}
impl std::fmt::Debug for AgentHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AgentHandle")
.field("agent_id", &self.agent_id)
.field("status", &self.status())
.finish()
}
}
#[derive(Clone)]
pub struct AgentSupervisor {
agents: Arc<RwLock<HashMap<String, AgentHandle>>>,
lifecycle_tx: broadcast::Sender<AgentLifecycleEvent>,
snapshot_store: Arc<dyn SnapshotStore>,
policy: SupervisorPolicy,
restart_log: Arc<RwLock<HashMap<String, Vec<u64>>>>,
resolver: Arc<dyn ProviderResolver>,
}
impl AgentSupervisor {
pub fn new(
resolver: Arc<dyn ProviderResolver>,
snapshot_store: Arc<dyn SnapshotStore>,
) -> Self {
Self::with_policy(resolver, snapshot_store, SupervisorPolicy::default())
}
pub fn with_policy(
resolver: Arc<dyn ProviderResolver>,
snapshot_store: Arc<dyn SnapshotStore>,
policy: SupervisorPolicy,
) -> Self {
let (tx, _) = broadcast::channel(1024);
Self {
agents: Arc::new(RwLock::new(HashMap::new())),
lifecycle_tx: tx,
snapshot_store,
policy,
restart_log: Arc::new(RwLock::new(HashMap::new())),
resolver,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<AgentLifecycleEvent> {
self.lifecycle_tx.subscribe()
}
pub fn spawn(&self, config: AgentConfig) -> anyhow::Result<AgentHandle> {
let model = self
.resolver
.resolve_model(&config.model_id)
.ok_or_else(|| SdkError::ModelNotFound {
model_id: config.model_id.clone(),
})?;
let provider = self
.resolver
.resolve_provider(&model.provider)
.ok_or_else(|| SdkError::ProviderNotFound {
provider: model.provider.clone(),
})?;
let tools = Arc::new(ToolRegistry::new());
let agent = oxi_agent::Agent::new(provider, config.clone(), tools);
let handle = AgentHandle::new(agent, config.clone(), None, self.lifecycle_tx.clone());
self.agents
.write()
.insert(handle.agent_id().to_string(), handle.clone());
self.emit(AgentLifecycleEvent::Spawned {
agent_id: handle.agent_id().to_string(),
parent_id: None,
model_id: config.model_id.clone(),
timestamp_ms: AgentLifecycleEvent::now_ms(),
});
Ok(handle)
}
pub fn spawn_child(&self, parent_id: &str, config: AgentConfig) -> anyhow::Result<AgentHandle> {
let model = self
.resolver
.resolve_model(&config.model_id)
.ok_or_else(|| SdkError::ModelNotFound {
model_id: config.model_id.clone(),
})?;
let provider = self
.resolver
.resolve_provider(&model.provider)
.ok_or_else(|| SdkError::ProviderNotFound {
provider: model.provider.clone(),
})?;
let tools = Arc::new(ToolRegistry::new());
let agent = oxi_agent::Agent::new(provider, config.clone(), tools);
let handle = AgentHandle::new(
agent,
config.clone(),
Some(parent_id.to_string()),
self.lifecycle_tx.clone(),
);
self.agents
.write()
.insert(handle.agent_id().to_string(), handle.clone());
self.emit(AgentLifecycleEvent::Spawned {
agent_id: handle.agent_id().to_string(),
parent_id: Some(parent_id.to_string()),
model_id: config.model_id.clone(),
timestamp_ms: AgentLifecycleEvent::now_ms(),
});
Ok(handle)
}
pub fn get(&self, agent_id: &str) -> Option<AgentHandle> {
self.agents.read().get(agent_id).cloned()
}
pub fn list(&self) -> Vec<(String, AgentStatus)> {
self.agents
.read()
.iter()
.map(|(id, h)| (id.clone(), h.status()))
.collect()
}
pub fn count_by_status(&self) -> HashMap<AgentStatus, usize> {
let mut counts = HashMap::new();
for handle in self.agents.read().values() {
*counts.entry(handle.status()).or_insert(0) += 1;
}
counts
}
pub async fn suspend(&self, agent_id: &str) -> anyhow::Result<AgentSnapshot> {
let handle = self
.get(agent_id)
.ok_or_else(|| SdkError::SnapshotNotFound {
agent_id: agent_id.to_string(),
})?;
let snapshot = handle.suspend().await?;
self.snapshot_store.save(&snapshot).await?;
Ok(snapshot)
}
pub async fn restore(&self, agent_id: &str) -> anyhow::Result<AgentHandle> {
if let Some(handle) = self.get(agent_id) {
return Ok(handle);
}
let snapshot = self.snapshot_store.load(agent_id).await?.ok_or_else(|| {
SdkError::SnapshotNotFound {
agent_id: agent_id.to_string(),
}
})?;
self.restore_from_snapshot(snapshot).await
}
pub async fn restore_from_snapshot(
&self,
snapshot: AgentSnapshot,
) -> anyhow::Result<AgentHandle> {
let model = self
.resolver
.resolve_model(&snapshot.config.model_id)
.ok_or_else(|| SdkError::ModelNotFound {
model_id: snapshot.config.model_id.clone(),
})?;
let provider = self
.resolver
.resolve_provider(&model.provider)
.ok_or_else(|| SdkError::ProviderNotFound {
provider: model.provider.clone(),
})?;
let tools = Arc::new(ToolRegistry::new());
let agent = oxi_agent::Agent::new(provider, snapshot.config.clone(), tools);
let state_json = serde_json::to_value(&snapshot.state)?;
agent.import_state(state_json)?;
let handle = AgentHandle::new(
agent,
snapshot.config.clone(),
snapshot.parent_id.clone(),
self.lifecycle_tx.clone(),
);
self.agents
.write()
.insert(handle.agent_id().to_string(), handle.clone());
self.emit(AgentLifecycleEvent::Resumed {
agent_id: handle.agent_id().to_string(),
from_snapshot_id: Some(snapshot.agent_id.clone()),
timestamp_ms: AgentLifecycleEvent::now_ms(),
});
Ok(handle)
}
pub fn terminate(&self, agent_id: &str) -> anyhow::Result<()> {
let handle = self
.get(agent_id)
.ok_or_else(|| SdkError::SnapshotNotFound {
agent_id: agent_id.to_string(),
})?;
if handle.is_running() {
return Err(SdkError::AgentNotRunnable {
agent_id: agent_id.to_string(),
status: "running".to_string(),
}
.into());
}
handle.terminate()?;
self.agents.write().remove(agent_id);
self.restart_log.write().remove(agent_id);
Ok(())
}
pub fn can_restart(&self, agent_id: &str) -> bool {
if self.policy.max_restarts == 0 {
return false;
}
let now = AgentLifecycleEvent::now_ms();
let window_ms = self.policy.restart_window_secs * 1000;
let log = self.restart_log.read();
let restarts = log
.get(agent_id)
.map(|ts| {
ts.iter()
.filter(|&&t| now.saturating_sub(t) <= window_ms)
.count()
})
.unwrap_or(0);
restarts < self.policy.max_restarts
}
pub async fn restart(&self, agent_id: &str) -> anyhow::Result<AgentHandle> {
if !self.can_restart(agent_id) {
return Err(anyhow::anyhow!(
"Agent '{}' exceeded max restarts ({})",
agent_id,
self.policy.max_restarts
));
}
let old = self.agents.read().get(agent_id).cloned();
let config = match &old {
Some(h) => h.config.read().clone(),
None => {
return Err(SdkError::SnapshotNotFound {
agent_id: agent_id.to_string(),
}
.into())
}
};
if let Some(delay) = self.compute_backoff(agent_id) {
tokio::time::sleep(delay).await;
}
self.restart_log
.write()
.entry(agent_id.to_string())
.or_default()
.push(AgentLifecycleEvent::now_ms());
self.agents.write().remove(agent_id);
self.spawn(config)
}
fn compute_backoff(&self, agent_id: &str) -> Option<std::time::Duration> {
let count = self
.restart_log
.read()
.get(agent_id)
.map(|ts| ts.len())
.unwrap_or(0);
if count == 0 {
return None;
}
match &self.policy.backoff {
RestartBackoff::None => None,
RestartBackoff::Fixed { delay_ms } => Some(std::time::Duration::from_millis(*delay_ms)),
RestartBackoff::Exponential { base_ms, max_ms } => {
let delay = (*base_ms).saturating_mul(2u64.saturating_pow(count as u32));
Some(std::time::Duration::from_millis(delay.min(*max_ms)))
}
}
}
fn emit(&self, event: AgentLifecycleEvent) {
let _ = self.lifecycle_tx.send(event);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
fn mock_resolver() -> Arc<dyn ProviderResolver> {
struct MockProvider;
#[async_trait::async_trait]
impl oxi_ai::Provider for MockProvider {
fn name(&self) -> &str {
"mock"
}
async fn stream(
&self,
_model: &oxi_ai::Model,
_context: &oxi_ai::Context,
_options: Option<oxi_ai::StreamOptions>,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = oxi_ai::ProviderEvent> + Send>>,
oxi_ai::ProviderError,
> {
Err(oxi_ai::ProviderError::NotImplemented("mock".into()))
}
}
struct Mock;
impl ProviderResolver for Mock {
fn resolve_provider(&self, _name: &str) -> Option<Arc<dyn oxi_ai::Provider>> {
Some(Arc::new(MockProvider))
}
fn resolve_model(&self, _model_id: &str) -> Option<oxi_ai::Model> {
Some(oxi_ai::Model::new(
"anthropic/claude-sonnet-4-20250514",
"Claude",
oxi_ai::Api::AnthropicMessages,
"anthropic",
"https://api.anthropic.com",
))
}
}
Arc::new(Mock)
}
struct NoopStore;
#[async_trait::async_trait]
impl SnapshotStore for NoopStore {
async fn save(&self, _snapshot: &AgentSnapshot) -> anyhow::Result<()> {
Ok(())
}
async fn load(&self, _agent_id: &str) -> anyhow::Result<Option<AgentSnapshot>> {
Ok(None)
}
async fn list(&self) -> anyhow::Result<Vec<String>> {
Ok(vec![])
}
async fn delete(&self, _agent_id: &str) -> anyhow::Result<()> {
Ok(())
}
}
fn make_supervisor() -> AgentSupervisor {
AgentSupervisor::new(
mock_resolver(),
Arc::new(NoopStore) as Arc<dyn SnapshotStore>,
)
}
fn test_config() -> AgentConfig {
AgentConfig {
model_id: "anthropic/claude-sonnet-4-20250514".into(),
name: uuid::Uuid::new_v4().to_string(),
max_iterations: 10,
..Default::default()
}
}
#[test]
fn supervisor_policy_default() {
let policy = SupervisorPolicy::default();
assert_eq!(policy.max_restarts, 3);
assert!(matches!(policy.backoff, RestartBackoff::Exponential { .. }));
}
#[test]
fn supervisor_policy_no_restart() {
let policy = SupervisorPolicy::no_restart();
assert_eq!(policy.max_restarts, 0);
assert!(matches!(policy.backoff, RestartBackoff::None));
}
#[test]
fn supervisor_spawn_and_get() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
assert!(!handle.agent_id().is_empty());
assert_eq!(handle.status(), AgentStatus::Created);
assert_eq!(handle.parent_id(), None);
}
#[test]
fn supervisor_spawn_child() {
let supervisor = make_supervisor();
let parent = supervisor.spawn(test_config()).unwrap();
let child = supervisor
.spawn_child(parent.agent_id(), test_config())
.unwrap();
assert_eq!(child.parent_id(), Some(parent.agent_id()));
}
#[test]
fn supervisor_terminate() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
let id = handle.agent_id().to_string();
supervisor.terminate(&id).unwrap();
assert!(supervisor.get(&id).is_none());
}
#[test]
fn supervisor_list_and_count() {
let supervisor = make_supervisor();
supervisor.spawn(test_config()).unwrap();
supervisor.spawn(test_config()).unwrap();
let list = supervisor.list();
assert_eq!(list.len(), 2);
let counts = supervisor.count_by_status();
assert_eq!(counts.get(&AgentStatus::Created), Some(&2));
}
#[test]
fn handle_status_transitions() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
handle.terminate().unwrap();
assert_eq!(handle.status(), AgentStatus::Terminated);
assert!(handle.status().is_terminal());
assert!(handle.terminate().is_err());
}
#[test]
fn handle_switch_model() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
let result = handle.switch_model("openai/gpt-4o", None);
let _ = result;
}
#[test]
fn handle_set_system_prompt() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
handle.set_system_prompt("You are a test agent.".into());
assert_eq!(
handle.config.read().system_prompt,
Some("You are a test agent.".into())
);
}
#[test]
fn handle_snapshot() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
let snap = handle.snapshot().unwrap();
assert_eq!(snap.agent_id, handle.agent_id());
}
#[test]
fn lifecycle_events_received() {
let supervisor = make_supervisor();
let mut rx = supervisor.subscribe();
supervisor.spawn(test_config()).unwrap();
let event = rx.try_recv().expect("should receive Spawned event");
match event {
AgentLifecycleEvent::Spawned { agent_id, .. } => {
assert!(!agent_id.is_empty());
}
_ => panic!("Expected Spawned event"),
}
}
#[test]
fn handle_has_routing_control() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
assert!(handle.routing().is_enabled());
}
#[test]
fn handle_routing_toggle() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
handle.disable_routing();
assert!(!handle.routing().is_enabled());
handle.enable_routing();
assert!(handle.routing().is_enabled());
}
#[test]
fn handle_routing_exclude_model() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
handle.exclude_route_model("openai/gpt-4o");
assert!(handle
.routing()
.excluded_models()
.contains(&"openai/gpt-4o".to_string()));
}
#[test]
fn handle_routing_fallback_models() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
handle
.routing()
.set_fallback_models(vec!["anthropic/claude-sonnet-4-20250514".into()]);
assert_eq!(handle.routing().fallback_models().len(), 1);
}
#[test]
fn supervisor_can_restart_default_policy() {
let supervisor = make_supervisor();
let handle = supervisor.spawn(test_config()).unwrap();
let id = handle.agent_id().to_string();
assert!(supervisor.can_restart(&id));
}
#[test]
fn supervisor_cannot_restart_no_restart_policy() {
let supervisor = AgentSupervisor::with_policy(
mock_resolver(),
Arc::new(NoopStore) as Arc<dyn SnapshotStore>,
SupervisorPolicy::no_restart(),
);
let handle = supervisor.spawn(test_config()).unwrap();
let id = handle.agent_id().to_string();
assert!(!supervisor.can_restart(&id));
}
#[tokio::test]
async fn supervisor_restart_with_no_restart_policy_fails() {
let supervisor = AgentSupervisor::with_policy(
mock_resolver(),
Arc::new(NoopStore) as Arc<dyn SnapshotStore>,
SupervisorPolicy::no_restart(),
);
let handle = supervisor.spawn(test_config()).unwrap();
let id = handle.agent_id().to_string();
let result = supervisor.restart(&id).await;
assert!(result.is_err());
}
#[tokio::test]
async fn supervisor_restart_spawns_new_agent() {
let policy = SupervisorPolicy {
max_restarts: 3,
restart_window_secs: 60,
backoff: RestartBackoff::Fixed { delay_ms: 0 },
};
let supervisor = AgentSupervisor::with_policy(
mock_resolver(),
Arc::new(NoopStore) as Arc<dyn SnapshotStore>,
policy,
);
let handle = supervisor.spawn(test_config()).unwrap();
let old_id = handle.agent_id().to_string();
let new_handle = supervisor.restart(&old_id).await.unwrap();
assert!(supervisor.get(new_handle.agent_id()).is_some());
assert_eq!(new_handle.status(), AgentStatus::Created);
let log = supervisor.restart_log.read();
assert!(log.values().any(|ts| !ts.is_empty()));
}
#[tokio::test]
async fn supervisor_restart_respects_max_restarts() {
let policy = SupervisorPolicy {
max_restarts: 1,
restart_window_secs: 60,
backoff: RestartBackoff::None,
};
let supervisor = AgentSupervisor::with_policy(
mock_resolver(),
Arc::new(NoopStore) as Arc<dyn SnapshotStore>,
policy,
);
let handle = supervisor.spawn(test_config()).unwrap();
let id = handle.agent_id().to_string();
let first = supervisor.restart(&id).await.unwrap();
let first_id = first.agent_id().to_string();
assert!(!supervisor.can_restart(&first_id));
let result = supervisor.restart(&first_id).await;
assert!(result.is_err());
}
}