Skip to main content

bamboo_engine/runtime/execution/
runner_lifecycle.rs

1//! Runner lifecycle helpers for background agent execution.
2//!
3//! Provides shared implementations for:
4//! - Runner reservation (check existing → create new with cancel token)
5//! - Runner finalization (map execution result to `AgentStatus`)
6//! - Status mapping
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::AgentEvent;
16
17use super::runner_state::{AgentRunner, AgentStatus};
18
19/// Reservation result from `try_reserve_runner`.
20#[derive(Debug, Clone)]
21pub struct RunnerReservation {
22    pub cancel_token: CancellationToken,
23    pub run_id: String,
24}
25
26/// Try to reserve a runner for the given session.
27///
28/// If a runner with `Running` status already exists, returns `None`
29/// (caller should skip execution). The `AlreadyRunning` case is surfaced
30/// by the caller via `ExecuteResponse` with the *existing* runner's `run_id`
31/// so the frontend can correlate subsequent SSE events.
32///
33/// Otherwise removes any stale runner and inserts a fresh one, returning
34/// the associated `CancellationToken` and the new `run_id`.
35pub async fn try_reserve_runner(
36    runners: &Arc<RwLock<HashMap<String, AgentRunner>>>,
37    session_id: &str,
38    event_sender: &broadcast::Sender<AgentEvent>,
39) -> Option<RunnerReservation> {
40    let mut guard = runners.write().await;
41    if let Some(runner) = guard.get(session_id) {
42        if matches!(runner.status, AgentStatus::Running) {
43            tracing::debug!("[{}] Runner already running, skipping", session_id);
44            return None;
45        }
46    }
47
48    guard.remove(session_id);
49
50    let mut runner = AgentRunner::new();
51    runner.status = AgentStatus::Running;
52    runner.event_sender = event_sender.clone();
53    let reservation = RunnerReservation {
54        cancel_token: runner.cancel_token.clone(),
55        run_id: runner.run_id.clone(),
56    };
57
58    guard.insert(session_id.to_string(), runner);
59    Some(reservation)
60}
61
62/// Map an execution result to `AgentStatus`.
63pub fn status_from_execution_result<E>(result: &Result<(), E>) -> AgentStatus
64where
65    E: std::fmt::Display,
66{
67    match result {
68        Ok(_) => AgentStatus::Completed,
69        Err(error) if error.to_string().contains("cancelled") => AgentStatus::Cancelled,
70        Err(error) => AgentStatus::Error(error.to_string()),
71    }
72}
73
74/// Update a runner's terminal status and completion timestamp.
75pub async fn finalize_runner(
76    runners: &Arc<RwLock<HashMap<String, AgentRunner>>>,
77    session_id: &str,
78    result: &Result<(), impl std::fmt::Display>,
79) {
80    let mut guard = runners.write().await;
81    if let Some(runner) = guard.get_mut(session_id) {
82        runner.status = status_from_execution_result(result);
83        runner.completed_at = Some(Utc::now());
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    fn new_runners() -> Arc<RwLock<HashMap<String, AgentRunner>>> {
92        Arc::new(RwLock::new(HashMap::new()))
93    }
94
95    fn new_broadcaster() -> broadcast::Sender<AgentEvent> {
96        broadcast::channel(100).0
97    }
98
99    #[tokio::test]
100    async fn try_reserve_runner_creates_runner_with_running_status() {
101        let runners = new_runners();
102        let tx = new_broadcaster();
103        let token = try_reserve_runner(&runners, "s1", &tx).await;
104        assert!(token.is_some());
105
106        let guard = runners.read().await;
107        let runner = guard.get("s1").unwrap();
108        assert!(matches!(runner.status, AgentStatus::Running));
109    }
110
111    #[tokio::test]
112    async fn try_reserve_runner_returns_none_when_already_running() {
113        let runners = new_runners();
114        let tx = new_broadcaster();
115        let _ = try_reserve_runner(&runners, "s1", &tx).await;
116        let second = try_reserve_runner(&runners, "s1", &tx).await;
117        assert!(second.is_none());
118    }
119
120    #[tokio::test]
121    async fn try_reserve_runner_replaces_completed_runner() {
122        let runners = new_runners();
123        let tx = new_broadcaster();
124        let _ = try_reserve_runner(&runners, "s1", &tx).await;
125
126        {
127            let mut guard = runners.write().await;
128            let runner = guard.get_mut("s1").unwrap();
129            runner.status = AgentStatus::Completed;
130        }
131
132        let second = try_reserve_runner(&runners, "s1", &tx).await;
133        assert!(second.is_some());
134    }
135
136    #[test]
137    fn status_from_execution_result_maps_correctly() {
138        let ok_result: Result<(), String> = Ok(());
139        assert!(matches!(
140            status_from_execution_result(&ok_result),
141            AgentStatus::Completed
142        ));
143
144        let cancelled = Err("task cancelled".to_string());
145        assert!(matches!(
146            status_from_execution_result(&cancelled),
147            AgentStatus::Cancelled
148        ));
149
150        let failed = Err("network error".to_string());
151        match status_from_execution_result(&failed) {
152            AgentStatus::Error(message) => assert!(message.contains("network error")),
153            other => panic!("unexpected status: {other:?}"),
154        }
155    }
156}