1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, mpsc, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::tools::ToolExecutor;
16use bamboo_agent_core::{AgentEvent, Role, Session, SessionKind};
17use bamboo_domain::ProviderModelRef;
18use bamboo_infrastructure::{LLMProvider, ProviderModelRouter};
19
20use crate::runtime::Agent;
21use crate::runtime::ExecuteRequest;
22
23use super::child_completion::{ChildCompletion, ChildCompletionHandler};
24use super::event_forwarder::create_event_forwarder;
25use super::runner_lifecycle::{finalize_runner, try_reserve_runner, RunnerReservation};
26use super::runner_state::AgentRunner;
27use super::session_events::get_or_create_event_sender;
28
29#[derive(Debug, Clone)]
30pub struct SpawnJob {
31 pub parent_session_id: String,
32 pub child_session_id: String,
33 pub model: String,
34 pub disabled_tools: Option<Vec<String>>,
37}
38
39#[async_trait::async_trait]
44pub trait ExternalChildRunner: Send + Sync {
45 async fn should_handle(&self, session: &Session) -> bool;
47
48 async fn execute_external_child(
50 &self,
51 session: &mut Session,
52 job: &SpawnJob,
53 event_tx: tokio::sync::mpsc::Sender<AgentEvent>,
54 cancel_token: CancellationToken,
55 ) -> crate::runtime::runner::Result<()>;
56}
57
58#[derive(Clone)]
59pub struct SpawnContext {
60 pub agent: Arc<Agent>,
61 pub tools: Arc<dyn ToolExecutor>,
62 pub sessions_cache: Arc<RwLock<HashMap<String, Session>>>,
63 pub agent_runners: Arc<RwLock<HashMap<String, AgentRunner>>>,
64 pub session_event_senders: Arc<RwLock<HashMap<String, broadcast::Sender<AgentEvent>>>>,
65 pub external_child_runner: Option<Arc<dyn ExternalChildRunner>>,
66 pub provider_router: Option<Arc<ProviderModelRouter>>,
67 pub app_data_dir: Option<std::path::PathBuf>,
68 pub completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
73}
74
75#[derive(Clone)]
76pub struct SpawnScheduler {
77 tx: mpsc::Sender<SpawnJob>,
78}
79
80impl SpawnScheduler {
81 pub fn new(ctx: SpawnContext) -> Self {
82 let (tx, mut rx) = mpsc::channel::<SpawnJob>(128);
83
84 tokio::spawn(async move {
85 while let Some(job) = rx.recv().await {
86 if let Err(err) = run_spawn_job(ctx.clone(), job).await {
87 tracing::warn!("spawn job failed: {}", err);
88 }
89 }
90 });
91
92 Self { tx }
93 }
94
95 pub async fn enqueue(&self, job: SpawnJob) -> Result<(), String> {
96 self.tx
97 .send(job)
98 .await
99 .map_err(|_| "spawn scheduler is not running".to_string())
100 }
101}
102
103fn child_model_ref(session: &Session, model: &str) -> Option<ProviderModelRef> {
104 if let Some(model_ref) = session.model_ref.clone() {
105 let provider = model_ref.provider.trim();
106 let model_name = model_ref.model.trim();
107 if !provider.is_empty() && !model_name.is_empty() {
108 return Some(ProviderModelRef::new(provider, model_name));
109 }
110 }
111
112 let provider = session
113 .metadata
114 .get("provider_name")
115 .map(String::as_str)
116 .map(str::trim)
117 .filter(|value| !value.is_empty())?;
118 let model_name = model.trim();
119 if model_name.is_empty() {
120 return None;
121 }
122 Some(ProviderModelRef::new(provider, model_name))
123}
124
125#[derive(Debug, Clone, Copy)]
126struct ChildWatchdogPolicy {
127 check_interval_secs: i64,
128 max_total_secs: i64,
129 max_idle_secs: i64,
130}
131
132impl Default for ChildWatchdogPolicy {
133 fn default() -> Self {
134 Self {
135 check_interval_secs: 15,
136 max_total_secs: 60 * 60,
140 max_idle_secs: 15 * 60,
142 }
143 }
144}
145
146fn metadata_i64(session: &Session, key: &str) -> Option<i64> {
147 session
148 .metadata
149 .get(key)
150 .and_then(|value| value.trim().parse::<i64>().ok())
151 .filter(|value| *value > 0)
152}
153
154fn watchdog_policy_for_session(session: &Session) -> ChildWatchdogPolicy {
155 let mut policy = ChildWatchdogPolicy::default();
156 if let Some(value) = metadata_i64(session, "child_watchdog.max_total_secs") {
157 policy.max_total_secs = value;
158 }
159 if let Some(value) = metadata_i64(session, "child_watchdog.max_idle_secs") {
160 policy.max_idle_secs = value;
161 }
162 if let Some(value) = metadata_i64(session, "child_watchdog.check_interval_secs") {
163 policy.check_interval_secs = value;
164 }
165 policy
166}
167
168async fn publish_child_completion(
169 parent_tx: &broadcast::Sender<AgentEvent>,
170 completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
171 completion: ChildCompletion,
172) {
173 let _ = parent_tx.send(AgentEvent::SubAgentCompleted {
174 parent_session_id: completion.parent_session_id.clone(),
175 child_session_id: completion.child_session_id.clone(),
176 status: completion.status.clone(),
177 error: completion.error.clone(),
178 });
179
180 if let Some(handler) = completion_handler {
181 handler.on_child_completed(completion).await;
182 }
183}
184
185async fn publish_child_completion_parts(
186 parent_tx: &broadcast::Sender<AgentEvent>,
187 completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
188 parent_session_id: String,
189 child_session_id: String,
190 status: String,
191 error: Option<String>,
192) {
193 publish_child_completion(
194 parent_tx,
195 completion_handler,
196 ChildCompletion {
197 parent_session_id,
198 child_session_id,
199 status,
200 error,
201 completed_at: Utc::now(),
202 },
203 )
204 .await;
205}
206
207async fn watch_child_liveness(
208 parent_session_id: String,
209 child_session_id: String,
210 runners: Arc<RwLock<HashMap<String, AgentRunner>>>,
211 cancel_token: CancellationToken,
212 timeout_reason: Arc<RwLock<Option<String>>>,
213 done: CancellationToken,
214 policy: ChildWatchdogPolicy,
215) {
216 let mut ticker =
217 tokio::time::interval(Duration::from_secs(policy.check_interval_secs.max(1) as u64));
218 ticker.tick().await;
220
221 loop {
222 tokio::select! {
223 _ = done.cancelled() => return,
224 _ = ticker.tick() => {
225 if cancel_token.is_cancelled() {
226 return;
227 }
228
229 let snapshot = {
230 let guard = runners.read().await;
231 guard.get(&child_session_id).cloned()
232 };
233 let Some(runner) = snapshot else {
234 return;
235 };
236 if !matches!(runner.status, super::runner_state::AgentStatus::Running) {
237 return;
238 }
239
240 let now = Utc::now();
241 let total_secs = now.signed_duration_since(runner.started_at).num_seconds();
242 if total_secs >= policy.max_total_secs {
243 let reason = format!(
244 "Child session timed out after {} seconds (max_total_secs={})",
245 total_secs, policy.max_total_secs
246 );
247 tracing::warn!(
248 parent_session_id = %parent_session_id,
249 child_session_id = %child_session_id,
250 reason = %reason,
251 "child session total timeout; cancelling child runner"
252 );
253 *timeout_reason.write().await = Some(reason);
254 cancel_token.cancel();
255 return;
256 }
257
258 let last_activity_at = runner.last_event_at.unwrap_or(runner.started_at);
259 let idle_secs = now.signed_duration_since(last_activity_at).num_seconds();
260 if idle_secs >= policy.max_idle_secs {
261 let reason = format!(
262 "Child session idle timeout after {} seconds without events (max_idle_secs={})",
263 idle_secs, policy.max_idle_secs
264 );
265 tracing::warn!(
266 parent_session_id = %parent_session_id,
267 child_session_id = %child_session_id,
268 reason = %reason,
269 last_tool_name = ?runner.last_tool_name,
270 last_tool_phase = ?runner.last_tool_phase,
271 round_count = runner.round_count,
272 "child session idle timeout; cancelling child runner"
273 );
274 *timeout_reason.write().await = Some(reason);
275 cancel_token.cancel();
276 return;
277 }
278 }
279 }
280 }
281}
282
283fn resolve_child_provider_override(
284 router: Option<&Arc<ProviderModelRouter>>,
285 session: &Session,
286 model: &str,
287) -> (Option<Arc<dyn LLMProvider>>, Option<String>, Option<String>) {
288 let model_ref = child_model_ref(session, model);
289 let provider_name = model_ref
290 .as_ref()
291 .map(|model_ref| model_ref.provider.clone());
292 let provider_type = if let (Some(router), Some(model_ref)) = (router, model_ref.as_ref()) {
293 router.provider_type_for(model_ref)
294 } else {
295 provider_name.clone()
296 };
297 let provider = router.and_then(|router| {
298 let model_ref = model_ref.as_ref()?;
299 match router.route(model_ref) {
300 Ok(provider) => Some(provider),
301 Err(error) => {
302 tracing::warn!(
303 session_id = %session.id,
304 provider = %model_ref.provider,
305 model = %model_ref.model,
306 error = %error,
307 "failed to resolve provider override for child session; falling back to runtime provider"
308 );
309 None
310 }
311 }
312 });
313 (provider, provider_name, provider_type)
314}
315
316async fn run_spawn_job(ctx: SpawnContext, job: SpawnJob) -> Result<(), String> {
317 let parent_tx =
319 get_or_create_event_sender(&ctx.session_event_senders, &job.parent_session_id).await;
320 let child_tx =
321 get_or_create_event_sender(&ctx.session_event_senders, &job.child_session_id).await;
322
323 let mut session = match ctx
325 .agent
326 .storage()
327 .load_session(&job.child_session_id)
328 .await
329 {
330 Ok(Some(s)) => s,
331 Ok(None) => {
332 let error = "child session not found".to_string();
333 publish_child_completion_parts(
334 &parent_tx,
335 ctx.completion_handler.clone(),
336 job.parent_session_id.clone(),
337 job.child_session_id.clone(),
338 "error".to_string(),
339 Some(error.clone()),
340 )
341 .await;
342 return Err(error);
343 }
344 Err(e) => {
345 let error = format!("failed to load child session: {e}");
346 publish_child_completion_parts(
347 &parent_tx,
348 ctx.completion_handler.clone(),
349 job.parent_session_id.clone(),
350 job.child_session_id.clone(),
351 "error".to_string(),
352 Some(error.clone()),
353 )
354 .await;
355 return Err(error);
356 }
357 };
358
359 if let Some(ref ws) = session.workspace {
362 bamboo_agent_core::workspace_state::set_workspace(
363 &session.id,
364 std::path::PathBuf::from(ws),
365 );
366 }
367
368 if session.kind != SessionKind::Child {
369 let error = "spawn job child session is not kind=child".to_string();
370 publish_child_completion_parts(
371 &parent_tx,
372 ctx.completion_handler.clone(),
373 job.parent_session_id.clone(),
374 job.child_session_id.clone(),
375 "error".to_string(),
376 Some(error.clone()),
377 )
378 .await;
379 return Err(error);
380 }
381
382 let last_is_user = session
384 .messages
385 .last()
386 .map(|m| matches!(m.role, Role::User))
387 .unwrap_or(false);
388 if !last_is_user {
389 session
390 .metadata
391 .insert("last_run_status".to_string(), "skipped".to_string());
392 session.metadata.insert(
393 "last_run_error".to_string(),
394 "No pending message to execute".to_string(),
395 );
396 let _ = ctx
397 .agent
398 .persistence()
399 .save_runtime_session(&mut session)
400 .await;
401 {
402 let mut sessions = ctx.sessions_cache.write().await;
403 sessions.insert(job.child_session_id.clone(), session);
404 }
405 publish_child_completion_parts(
406 &parent_tx,
407 ctx.completion_handler.clone(),
408 job.parent_session_id.clone(),
409 job.child_session_id.clone(),
410 "skipped".to_string(),
411 Some("No pending message to execute".to_string()),
412 )
413 .await;
414 return Ok(());
415 }
416
417 session
419 .metadata
420 .insert("last_run_status".to_string(), "running".to_string());
421 session.metadata.remove("last_run_error");
422 let _ = ctx
423 .agent
424 .persistence()
425 .save_runtime_session(&mut session)
426 .await;
427
428 let Some(RunnerReservation { cancel_token, .. }) =
430 try_reserve_runner(&ctx.agent_runners, &job.child_session_id, &child_tx).await
431 else {
432 return Ok(());
433 };
434
435 let forwarder_done = CancellationToken::new();
437 {
438 let mut rx = child_tx.subscribe();
439 let parent_tx = parent_tx.clone();
440 let job_clone = job.clone();
441 let done = forwarder_done.clone();
442 tokio::spawn(async move {
443 loop {
444 tokio::select! {
445 _ = done.cancelled() => break,
446 evt = rx.recv() => {
447 match evt {
448 Ok(event) => {
449 let _ = parent_tx.send(AgentEvent::SubAgentEvent {
450 parent_session_id: job_clone.parent_session_id.clone(),
451 child_session_id: job_clone.child_session_id.clone(),
452 event: Box::new(event),
453 });
454 }
455 Err(broadcast::error::RecvError::Lagged(_)) => {
456 continue;
457 }
458 Err(_) => break,
459 }
460 }
461 }
462 }
463 });
464 }
465 {
466 let parent_tx = parent_tx.clone();
467 let job_clone = job.clone();
468 let done = forwarder_done.clone();
469 tokio::spawn(async move {
470 let mut ticker = tokio::time::interval(Duration::from_secs(5));
471 loop {
472 tokio::select! {
473 _ = done.cancelled() => break,
474 _ = ticker.tick() => {
475 let _ = parent_tx.send(AgentEvent::SubAgentHeartbeat {
476 parent_session_id: job_clone.parent_session_id.clone(),
477 child_session_id: job_clone.child_session_id.clone(),
478 timestamp: Utc::now(),
479 });
480 }
481 }
482 }
483 });
484 }
485
486 let (mpsc_tx, _forwarder_handle) = create_event_forwarder(
488 job.child_session_id.clone(),
489 child_tx.clone(),
490 ctx.agent_runners.clone(),
491 );
492
493 let timeout_reason = Arc::new(RwLock::new(None::<String>));
496 let watchdog_policy = watchdog_policy_for_session(&session);
497 tokio::spawn(watch_child_liveness(
498 job.parent_session_id.clone(),
499 job.child_session_id.clone(),
500 ctx.agent_runners.clone(),
501 cancel_token.clone(),
502 timeout_reason.clone(),
503 forwarder_done.clone(),
504 watchdog_policy,
505 ));
506
507 let model = job.model.clone();
509 let session_id_clone = job.child_session_id.clone();
510 let agent_runners_for_status = ctx.agent_runners.clone();
511 let sessions_cache = ctx.sessions_cache.clone();
512 let agent = ctx.agent.clone();
513 let tools = ctx.tools.clone();
514 let external_runner = ctx.external_child_runner.clone();
515 let done = forwarder_done.clone();
516 let parent_tx_for_done = parent_tx.clone();
517 let parent_id_for_done = job.parent_session_id.clone();
518 let child_id_for_done = job.child_session_id.clone();
519 let session_event_senders = ctx.session_event_senders.clone();
520 let provider_router = ctx.provider_router.clone();
521 let completion_handler = ctx.completion_handler.clone();
522
523 tokio::spawn(async move {
524 session.model = model.clone();
525
526 let wants_external = session
527 .metadata
528 .get("runtime.kind")
529 .is_some_and(|v| v == "external");
530
531 let result: crate::runtime::runner::Result<()> = if wants_external {
532 if let Some(runner) = external_runner {
533 if runner.should_handle(&session).await {
534 runner
535 .execute_external_child(&mut session, &job, mpsc_tx, cancel_token.clone())
536 .await
537 } else {
538 Err(bamboo_agent_core::AgentError::LLM(format!(
539 "No external runner matched child session runtime metadata: agent_id={:?}, protocol={:?}",
540 session.metadata.get("external.agent_id"),
541 session.metadata.get("external.protocol"),
542 )))
543 }
544 } else {
545 Err(bamboo_agent_core::AgentError::LLM(
546 "Child session requires external runtime, but no external runner is configured"
547 .to_string(),
548 ))
549 }
550 } else {
551 let (provider_override, provider_name, provider_type) =
552 resolve_child_provider_override(provider_router.as_ref(), &session, &model);
553 let disabled_tools: Option<std::collections::BTreeSet<String>> =
554 job.disabled_tools.map(|v| v.into_iter().collect());
555 agent
556 .execute(
557 &mut session,
558 ExecuteRequest {
559 initial_message: String::new(), event_tx: mpsc_tx,
561 cancel_token: cancel_token.clone(),
562 tools: Some(tools),
563 provider_override,
564 model: Some(model.clone()),
565 provider_name,
566 provider_type,
567 fast_model: None,
568 fast_model_provider: None,
569 background_model: None,
570 background_model_provider: None,
571 summarization_model: None,
572 summarization_model_provider: None,
573 reasoning_effort: None,
574 disabled_tools,
575 disabled_skill_ids: None,
576 selected_skill_ids: None,
577 selected_skill_mode: None,
578 image_fallback: None,
579 app_data_dir: ctx.app_data_dir.clone(),
580 },
581 )
582 .await
583 };
584
585 let timeout_error = timeout_reason.read().await.clone();
586 let (status, error) = if let Some(reason) = timeout_error {
587 ("timeout".to_string(), Some(reason))
588 } else {
589 match &result {
590 Ok(_) => ("completed".to_string(), None),
591 Err(e) if e.to_string().contains("cancelled") => {
592 ("cancelled".to_string(), Some(e.to_string()))
593 }
594 Err(e) => ("error".to_string(), Some(e.to_string())),
595 }
596 };
597
598 finalize_runner(&agent_runners_for_status, &session_id_clone, &result).await;
599
600 crate::runtime::runner::state_bridge::merge_pending_injected_messages(
603 &mut session,
604 Some(agent.storage()),
605 Some(agent.persistence()),
606 )
607 .await;
608
609 session
611 .metadata
612 .insert("last_run_status".to_string(), status.clone());
613 if let Some(err) = &error {
614 session
615 .metadata
616 .insert("last_run_error".to_string(), err.clone());
617 } else {
618 session.metadata.remove("last_run_error");
619 }
620 let _ = agent.persistence().save_runtime_session(&mut session).await;
621 {
622 let mut sessions = sessions_cache.write().await;
623 sessions.insert(session_id_clone.clone(), session);
624 }
625
626 done.cancel();
629 publish_child_completion_parts(
630 &parent_tx_for_done,
631 completion_handler,
632 parent_id_for_done,
633 child_id_for_done,
634 status,
635 error,
636 )
637 .await;
638
639 drop(session_event_senders);
641 });
642
643 Ok(())
644}