ricecoder_sessions/
background_agent.rs1use crate::error::{SessionError, SessionResult};
4use crate::models::{AgentStatus, BackgroundAgent};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tokio::task::JoinHandle;
9
10#[derive(Debug, Clone)]
12pub struct AgentCompletionEvent {
13 pub agent_id: String,
15 pub status: AgentStatus,
17 pub message: Option<String>,
19}
20
21#[derive(Debug, Clone)]
23pub struct BackgroundAgentManager {
24 agents: Arc<RwLock<HashMap<String, BackgroundAgent>>>,
26 tasks: Arc<RwLock<HashMap<String, JoinHandle<()>>>>,
28 completion_events: Arc<RwLock<Vec<AgentCompletionEvent>>>,
30}
31
32impl BackgroundAgentManager {
33 pub fn new() -> Self {
35 Self {
36 agents: Arc::new(RwLock::new(HashMap::new())),
37 tasks: Arc::new(RwLock::new(HashMap::new())),
38 completion_events: Arc::new(RwLock::new(Vec::new())),
39 }
40 }
41
42 pub async fn start_agent(&self, agent: BackgroundAgent) -> SessionResult<String> {
44 let agent_id = agent.id.clone();
45 let agent_type = agent.agent_type.clone();
46
47 {
49 let mut agents = self.agents.write().await;
50 agents.insert(agent_id.clone(), agent.clone());
51 }
52
53 let agent_id_clone = agent_id.clone();
55 let agents = Arc::clone(&self.agents);
56 let completion_events = Arc::clone(&self.completion_events);
57
58 let task = tokio::spawn(async move {
59 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
61
62 {
64 let mut agents_lock = agents.write().await;
65 if let Some(agent) = agents_lock.get_mut(&agent_id_clone) {
66 agent.status = AgentStatus::Completed;
67 agent.completed_at = Some(chrono::Utc::now());
68 }
69 }
70
71 {
73 let mut events = completion_events.write().await;
74 events.push(AgentCompletionEvent {
75 agent_id: agent_id_clone.clone(),
76 status: AgentStatus::Completed,
77 message: Some(format!("Agent {} completed successfully", agent_type)),
78 });
79 }
80 });
81
82 {
84 let mut tasks = self.tasks.write().await;
85 tasks.insert(agent_id.clone(), task);
86 }
87
88 Ok(agent_id)
89 }
90
91 pub async fn get_agent_status(&self, agent_id: &str) -> SessionResult<AgentStatus> {
93 let agents = self.agents.read().await;
94 agents
95 .get(agent_id)
96 .map(|agent| agent.status)
97 .ok_or_else(|| SessionError::AgentError(format!("Agent not found: {}", agent_id)))
98 }
99
100 pub async fn get_agent(&self, agent_id: &str) -> SessionResult<BackgroundAgent> {
102 let agents = self.agents.read().await;
103 agents
104 .get(agent_id)
105 .cloned()
106 .ok_or_else(|| SessionError::AgentError(format!("Agent not found: {}", agent_id)))
107 }
108
109 pub async fn pause_agent(&self, agent_id: &str) -> SessionResult<()> {
111 let mut agents = self.agents.write().await;
112 if let Some(agent) = agents.get_mut(agent_id) {
113 if agent.status == AgentStatus::Running {
114 agent.status = AgentStatus::Cancelled;
115 Ok(())
116 } else {
117 Err(SessionError::AgentError(format!(
118 "Cannot pause agent in {:?} state",
119 agent.status
120 )))
121 }
122 } else {
123 Err(SessionError::AgentError(format!(
124 "Agent not found: {}",
125 agent_id
126 )))
127 }
128 }
129
130 pub async fn cancel_agent(&self, agent_id: &str) -> SessionResult<()> {
132 let mut agents = self.agents.write().await;
133 if let Some(agent) = agents.get_mut(agent_id) {
134 agent.status = AgentStatus::Cancelled;
135 agent.completed_at = Some(chrono::Utc::now());
136
137 let mut events = self.completion_events.write().await;
139 events.push(AgentCompletionEvent {
140 agent_id: agent_id.to_string(),
141 status: AgentStatus::Cancelled,
142 message: Some("Agent was cancelled".to_string()),
143 });
144
145 Ok(())
146 } else {
147 Err(SessionError::AgentError(format!(
148 "Agent not found: {}",
149 agent_id
150 )))
151 }
152 }
153
154 pub async fn list_agents(&self) -> Vec<BackgroundAgent> {
156 let agents = self.agents.read().await;
157 agents.values().cloned().collect()
158 }
159
160 pub async fn get_completion_events(&self) -> Vec<AgentCompletionEvent> {
162 let events = self.completion_events.read().await;
163 events.clone()
164 }
165
166 pub async fn clear_completion_events(&self) {
168 let mut events = self.completion_events.write().await;
169 events.clear();
170 }
171
172 pub async fn is_agent_running(&self, agent_id: &str) -> bool {
174 if let Ok(status) = self.get_agent_status(agent_id).await {
175 status == AgentStatus::Running
176 } else {
177 false
178 }
179 }
180
181 pub async fn wait_for_agent(&self, agent_id: &str) -> SessionResult<AgentStatus> {
183 loop {
184 let status = self.get_agent_status(agent_id).await?;
185 if status != AgentStatus::Running {
186 return Ok(status);
187 }
188 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
189 }
190 }
191}
192
193impl Default for BackgroundAgentManager {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[tokio::test]
204 async fn test_start_agent() {
205 let manager = BackgroundAgentManager::new();
206 let agent = BackgroundAgent::new("test_agent".to_string(), Some("test task".to_string()));
207
208 let agent_id = manager.start_agent(agent).await.unwrap();
209 assert!(!agent_id.is_empty());
210
211 let status = manager.wait_for_agent(&agent_id).await.unwrap();
213 assert_eq!(status, AgentStatus::Completed);
214 }
215
216 #[tokio::test]
217 async fn test_get_agent_status() {
218 let manager = BackgroundAgentManager::new();
219 let agent = BackgroundAgent::new("test_agent".to_string(), None);
220 let agent_id = agent.id.clone();
221
222 manager.start_agent(agent).await.unwrap();
223
224 let status = manager.get_agent_status(&agent_id).await.unwrap();
225 assert_eq!(status, AgentStatus::Running);
226 }
227
228 #[tokio::test]
229 async fn test_cancel_agent() {
230 let manager = BackgroundAgentManager::new();
231 let agent = BackgroundAgent::new("test_agent".to_string(), None);
232 let agent_id = agent.id.clone();
233
234 manager.start_agent(agent).await.unwrap();
235 manager.cancel_agent(&agent_id).await.unwrap();
236
237 let status = manager.get_agent_status(&agent_id).await.unwrap();
238 assert_eq!(status, AgentStatus::Cancelled);
239 }
240
241 #[tokio::test]
242 async fn test_list_agents() {
243 let manager = BackgroundAgentManager::new();
244 let agent1 = BackgroundAgent::new("agent1".to_string(), None);
245 let agent2 = BackgroundAgent::new("agent2".to_string(), None);
246
247 manager.start_agent(agent1).await.unwrap();
248 manager.start_agent(agent2).await.unwrap();
249
250 let agents = manager.list_agents().await;
251 assert_eq!(agents.len(), 2);
252 }
253
254 #[tokio::test]
255 async fn test_completion_events() {
256 let manager = BackgroundAgentManager::new();
257 let agent = BackgroundAgent::new("test_agent".to_string(), None);
258
259 manager.start_agent(agent).await.unwrap();
260
261 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
263
264 let events = manager.get_completion_events().await;
265 assert!(!events.is_empty());
266 assert_eq!(events[0].status, AgentStatus::Completed);
267 }
268}