1use std::collections::HashMap;
2use std::sync::Arc;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::sync::{Mutex, Semaphore};
7use tracing::{debug, warn};
8
9use roboticus_core::{Result, RoboticusError};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum AgentRunState {
13 Idle,
14 Starting,
15 Running,
16 Stopped,
17 Error,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct AgentInstance {
22 pub id: String,
23 pub name: String,
24 pub model: String,
25 pub state: AgentRunState,
26 pub session_count: usize,
27 pub started_at: Option<DateTime<Utc>>,
28 pub last_error: Option<String>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct AgentInstanceConfig {
33 pub id: String,
34 pub name: String,
35 pub model: String,
36 #[serde(default)]
37 pub skills: Vec<String>,
38 #[serde(default)]
39 pub allowed_subagents: Vec<String>,
40 #[serde(default = "default_max_concurrent")]
41 pub max_concurrent: usize,
42}
43
44fn default_max_concurrent() -> usize {
45 4
46}
47
48pub struct SubagentRegistry {
49 agents: Mutex<HashMap<String, AgentInstance>>,
50 concurrency: Arc<Semaphore>,
51 max_concurrent: usize,
52 allowed_ids: Vec<String>,
53}
54
55impl SubagentRegistry {
56 pub fn new(max_concurrent: usize, allowed_ids: Vec<String>) -> Self {
57 Self {
58 agents: Mutex::new(HashMap::new()),
59 concurrency: Arc::new(Semaphore::new(max_concurrent)),
60 max_concurrent,
61 allowed_ids,
62 }
63 }
64
65 pub fn is_allowed(&self, agent_id: &str) -> bool {
66 self.allowed_ids.is_empty() || self.allowed_ids.iter().any(|id| id == agent_id)
67 }
68
69 pub async fn register(&self, config: AgentInstanceConfig) -> Result<()> {
70 if !self.is_allowed(&config.id) {
71 return Err(RoboticusError::Config(format!(
72 "agent '{}' is not in the allowed list",
73 config.id
74 )));
75 }
76
77 let instance = AgentInstance {
78 id: config.id.clone(),
79 name: config.name,
80 model: config.model,
81 state: AgentRunState::Idle,
82 session_count: 0,
83 started_at: None,
84 last_error: None,
85 };
86
87 debug!(id = %config.id, "registered agent");
88 let mut agents = self.agents.lock().await;
89 agents.insert(config.id, instance);
90 Ok(())
91 }
92
93 pub async fn start_agent(&self, agent_id: &str) -> Result<()> {
94 let mut agents = self.agents.lock().await;
95 let agent = agents
96 .get_mut(agent_id)
97 .ok_or_else(|| RoboticusError::Config(format!("agent '{agent_id}' not found")))?;
98
99 if matches!(
100 agent.state,
101 AgentRunState::Running | AgentRunState::Starting
102 ) {
103 return Ok(());
104 }
105
106 agent.state = AgentRunState::Running;
107 agent.started_at = Some(Utc::now());
108 agent.last_error = None;
109 debug!(id = agent_id, "agent started");
110 Ok(())
111 }
112
113 pub async fn stop_agent(&self, agent_id: &str) -> Result<()> {
114 let mut agents = self.agents.lock().await;
115 let agent = agents
116 .get_mut(agent_id)
117 .ok_or_else(|| RoboticusError::Config(format!("agent '{agent_id}' not found")))?;
118
119 agent.state = AgentRunState::Stopped;
120 debug!(id = agent_id, "agent stopped");
121 Ok(())
122 }
123
124 pub async fn unregister(&self, agent_id: &str) -> bool {
125 let mut agents = self.agents.lock().await;
126 let removed = agents.remove(agent_id).is_some();
127 if removed {
128 debug!(id = agent_id, "agent unregistered");
129 }
130 removed
131 }
132
133 pub async fn mark_error(&self, agent_id: &str, error: String) {
134 let mut agents = self.agents.lock().await;
135 if let Some(agent) = agents.get_mut(agent_id) {
136 agent.state = AgentRunState::Error;
137 agent.last_error = Some(error);
138 warn!(id = agent_id, "agent errored");
139 }
140 }
141
142 pub async fn get_agent(&self, agent_id: &str) -> Option<AgentInstance> {
143 let agents = self.agents.lock().await;
144 agents.get(agent_id).cloned()
145 }
146
147 pub async fn list_agents(&self) -> Vec<AgentInstance> {
148 let agents = self.agents.lock().await;
149 agents.values().cloned().collect()
150 }
151
152 pub async fn running_count(&self) -> usize {
153 let agents = self.agents.lock().await;
154 agents
155 .values()
156 .filter(|a| a.state == AgentRunState::Running)
157 .count()
158 }
159
160 pub async fn agent_count(&self) -> usize {
161 let agents = self.agents.lock().await;
162 agents.len()
163 }
164
165 pub async fn acquire_slot(&self) -> Result<tokio::sync::OwnedSemaphorePermit> {
166 Arc::clone(&self.concurrency)
167 .acquire_owned()
168 .await
169 .map_err(|_| RoboticusError::Config("concurrency semaphore closed".into()))
170 }
171
172 pub fn max_concurrent(&self) -> usize {
173 self.max_concurrent
174 }
175
176 pub fn available_slots(&self) -> usize {
177 self.concurrency.available_permits()
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 fn test_config(id: &str) -> AgentInstanceConfig {
186 AgentInstanceConfig {
187 id: id.into(),
188 name: format!("Agent {id}"),
189 model: "test-model".into(),
190 skills: vec![],
191 allowed_subagents: vec![],
192 max_concurrent: 4,
193 }
194 }
195
196 #[test]
197 fn allowed_empty_means_all() {
198 let reg = SubagentRegistry::new(4, vec![]);
199 assert!(reg.is_allowed("anything"));
200 }
201
202 #[test]
203 fn allowed_filters() {
204 let reg = SubagentRegistry::new(4, vec!["a".into(), "b".into()]);
205 assert!(reg.is_allowed("a"));
206 assert!(!reg.is_allowed("c"));
207 }
208
209 #[tokio::test]
210 async fn register_and_list() {
211 let reg = SubagentRegistry::new(4, vec![]);
212 reg.register(test_config("agent-1")).await.unwrap();
213 assert_eq!(reg.agent_count().await, 1);
214 let agents = reg.list_agents().await;
215 assert_eq!(agents[0].id, "agent-1");
216 assert_eq!(agents[0].state, AgentRunState::Idle);
217 }
218
219 #[tokio::test]
220 async fn register_disallowed_fails() {
221 let reg = SubagentRegistry::new(4, vec!["allowed".into()]);
222 let result = reg.register(test_config("not-allowed")).await;
223 assert!(result.is_err());
224 }
225
226 #[tokio::test]
227 async fn start_and_stop() {
228 let reg = SubagentRegistry::new(4, vec![]);
229 reg.register(test_config("a")).await.unwrap();
230
231 reg.start_agent("a").await.unwrap();
232 let agent = reg.get_agent("a").await.unwrap();
233 assert_eq!(agent.state, AgentRunState::Running);
234 assert!(agent.started_at.is_some());
235
236 reg.stop_agent("a").await.unwrap();
237 let agent = reg.get_agent("a").await.unwrap();
238 assert_eq!(agent.state, AgentRunState::Stopped);
239 }
240
241 #[tokio::test]
242 async fn start_nonexistent_fails() {
243 let reg = SubagentRegistry::new(4, vec![]);
244 let result = reg.start_agent("nope").await;
245 assert!(result.is_err());
246 }
247
248 #[tokio::test]
249 async fn start_already_running_is_idempotent() {
250 let reg = SubagentRegistry::new(4, vec![]);
251 reg.register(test_config("dup")).await.unwrap();
252 reg.start_agent("dup").await.unwrap();
253 reg.start_agent("dup").await.unwrap();
255 let agent = reg.get_agent("dup").await.unwrap();
256 assert_eq!(agent.state, AgentRunState::Running);
257 }
258
259 #[tokio::test]
260 async fn concurrent_starts_all_reach_running() {
261 let reg = std::sync::Arc::new(SubagentRegistry::new(4, vec![]));
262 for i in 0..5 {
263 reg.register(test_config(&format!("agent-{i}")))
264 .await
265 .unwrap();
266 }
267 let mut handles = vec![];
269 for i in 0..5 {
270 let reg = reg.clone();
271 handles.push(tokio::spawn(async move {
272 reg.start_agent(&format!("agent-{i}")).await.unwrap();
273 }));
274 }
275 for h in handles {
276 h.await.unwrap();
277 }
278 for agent in reg.list_agents().await {
280 assert_eq!(
281 agent.state,
282 AgentRunState::Running,
283 "agent {} stuck in {:?}",
284 agent.id,
285 agent.state
286 );
287 }
288 }
289
290 #[tokio::test]
291 async fn mark_error() {
292 let reg = SubagentRegistry::new(4, vec![]);
293 reg.register(test_config("e")).await.unwrap();
294 reg.start_agent("e").await.unwrap();
295 reg.mark_error("e", "something broke".into()).await;
296 let agent = reg.get_agent("e").await.unwrap();
297 assert_eq!(agent.state, AgentRunState::Error);
298 assert_eq!(agent.last_error.as_deref(), Some("something broke"));
299 }
300
301 #[tokio::test]
302 async fn running_count() {
303 let reg = SubagentRegistry::new(4, vec![]);
304 reg.register(test_config("a")).await.unwrap();
305 reg.register(test_config("b")).await.unwrap();
306 reg.start_agent("a").await.unwrap();
307 assert_eq!(reg.running_count().await, 1);
308 }
309
310 #[tokio::test]
311 async fn unregister_removes_agent() {
312 let reg = SubagentRegistry::new(4, vec![]);
313 reg.register(test_config("a")).await.unwrap();
314 assert_eq!(reg.agent_count().await, 1);
315 assert!(reg.unregister("a").await);
316 assert_eq!(reg.agent_count().await, 0);
317 assert!(!reg.unregister("a").await);
318 }
319
320 #[tokio::test]
321 async fn concurrency_slots() {
322 let reg = SubagentRegistry::new(2, vec![]);
323 assert_eq!(reg.available_slots(), 2);
324 assert_eq!(reg.max_concurrent(), 2);
325 let _permit = reg.acquire_slot().await.unwrap();
326 assert_eq!(reg.available_slots(), 1);
327 }
328
329 #[test]
330 fn agent_instance_config_defaults() {
331 let cfg = test_config("test");
332 assert_eq!(cfg.max_concurrent, 4);
333 assert!(cfg.skills.is_empty());
334 assert!(cfg.allowed_subagents.is_empty());
335 }
336
337 #[test]
338 fn agent_run_state_serde() {
339 for state in [
340 AgentRunState::Idle,
341 AgentRunState::Starting,
342 AgentRunState::Running,
343 AgentRunState::Stopped,
344 AgentRunState::Error,
345 ] {
346 let json = serde_json::to_string(&state).unwrap();
347 let back: AgentRunState = serde_json::from_str(&json).unwrap();
348 assert_eq!(state, back);
349 }
350 }
351}