Skip to main content

bamboo_engine/runtime/execution/
spawn.rs

1//! Sub-session spawn scheduler.
2//!
3//! Provides a background queue for spawning child sessions. Spawn is async
4//! (tool returns immediately), but the UI can observe child progress via
5//! events forwarded to the parent session stream.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, mpsc, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::tools::ToolExecutor;
16use bamboo_agent_core::{AgentEvent, Role, Session, SessionKind};
17use bamboo_domain::ProviderModelRef;
18use bamboo_infrastructure::{LLMProvider, ProviderModelRouter};
19
20use crate::runtime::Agent;
21use crate::runtime::ExecuteRequest;
22
23use super::event_forwarder::create_event_forwarder;
24use super::runner_lifecycle::{finalize_runner, try_reserve_runner};
25use super::runner_state::AgentRunner;
26use super::session_events::get_or_create_event_sender;
27
28#[derive(Debug, Clone)]
29pub struct SpawnJob {
30    pub parent_session_id: String,
31    pub child_session_id: String,
32    pub model: String,
33}
34
35/// Trait for external child session runtimes (e.g. A2A, CLI adapters).
36///
37/// Implementors are responsible for emitting AgentEvents via `event_tx`
38/// and respecting the `cancel_token`.
39#[async_trait::async_trait]
40pub trait ExternalChildRunner: Send + Sync {
41    /// Returns true if this runner should handle the given child session.
42    async fn should_handle(&self, session: &Session) -> bool;
43
44    /// Execute the child session using an external runtime.
45    async fn execute_external_child(
46        &self,
47        session: &mut Session,
48        job: &SpawnJob,
49        event_tx: tokio::sync::mpsc::Sender<AgentEvent>,
50        cancel_token: CancellationToken,
51    ) -> crate::runtime::runner::Result<()>;
52}
53
54#[derive(Clone)]
55pub struct SpawnContext {
56    pub agent: Arc<Agent>,
57    pub tools: Arc<dyn ToolExecutor>,
58    pub sessions_cache: Arc<RwLock<HashMap<String, Session>>>,
59    pub agent_runners: Arc<RwLock<HashMap<String, AgentRunner>>>,
60    pub session_event_senders: Arc<RwLock<HashMap<String, broadcast::Sender<AgentEvent>>>>,
61    pub external_child_runner: Option<Arc<dyn ExternalChildRunner>>,
62    pub provider_router: Option<Arc<ProviderModelRouter>>,
63}
64
65#[derive(Clone)]
66pub struct SpawnScheduler {
67    tx: mpsc::Sender<SpawnJob>,
68}
69
70impl SpawnScheduler {
71    pub fn new(ctx: SpawnContext) -> Self {
72        let (tx, mut rx) = mpsc::channel::<SpawnJob>(128);
73
74        tokio::spawn(async move {
75            while let Some(job) = rx.recv().await {
76                if let Err(err) = run_spawn_job(ctx.clone(), job).await {
77                    tracing::warn!("spawn job failed: {}", err);
78                }
79            }
80        });
81
82        Self { tx }
83    }
84
85    pub async fn enqueue(&self, job: SpawnJob) -> Result<(), String> {
86        self.tx
87            .send(job)
88            .await
89            .map_err(|_| "spawn scheduler is not running".to_string())
90    }
91}
92
93fn child_model_ref(session: &Session, model: &str) -> Option<ProviderModelRef> {
94    if let Some(model_ref) = session.model_ref.clone() {
95        let provider = model_ref.provider.trim();
96        let model_name = model_ref.model.trim();
97        if !provider.is_empty() && !model_name.is_empty() {
98            return Some(ProviderModelRef::new(provider, model_name));
99        }
100    }
101
102    let provider = session
103        .metadata
104        .get("provider_name")
105        .map(String::as_str)
106        .map(str::trim)
107        .filter(|value| !value.is_empty())?;
108    let model_name = model.trim();
109    if model_name.is_empty() {
110        return None;
111    }
112    Some(ProviderModelRef::new(provider, model_name))
113}
114
115fn resolve_child_provider_override(
116    router: Option<&Arc<ProviderModelRouter>>,
117    session: &Session,
118    model: &str,
119) -> (Option<Arc<dyn LLMProvider>>, Option<String>) {
120    let model_ref = child_model_ref(session, model);
121    let provider_name = model_ref
122        .as_ref()
123        .map(|model_ref| model_ref.provider.clone());
124    let provider = router.and_then(|router| {
125        let model_ref = model_ref.as_ref()?;
126        match router.route(model_ref) {
127            Ok(provider) => Some(provider),
128            Err(error) => {
129                tracing::warn!(
130                    session_id = %session.id,
131                    provider = %model_ref.provider,
132                    model = %model_ref.model,
133                    error = %error,
134                    "failed to resolve provider override for child session; falling back to runtime provider"
135                );
136                None
137            }
138        }
139    });
140    (provider, provider_name)
141}
142
143async fn run_spawn_job(ctx: SpawnContext, job: SpawnJob) -> Result<(), String> {
144    // Ensure both session event streams exist.
145    let parent_tx =
146        get_or_create_event_sender(&ctx.session_event_senders, &job.parent_session_id).await;
147    let child_tx =
148        get_or_create_event_sender(&ctx.session_event_senders, &job.child_session_id).await;
149
150    let emit_error_completion = |error: String| {
151        let _ = parent_tx.send(AgentEvent::SubSessionCompleted {
152            parent_session_id: job.parent_session_id.clone(),
153            child_session_id: job.child_session_id.clone(),
154            status: "error".to_string(),
155            error: Some(error.clone()),
156        });
157        error
158    };
159
160    // Load child session.
161    let mut session = match ctx
162        .agent
163        .storage()
164        .load_session(&job.child_session_id)
165        .await
166    {
167        Ok(Some(s)) => s,
168        Ok(None) => return Err(emit_error_completion("child session not found".to_string())),
169        Err(e) => {
170            return Err(emit_error_completion(format!(
171                "failed to load child session: {e}"
172            )))
173        }
174    };
175
176    if session.kind != SessionKind::Child {
177        return Err(emit_error_completion(
178            "spawn job child session is not kind=child".to_string(),
179        ));
180    }
181
182    // Ensure last message is user (otherwise nothing to do).
183    let last_is_user = session
184        .messages
185        .last()
186        .map(|m| matches!(m.role, Role::User))
187        .unwrap_or(false);
188    if !last_is_user {
189        session
190            .metadata
191            .insert("last_run_status".to_string(), "skipped".to_string());
192        session.metadata.insert(
193            "last_run_error".to_string(),
194            "No pending message to execute".to_string(),
195        );
196        let _ = ctx.agent.storage().save_session(&session).await;
197        let _ = parent_tx.send(AgentEvent::SubSessionCompleted {
198            parent_session_id: job.parent_session_id.clone(),
199            child_session_id: job.child_session_id.clone(),
200            status: "skipped".to_string(),
201            error: Some("No pending message to execute".to_string()),
202        });
203        return Ok(());
204    }
205
206    // Persist a running marker early so list_sessions can reconstruct status.
207    session
208        .metadata
209        .insert("last_run_status".to_string(), "running".to_string());
210    session.metadata.remove("last_run_error");
211    let _ = ctx.agent.storage().save_session(&session).await;
212
213    // Insert runner status.
214    let Some(cancel_token) =
215        try_reserve_runner(&ctx.agent_runners, &job.child_session_id, &child_tx).await
216    else {
217        return Ok(());
218    };
219
220    // Forward ALL child events to parent.
221    let forwarder_done = CancellationToken::new();
222    {
223        let mut rx = child_tx.subscribe();
224        let parent_tx = parent_tx.clone();
225        let job_clone = job.clone();
226        let done = forwarder_done.clone();
227        tokio::spawn(async move {
228            loop {
229                tokio::select! {
230                    _ = done.cancelled() => break,
231                    evt = rx.recv() => {
232                        match evt {
233                            Ok(event) => {
234                                let _ = parent_tx.send(AgentEvent::SubSessionEvent {
235                                    parent_session_id: job_clone.parent_session_id.clone(),
236                                    child_session_id: job_clone.child_session_id.clone(),
237                                    event: Box::new(event),
238                                });
239                            }
240                            Err(broadcast::error::RecvError::Lagged(_)) => {
241                                continue;
242                            }
243                            Err(_) => break,
244                        }
245                    }
246                }
247            }
248        });
249    }
250    {
251        let parent_tx = parent_tx.clone();
252        let job_clone = job.clone();
253        let done = forwarder_done.clone();
254        tokio::spawn(async move {
255            let mut ticker = tokio::time::interval(Duration::from_secs(5));
256            loop {
257                tokio::select! {
258                    _ = done.cancelled() => break,
259                    _ = ticker.tick() => {
260                        let _ = parent_tx.send(AgentEvent::SubSessionHeartbeat {
261                            parent_session_id: job_clone.parent_session_id.clone(),
262                            child_session_id: job_clone.child_session_id.clone(),
263                            timestamp: Utc::now(),
264                        });
265                    }
266                }
267            }
268        });
269    }
270
271    // Create mpsc channel for agent loop → session events sender.
272    let (mpsc_tx, _forwarder_handle) = create_event_forwarder(
273        job.child_session_id.clone(),
274        child_tx.clone(),
275        ctx.agent_runners.clone(),
276    );
277
278    // Run child loop via unified spawn_session_execution.
279    let model = job.model.clone();
280    let session_id_clone = job.child_session_id.clone();
281    let agent_runners_for_status = ctx.agent_runners.clone();
282    let sessions_cache = ctx.sessions_cache.clone();
283    let agent = ctx.agent.clone();
284    let tools = ctx.tools.clone();
285    let external_runner = ctx.external_child_runner.clone();
286    let done = forwarder_done.clone();
287    let parent_tx_for_done = parent_tx.clone();
288    let parent_id_for_done = job.parent_session_id.clone();
289    let child_id_for_done = job.child_session_id.clone();
290    let session_event_senders = ctx.session_event_senders.clone();
291    let provider_router = ctx.provider_router.clone();
292
293    tokio::spawn(async move {
294        session.model = model.clone();
295
296        let wants_external = session
297            .metadata
298            .get("runtime.kind")
299            .is_some_and(|v| v == "external");
300
301        let result: crate::runtime::runner::Result<()> = if wants_external {
302            if let Some(runner) = external_runner {
303                if runner.should_handle(&session).await {
304                    runner
305                        .execute_external_child(&mut session, &job, mpsc_tx, cancel_token)
306                        .await
307                } else {
308                    Err(bamboo_agent_core::AgentError::LLM(format!(
309                        "No external runner matched child session runtime metadata: agent_id={:?}, protocol={:?}",
310                        session.metadata.get("external.agent_id"),
311                        session.metadata.get("external.protocol"),
312                    )))
313                }
314            } else {
315                Err(bamboo_agent_core::AgentError::LLM(
316                    "Child session requires external runtime, but no external runner is configured"
317                        .to_string(),
318                ))
319            }
320        } else {
321            let (provider_override, provider_name) =
322                resolve_child_provider_override(provider_router.as_ref(), &session, &model);
323            agent
324                .execute(
325                    &mut session,
326                    ExecuteRequest {
327                        initial_message: String::new(), // handled by agent loop
328                        event_tx: mpsc_tx,
329                        cancel_token,
330                        tools: Some(tools),
331                        provider_override,
332                        model: Some(model.clone()),
333                        provider_name,
334                        background_model: None,
335                        background_model_provider: None,
336                        reasoning_effort: None,
337                        disabled_tools: None,
338                        disabled_skill_ids: None,
339                        selected_skill_ids: None,
340                        selected_skill_mode: None,
341                        image_fallback: None,
342                    },
343                )
344                .await
345        };
346
347        let (status, error) = match &result {
348            Ok(_) => ("completed".to_string(), None),
349            Err(e) if e.to_string().contains("cancelled") => {
350                ("cancelled".to_string(), Some(e.to_string()))
351            }
352            Err(e) => ("error".to_string(), Some(e.to_string())),
353        };
354
355        finalize_runner(&agent_runners_for_status, &session_id_clone, &result).await;
356
357        // Merge any queued injected messages that the pipeline didn't pick up
358        // (e.g. if the loop exited before the next turn boundary).
359        if let Ok(Some(latest)) = agent.storage().load_session(&session_id_clone).await {
360            if let Some(raw) = latest.metadata.get("pending_injected_messages") {
361                if let Ok(messages) = serde_json::from_str::<Vec<serde_json::Value>>(raw) {
362                    for msg in messages {
363                        if let Some(content) = msg.get("content").and_then(|v| v.as_str()) {
364                            session
365                                .add_message(bamboo_agent_core::Message::user(content.to_string()));
366                        }
367                    }
368                    session.metadata.remove("pending_injected_messages");
369                }
370            }
371        }
372
373        // Persist final session snapshot.
374        session
375            .metadata
376            .insert("last_run_status".to_string(), status.clone());
377        if let Some(err) = &error {
378            session
379                .metadata
380                .insert("last_run_error".to_string(), err.clone());
381        } else {
382            session.metadata.remove("last_run_error");
383        }
384        let _ = agent.storage().save_session(&session).await;
385        {
386            let mut sessions = sessions_cache.write().await;
387            sessions.insert(session_id_clone.clone(), session);
388        }
389
390        // Stop forwarding/heartbeats and emit terminal child status.
391        done.cancel();
392        let _ = parent_tx_for_done.send(AgentEvent::SubSessionCompleted {
393            parent_session_id: parent_id_for_done,
394            child_session_id: child_id_for_done,
395            status,
396            error,
397        });
398
399        // Allow dead code: session_event_senders keeps the sender alive during the task.
400        drop(session_event_senders);
401    });
402
403    Ok(())
404}