bamboo_engine/runtime/execution/
runner_lifecycle.rs1use std::collections::HashMap;
9use std::sync::Arc;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::AgentEvent;
16
17use super::runner_state::{AgentRunner, AgentStatus};
18
19pub async fn try_reserve_runner(
27 runners: &Arc<RwLock<HashMap<String, AgentRunner>>>,
28 session_id: &str,
29 event_sender: &broadcast::Sender<AgentEvent>,
30) -> Option<CancellationToken> {
31 let mut guard = runners.write().await;
32 if let Some(runner) = guard.get(session_id) {
33 if matches!(runner.status, AgentStatus::Running) {
34 tracing::debug!("[{}] Runner already running, skipping", session_id);
35 return None;
36 }
37 }
38
39 guard.remove(session_id);
40
41 let mut runner = AgentRunner::new();
42 runner.status = AgentStatus::Running;
43 runner.event_sender = event_sender.clone();
44 let cancel_token = runner.cancel_token.clone();
45
46 guard.insert(session_id.to_string(), runner);
47 Some(cancel_token)
48}
49
50pub fn status_from_execution_result<E>(result: &Result<(), E>) -> AgentStatus
52where
53 E: std::fmt::Display,
54{
55 match result {
56 Ok(_) => AgentStatus::Completed,
57 Err(error) if error.to_string().contains("cancelled") => AgentStatus::Cancelled,
58 Err(error) => AgentStatus::Error(error.to_string()),
59 }
60}
61
62pub async fn finalize_runner(
64 runners: &Arc<RwLock<HashMap<String, AgentRunner>>>,
65 session_id: &str,
66 result: &Result<(), impl std::fmt::Display>,
67) {
68 let mut guard = runners.write().await;
69 if let Some(runner) = guard.get_mut(session_id) {
70 runner.status = status_from_execution_result(result);
71 runner.completed_at = Some(Utc::now());
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78
79 fn new_runners() -> Arc<RwLock<HashMap<String, AgentRunner>>> {
80 Arc::new(RwLock::new(HashMap::new()))
81 }
82
83 fn new_broadcaster() -> broadcast::Sender<AgentEvent> {
84 broadcast::channel(100).0
85 }
86
87 #[tokio::test]
88 async fn try_reserve_runner_creates_runner_with_running_status() {
89 let runners = new_runners();
90 let tx = new_broadcaster();
91 let token = try_reserve_runner(&runners, "s1", &tx).await;
92 assert!(token.is_some());
93
94 let guard = runners.read().await;
95 let runner = guard.get("s1").unwrap();
96 assert!(matches!(runner.status, AgentStatus::Running));
97 }
98
99 #[tokio::test]
100 async fn try_reserve_runner_returns_none_when_already_running() {
101 let runners = new_runners();
102 let tx = new_broadcaster();
103 let _ = try_reserve_runner(&runners, "s1", &tx).await;
104 let second = try_reserve_runner(&runners, "s1", &tx).await;
105 assert!(second.is_none());
106 }
107
108 #[tokio::test]
109 async fn try_reserve_runner_replaces_completed_runner() {
110 let runners = new_runners();
111 let tx = new_broadcaster();
112 let _ = try_reserve_runner(&runners, "s1", &tx).await;
113
114 {
115 let mut guard = runners.write().await;
116 let runner = guard.get_mut("s1").unwrap();
117 runner.status = AgentStatus::Completed;
118 }
119
120 let second = try_reserve_runner(&runners, "s1", &tx).await;
121 assert!(second.is_some());
122 }
123
124 #[test]
125 fn status_from_execution_result_maps_correctly() {
126 let ok_result: Result<(), String> = Ok(());
127 assert!(matches!(
128 status_from_execution_result(&ok_result),
129 AgentStatus::Completed
130 ));
131
132 let cancelled = Err("task cancelled".to_string());
133 assert!(matches!(
134 status_from_execution_result(&cancelled),
135 AgentStatus::Cancelled
136 ));
137
138 let failed = Err("network error".to_string());
139 match status_from_execution_result(&failed) {
140 AgentStatus::Error(message) => assert!(message.contains("network error")),
141 other => panic!("unexpected status: {other:?}"),
142 }
143 }
144}