1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, mpsc, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::tools::ToolExecutor;
16use bamboo_agent_core::{AgentEvent, Role, Session, SessionKind};
17use bamboo_domain::ProviderModelRef;
18use bamboo_infrastructure::{LLMProvider, ProviderModelRouter};
19
20use crate::runtime::Agent;
21use crate::runtime::ExecuteRequest;
22
23use super::child_completion::{ChildCompletion, ChildCompletionHandler};
24use super::event_forwarder::create_event_forwarder;
25use super::runner_lifecycle::{finalize_runner, try_reserve_runner, RunnerReservation};
26use super::runner_state::AgentRunner;
27use super::session_events::get_or_create_event_sender;
28
29#[derive(Debug, Clone)]
30pub struct SpawnJob {
31 pub parent_session_id: String,
32 pub child_session_id: String,
33 pub model: String,
34}
35
36#[async_trait::async_trait]
41pub trait ExternalChildRunner: Send + Sync {
42 async fn should_handle(&self, session: &Session) -> bool;
44
45 async fn execute_external_child(
47 &self,
48 session: &mut Session,
49 job: &SpawnJob,
50 event_tx: tokio::sync::mpsc::Sender<AgentEvent>,
51 cancel_token: CancellationToken,
52 ) -> crate::runtime::runner::Result<()>;
53}
54
55#[derive(Clone)]
56pub struct SpawnContext {
57 pub agent: Arc<Agent>,
58 pub tools: Arc<dyn ToolExecutor>,
59 pub sessions_cache: Arc<RwLock<HashMap<String, Session>>>,
60 pub agent_runners: Arc<RwLock<HashMap<String, AgentRunner>>>,
61 pub session_event_senders: Arc<RwLock<HashMap<String, broadcast::Sender<AgentEvent>>>>,
62 pub external_child_runner: Option<Arc<dyn ExternalChildRunner>>,
63 pub provider_router: Option<Arc<ProviderModelRouter>>,
64 pub completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
69}
70
71#[derive(Clone)]
72pub struct SpawnScheduler {
73 tx: mpsc::Sender<SpawnJob>,
74}
75
76impl SpawnScheduler {
77 pub fn new(ctx: SpawnContext) -> Self {
78 let (tx, mut rx) = mpsc::channel::<SpawnJob>(128);
79
80 tokio::spawn(async move {
81 while let Some(job) = rx.recv().await {
82 if let Err(err) = run_spawn_job(ctx.clone(), job).await {
83 tracing::warn!("spawn job failed: {}", err);
84 }
85 }
86 });
87
88 Self { tx }
89 }
90
91 pub async fn enqueue(&self, job: SpawnJob) -> Result<(), String> {
92 self.tx
93 .send(job)
94 .await
95 .map_err(|_| "spawn scheduler is not running".to_string())
96 }
97}
98
99fn child_model_ref(session: &Session, model: &str) -> Option<ProviderModelRef> {
100 if let Some(model_ref) = session.model_ref.clone() {
101 let provider = model_ref.provider.trim();
102 let model_name = model_ref.model.trim();
103 if !provider.is_empty() && !model_name.is_empty() {
104 return Some(ProviderModelRef::new(provider, model_name));
105 }
106 }
107
108 let provider = session
109 .metadata
110 .get("provider_name")
111 .map(String::as_str)
112 .map(str::trim)
113 .filter(|value| !value.is_empty())?;
114 let model_name = model.trim();
115 if model_name.is_empty() {
116 return None;
117 }
118 Some(ProviderModelRef::new(provider, model_name))
119}
120
121#[derive(Debug, Clone, Copy)]
122struct ChildWatchdogPolicy {
123 check_interval_secs: i64,
124 max_total_secs: i64,
125 max_idle_secs: i64,
126}
127
128impl Default for ChildWatchdogPolicy {
129 fn default() -> Self {
130 Self {
131 check_interval_secs: 15,
132 max_total_secs: 60 * 60,
136 max_idle_secs: 15 * 60,
138 }
139 }
140}
141
142fn metadata_i64(session: &Session, key: &str) -> Option<i64> {
143 session
144 .metadata
145 .get(key)
146 .and_then(|value| value.trim().parse::<i64>().ok())
147 .filter(|value| *value > 0)
148}
149
150fn watchdog_policy_for_session(session: &Session) -> ChildWatchdogPolicy {
151 let mut policy = ChildWatchdogPolicy::default();
152 if let Some(value) = metadata_i64(session, "child_watchdog.max_total_secs") {
153 policy.max_total_secs = value;
154 }
155 if let Some(value) = metadata_i64(session, "child_watchdog.max_idle_secs") {
156 policy.max_idle_secs = value;
157 }
158 if let Some(value) = metadata_i64(session, "child_watchdog.check_interval_secs") {
159 policy.check_interval_secs = value;
160 }
161 policy
162}
163
164async fn publish_child_completion(
165 parent_tx: &broadcast::Sender<AgentEvent>,
166 completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
167 completion: ChildCompletion,
168) {
169 let _ = parent_tx.send(AgentEvent::SubSessionCompleted {
170 parent_session_id: completion.parent_session_id.clone(),
171 child_session_id: completion.child_session_id.clone(),
172 status: completion.status.clone(),
173 error: completion.error.clone(),
174 });
175
176 if let Some(handler) = completion_handler {
177 handler.on_child_completed(completion).await;
178 }
179}
180
181async fn publish_child_completion_parts(
182 parent_tx: &broadcast::Sender<AgentEvent>,
183 completion_handler: Option<Arc<dyn ChildCompletionHandler>>,
184 parent_session_id: String,
185 child_session_id: String,
186 status: String,
187 error: Option<String>,
188) {
189 publish_child_completion(
190 parent_tx,
191 completion_handler,
192 ChildCompletion {
193 parent_session_id,
194 child_session_id,
195 status,
196 error,
197 completed_at: Utc::now(),
198 },
199 )
200 .await;
201}
202
203async fn watch_child_liveness(
204 parent_session_id: String,
205 child_session_id: String,
206 runners: Arc<RwLock<HashMap<String, AgentRunner>>>,
207 cancel_token: CancellationToken,
208 timeout_reason: Arc<RwLock<Option<String>>>,
209 done: CancellationToken,
210 policy: ChildWatchdogPolicy,
211) {
212 let mut ticker =
213 tokio::time::interval(Duration::from_secs(policy.check_interval_secs.max(1) as u64));
214 ticker.tick().await;
216
217 loop {
218 tokio::select! {
219 _ = done.cancelled() => return,
220 _ = ticker.tick() => {
221 if cancel_token.is_cancelled() {
222 return;
223 }
224
225 let snapshot = {
226 let guard = runners.read().await;
227 guard.get(&child_session_id).cloned()
228 };
229 let Some(runner) = snapshot else {
230 return;
231 };
232 if !matches!(runner.status, super::runner_state::AgentStatus::Running) {
233 return;
234 }
235
236 let now = Utc::now();
237 let total_secs = now.signed_duration_since(runner.started_at).num_seconds();
238 if total_secs >= policy.max_total_secs {
239 let reason = format!(
240 "Child session timed out after {} seconds (max_total_secs={})",
241 total_secs, policy.max_total_secs
242 );
243 tracing::warn!(
244 parent_session_id = %parent_session_id,
245 child_session_id = %child_session_id,
246 reason = %reason,
247 "child session total timeout; cancelling child runner"
248 );
249 *timeout_reason.write().await = Some(reason);
250 cancel_token.cancel();
251 return;
252 }
253
254 let last_activity_at = runner.last_event_at.unwrap_or(runner.started_at);
255 let idle_secs = now.signed_duration_since(last_activity_at).num_seconds();
256 if idle_secs >= policy.max_idle_secs {
257 let reason = format!(
258 "Child session idle timeout after {} seconds without events (max_idle_secs={})",
259 idle_secs, policy.max_idle_secs
260 );
261 tracing::warn!(
262 parent_session_id = %parent_session_id,
263 child_session_id = %child_session_id,
264 reason = %reason,
265 last_tool_name = ?runner.last_tool_name,
266 last_tool_phase = ?runner.last_tool_phase,
267 round_count = runner.round_count,
268 "child session idle timeout; cancelling child runner"
269 );
270 *timeout_reason.write().await = Some(reason);
271 cancel_token.cancel();
272 return;
273 }
274 }
275 }
276 }
277}
278
279fn resolve_child_provider_override(
280 router: Option<&Arc<ProviderModelRouter>>,
281 session: &Session,
282 model: &str,
283) -> (Option<Arc<dyn LLMProvider>>, Option<String>) {
284 let model_ref = child_model_ref(session, model);
285 let provider_name = model_ref
286 .as_ref()
287 .map(|model_ref| model_ref.provider.clone());
288 let provider = router.and_then(|router| {
289 let model_ref = model_ref.as_ref()?;
290 match router.route(model_ref) {
291 Ok(provider) => Some(provider),
292 Err(error) => {
293 tracing::warn!(
294 session_id = %session.id,
295 provider = %model_ref.provider,
296 model = %model_ref.model,
297 error = %error,
298 "failed to resolve provider override for child session; falling back to runtime provider"
299 );
300 None
301 }
302 }
303 });
304 (provider, provider_name)
305}
306
307async fn run_spawn_job(ctx: SpawnContext, job: SpawnJob) -> Result<(), String> {
308 let parent_tx =
310 get_or_create_event_sender(&ctx.session_event_senders, &job.parent_session_id).await;
311 let child_tx =
312 get_or_create_event_sender(&ctx.session_event_senders, &job.child_session_id).await;
313
314 let mut session = match ctx
316 .agent
317 .storage()
318 .load_session(&job.child_session_id)
319 .await
320 {
321 Ok(Some(s)) => s,
322 Ok(None) => {
323 let error = "child session not found".to_string();
324 publish_child_completion_parts(
325 &parent_tx,
326 ctx.completion_handler.clone(),
327 job.parent_session_id.clone(),
328 job.child_session_id.clone(),
329 "error".to_string(),
330 Some(error.clone()),
331 )
332 .await;
333 return Err(error);
334 }
335 Err(e) => {
336 let error = format!("failed to load child session: {e}");
337 publish_child_completion_parts(
338 &parent_tx,
339 ctx.completion_handler.clone(),
340 job.parent_session_id.clone(),
341 job.child_session_id.clone(),
342 "error".to_string(),
343 Some(error.clone()),
344 )
345 .await;
346 return Err(error);
347 }
348 };
349
350 if session.kind != SessionKind::Child {
351 let error = "spawn job child session is not kind=child".to_string();
352 publish_child_completion_parts(
353 &parent_tx,
354 ctx.completion_handler.clone(),
355 job.parent_session_id.clone(),
356 job.child_session_id.clone(),
357 "error".to_string(),
358 Some(error.clone()),
359 )
360 .await;
361 return Err(error);
362 }
363
364 let last_is_user = session
366 .messages
367 .last()
368 .map(|m| matches!(m.role, Role::User))
369 .unwrap_or(false);
370 if !last_is_user {
371 session
372 .metadata
373 .insert("last_run_status".to_string(), "skipped".to_string());
374 session.metadata.insert(
375 "last_run_error".to_string(),
376 "No pending message to execute".to_string(),
377 );
378 let _ = ctx
379 .agent
380 .persistence()
381 .save_runtime_session(&mut session)
382 .await;
383 {
384 let mut sessions = ctx.sessions_cache.write().await;
385 sessions.insert(job.child_session_id.clone(), session);
386 }
387 publish_child_completion_parts(
388 &parent_tx,
389 ctx.completion_handler.clone(),
390 job.parent_session_id.clone(),
391 job.child_session_id.clone(),
392 "skipped".to_string(),
393 Some("No pending message to execute".to_string()),
394 )
395 .await;
396 return Ok(());
397 }
398
399 session
401 .metadata
402 .insert("last_run_status".to_string(), "running".to_string());
403 session.metadata.remove("last_run_error");
404 let _ = ctx
405 .agent
406 .persistence()
407 .save_runtime_session(&mut session)
408 .await;
409
410 let Some(RunnerReservation { cancel_token, .. }) =
412 try_reserve_runner(&ctx.agent_runners, &job.child_session_id, &child_tx).await
413 else {
414 return Ok(());
415 };
416
417 let forwarder_done = CancellationToken::new();
419 {
420 let mut rx = child_tx.subscribe();
421 let parent_tx = parent_tx.clone();
422 let job_clone = job.clone();
423 let done = forwarder_done.clone();
424 tokio::spawn(async move {
425 loop {
426 tokio::select! {
427 _ = done.cancelled() => break,
428 evt = rx.recv() => {
429 match evt {
430 Ok(event) => {
431 let _ = parent_tx.send(AgentEvent::SubSessionEvent {
432 parent_session_id: job_clone.parent_session_id.clone(),
433 child_session_id: job_clone.child_session_id.clone(),
434 event: Box::new(event),
435 });
436 }
437 Err(broadcast::error::RecvError::Lagged(_)) => {
438 continue;
439 }
440 Err(_) => break,
441 }
442 }
443 }
444 }
445 });
446 }
447 {
448 let parent_tx = parent_tx.clone();
449 let job_clone = job.clone();
450 let done = forwarder_done.clone();
451 tokio::spawn(async move {
452 let mut ticker = tokio::time::interval(Duration::from_secs(5));
453 loop {
454 tokio::select! {
455 _ = done.cancelled() => break,
456 _ = ticker.tick() => {
457 let _ = parent_tx.send(AgentEvent::SubSessionHeartbeat {
458 parent_session_id: job_clone.parent_session_id.clone(),
459 child_session_id: job_clone.child_session_id.clone(),
460 timestamp: Utc::now(),
461 });
462 }
463 }
464 }
465 });
466 }
467
468 let (mpsc_tx, _forwarder_handle) = create_event_forwarder(
470 job.child_session_id.clone(),
471 child_tx.clone(),
472 ctx.agent_runners.clone(),
473 );
474
475 let timeout_reason = Arc::new(RwLock::new(None::<String>));
478 let watchdog_policy = watchdog_policy_for_session(&session);
479 tokio::spawn(watch_child_liveness(
480 job.parent_session_id.clone(),
481 job.child_session_id.clone(),
482 ctx.agent_runners.clone(),
483 cancel_token.clone(),
484 timeout_reason.clone(),
485 forwarder_done.clone(),
486 watchdog_policy,
487 ));
488
489 let model = job.model.clone();
491 let session_id_clone = job.child_session_id.clone();
492 let agent_runners_for_status = ctx.agent_runners.clone();
493 let sessions_cache = ctx.sessions_cache.clone();
494 let agent = ctx.agent.clone();
495 let tools = ctx.tools.clone();
496 let external_runner = ctx.external_child_runner.clone();
497 let done = forwarder_done.clone();
498 let parent_tx_for_done = parent_tx.clone();
499 let parent_id_for_done = job.parent_session_id.clone();
500 let child_id_for_done = job.child_session_id.clone();
501 let session_event_senders = ctx.session_event_senders.clone();
502 let provider_router = ctx.provider_router.clone();
503 let completion_handler = ctx.completion_handler.clone();
504
505 tokio::spawn(async move {
506 session.model = model.clone();
507
508 let wants_external = session
509 .metadata
510 .get("runtime.kind")
511 .is_some_and(|v| v == "external");
512
513 let result: crate::runtime::runner::Result<()> = if wants_external {
514 if let Some(runner) = external_runner {
515 if runner.should_handle(&session).await {
516 runner
517 .execute_external_child(&mut session, &job, mpsc_tx, cancel_token.clone())
518 .await
519 } else {
520 Err(bamboo_agent_core::AgentError::LLM(format!(
521 "No external runner matched child session runtime metadata: agent_id={:?}, protocol={:?}",
522 session.metadata.get("external.agent_id"),
523 session.metadata.get("external.protocol"),
524 )))
525 }
526 } else {
527 Err(bamboo_agent_core::AgentError::LLM(
528 "Child session requires external runtime, but no external runner is configured"
529 .to_string(),
530 ))
531 }
532 } else {
533 let (provider_override, provider_name) =
534 resolve_child_provider_override(provider_router.as_ref(), &session, &model);
535 agent
536 .execute(
537 &mut session,
538 ExecuteRequest {
539 initial_message: String::new(), event_tx: mpsc_tx,
541 cancel_token: cancel_token.clone(),
542 tools: Some(tools),
543 provider_override,
544 model: Some(model.clone()),
545 provider_name,
546 background_model: None,
547 background_model_provider: None,
548 reasoning_effort: None,
549 disabled_tools: None,
550 disabled_skill_ids: None,
551 selected_skill_ids: None,
552 selected_skill_mode: None,
553 image_fallback: None,
554 },
555 )
556 .await
557 };
558
559 let timeout_error = timeout_reason.read().await.clone();
560 let (status, error) = if let Some(reason) = timeout_error {
561 ("timeout".to_string(), Some(reason))
562 } else {
563 match &result {
564 Ok(_) => ("completed".to_string(), None),
565 Err(e) if e.to_string().contains("cancelled") => {
566 ("cancelled".to_string(), Some(e.to_string()))
567 }
568 Err(e) => ("error".to_string(), Some(e.to_string())),
569 }
570 };
571
572 finalize_runner(&agent_runners_for_status, &session_id_clone, &result).await;
573
574 if let Ok(Some(latest)) = agent.storage().load_session(&session_id_clone).await {
577 if let Some(raw) = latest.metadata.get("pending_injected_messages") {
578 if let Ok(messages) = serde_json::from_str::<Vec<serde_json::Value>>(raw) {
579 for msg in messages {
580 if let Some(content) = msg.get("content").and_then(|v| v.as_str()) {
581 session
582 .add_message(bamboo_agent_core::Message::user(content.to_string()));
583 }
584 }
585 session.metadata.remove("pending_injected_messages");
586 }
587 }
588 }
589
590 session
592 .metadata
593 .insert("last_run_status".to_string(), status.clone());
594 if let Some(err) = &error {
595 session
596 .metadata
597 .insert("last_run_error".to_string(), err.clone());
598 } else {
599 session.metadata.remove("last_run_error");
600 }
601 let _ = agent.persistence().save_runtime_session(&mut session).await;
602 {
603 let mut sessions = sessions_cache.write().await;
604 sessions.insert(session_id_clone.clone(), session);
605 }
606
607 done.cancel();
610 publish_child_completion_parts(
611 &parent_tx_for_done,
612 completion_handler,
613 parent_id_for_done,
614 child_id_for_done,
615 status,
616 error,
617 )
618 .await;
619
620 drop(session_event_senders);
622 });
623
624 Ok(())
625}