Skip to main content

bamboo_engine/runtime/execution/
spawn.rs

1//! Sub-session spawn scheduler.
2//!
3//! Provides a background queue for spawning child sessions. Spawn is async
4//! (tool returns immediately), but the UI can observe child progress via
5//! events forwarded to the parent session stream.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, mpsc, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::tools::ToolExecutor;
16use bamboo_agent_core::{AgentEvent, Role, Session, SessionKind};
17use bamboo_domain::ProviderModelRef;
18use bamboo_infrastructure::{LLMProvider, ProviderModelRouter};
19
20use crate::runtime::Agent;
21use crate::runtime::ExecuteRequest;
22
23use super::child_completion::{ChildCompletion, ChildCompletionHandler};
24use super::event_forwarder::create_event_forwarder;
25use super::runner_lifecycle::{finalize_runner, try_reserve_runner, RunnerReservation};
26use super::runner_state::AgentRunner;
27use super::session_events::get_or_create_event_sender;
28
29#[derive(Debug, Clone)]
30pub struct SpawnJob {
31    pub parent_session_id: String,
32    pub child_session_id: String,
33    pub model: String,
34    /// Tool names to hide from the LLM schema for this child session.
35    /// Computed from the child's `subagent_type` profile policy.
36    pub disabled_tools: Option<Vec<String>>,
37}
38
39/// Trait for external child session runtimes (e.g. A2A, CLI adapters).
40///
41/// Implementors are responsible for emitting AgentEvents via `event_tx`
42/// and respecting the `cancel_token`.
43#[async_trait::async_trait]
44pub trait ExternalChildRunner: Send + Sync {
45    /// Returns true if this runner should handle the given child session.
46    async fn should_handle(&self, session: &Session) -> bool;
47
48    /// Execute the child session using an external runtime.
49    async fn execute_external_child(
50        &self,
51        session: &mut Session,
52        job: &SpawnJob,
53        event_tx: tokio::sync::mpsc::Sender<AgentEvent>,
54        cancel_token: CancellationToken,
55    ) -> crate::runtime::runner::Result<()>;
56}
57
58#[derive(Clone)]
59pub struct SpawnContext {
60    pub agent: Arc<Agent>,
61    pub tools: Arc<dyn ToolExecutor>,
62    pub sessions_cache: Arc<RwLock<HashMap<String, Session>>>,
63    pub agent_runners: Arc<RwLock<HashMap<String, AgentRunner>>>,
64    pub session_event_senders: Arc<RwLock<HashMap<String, broadcast::Sender<AgentEvent>>>>,
65    pub external_child_runner: Option<Arc<dyn ExternalChildRunner>>,
66    pub provider_router: Option<Arc<ProviderModelRouter>>,
67    pub app_data_dir: Option<std::path::PathBuf>,
68    /// Optional application-layer completion hook. The engine still emits
69    /// `SubAgentCompleted` to the parent stream itself; this hook lets the
70    /// server persist parent wait state and resume the parent runner without
71    /// introducing an engine -> AppState dependency.
72    pub completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
73}
74
75#[derive(Clone)]
76pub struct SpawnScheduler {
77    tx: mpsc::Sender<SpawnJob>,
78}
79
80impl SpawnScheduler {
81    pub fn new(ctx: SpawnContext) -> Self {
82        let (tx, mut rx) = mpsc::channel::<SpawnJob>(128);
83
84        tokio::spawn(async move {
85            while let Some(job) = rx.recv().await {
86                if let Err(err) = run_spawn_job(ctx.clone(), job).await {
87                    tracing::warn!("spawn job failed: {}", err);
88                }
89            }
90        });
91
92        Self { tx }
93    }
94
95    pub async fn enqueue(&self, job: SpawnJob) -> Result<(), String> {
96        self.tx
97            .send(job)
98            .await
99            .map_err(|_| "spawn scheduler is not running".to_string())
100    }
101}
102
103fn child_model_ref(session: &Session, model: &str) -> Option<ProviderModelRef> {
104    if let Some(model_ref) = session.model_ref.clone() {
105        let provider = model_ref.provider.trim();
106        let model_name = model_ref.model.trim();
107        if !provider.is_empty() && !model_name.is_empty() {
108            return Some(ProviderModelRef::new(provider, model_name));
109        }
110    }
111
112    let provider = session
113        .metadata
114        .get("provider_name")
115        .map(String::as_str)
116        .map(str::trim)
117        .filter(|value| !value.is_empty())?;
118    let model_name = model.trim();
119    if model_name.is_empty() {
120        return None;
121    }
122    Some(ProviderModelRef::new(provider, model_name))
123}
124
125#[derive(Debug, Clone, Copy)]
126struct ChildWatchdogPolicy {
127    check_interval_secs: i64,
128    max_total_secs: i64,
129    max_idle_secs: i64,
130}
131
132impl Default for ChildWatchdogPolicy {
133    fn default() -> Self {
134        Self {
135            check_interval_secs: 15,
136            // Parent waits may be longer, but child execution owns its own
137            // liveness. A one hour total cap avoids indefinitely orphaned
138            // sub-session runners.
139            max_total_secs: 60 * 60,
140            // No child event for 15 minutes is considered stalled.
141            max_idle_secs: 15 * 60,
142        }
143    }
144}
145
146fn metadata_i64(session: &Session, key: &str) -> Option<i64> {
147    session
148        .metadata
149        .get(key)
150        .and_then(|value| value.trim().parse::<i64>().ok())
151        .filter(|value| *value > 0)
152}
153
154fn watchdog_policy_for_session(session: &Session) -> ChildWatchdogPolicy {
155    let mut policy = ChildWatchdogPolicy::default();
156    if let Some(value) = metadata_i64(session, "child_watchdog.max_total_secs") {
157        policy.max_total_secs = value;
158    }
159    if let Some(value) = metadata_i64(session, "child_watchdog.max_idle_secs") {
160        policy.max_idle_secs = value;
161    }
162    if let Some(value) = metadata_i64(session, "child_watchdog.check_interval_secs") {
163        policy.check_interval_secs = value;
164    }
165    policy
166}
167
168async fn publish_child_completion(
169    parent_tx: &broadcast::Sender<AgentEvent>,
170    completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
171    completion: ChildCompletion,
172) {
173    let _ = parent_tx.send(AgentEvent::SubAgentCompleted {
174        parent_session_id: completion.parent_session_id.clone(),
175        child_session_id: completion.child_session_id.clone(),
176        status: completion.status.clone(),
177        error: completion.error.clone(),
178    });
179
180    if let Some(handler) = completion_handler {
181        handler.on_child_completed(completion).await;
182    }
183}
184
185async fn publish_child_completion_parts(
186    parent_tx: &broadcast::Sender<AgentEvent>,
187    completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
188    parent_session_id: String,
189    child_session_id: String,
190    status: String,
191    error: Option<String>,
192) {
193    publish_child_completion(
194        parent_tx,
195        completion_handler,
196        ChildCompletion {
197            parent_session_id,
198            child_session_id,
199            status,
200            error,
201            completed_at: Utc::now(),
202        },
203    )
204    .await;
205}
206
207async fn watch_child_liveness(
208    parent_session_id: String,
209    child_session_id: String,
210    runners: Arc<RwLock<HashMap<String, AgentRunner>>>,
211    cancel_token: CancellationToken,
212    timeout_reason: Arc<RwLock<Option<String>>>,
213    done: CancellationToken,
214    policy: ChildWatchdogPolicy,
215) {
216    let mut ticker =
217        tokio::time::interval(Duration::from_secs(policy.check_interval_secs.max(1) as u64));
218    // Skip the immediate tick.
219    ticker.tick().await;
220
221    loop {
222        tokio::select! {
223            _ = done.cancelled() => return,
224            _ = ticker.tick() => {
225                if cancel_token.is_cancelled() {
226                    return;
227                }
228
229                let snapshot = {
230                    let guard = runners.read().await;
231                    guard.get(&child_session_id).cloned()
232                };
233                let Some(runner) = snapshot else {
234                    return;
235                };
236                if !matches!(runner.status, super::runner_state::AgentStatus::Running) {
237                    return;
238                }
239
240                let now = Utc::now();
241                let total_secs = now.signed_duration_since(runner.started_at).num_seconds();
242                if total_secs >= policy.max_total_secs {
243                    let reason = format!(
244                        "Child session timed out after {} seconds (max_total_secs={})",
245                        total_secs, policy.max_total_secs
246                    );
247                    tracing::warn!(
248                        parent_session_id = %parent_session_id,
249                        child_session_id = %child_session_id,
250                        reason = %reason,
251                        "child session total timeout; cancelling child runner"
252                    );
253                    *timeout_reason.write().await = Some(reason);
254                    cancel_token.cancel();
255                    return;
256                }
257
258                let last_activity_at = runner.last_event_at.unwrap_or(runner.started_at);
259                let idle_secs = now.signed_duration_since(last_activity_at).num_seconds();
260                if idle_secs >= policy.max_idle_secs {
261                    let reason = format!(
262                        "Child session idle timeout after {} seconds without events (max_idle_secs={})",
263                        idle_secs, policy.max_idle_secs
264                    );
265                    tracing::warn!(
266                        parent_session_id = %parent_session_id,
267                        child_session_id = %child_session_id,
268                        reason = %reason,
269                        last_tool_name = ?runner.last_tool_name,
270                        last_tool_phase = ?runner.last_tool_phase,
271                        round_count = runner.round_count,
272                        "child session idle timeout; cancelling child runner"
273                    );
274                    *timeout_reason.write().await = Some(reason);
275                    cancel_token.cancel();
276                    return;
277                }
278            }
279        }
280    }
281}
282
283fn resolve_child_provider_override(
284    router: Option<&Arc<ProviderModelRouter>>,
285    session: &Session,
286    model: &str,
287) -> (Option<Arc<dyn LLMProvider>>, Option<String>) {
288    let model_ref = child_model_ref(session, model);
289    let provider_name = model_ref
290        .as_ref()
291        .map(|model_ref| model_ref.provider.clone());
292    let provider = router.and_then(|router| {
293        let model_ref = model_ref.as_ref()?;
294        match router.route(model_ref) {
295            Ok(provider) => Some(provider),
296            Err(error) => {
297                tracing::warn!(
298                    session_id = %session.id,
299                    provider = %model_ref.provider,
300                    model = %model_ref.model,
301                    error = %error,
302                    "failed to resolve provider override for child session; falling back to runtime provider"
303                );
304                None
305            }
306        }
307    });
308    (provider, provider_name)
309}
310
311async fn run_spawn_job(ctx: SpawnContext, job: SpawnJob) -> Result<(), String> {
312    // Ensure both session event streams exist.
313    let parent_tx =
314        get_or_create_event_sender(&ctx.session_event_senders, &job.parent_session_id).await;
315    let child_tx =
316        get_or_create_event_sender(&ctx.session_event_senders, &job.child_session_id).await;
317
318    // Load child session.
319    let mut session = match ctx
320        .agent
321        .storage()
322        .load_session(&job.child_session_id)
323        .await
324    {
325        Ok(Some(s)) => s,
326        Ok(None) => {
327            let error = "child session not found".to_string();
328            publish_child_completion_parts(
329                &parent_tx,
330                ctx.completion_handler.clone(),
331                job.parent_session_id.clone(),
332                job.child_session_id.clone(),
333                "error".to_string(),
334                Some(error.clone()),
335            )
336            .await;
337            return Err(error);
338        }
339        Err(e) => {
340            let error = format!("failed to load child session: {e}");
341            publish_child_completion_parts(
342                &parent_tx,
343                ctx.completion_handler.clone(),
344                job.parent_session_id.clone(),
345                job.child_session_id.clone(),
346                "error".to_string(),
347                Some(error.clone()),
348            )
349            .await;
350            return Err(error);
351        }
352    };
353
354    // Register the child's workspace in the global state so tools
355    // (Read, Glob, Grep, Bash, etc.) can resolve relative paths.
356    if let Some(ref ws) = session.workspace {
357        bamboo_agent_core::workspace_state::set_workspace(
358            &session.id,
359            std::path::PathBuf::from(ws),
360        );
361    }
362
363    if session.kind != SessionKind::Child {
364        let error = "spawn job child session is not kind=child".to_string();
365        publish_child_completion_parts(
366            &parent_tx,
367            ctx.completion_handler.clone(),
368            job.parent_session_id.clone(),
369            job.child_session_id.clone(),
370            "error".to_string(),
371            Some(error.clone()),
372        )
373        .await;
374        return Err(error);
375    }
376
377    // Ensure last message is user (otherwise nothing to do).
378    let last_is_user = session
379        .messages
380        .last()
381        .map(|m| matches!(m.role, Role::User))
382        .unwrap_or(false);
383    if !last_is_user {
384        session
385            .metadata
386            .insert("last_run_status".to_string(), "skipped".to_string());
387        session.metadata.insert(
388            "last_run_error".to_string(),
389            "No pending message to execute".to_string(),
390        );
391        let _ = ctx
392            .agent
393            .persistence()
394            .save_runtime_session(&mut session)
395            .await;
396        {
397            let mut sessions = ctx.sessions_cache.write().await;
398            sessions.insert(job.child_session_id.clone(), session);
399        }
400        publish_child_completion_parts(
401            &parent_tx,
402            ctx.completion_handler.clone(),
403            job.parent_session_id.clone(),
404            job.child_session_id.clone(),
405            "skipped".to_string(),
406            Some("No pending message to execute".to_string()),
407        )
408        .await;
409        return Ok(());
410    }
411
412    // Persist a running marker early so list_sessions can reconstruct status.
413    session
414        .metadata
415        .insert("last_run_status".to_string(), "running".to_string());
416    session.metadata.remove("last_run_error");
417    let _ = ctx
418        .agent
419        .persistence()
420        .save_runtime_session(&mut session)
421        .await;
422
423    // Insert runner status.
424    let Some(RunnerReservation { cancel_token, .. }) =
425        try_reserve_runner(&ctx.agent_runners, &job.child_session_id, &child_tx).await
426    else {
427        return Ok(());
428    };
429
430    // Forward ALL child events to parent.
431    let forwarder_done = CancellationToken::new();
432    {
433        let mut rx = child_tx.subscribe();
434        let parent_tx = parent_tx.clone();
435        let job_clone = job.clone();
436        let done = forwarder_done.clone();
437        tokio::spawn(async move {
438            loop {
439                tokio::select! {
440                    _ = done.cancelled() => break,
441                    evt = rx.recv() => {
442                        match evt {
443                            Ok(event) => {
444                                let _ = parent_tx.send(AgentEvent::SubAgentEvent {
445                                    parent_session_id: job_clone.parent_session_id.clone(),
446                                    child_session_id: job_clone.child_session_id.clone(),
447                                    event: Box::new(event),
448                                });
449                            }
450                            Err(broadcast::error::RecvError::Lagged(_)) => {
451                                continue;
452                            }
453                            Err(_) => break,
454                        }
455                    }
456                }
457            }
458        });
459    }
460    {
461        let parent_tx = parent_tx.clone();
462        let job_clone = job.clone();
463        let done = forwarder_done.clone();
464        tokio::spawn(async move {
465            let mut ticker = tokio::time::interval(Duration::from_secs(5));
466            loop {
467                tokio::select! {
468                    _ = done.cancelled() => break,
469                    _ = ticker.tick() => {
470                        let _ = parent_tx.send(AgentEvent::SubAgentHeartbeat {
471                            parent_session_id: job_clone.parent_session_id.clone(),
472                            child_session_id: job_clone.child_session_id.clone(),
473                            timestamp: Utc::now(),
474                        });
475                    }
476                }
477            }
478        });
479    }
480
481    // Create mpsc channel for agent loop → session events sender.
482    let (mpsc_tx, _forwarder_handle) = create_event_forwarder(
483        job.child_session_id.clone(),
484        child_tx.clone(),
485        ctx.agent_runners.clone(),
486    );
487
488    // Child liveness is owned by the child runner. The parent wait state can
489    // have a longer lease, but it should not poll or terminate children.
490    let timeout_reason = Arc::new(RwLock::new(None::<String>));
491    let watchdog_policy = watchdog_policy_for_session(&session);
492    tokio::spawn(watch_child_liveness(
493        job.parent_session_id.clone(),
494        job.child_session_id.clone(),
495        ctx.agent_runners.clone(),
496        cancel_token.clone(),
497        timeout_reason.clone(),
498        forwarder_done.clone(),
499        watchdog_policy,
500    ));
501
502    // Run child loop via unified spawn_session_execution.
503    let model = job.model.clone();
504    let session_id_clone = job.child_session_id.clone();
505    let agent_runners_for_status = ctx.agent_runners.clone();
506    let sessions_cache = ctx.sessions_cache.clone();
507    let agent = ctx.agent.clone();
508    let tools = ctx.tools.clone();
509    let external_runner = ctx.external_child_runner.clone();
510    let done = forwarder_done.clone();
511    let parent_tx_for_done = parent_tx.clone();
512    let parent_id_for_done = job.parent_session_id.clone();
513    let child_id_for_done = job.child_session_id.clone();
514    let session_event_senders = ctx.session_event_senders.clone();
515    let provider_router = ctx.provider_router.clone();
516    let completion_handler = ctx.completion_handler.clone();
517
518    tokio::spawn(async move {
519        session.model = model.clone();
520
521        let wants_external = session
522            .metadata
523            .get("runtime.kind")
524            .is_some_and(|v| v == "external");
525
526        let result: crate::runtime::runner::Result<()> = if wants_external {
527            if let Some(runner) = external_runner {
528                if runner.should_handle(&session).await {
529                    runner
530                        .execute_external_child(&mut session, &job, mpsc_tx, cancel_token.clone())
531                        .await
532                } else {
533                    Err(bamboo_agent_core::AgentError::LLM(format!(
534                        "No external runner matched child session runtime metadata: agent_id={:?}, protocol={:?}",
535                        session.metadata.get("external.agent_id"),
536                        session.metadata.get("external.protocol"),
537                    )))
538                }
539            } else {
540                Err(bamboo_agent_core::AgentError::LLM(
541                    "Child session requires external runtime, but no external runner is configured"
542                        .to_string(),
543                ))
544            }
545        } else {
546            let (provider_override, provider_name) =
547                resolve_child_provider_override(provider_router.as_ref(), &session, &model);
548            let disabled_tools: Option<std::collections::BTreeSet<String>> =
549                job.disabled_tools.map(|v| v.into_iter().collect());
550            agent
551                .execute(
552                    &mut session,
553                    ExecuteRequest {
554                        initial_message: String::new(), // handled by agent loop
555                        event_tx: mpsc_tx,
556                        cancel_token: cancel_token.clone(),
557                        tools: Some(tools),
558                        provider_override,
559                        model: Some(model.clone()),
560                        provider_name,
561                        background_model: None,
562                        background_model_provider: None,
563                        reasoning_effort: None,
564                        disabled_tools,
565                        disabled_skill_ids: None,
566                        selected_skill_ids: None,
567                        selected_skill_mode: None,
568                        image_fallback: None,
569                        app_data_dir: ctx.app_data_dir.clone(),
570                    },
571                )
572                .await
573        };
574
575        let timeout_error = timeout_reason.read().await.clone();
576        let (status, error) = if let Some(reason) = timeout_error {
577            ("timeout".to_string(), Some(reason))
578        } else {
579            match &result {
580                Ok(_) => ("completed".to_string(), None),
581                Err(e) if e.to_string().contains("cancelled") => {
582                    ("cancelled".to_string(), Some(e.to_string()))
583                }
584                Err(e) => ("error".to_string(), Some(e.to_string())),
585            }
586        };
587
588        finalize_runner(&agent_runners_for_status, &session_id_clone, &result).await;
589
590        // Merge any queued injected messages that the pipeline didn't pick up
591        // (e.g. if the loop exited before the next turn boundary).
592        crate::runtime::runner::state_bridge::merge_pending_injected_messages(
593            &mut session,
594            Some(agent.storage()),
595            Some(agent.persistence()),
596        )
597        .await;
598
599        // Persist final session snapshot.
600        session
601            .metadata
602            .insert("last_run_status".to_string(), status.clone());
603        if let Some(err) = &error {
604            session
605                .metadata
606                .insert("last_run_error".to_string(), err.clone());
607        } else {
608            session.metadata.remove("last_run_error");
609        }
610        let _ = agent.persistence().save_runtime_session(&mut session).await;
611        {
612            let mut sessions = sessions_cache.write().await;
613            sessions.insert(session_id_clone.clone(), session);
614        }
615
616        // Stop forwarding/heartbeats and emit terminal child status through the
617        // same durable completion path used by success/error/cancel/timeout.
618        done.cancel();
619        publish_child_completion_parts(
620            &parent_tx_for_done,
621            completion_handler,
622            parent_id_for_done,
623            child_id_for_done,
624            status,
625            error,
626        )
627        .await;
628
629        // Allow dead code: session_event_senders keeps the sender alive during the task.
630        drop(session_event_senders);
631    });
632
633    Ok(())
634}