bamboo_engine/runtime/execution/
spawn.rs1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use chrono::Utc;
12use tokio::sync::{broadcast, mpsc, RwLock};
13use tokio_util::sync::CancellationToken;
14
15use bamboo_agent_core::tools::ToolExecutor;
16use bamboo_agent_core::{AgentEvent, Role, Session, SessionKind};
17use bamboo_domain::ProviderModelRef;
18use bamboo_infrastructure::{LLMProvider, ProviderModelRouter};
19
20use crate::runtime::Agent;
21use crate::runtime::ExecuteRequest;
22
23use super::event_forwarder::create_event_forwarder;
24use super::runner_lifecycle::{finalize_runner, try_reserve_runner};
25use super::runner_state::AgentRunner;
26use super::session_events::get_or_create_event_sender;
27
28#[derive(Debug, Clone)]
29pub struct SpawnJob {
30 pub parent_session_id: String,
31 pub child_session_id: String,
32 pub model: String,
33}
34
35#[async_trait::async_trait]
40pub trait ExternalChildRunner: Send + Sync {
41 async fn should_handle(&self, session: &Session) -> bool;
43
44 async fn execute_external_child(
46 &self,
47 session: &mut Session,
48 job: &SpawnJob,
49 event_tx: tokio::sync::mpsc::Sender<AgentEvent>,
50 cancel_token: CancellationToken,
51 ) -> crate::runtime::runner::Result<()>;
52}
53
54#[derive(Clone)]
55pub struct SpawnContext {
56 pub agent: Arc<Agent>,
57 pub tools: Arc<dyn ToolExecutor>,
58 pub sessions_cache: Arc<RwLock<HashMap<String, Session>>>,
59 pub agent_runners: Arc<RwLock<HashMap<String, AgentRunner>>>,
60 pub session_event_senders: Arc<RwLock<HashMap<String, broadcast::Sender<AgentEvent>>>>,
61 pub external_child_runner: Option<Arc<dyn ExternalChildRunner>>,
62 pub provider_router: Option<Arc<ProviderModelRouter>>,
63}
64
65#[derive(Clone)]
66pub struct SpawnScheduler {
67 tx: mpsc::Sender<SpawnJob>,
68}
69
70impl SpawnScheduler {
71 pub fn new(ctx: SpawnContext) -> Self {
72 let (tx, mut rx) = mpsc::channel::<SpawnJob>(128);
73
74 tokio::spawn(async move {
75 while let Some(job) = rx.recv().await {
76 if let Err(err) = run_spawn_job(ctx.clone(), job).await {
77 tracing::warn!("spawn job failed: {}", err);
78 }
79 }
80 });
81
82 Self { tx }
83 }
84
85 pub async fn enqueue(&self, job: SpawnJob) -> Result<(), String> {
86 self.tx
87 .send(job)
88 .await
89 .map_err(|_| "spawn scheduler is not running".to_string())
90 }
91}
92
93fn child_model_ref(session: &Session, model: &str) -> Option<ProviderModelRef> {
94 if let Some(model_ref) = session.model_ref.clone() {
95 let provider = model_ref.provider.trim();
96 let model_name = model_ref.model.trim();
97 if !provider.is_empty() && !model_name.is_empty() {
98 return Some(ProviderModelRef::new(provider, model_name));
99 }
100 }
101
102 let provider = session
103 .metadata
104 .get("provider_name")
105 .map(String::as_str)
106 .map(str::trim)
107 .filter(|value| !value.is_empty())?;
108 let model_name = model.trim();
109 if model_name.is_empty() {
110 return None;
111 }
112 Some(ProviderModelRef::new(provider, model_name))
113}
114
115fn resolve_child_provider_override(
116 router: Option<&Arc<ProviderModelRouter>>,
117 session: &Session,
118 model: &str,
119) -> (Option<Arc<dyn LLMProvider>>, Option<String>) {
120 let model_ref = child_model_ref(session, model);
121 let provider_name = model_ref
122 .as_ref()
123 .map(|model_ref| model_ref.provider.clone());
124 let provider = router.and_then(|router| {
125 let model_ref = model_ref.as_ref()?;
126 match router.route(model_ref) {
127 Ok(provider) => Some(provider),
128 Err(error) => {
129 tracing::warn!(
130 session_id = %session.id,
131 provider = %model_ref.provider,
132 model = %model_ref.model,
133 error = %error,
134 "failed to resolve provider override for child session; falling back to runtime provider"
135 );
136 None
137 }
138 }
139 });
140 (provider, provider_name)
141}
142
143async fn run_spawn_job(ctx: SpawnContext, job: SpawnJob) -> Result<(), String> {
144 let parent_tx =
146 get_or_create_event_sender(&ctx.session_event_senders, &job.parent_session_id).await;
147 let child_tx =
148 get_or_create_event_sender(&ctx.session_event_senders, &job.child_session_id).await;
149
150 let emit_error_completion = |error: String| {
151 let _ = parent_tx.send(AgentEvent::SubSessionCompleted {
152 parent_session_id: job.parent_session_id.clone(),
153 child_session_id: job.child_session_id.clone(),
154 status: "error".to_string(),
155 error: Some(error.clone()),
156 });
157 error
158 };
159
160 let mut session = match ctx
162 .agent
163 .storage()
164 .load_session(&job.child_session_id)
165 .await
166 {
167 Ok(Some(s)) => s,
168 Ok(None) => return Err(emit_error_completion("child session not found".to_string())),
169 Err(e) => {
170 return Err(emit_error_completion(format!(
171 "failed to load child session: {e}"
172 )))
173 }
174 };
175
176 if session.kind != SessionKind::Child {
177 return Err(emit_error_completion(
178 "spawn job child session is not kind=child".to_string(),
179 ));
180 }
181
182 let last_is_user = session
184 .messages
185 .last()
186 .map(|m| matches!(m.role, Role::User))
187 .unwrap_or(false);
188 if !last_is_user {
189 session
190 .metadata
191 .insert("last_run_status".to_string(), "skipped".to_string());
192 session.metadata.insert(
193 "last_run_error".to_string(),
194 "No pending message to execute".to_string(),
195 );
196 let _ = ctx.agent.storage().save_session(&session).await;
197 let _ = parent_tx.send(AgentEvent::SubSessionCompleted {
198 parent_session_id: job.parent_session_id.clone(),
199 child_session_id: job.child_session_id.clone(),
200 status: "skipped".to_string(),
201 error: Some("No pending message to execute".to_string()),
202 });
203 return Ok(());
204 }
205
206 session
208 .metadata
209 .insert("last_run_status".to_string(), "running".to_string());
210 session.metadata.remove("last_run_error");
211 let _ = ctx.agent.storage().save_session(&session).await;
212
213 let Some(cancel_token) =
215 try_reserve_runner(&ctx.agent_runners, &job.child_session_id, &child_tx).await
216 else {
217 return Ok(());
218 };
219
220 let forwarder_done = CancellationToken::new();
222 {
223 let mut rx = child_tx.subscribe();
224 let parent_tx = parent_tx.clone();
225 let job_clone = job.clone();
226 let done = forwarder_done.clone();
227 tokio::spawn(async move {
228 loop {
229 tokio::select! {
230 _ = done.cancelled() => break,
231 evt = rx.recv() => {
232 match evt {
233 Ok(event) => {
234 let _ = parent_tx.send(AgentEvent::SubSessionEvent {
235 parent_session_id: job_clone.parent_session_id.clone(),
236 child_session_id: job_clone.child_session_id.clone(),
237 event: Box::new(event),
238 });
239 }
240 Err(broadcast::error::RecvError::Lagged(_)) => {
241 continue;
242 }
243 Err(_) => break,
244 }
245 }
246 }
247 }
248 });
249 }
250 {
251 let parent_tx = parent_tx.clone();
252 let job_clone = job.clone();
253 let done = forwarder_done.clone();
254 tokio::spawn(async move {
255 let mut ticker = tokio::time::interval(Duration::from_secs(5));
256 loop {
257 tokio::select! {
258 _ = done.cancelled() => break,
259 _ = ticker.tick() => {
260 let _ = parent_tx.send(AgentEvent::SubSessionHeartbeat {
261 parent_session_id: job_clone.parent_session_id.clone(),
262 child_session_id: job_clone.child_session_id.clone(),
263 timestamp: Utc::now(),
264 });
265 }
266 }
267 }
268 });
269 }
270
271 let (mpsc_tx, _forwarder_handle) = create_event_forwarder(
273 job.child_session_id.clone(),
274 child_tx.clone(),
275 ctx.agent_runners.clone(),
276 );
277
278 let model = job.model.clone();
280 let session_id_clone = job.child_session_id.clone();
281 let agent_runners_for_status = ctx.agent_runners.clone();
282 let sessions_cache = ctx.sessions_cache.clone();
283 let agent = ctx.agent.clone();
284 let tools = ctx.tools.clone();
285 let external_runner = ctx.external_child_runner.clone();
286 let done = forwarder_done.clone();
287 let parent_tx_for_done = parent_tx.clone();
288 let parent_id_for_done = job.parent_session_id.clone();
289 let child_id_for_done = job.child_session_id.clone();
290 let session_event_senders = ctx.session_event_senders.clone();
291 let provider_router = ctx.provider_router.clone();
292
293 tokio::spawn(async move {
294 session.model = model.clone();
295
296 let wants_external = session
297 .metadata
298 .get("runtime.kind")
299 .is_some_and(|v| v == "external");
300
301 let result: crate::runtime::runner::Result<()> = if wants_external {
302 if let Some(runner) = external_runner {
303 if runner.should_handle(&session).await {
304 runner
305 .execute_external_child(&mut session, &job, mpsc_tx, cancel_token)
306 .await
307 } else {
308 Err(bamboo_agent_core::AgentError::LLM(format!(
309 "No external runner matched child session runtime metadata: agent_id={:?}, protocol={:?}",
310 session.metadata.get("external.agent_id"),
311 session.metadata.get("external.protocol"),
312 )))
313 }
314 } else {
315 Err(bamboo_agent_core::AgentError::LLM(
316 "Child session requires external runtime, but no external runner is configured"
317 .to_string(),
318 ))
319 }
320 } else {
321 let (provider_override, provider_name) =
322 resolve_child_provider_override(provider_router.as_ref(), &session, &model);
323 agent
324 .execute(
325 &mut session,
326 ExecuteRequest {
327 initial_message: String::new(), event_tx: mpsc_tx,
329 cancel_token,
330 tools: Some(tools),
331 provider_override,
332 model: Some(model.clone()),
333 provider_name,
334 background_model: None,
335 background_model_provider: None,
336 reasoning_effort: None,
337 disabled_tools: None,
338 disabled_skill_ids: None,
339 selected_skill_ids: None,
340 selected_skill_mode: None,
341 image_fallback: None,
342 },
343 )
344 .await
345 };
346
347 let (status, error) = match &result {
348 Ok(_) => ("completed".to_string(), None),
349 Err(e) if e.to_string().contains("cancelled") => {
350 ("cancelled".to_string(), Some(e.to_string()))
351 }
352 Err(e) => ("error".to_string(), Some(e.to_string())),
353 };
354
355 finalize_runner(&agent_runners_for_status, &session_id_clone, &result).await;
356
357 if let Ok(Some(latest)) = agent.storage().load_session(&session_id_clone).await {
360 if let Some(raw) = latest.metadata.get("pending_injected_messages") {
361 if let Ok(messages) = serde_json::from_str::<Vec<serde_json::Value>>(raw) {
362 for msg in messages {
363 if let Some(content) = msg.get("content").and_then(|v| v.as_str()) {
364 session
365 .add_message(bamboo_agent_core::Message::user(content.to_string()));
366 }
367 }
368 session.metadata.remove("pending_injected_messages");
369 }
370 }
371 }
372
373 session
375 .metadata
376 .insert("last_run_status".to_string(), status.clone());
377 if let Some(err) = &error {
378 session
379 .metadata
380 .insert("last_run_error".to_string(), err.clone());
381 } else {
382 session.metadata.remove("last_run_error");
383 }
384 let _ = agent.storage().save_session(&session).await;
385 {
386 let mut sessions = sessions_cache.write().await;
387 sessions.insert(session_id_clone.clone(), session);
388 }
389
390 done.cancel();
392 let _ = parent_tx_for_done.send(AgentEvent::SubSessionCompleted {
393 parent_session_id: parent_id_for_done,
394 child_session_id: child_id_for_done,
395 status,
396 error,
397 });
398
399 drop(session_event_senders);
401 });
402
403 Ok(())
404}