bamboo_engine/runtime/execution/
runner_lifecycle.rs1use std::collections::HashMap;
9use std::sync::Arc;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::{AgentError, AgentEvent};
16
17use super::runner_state::{AgentRunner, AgentStatus};
18
19#[derive(Debug, Clone)]
21pub struct RunnerReservation {
22 pub cancel_token: CancellationToken,
23 pub run_id: String,
24}
25
26pub async fn try_reserve_runner(
36 runners: &Arc<RwLock<HashMap<String, AgentRunner>>>,
37 session_id: &str,
38 event_sender: &broadcast::Sender<AgentEvent>,
39) -> Option<RunnerReservation> {
40 let mut guard = runners.write().await;
41 if let Some(runner) = guard.get(session_id) {
42 if matches!(runner.status, AgentStatus::Running) {
43 tracing::debug!("[{}] Runner already running, skipping", session_id);
44 return None;
45 }
46 }
47
48 guard.remove(session_id);
49
50 let mut runner = AgentRunner::new();
51 runner.status = AgentStatus::Running;
52 runner.event_sender = event_sender.clone();
53 let reservation = RunnerReservation {
54 cancel_token: runner.cancel_token.clone(),
55 run_id: runner.run_id.clone(),
56 };
57
58 guard.insert(session_id.to_string(), runner);
59 Some(reservation)
60}
61
62pub fn status_from_execution_result(result: &Result<(), AgentError>) -> AgentStatus {
64 match result {
65 Ok(_) => AgentStatus::Completed,
66 Err(error) if error.is_cancelled() => AgentStatus::Cancelled,
67 Err(error) => AgentStatus::Error(error.to_string()),
68 }
69}
70
71pub async fn finalize_runner(
73 runners: &Arc<RwLock<HashMap<String, AgentRunner>>>,
74 session_id: &str,
75 result: &Result<(), AgentError>,
76) {
77 let mut guard = runners.write().await;
78 if let Some(runner) = guard.get_mut(session_id) {
79 runner.status = status_from_execution_result(result);
80 runner.completed_at = Some(Utc::now());
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 fn new_runners() -> Arc<RwLock<HashMap<String, AgentRunner>>> {
89 Arc::new(RwLock::new(HashMap::new()))
90 }
91
92 fn new_broadcaster() -> broadcast::Sender<AgentEvent> {
93 broadcast::channel(100).0
94 }
95
96 #[tokio::test]
97 async fn try_reserve_runner_creates_runner_with_running_status() {
98 let runners = new_runners();
99 let tx = new_broadcaster();
100 let token = try_reserve_runner(&runners, "s1", &tx).await;
101 assert!(token.is_some());
102
103 let guard = runners.read().await;
104 let runner = guard.get("s1").unwrap();
105 assert!(matches!(runner.status, AgentStatus::Running));
106 }
107
108 #[tokio::test]
109 async fn try_reserve_runner_returns_none_when_already_running() {
110 let runners = new_runners();
111 let tx = new_broadcaster();
112 let _ = try_reserve_runner(&runners, "s1", &tx).await;
113 let second = try_reserve_runner(&runners, "s1", &tx).await;
114 assert!(second.is_none());
115 }
116
117 #[tokio::test]
118 async fn try_reserve_runner_replaces_completed_runner() {
119 let runners = new_runners();
120 let tx = new_broadcaster();
121 let _ = try_reserve_runner(&runners, "s1", &tx).await;
122
123 {
124 let mut guard = runners.write().await;
125 let runner = guard.get_mut("s1").unwrap();
126 runner.status = AgentStatus::Completed;
127 }
128
129 let second = try_reserve_runner(&runners, "s1", &tx).await;
130 assert!(second.is_some());
131 }
132
133 #[test]
134 fn status_from_execution_result_maps_correctly() {
135 let ok_result: Result<(), AgentError> = Ok(());
136 assert!(matches!(
137 status_from_execution_result(&ok_result),
138 AgentStatus::Completed
139 ));
140
141 let cancelled: Result<(), AgentError> = Err(AgentError::Cancelled);
146 assert!(matches!(
147 status_from_execution_result(&cancelled),
148 AgentStatus::Cancelled
149 ));
150
151 let failed: Result<(), AgentError> = Err(AgentError::LLM("network error".to_string()));
152 match status_from_execution_result(&failed) {
153 AgentStatus::Error(message) => assert!(message.contains("network error")),
154 other => panic!("unexpected status: {other:?}"),
155 }
156 }
157}