Skip to main content

roboticus_agent/
subagents.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use tokio::sync::{Mutex, Semaphore};
7use tracing::{debug, warn};
8
9use roboticus_core::{Result, RoboticusError};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum AgentRunState {
13    Idle,
14    Starting,
15    Running,
16    Stopped,
17    Error,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct AgentInstance {
22    pub id: String,
23    pub name: String,
24    pub model: String,
25    pub state: AgentRunState,
26    pub session_count: usize,
27    pub started_at: Option<DateTime<Utc>>,
28    pub last_error: Option<String>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct AgentInstanceConfig {
33    pub id: String,
34    pub name: String,
35    pub model: String,
36    #[serde(default)]
37    pub skills: Vec<String>,
38    #[serde(default)]
39    pub allowed_subagents: Vec<String>,
40    #[serde(default = "default_max_concurrent")]
41    pub max_concurrent: usize,
42}
43
44fn default_max_concurrent() -> usize {
45    4
46}
47
48pub struct SubagentRegistry {
49    agents: Mutex<HashMap<String, AgentInstance>>,
50    concurrency: Arc<Semaphore>,
51    max_concurrent: usize,
52    allowed_ids: Vec<String>,
53}
54
55impl SubagentRegistry {
56    pub fn new(max_concurrent: usize, allowed_ids: Vec<String>) -> Self {
57        Self {
58            agents: Mutex::new(HashMap::new()),
59            concurrency: Arc::new(Semaphore::new(max_concurrent)),
60            max_concurrent,
61            allowed_ids,
62        }
63    }
64
65    pub fn is_allowed(&self, agent_id: &str) -> bool {
66        self.allowed_ids.is_empty() || self.allowed_ids.iter().any(|id| id == agent_id)
67    }
68
69    pub async fn register(&self, config: AgentInstanceConfig) -> Result<()> {
70        if !self.is_allowed(&config.id) {
71            return Err(RoboticusError::Config(format!(
72                "agent '{}' is not in the allowed list",
73                config.id
74            )));
75        }
76
77        let instance = AgentInstance {
78            id: config.id.clone(),
79            name: config.name,
80            model: config.model,
81            state: AgentRunState::Idle,
82            session_count: 0,
83            started_at: None,
84            last_error: None,
85        };
86
87        debug!(id = %config.id, "registered agent");
88        let mut agents = self.agents.lock().await;
89        agents.insert(config.id, instance);
90        Ok(())
91    }
92
93    pub async fn start_agent(&self, agent_id: &str) -> Result<()> {
94        let mut agents = self.agents.lock().await;
95        let agent = agents
96            .get_mut(agent_id)
97            .ok_or_else(|| RoboticusError::Config(format!("agent '{agent_id}' not found")))?;
98
99        if matches!(
100            agent.state,
101            AgentRunState::Running | AgentRunState::Starting
102        ) {
103            return Ok(());
104        }
105
106        agent.state = AgentRunState::Running;
107        agent.started_at = Some(Utc::now());
108        agent.last_error = None;
109        debug!(id = agent_id, "agent started");
110        Ok(())
111    }
112
113    pub async fn stop_agent(&self, agent_id: &str) -> Result<()> {
114        let mut agents = self.agents.lock().await;
115        let agent = agents
116            .get_mut(agent_id)
117            .ok_or_else(|| RoboticusError::Config(format!("agent '{agent_id}' not found")))?;
118
119        agent.state = AgentRunState::Stopped;
120        debug!(id = agent_id, "agent stopped");
121        Ok(())
122    }
123
124    pub async fn unregister(&self, agent_id: &str) -> bool {
125        let mut agents = self.agents.lock().await;
126        let removed = agents.remove(agent_id).is_some();
127        if removed {
128            debug!(id = agent_id, "agent unregistered");
129        }
130        removed
131    }
132
133    pub async fn mark_error(&self, agent_id: &str, error: String) {
134        let mut agents = self.agents.lock().await;
135        if let Some(agent) = agents.get_mut(agent_id) {
136            agent.state = AgentRunState::Error;
137            agent.last_error = Some(error);
138            warn!(id = agent_id, "agent errored");
139        }
140    }
141
142    pub async fn get_agent(&self, agent_id: &str) -> Option<AgentInstance> {
143        let agents = self.agents.lock().await;
144        agents.get(agent_id).cloned()
145    }
146
147    pub async fn list_agents(&self) -> Vec<AgentInstance> {
148        let agents = self.agents.lock().await;
149        agents.values().cloned().collect()
150    }
151
152    pub async fn running_count(&self) -> usize {
153        let agents = self.agents.lock().await;
154        agents
155            .values()
156            .filter(|a| a.state == AgentRunState::Running)
157            .count()
158    }
159
160    pub async fn agent_count(&self) -> usize {
161        let agents = self.agents.lock().await;
162        agents.len()
163    }
164
165    pub async fn acquire_slot(&self) -> Result<tokio::sync::OwnedSemaphorePermit> {
166        Arc::clone(&self.concurrency)
167            .acquire_owned()
168            .await
169            .map_err(|_| RoboticusError::Config("concurrency semaphore closed".into()))
170    }
171
172    pub fn max_concurrent(&self) -> usize {
173        self.max_concurrent
174    }
175
176    pub fn available_slots(&self) -> usize {
177        self.concurrency.available_permits()
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    fn test_config(id: &str) -> AgentInstanceConfig {
186        AgentInstanceConfig {
187            id: id.into(),
188            name: format!("Agent {id}"),
189            model: "test-model".into(),
190            skills: vec![],
191            allowed_subagents: vec![],
192            max_concurrent: 4,
193        }
194    }
195
196    #[test]
197    fn allowed_empty_means_all() {
198        let reg = SubagentRegistry::new(4, vec![]);
199        assert!(reg.is_allowed("anything"));
200    }
201
202    #[test]
203    fn allowed_filters() {
204        let reg = SubagentRegistry::new(4, vec!["a".into(), "b".into()]);
205        assert!(reg.is_allowed("a"));
206        assert!(!reg.is_allowed("c"));
207    }
208
209    #[tokio::test]
210    async fn register_and_list() {
211        let reg = SubagentRegistry::new(4, vec![]);
212        reg.register(test_config("agent-1")).await.unwrap();
213        assert_eq!(reg.agent_count().await, 1);
214        let agents = reg.list_agents().await;
215        assert_eq!(agents[0].id, "agent-1");
216        assert_eq!(agents[0].state, AgentRunState::Idle);
217    }
218
219    #[tokio::test]
220    async fn register_disallowed_fails() {
221        let reg = SubagentRegistry::new(4, vec!["allowed".into()]);
222        let result = reg.register(test_config("not-allowed")).await;
223        assert!(result.is_err());
224    }
225
226    #[tokio::test]
227    async fn start_and_stop() {
228        let reg = SubagentRegistry::new(4, vec![]);
229        reg.register(test_config("a")).await.unwrap();
230
231        reg.start_agent("a").await.unwrap();
232        let agent = reg.get_agent("a").await.unwrap();
233        assert_eq!(agent.state, AgentRunState::Running);
234        assert!(agent.started_at.is_some());
235
236        reg.stop_agent("a").await.unwrap();
237        let agent = reg.get_agent("a").await.unwrap();
238        assert_eq!(agent.state, AgentRunState::Stopped);
239    }
240
241    #[tokio::test]
242    async fn start_nonexistent_fails() {
243        let reg = SubagentRegistry::new(4, vec![]);
244        let result = reg.start_agent("nope").await;
245        assert!(result.is_err());
246    }
247
248    #[tokio::test]
249    async fn start_already_running_is_idempotent() {
250        let reg = SubagentRegistry::new(4, vec![]);
251        reg.register(test_config("dup")).await.unwrap();
252        reg.start_agent("dup").await.unwrap();
253        // Second start should succeed silently, not error or leave stuck in Starting
254        reg.start_agent("dup").await.unwrap();
255        let agent = reg.get_agent("dup").await.unwrap();
256        assert_eq!(agent.state, AgentRunState::Running);
257    }
258
259    #[tokio::test]
260    async fn concurrent_starts_all_reach_running() {
261        let reg = std::sync::Arc::new(SubagentRegistry::new(4, vec![]));
262        for i in 0..5 {
263            reg.register(test_config(&format!("agent-{i}")))
264                .await
265                .unwrap();
266        }
267        // Start all agents concurrently
268        let mut handles = vec![];
269        for i in 0..5 {
270            let reg = reg.clone();
271            handles.push(tokio::spawn(async move {
272                reg.start_agent(&format!("agent-{i}")).await.unwrap();
273            }));
274        }
275        for h in handles {
276            h.await.unwrap();
277        }
278        // All agents should be Running, none stuck in Starting
279        for agent in reg.list_agents().await {
280            assert_eq!(
281                agent.state,
282                AgentRunState::Running,
283                "agent {} stuck in {:?}",
284                agent.id,
285                agent.state
286            );
287        }
288    }
289
290    #[tokio::test]
291    async fn mark_error() {
292        let reg = SubagentRegistry::new(4, vec![]);
293        reg.register(test_config("e")).await.unwrap();
294        reg.start_agent("e").await.unwrap();
295        reg.mark_error("e", "something broke".into()).await;
296        let agent = reg.get_agent("e").await.unwrap();
297        assert_eq!(agent.state, AgentRunState::Error);
298        assert_eq!(agent.last_error.as_deref(), Some("something broke"));
299    }
300
301    #[tokio::test]
302    async fn running_count() {
303        let reg = SubagentRegistry::new(4, vec![]);
304        reg.register(test_config("a")).await.unwrap();
305        reg.register(test_config("b")).await.unwrap();
306        reg.start_agent("a").await.unwrap();
307        assert_eq!(reg.running_count().await, 1);
308    }
309
310    #[tokio::test]
311    async fn unregister_removes_agent() {
312        let reg = SubagentRegistry::new(4, vec![]);
313        reg.register(test_config("a")).await.unwrap();
314        assert_eq!(reg.agent_count().await, 1);
315        assert!(reg.unregister("a").await);
316        assert_eq!(reg.agent_count().await, 0);
317        assert!(!reg.unregister("a").await);
318    }
319
320    #[tokio::test]
321    async fn concurrency_slots() {
322        let reg = SubagentRegistry::new(2, vec![]);
323        assert_eq!(reg.available_slots(), 2);
324        assert_eq!(reg.max_concurrent(), 2);
325        let _permit = reg.acquire_slot().await.unwrap();
326        assert_eq!(reg.available_slots(), 1);
327    }
328
329    #[test]
330    fn agent_instance_config_defaults() {
331        let cfg = test_config("test");
332        assert_eq!(cfg.max_concurrent, 4);
333        assert!(cfg.skills.is_empty());
334        assert!(cfg.allowed_subagents.is_empty());
335    }
336
337    #[test]
338    fn agent_run_state_serde() {
339        for state in [
340            AgentRunState::Idle,
341            AgentRunState::Starting,
342            AgentRunState::Running,
343            AgentRunState::Stopped,
344            AgentRunState::Error,
345        ] {
346            let json = serde_json::to_string(&state).unwrap();
347            let back: AgentRunState = serde_json::from_str(&json).unwrap();
348            assert_eq!(state, back);
349        }
350    }
351}