Skip to main content

bamboo_engine/runtime/execution/
spawn.rs

1//! Sub-session spawn scheduler.
2//!
3//! Provides a background queue for spawning child sessions. Spawn is async
4//! (tool returns immediately), but the UI can observe child progress via
5//! events forwarded to the parent session stream.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, mpsc, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::tools::ToolExecutor;
16use bamboo_agent_core::{AgentEvent, Role, Session, SessionKind};
17use bamboo_domain::ProviderModelRef;
18use bamboo_infrastructure::{LLMProvider, ProviderModelRouter};
19
20use crate::runtime::Agent;
21use crate::runtime::ExecuteRequest;
22
23use super::child_completion::{ChildCompletion, ChildCompletionHandler};
24use super::event_forwarder::create_event_forwarder;
25use super::runner_lifecycle::{finalize_runner, try_reserve_runner, RunnerReservation};
26use super::runner_state::AgentRunner;
27use super::session_events::get_or_create_event_sender;
28
29#[derive(Debug, Clone)]
30pub struct SpawnJob {
31    pub parent_session_id: String,
32    pub child_session_id: String,
33    pub model: String,
34}
35
36/// Trait for external child session runtimes (e.g. A2A, CLI adapters).
37///
38/// Implementors are responsible for emitting AgentEvents via `event_tx`
39/// and respecting the `cancel_token`.
40#[async_trait::async_trait]
41pub trait ExternalChildRunner: Send + Sync {
42    /// Returns true if this runner should handle the given child session.
43    async fn should_handle(&self, session: &Session) -> bool;
44
45    /// Execute the child session using an external runtime.
46    async fn execute_external_child(
47        &self,
48        session: &mut Session,
49        job: &SpawnJob,
50        event_tx: tokio::sync::mpsc::Sender<AgentEvent>,
51        cancel_token: CancellationToken,
52    ) -> crate::runtime::runner::Result<()>;
53}
54
55#[derive(Clone)]
56pub struct SpawnContext {
57    pub agent: Arc<Agent>,
58    pub tools: Arc<dyn ToolExecutor>,
59    pub sessions_cache: Arc<RwLock<HashMap<String, Session>>>,
60    pub agent_runners: Arc<RwLock<HashMap<String, AgentRunner>>>,
61    pub session_event_senders: Arc<RwLock<HashMap<String, broadcast::Sender<AgentEvent>>>>,
62    pub external_child_runner: Option<Arc<dyn ExternalChildRunner>>,
63    pub provider_router: Option<Arc<ProviderModelRouter>>,
64    /// Optional application-layer completion hook. The engine still emits
65    /// `SubSessionCompleted` to the parent stream itself; this hook lets the
66    /// server persist parent wait state and resume the parent runner without
67    /// introducing an engine -> AppState dependency.
68    pub completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
69}
70
71#[derive(Clone)]
72pub struct SpawnScheduler {
73    tx: mpsc::Sender<SpawnJob>,
74}
75
76impl SpawnScheduler {
77    pub fn new(ctx: SpawnContext) -> Self {
78        let (tx, mut rx) = mpsc::channel::<SpawnJob>(128);
79
80        tokio::spawn(async move {
81            while let Some(job) = rx.recv().await {
82                if let Err(err) = run_spawn_job(ctx.clone(), job).await {
83                    tracing::warn!("spawn job failed: {}", err);
84                }
85            }
86        });
87
88        Self { tx }
89    }
90
91    pub async fn enqueue(&self, job: SpawnJob) -> Result<(), String> {
92        self.tx
93            .send(job)
94            .await
95            .map_err(|_| "spawn scheduler is not running".to_string())
96    }
97}
98
99fn child_model_ref(session: &Session, model: &str) -> Option<ProviderModelRef> {
100    if let Some(model_ref) = session.model_ref.clone() {
101        let provider = model_ref.provider.trim();
102        let model_name = model_ref.model.trim();
103        if !provider.is_empty() && !model_name.is_empty() {
104            return Some(ProviderModelRef::new(provider, model_name));
105        }
106    }
107
108    let provider = session
109        .metadata
110        .get("provider_name")
111        .map(String::as_str)
112        .map(str::trim)
113        .filter(|value| !value.is_empty())?;
114    let model_name = model.trim();
115    if model_name.is_empty() {
116        return None;
117    }
118    Some(ProviderModelRef::new(provider, model_name))
119}
120
121#[derive(Debug, Clone, Copy)]
122struct ChildWatchdogPolicy {
123    check_interval_secs: i64,
124    max_total_secs: i64,
125    max_idle_secs: i64,
126}
127
128impl Default for ChildWatchdogPolicy {
129    fn default() -> Self {
130        Self {
131            check_interval_secs: 15,
132            // Parent waits may be longer, but child execution owns its own
133            // liveness. A one hour total cap avoids indefinitely orphaned
134            // sub-session runners.
135            max_total_secs: 60 * 60,
136            // No child event for 15 minutes is considered stalled.
137            max_idle_secs: 15 * 60,
138        }
139    }
140}
141
142fn metadata_i64(session: &Session, key: &str) -> Option<i64> {
143    session
144        .metadata
145        .get(key)
146        .and_then(|value| value.trim().parse::<i64>().ok())
147        .filter(|value| *value > 0)
148}
149
150fn watchdog_policy_for_session(session: &Session) -> ChildWatchdogPolicy {
151    let mut policy = ChildWatchdogPolicy::default();
152    if let Some(value) = metadata_i64(session, "child_watchdog.max_total_secs") {
153        policy.max_total_secs = value;
154    }
155    if let Some(value) = metadata_i64(session, "child_watchdog.max_idle_secs") {
156        policy.max_idle_secs = value;
157    }
158    if let Some(value) = metadata_i64(session, "child_watchdog.check_interval_secs") {
159        policy.check_interval_secs = value;
160    }
161    policy
162}
163
164async fn publish_child_completion(
165    parent_tx: &broadcast::Sender<AgentEvent>,
166    completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
167    completion: ChildCompletion,
168) {
169    let _ = parent_tx.send(AgentEvent::SubSessionCompleted {
170        parent_session_id: completion.parent_session_id.clone(),
171        child_session_id: completion.child_session_id.clone(),
172        status: completion.status.clone(),
173        error: completion.error.clone(),
174    });
175
176    if let Some(handler) = completion_handler {
177        handler.on_child_completed(completion).await;
178    }
179}
180
181async fn publish_child_completion_parts(
182    parent_tx: &broadcast::Sender<AgentEvent>,
183    completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
184    parent_session_id: String,
185    child_session_id: String,
186    status: String,
187    error: Option<String>,
188) {
189    publish_child_completion(
190        parent_tx,
191        completion_handler,
192        ChildCompletion {
193            parent_session_id,
194            child_session_id,
195            status,
196            error,
197            completed_at: Utc::now(),
198        },
199    )
200    .await;
201}
202
203async fn watch_child_liveness(
204    parent_session_id: String,
205    child_session_id: String,
206    runners: Arc<RwLock<HashMap<String, AgentRunner>>>,
207    cancel_token: CancellationToken,
208    timeout_reason: Arc<RwLock<Option<String>>>,
209    done: CancellationToken,
210    policy: ChildWatchdogPolicy,
211) {
212    let mut ticker =
213        tokio::time::interval(Duration::from_secs(policy.check_interval_secs.max(1) as u64));
214    // Skip the immediate tick.
215    ticker.tick().await;
216
217    loop {
218        tokio::select! {
219            _ = done.cancelled() => return,
220            _ = ticker.tick() => {
221                if cancel_token.is_cancelled() {
222                    return;
223                }
224
225                let snapshot = {
226                    let guard = runners.read().await;
227                    guard.get(&child_session_id).cloned()
228                };
229                let Some(runner) = snapshot else {
230                    return;
231                };
232                if !matches!(runner.status, super::runner_state::AgentStatus::Running) {
233                    return;
234                }
235
236                let now = Utc::now();
237                let total_secs = now.signed_duration_since(runner.started_at).num_seconds();
238                if total_secs >= policy.max_total_secs {
239                    let reason = format!(
240                        "Child session timed out after {} seconds (max_total_secs={})",
241                        total_secs, policy.max_total_secs
242                    );
243                    tracing::warn!(
244                        parent_session_id = %parent_session_id,
245                        child_session_id = %child_session_id,
246                        reason = %reason,
247                        "child session total timeout; cancelling child runner"
248                    );
249                    *timeout_reason.write().await = Some(reason);
250                    cancel_token.cancel();
251                    return;
252                }
253
254                let last_activity_at = runner.last_event_at.unwrap_or(runner.started_at);
255                let idle_secs = now.signed_duration_since(last_activity_at).num_seconds();
256                if idle_secs >= policy.max_idle_secs {
257                    let reason = format!(
258                        "Child session idle timeout after {} seconds without events (max_idle_secs={})",
259                        idle_secs, policy.max_idle_secs
260                    );
261                    tracing::warn!(
262                        parent_session_id = %parent_session_id,
263                        child_session_id = %child_session_id,
264                        reason = %reason,
265                        last_tool_name = ?runner.last_tool_name,
266                        last_tool_phase = ?runner.last_tool_phase,
267                        round_count = runner.round_count,
268                        "child session idle timeout; cancelling child runner"
269                    );
270                    *timeout_reason.write().await = Some(reason);
271                    cancel_token.cancel();
272                    return;
273                }
274            }
275        }
276    }
277}
278
279fn resolve_child_provider_override(
280    router: Option<&Arc<ProviderModelRouter>>,
281    session: &Session,
282    model: &str,
283) -> (Option<Arc<dyn LLMProvider>>, Option<String>) {
284    let model_ref = child_model_ref(session, model);
285    let provider_name = model_ref
286        .as_ref()
287        .map(|model_ref| model_ref.provider.clone());
288    let provider = router.and_then(|router| {
289        let model_ref = model_ref.as_ref()?;
290        match router.route(model_ref) {
291            Ok(provider) => Some(provider),
292            Err(error) => {
293                tracing::warn!(
294                    session_id = %session.id,
295                    provider = %model_ref.provider,
296                    model = %model_ref.model,
297                    error = %error,
298                    "failed to resolve provider override for child session; falling back to runtime provider"
299                );
300                None
301            }
302        }
303    });
304    (provider, provider_name)
305}
306
307async fn run_spawn_job(ctx: SpawnContext, job: SpawnJob) -> Result<(), String> {
308    // Ensure both session event streams exist.
309    let parent_tx =
310        get_or_create_event_sender(&ctx.session_event_senders, &job.parent_session_id).await;
311    let child_tx =
312        get_or_create_event_sender(&ctx.session_event_senders, &job.child_session_id).await;
313
314    // Load child session.
315    let mut session = match ctx
316        .agent
317        .storage()
318        .load_session(&job.child_session_id)
319        .await
320    {
321        Ok(Some(s)) => s,
322        Ok(None) => {
323            let error = "child session not found".to_string();
324            publish_child_completion_parts(
325                &parent_tx,
326                ctx.completion_handler.clone(),
327                job.parent_session_id.clone(),
328                job.child_session_id.clone(),
329                "error".to_string(),
330                Some(error.clone()),
331            )
332            .await;
333            return Err(error);
334        }
335        Err(e) => {
336            let error = format!("failed to load child session: {e}");
337            publish_child_completion_parts(
338                &parent_tx,
339                ctx.completion_handler.clone(),
340                job.parent_session_id.clone(),
341                job.child_session_id.clone(),
342                "error".to_string(),
343                Some(error.clone()),
344            )
345            .await;
346            return Err(error);
347        }
348    };
349
350    if session.kind != SessionKind::Child {
351        let error = "spawn job child session is not kind=child".to_string();
352        publish_child_completion_parts(
353            &parent_tx,
354            ctx.completion_handler.clone(),
355            job.parent_session_id.clone(),
356            job.child_session_id.clone(),
357            "error".to_string(),
358            Some(error.clone()),
359        )
360        .await;
361        return Err(error);
362    }
363
364    // Ensure last message is user (otherwise nothing to do).
365    let last_is_user = session
366        .messages
367        .last()
368        .map(|m| matches!(m.role, Role::User))
369        .unwrap_or(false);
370    if !last_is_user {
371        session
372            .metadata
373            .insert("last_run_status".to_string(), "skipped".to_string());
374        session.metadata.insert(
375            "last_run_error".to_string(),
376            "No pending message to execute".to_string(),
377        );
378        let _ = ctx.agent.storage().save_session(&session).await;
379        {
380            let mut sessions = ctx.sessions_cache.write().await;
381            sessions.insert(job.child_session_id.clone(), session);
382        }
383        publish_child_completion_parts(
384            &parent_tx,
385            ctx.completion_handler.clone(),
386            job.parent_session_id.clone(),
387            job.child_session_id.clone(),
388            "skipped".to_string(),
389            Some("No pending message to execute".to_string()),
390        )
391        .await;
392        return Ok(());
393    }
394
395    // Persist a running marker early so list_sessions can reconstruct status.
396    session
397        .metadata
398        .insert("last_run_status".to_string(), "running".to_string());
399    session.metadata.remove("last_run_error");
400    let _ = ctx.agent.storage().save_session(&session).await;
401
402    // Insert runner status.
403    let Some(RunnerReservation { cancel_token, .. }) =
404        try_reserve_runner(&ctx.agent_runners, &job.child_session_id, &child_tx).await
405    else {
406        return Ok(());
407    };
408
409    // Forward ALL child events to parent.
410    let forwarder_done = CancellationToken::new();
411    {
412        let mut rx = child_tx.subscribe();
413        let parent_tx = parent_tx.clone();
414        let job_clone = job.clone();
415        let done = forwarder_done.clone();
416        tokio::spawn(async move {
417            loop {
418                tokio::select! {
419                    _ = done.cancelled() => break,
420                    evt = rx.recv() => {
421                        match evt {
422                            Ok(event) => {
423                                let _ = parent_tx.send(AgentEvent::SubSessionEvent {
424                                    parent_session_id: job_clone.parent_session_id.clone(),
425                                    child_session_id: job_clone.child_session_id.clone(),
426                                    event: Box::new(event),
427                                });
428                            }
429                            Err(broadcast::error::RecvError::Lagged(_)) => {
430                                continue;
431                            }
432                            Err(_) => break,
433                        }
434                    }
435                }
436            }
437        });
438    }
439    {
440        let parent_tx = parent_tx.clone();
441        let job_clone = job.clone();
442        let done = forwarder_done.clone();
443        tokio::spawn(async move {
444            let mut ticker = tokio::time::interval(Duration::from_secs(5));
445            loop {
446                tokio::select! {
447                    _ = done.cancelled() => break,
448                    _ = ticker.tick() => {
449                        let _ = parent_tx.send(AgentEvent::SubSessionHeartbeat {
450                            parent_session_id: job_clone.parent_session_id.clone(),
451                            child_session_id: job_clone.child_session_id.clone(),
452                            timestamp: Utc::now(),
453                        });
454                    }
455                }
456            }
457        });
458    }
459
460    // Create mpsc channel for agent loop → session events sender.
461    let (mpsc_tx, _forwarder_handle) = create_event_forwarder(
462        job.child_session_id.clone(),
463        child_tx.clone(),
464        ctx.agent_runners.clone(),
465    );
466
467    // Child liveness is owned by the child runner. The parent wait state can
468    // have a longer lease, but it should not poll or terminate children.
469    let timeout_reason = Arc::new(RwLock::new(None::<String>));
470    let watchdog_policy = watchdog_policy_for_session(&session);
471    tokio::spawn(watch_child_liveness(
472        job.parent_session_id.clone(),
473        job.child_session_id.clone(),
474        ctx.agent_runners.clone(),
475        cancel_token.clone(),
476        timeout_reason.clone(),
477        forwarder_done.clone(),
478        watchdog_policy,
479    ));
480
481    // Run child loop via unified spawn_session_execution.
482    let model = job.model.clone();
483    let session_id_clone = job.child_session_id.clone();
484    let agent_runners_for_status = ctx.agent_runners.clone();
485    let sessions_cache = ctx.sessions_cache.clone();
486    let agent = ctx.agent.clone();
487    let tools = ctx.tools.clone();
488    let external_runner = ctx.external_child_runner.clone();
489    let done = forwarder_done.clone();
490    let parent_tx_for_done = parent_tx.clone();
491    let parent_id_for_done = job.parent_session_id.clone();
492    let child_id_for_done = job.child_session_id.clone();
493    let session_event_senders = ctx.session_event_senders.clone();
494    let provider_router = ctx.provider_router.clone();
495    let completion_handler = ctx.completion_handler.clone();
496
497    tokio::spawn(async move {
498        session.model = model.clone();
499
500        let wants_external = session
501            .metadata
502            .get("runtime.kind")
503            .is_some_and(|v| v == "external");
504
505        let result: crate::runtime::runner::Result<()> = if wants_external {
506            if let Some(runner) = external_runner {
507                if runner.should_handle(&session).await {
508                    runner
509                        .execute_external_child(&mut session, &job, mpsc_tx, cancel_token.clone())
510                        .await
511                } else {
512                    Err(bamboo_agent_core::AgentError::LLM(format!(
513                        "No external runner matched child session runtime metadata: agent_id={:?}, protocol={:?}",
514                        session.metadata.get("external.agent_id"),
515                        session.metadata.get("external.protocol"),
516                    )))
517                }
518            } else {
519                Err(bamboo_agent_core::AgentError::LLM(
520                    "Child session requires external runtime, but no external runner is configured"
521                        .to_string(),
522                ))
523            }
524        } else {
525            let (provider_override, provider_name) =
526                resolve_child_provider_override(provider_router.as_ref(), &session, &model);
527            agent
528                .execute(
529                    &mut session,
530                    ExecuteRequest {
531                        initial_message: String::new(), // handled by agent loop
532                        event_tx: mpsc_tx,
533                        cancel_token: cancel_token.clone(),
534                        tools: Some(tools),
535                        provider_override,
536                        model: Some(model.clone()),
537                        provider_name,
538                        background_model: None,
539                        background_model_provider: None,
540                        reasoning_effort: None,
541                        disabled_tools: None,
542                        disabled_skill_ids: None,
543                        selected_skill_ids: None,
544                        selected_skill_mode: None,
545                        image_fallback: None,
546                    },
547                )
548                .await
549        };
550
551        let timeout_error = timeout_reason.read().await.clone();
552        let (status, error) = if let Some(reason) = timeout_error {
553            ("timeout".to_string(), Some(reason))
554        } else {
555            match &result {
556                Ok(_) => ("completed".to_string(), None),
557                Err(e) if e.to_string().contains("cancelled") => {
558                    ("cancelled".to_string(), Some(e.to_string()))
559                }
560                Err(e) => ("error".to_string(), Some(e.to_string())),
561            }
562        };
563
564        finalize_runner(&agent_runners_for_status, &session_id_clone, &result).await;
565
566        // Merge any queued injected messages that the pipeline didn't pick up
567        // (e.g. if the loop exited before the next turn boundary).
568        if let Ok(Some(latest)) = agent.storage().load_session(&session_id_clone).await {
569            if let Some(raw) = latest.metadata.get("pending_injected_messages") {
570                if let Ok(messages) = serde_json::from_str::<Vec<serde_json::Value>>(raw) {
571                    for msg in messages {
572                        if let Some(content) = msg.get("content").and_then(|v| v.as_str()) {
573                            session
574                                .add_message(bamboo_agent_core::Message::user(content.to_string()));
575                        }
576                    }
577                    session.metadata.remove("pending_injected_messages");
578                }
579            }
580        }
581
582        // Persist final session snapshot.
583        session
584            .metadata
585            .insert("last_run_status".to_string(), status.clone());
586        if let Some(err) = &error {
587            session
588                .metadata
589                .insert("last_run_error".to_string(), err.clone());
590        } else {
591            session.metadata.remove("last_run_error");
592        }
593        let _ = agent.storage().save_session(&session).await;
594        {
595            let mut sessions = sessions_cache.write().await;
596            sessions.insert(session_id_clone.clone(), session);
597        }
598
599        // Stop forwarding/heartbeats and emit terminal child status through the
600        // same durable completion path used by success/error/cancel/timeout.
601        done.cancel();
602        publish_child_completion_parts(
603            &parent_tx_for_done,
604            completion_handler,
605            parent_id_for_done,
606            child_id_for_done,
607            status,
608            error,
609        )
610        .await;
611
612        // Allow dead code: session_event_senders keeps the sender alive during the task.
613        drop(session_event_senders);
614    });
615
616    Ok(())
617}