ricecoder_sessions/
background_agent.rs

1//! Background agent management
2
3use crate::error::{SessionError, SessionResult};
4use crate::models::{AgentStatus, BackgroundAgent};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tokio::task::JoinHandle;
9
10/// Event emitted when a background agent completes
11#[derive(Debug, Clone)]
12pub struct AgentCompletionEvent {
13    /// ID of the completed agent
14    pub agent_id: String,
15    /// Final status of the agent
16    pub status: AgentStatus,
17    /// Optional result message
18    pub message: Option<String>,
19}
20
21/// Manages background agents running in sessions
22#[derive(Debug, Clone)]
23pub struct BackgroundAgentManager {
24    /// All background agents indexed by ID
25    agents: Arc<RwLock<HashMap<String, BackgroundAgent>>>,
26    /// Running tasks indexed by agent ID
27    tasks: Arc<RwLock<HashMap<String, JoinHandle<()>>>>,
28    /// Completion events for agents
29    completion_events: Arc<RwLock<Vec<AgentCompletionEvent>>>,
30}
31
32impl BackgroundAgentManager {
33    /// Create a new background agent manager
34    pub fn new() -> Self {
35        Self {
36            agents: Arc::new(RwLock::new(HashMap::new())),
37            tasks: Arc::new(RwLock::new(HashMap::new())),
38            completion_events: Arc::new(RwLock::new(Vec::new())),
39        }
40    }
41
42    /// Start a background agent asynchronously
43    pub async fn start_agent(&self, agent: BackgroundAgent) -> SessionResult<String> {
44        let agent_id = agent.id.clone();
45        let agent_type = agent.agent_type.clone();
46
47        // Store the agent
48        {
49            let mut agents = self.agents.write().await;
50            agents.insert(agent_id.clone(), agent.clone());
51        }
52
53        // Spawn a task to simulate agent execution
54        let agent_id_clone = agent_id.clone();
55        let agents = Arc::clone(&self.agents);
56        let completion_events = Arc::clone(&self.completion_events);
57
58        let task = tokio::spawn(async move {
59            // Simulate agent work
60            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
61
62            // Update agent status to completed
63            {
64                let mut agents_lock = agents.write().await;
65                if let Some(agent) = agents_lock.get_mut(&agent_id_clone) {
66                    agent.status = AgentStatus::Completed;
67                    agent.completed_at = Some(chrono::Utc::now());
68                }
69            }
70
71            // Emit completion event
72            {
73                let mut events = completion_events.write().await;
74                events.push(AgentCompletionEvent {
75                    agent_id: agent_id_clone.clone(),
76                    status: AgentStatus::Completed,
77                    message: Some(format!("Agent {} completed successfully", agent_type)),
78                });
79            }
80        });
81
82        // Store the task
83        {
84            let mut tasks = self.tasks.write().await;
85            tasks.insert(agent_id.clone(), task);
86        }
87
88        Ok(agent_id)
89    }
90
91    /// Get the status of a background agent
92    pub async fn get_agent_status(&self, agent_id: &str) -> SessionResult<AgentStatus> {
93        let agents = self.agents.read().await;
94        agents
95            .get(agent_id)
96            .map(|agent| agent.status)
97            .ok_or_else(|| SessionError::AgentError(format!("Agent not found: {}", agent_id)))
98    }
99
100    /// Get a background agent by ID
101    pub async fn get_agent(&self, agent_id: &str) -> SessionResult<BackgroundAgent> {
102        let agents = self.agents.read().await;
103        agents
104            .get(agent_id)
105            .cloned()
106            .ok_or_else(|| SessionError::AgentError(format!("Agent not found: {}", agent_id)))
107    }
108
109    /// Pause a background agent
110    pub async fn pause_agent(&self, agent_id: &str) -> SessionResult<()> {
111        let mut agents = self.agents.write().await;
112        if let Some(agent) = agents.get_mut(agent_id) {
113            if agent.status == AgentStatus::Running {
114                agent.status = AgentStatus::Cancelled;
115                Ok(())
116            } else {
117                Err(SessionError::AgentError(format!(
118                    "Cannot pause agent in {:?} state",
119                    agent.status
120                )))
121            }
122        } else {
123            Err(SessionError::AgentError(format!(
124                "Agent not found: {}",
125                agent_id
126            )))
127        }
128    }
129
130    /// Cancel a background agent
131    pub async fn cancel_agent(&self, agent_id: &str) -> SessionResult<()> {
132        let mut agents = self.agents.write().await;
133        if let Some(agent) = agents.get_mut(agent_id) {
134            agent.status = AgentStatus::Cancelled;
135            agent.completed_at = Some(chrono::Utc::now());
136
137            // Emit cancellation event
138            let mut events = self.completion_events.write().await;
139            events.push(AgentCompletionEvent {
140                agent_id: agent_id.to_string(),
141                status: AgentStatus::Cancelled,
142                message: Some("Agent was cancelled".to_string()),
143            });
144
145            Ok(())
146        } else {
147            Err(SessionError::AgentError(format!(
148                "Agent not found: {}",
149                agent_id
150            )))
151        }
152    }
153
154    /// List all background agents
155    pub async fn list_agents(&self) -> Vec<BackgroundAgent> {
156        let agents = self.agents.read().await;
157        agents.values().cloned().collect()
158    }
159
160    /// Get all completion events
161    pub async fn get_completion_events(&self) -> Vec<AgentCompletionEvent> {
162        let events = self.completion_events.read().await;
163        events.clone()
164    }
165
166    /// Clear completion events
167    pub async fn clear_completion_events(&self) {
168        let mut events = self.completion_events.write().await;
169        events.clear();
170    }
171
172    /// Check if an agent is running
173    pub async fn is_agent_running(&self, agent_id: &str) -> bool {
174        if let Ok(status) = self.get_agent_status(agent_id).await {
175            status == AgentStatus::Running
176        } else {
177            false
178        }
179    }
180
181    /// Wait for an agent to complete
182    pub async fn wait_for_agent(&self, agent_id: &str) -> SessionResult<AgentStatus> {
183        loop {
184            let status = self.get_agent_status(agent_id).await?;
185            if status != AgentStatus::Running {
186                return Ok(status);
187            }
188            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
189        }
190    }
191}
192
193impl Default for BackgroundAgentManager {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[tokio::test]
204    async fn test_start_agent() {
205        let manager = BackgroundAgentManager::new();
206        let agent = BackgroundAgent::new("test_agent".to_string(), Some("test task".to_string()));
207
208        let agent_id = manager.start_agent(agent).await.unwrap();
209        assert!(!agent_id.is_empty());
210
211        // Wait for agent to complete
212        let status = manager.wait_for_agent(&agent_id).await.unwrap();
213        assert_eq!(status, AgentStatus::Completed);
214    }
215
216    #[tokio::test]
217    async fn test_get_agent_status() {
218        let manager = BackgroundAgentManager::new();
219        let agent = BackgroundAgent::new("test_agent".to_string(), None);
220        let agent_id = agent.id.clone();
221
222        manager.start_agent(agent).await.unwrap();
223
224        let status = manager.get_agent_status(&agent_id).await.unwrap();
225        assert_eq!(status, AgentStatus::Running);
226    }
227
228    #[tokio::test]
229    async fn test_cancel_agent() {
230        let manager = BackgroundAgentManager::new();
231        let agent = BackgroundAgent::new("test_agent".to_string(), None);
232        let agent_id = agent.id.clone();
233
234        manager.start_agent(agent).await.unwrap();
235        manager.cancel_agent(&agent_id).await.unwrap();
236
237        let status = manager.get_agent_status(&agent_id).await.unwrap();
238        assert_eq!(status, AgentStatus::Cancelled);
239    }
240
241    #[tokio::test]
242    async fn test_list_agents() {
243        let manager = BackgroundAgentManager::new();
244        let agent1 = BackgroundAgent::new("agent1".to_string(), None);
245        let agent2 = BackgroundAgent::new("agent2".to_string(), None);
246
247        manager.start_agent(agent1).await.unwrap();
248        manager.start_agent(agent2).await.unwrap();
249
250        let agents = manager.list_agents().await;
251        assert_eq!(agents.len(), 2);
252    }
253
254    #[tokio::test]
255    async fn test_completion_events() {
256        let manager = BackgroundAgentManager::new();
257        let agent = BackgroundAgent::new("test_agent".to_string(), None);
258
259        manager.start_agent(agent).await.unwrap();
260
261        // Wait for completion
262        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
263
264        let events = manager.get_completion_events().await;
265        assert!(!events.is_empty());
266        assert_eq!(events[0].status, AgentStatus::Completed);
267    }
268}