bamboo_engine/runtime/execution/
runner_lifecycle.rs1use std::collections::HashMap;
9use std::sync::Arc;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::AgentEvent;
16
17use super::runner_state::{AgentRunner, AgentStatus};
18
19#[derive(Debug, Clone)]
21pub struct RunnerReservation {
22 pub cancel_token: CancellationToken,
23 pub run_id: String,
24}
25
26pub async fn try_reserve_runner(
36 runners: &Arc<RwLock<HashMap<String, AgentRunner>>>,
37 session_id: &str,
38 event_sender: &broadcast::Sender<AgentEvent>,
39) -> Option<RunnerReservation> {
40 let mut guard = runners.write().await;
41 if let Some(runner) = guard.get(session_id) {
42 if matches!(runner.status, AgentStatus::Running) {
43 tracing::debug!("[{}] Runner already running, skipping", session_id);
44 return None;
45 }
46 }
47
48 guard.remove(session_id);
49
50 let mut runner = AgentRunner::new();
51 runner.status = AgentStatus::Running;
52 runner.event_sender = event_sender.clone();
53 let reservation = RunnerReservation {
54 cancel_token: runner.cancel_token.clone(),
55 run_id: runner.run_id.clone(),
56 };
57
58 guard.insert(session_id.to_string(), runner);
59 Some(reservation)
60}
61
62pub fn status_from_execution_result<E>(result: &Result<(), E>) -> AgentStatus
64where
65 E: std::fmt::Display,
66{
67 match result {
68 Ok(_) => AgentStatus::Completed,
69 Err(error) if error.to_string().contains("cancelled") => AgentStatus::Cancelled,
70 Err(error) => AgentStatus::Error(error.to_string()),
71 }
72}
73
74pub async fn finalize_runner(
76 runners: &Arc<RwLock<HashMap<String, AgentRunner>>>,
77 session_id: &str,
78 result: &Result<(), impl std::fmt::Display>,
79) {
80 let mut guard = runners.write().await;
81 if let Some(runner) = guard.get_mut(session_id) {
82 runner.status = status_from_execution_result(result);
83 runner.completed_at = Some(Utc::now());
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 fn new_runners() -> Arc<RwLock<HashMap<String, AgentRunner>>> {
92 Arc::new(RwLock::new(HashMap::new()))
93 }
94
95 fn new_broadcaster() -> broadcast::Sender<AgentEvent> {
96 broadcast::channel(100).0
97 }
98
99 #[tokio::test]
100 async fn try_reserve_runner_creates_runner_with_running_status() {
101 let runners = new_runners();
102 let tx = new_broadcaster();
103 let token = try_reserve_runner(&runners, "s1", &tx).await;
104 assert!(token.is_some());
105
106 let guard = runners.read().await;
107 let runner = guard.get("s1").unwrap();
108 assert!(matches!(runner.status, AgentStatus::Running));
109 }
110
111 #[tokio::test]
112 async fn try_reserve_runner_returns_none_when_already_running() {
113 let runners = new_runners();
114 let tx = new_broadcaster();
115 let _ = try_reserve_runner(&runners, "s1", &tx).await;
116 let second = try_reserve_runner(&runners, "s1", &tx).await;
117 assert!(second.is_none());
118 }
119
120 #[tokio::test]
121 async fn try_reserve_runner_replaces_completed_runner() {
122 let runners = new_runners();
123 let tx = new_broadcaster();
124 let _ = try_reserve_runner(&runners, "s1", &tx).await;
125
126 {
127 let mut guard = runners.write().await;
128 let runner = guard.get_mut("s1").unwrap();
129 runner.status = AgentStatus::Completed;
130 }
131
132 let second = try_reserve_runner(&runners, "s1", &tx).await;
133 assert!(second.is_some());
134 }
135
136 #[test]
137 fn status_from_execution_result_maps_correctly() {
138 let ok_result: Result<(), String> = Ok(());
139 assert!(matches!(
140 status_from_execution_result(&ok_result),
141 AgentStatus::Completed
142 ));
143
144 let cancelled = Err("task cancelled".to_string());
145 assert!(matches!(
146 status_from_execution_result(&cancelled),
147 AgentStatus::Cancelled
148 ));
149
150 let failed = Err("network error".to_string());
151 match status_from_execution_result(&failed) {
152 AgentStatus::Error(message) => assert!(message.contains("network error")),
153 other => panic!("unexpected status: {other:?}"),
154 }
155 }
156}