Skip to main content

bamboo_engine/runtime/execution/
runner_lifecycle.rs

1//! Runner lifecycle helpers for background agent execution.
2//!
3//! Provides shared implementations for:
4//! - Runner reservation (check existing → create new with cancel token)
5//! - Runner finalization (map execution result to `AgentStatus`)
6//! - Status mapping
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::{AgentError, AgentEvent};
16
17use super::runner_state::{AgentRunner, AgentStatus};
18
19/// Reservation result from `try_reserve_runner`.
20#[derive(Debug, Clone)]
21pub struct RunnerReservation {
22    pub cancel_token: CancellationToken,
23    pub run_id: String,
24}
25
26/// Try to reserve a runner for the given session.
27///
28/// If a runner with `Running` status already exists, returns `None`
29/// (caller should skip execution). The `AlreadyRunning` case is surfaced
30/// by the caller via `ExecuteResponse` with the *existing* runner's `run_id`
31/// so the frontend can correlate subsequent SSE events.
32///
33/// Otherwise removes any stale runner and inserts a fresh one, returning
34/// the associated `CancellationToken` and the new `run_id`.
35pub async fn try_reserve_runner(
36    runners: &Arc<RwLock<HashMap<String, AgentRunner>>>,
37    session_id: &str,
38    event_sender: &broadcast::Sender<AgentEvent>,
39) -> Option<RunnerReservation> {
40    let mut guard = runners.write().await;
41    if let Some(runner) = guard.get(session_id) {
42        if matches!(runner.status, AgentStatus::Running) {
43            tracing::debug!("[{}] Runner already running, skipping", session_id);
44            return None;
45        }
46    }
47
48    guard.remove(session_id);
49
50    let mut runner = AgentRunner::new();
51    runner.status = AgentStatus::Running;
52    runner.event_sender = event_sender.clone();
53    let reservation = RunnerReservation {
54        cancel_token: runner.cancel_token.clone(),
55        run_id: runner.run_id.clone(),
56    };
57
58    guard.insert(session_id.to_string(), runner);
59    Some(reservation)
60}
61
62/// Map an execution result to `AgentStatus`.
63pub fn status_from_execution_result(result: &Result<(), AgentError>) -> AgentStatus {
64    match result {
65        Ok(_) => AgentStatus::Completed,
66        Err(error) if error.is_cancelled() => AgentStatus::Cancelled,
67        Err(error) => AgentStatus::Error(error.to_string()),
68    }
69}
70
71/// Update a runner's terminal status and completion timestamp.
72pub async fn finalize_runner(
73    runners: &Arc<RwLock<HashMap<String, AgentRunner>>>,
74    session_id: &str,
75    result: &Result<(), AgentError>,
76) {
77    let mut guard = runners.write().await;
78    if let Some(runner) = guard.get_mut(session_id) {
79        runner.status = status_from_execution_result(result);
80        runner.completed_at = Some(Utc::now());
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    fn new_runners() -> Arc<RwLock<HashMap<String, AgentRunner>>> {
89        Arc::new(RwLock::new(HashMap::new()))
90    }
91
92    fn new_broadcaster() -> broadcast::Sender<AgentEvent> {
93        broadcast::channel(100).0
94    }
95
96    #[tokio::test]
97    async fn try_reserve_runner_creates_runner_with_running_status() {
98        let runners = new_runners();
99        let tx = new_broadcaster();
100        let token = try_reserve_runner(&runners, "s1", &tx).await;
101        assert!(token.is_some());
102
103        let guard = runners.read().await;
104        let runner = guard.get("s1").unwrap();
105        assert!(matches!(runner.status, AgentStatus::Running));
106    }
107
108    #[tokio::test]
109    async fn try_reserve_runner_returns_none_when_already_running() {
110        let runners = new_runners();
111        let tx = new_broadcaster();
112        let _ = try_reserve_runner(&runners, "s1", &tx).await;
113        let second = try_reserve_runner(&runners, "s1", &tx).await;
114        assert!(second.is_none());
115    }
116
117    #[tokio::test]
118    async fn try_reserve_runner_replaces_completed_runner() {
119        let runners = new_runners();
120        let tx = new_broadcaster();
121        let _ = try_reserve_runner(&runners, "s1", &tx).await;
122
123        {
124            let mut guard = runners.write().await;
125            let runner = guard.get_mut("s1").unwrap();
126            runner.status = AgentStatus::Completed;
127        }
128
129        let second = try_reserve_runner(&runners, "s1", &tx).await;
130        assert!(second.is_some());
131    }
132
133    #[test]
134    fn status_from_execution_result_maps_correctly() {
135        let ok_result: Result<(), AgentError> = Ok(());
136        assert!(matches!(
137            status_from_execution_result(&ok_result),
138            AgentStatus::Completed
139        ));
140
141        // Cancellation is detected by matching the `AgentError::Cancelled`
142        // variant, not by substring-matching the (display) message — note the
143        // variant's message is "Cancelled", which would not even contain the
144        // lowercase "cancelled" the old code searched for.
145        let cancelled: Result<(), AgentError> = Err(AgentError::Cancelled);
146        assert!(matches!(
147            status_from_execution_result(&cancelled),
148            AgentStatus::Cancelled
149        ));
150
151        let failed: Result<(), AgentError> = Err(AgentError::LLM("network error".to_string()));
152        match status_from_execution_result(&failed) {
153            AgentStatus::Error(message) => assert!(message.contains("network error")),
154            other => panic!("unexpected status: {other:?}"),
155        }
156    }
157}