Skip to main content

bamboo_engine/external_agents/
a2a_adapter.rs

1use std::time::Duration;
2
3use crate::runtime::execution::{ExternalChildRunner, SpawnJob};
4use async_trait::async_trait;
5use bamboo_a2a::types::{
6    A2ARole, CancelTaskRequest, GetTaskRequest, Message, Part, PartContentWire,
7    SendMessageConfiguration, SendMessageRequest,
8};
9use bamboo_a2a::{
10    validate_agent_card_for_jsonrpc_mvp, A2AAuth, A2AClient, A2AClientConfig, A2AJsonRpcClient,
11};
12use bamboo_agent_core::{AgentError, AgentEvent, Role, TokenUsage};
13use futures::StreamExt;
14use tokio::sync::mpsc;
15use tokio_util::sync::CancellationToken;
16
17use super::config::ExternalAgentProfile;
18use super::mapping::{A2AEventMapper, A2AMappedEvents};
19
20/// A2A external child runner that delegates child session execution to a remote
21/// A2A-compliant agent via JSON-RPC + SSE.
22pub struct A2AExternalChildRunner {
23    client: A2AJsonRpcClient,
24    profile: ExternalAgentProfile,
25}
26
27impl A2AExternalChildRunner {
28    pub fn new(client: A2AJsonRpcClient, profile: ExternalAgentProfile) -> Self {
29        Self { client, profile }
30    }
31
32    /// Build an [`A2AClientConfig`] from an [`ExternalAgentProfile`] and auth token.
33    pub fn build_client_config(
34        profile: &ExternalAgentProfile,
35        auth_token: Option<String>,
36    ) -> bamboo_a2a::A2AClientResult<A2AClientConfig> {
37        let auth = match auth_token {
38            Some(token) => A2AAuth::Bearer(token),
39            None => A2AAuth::None,
40        };
41
42        let agent_card_url = profile.agent_card_url.clone().ok_or_else(|| {
43            bamboo_a2a::A2AClientError::InvalidAgentCard(format!(
44                "Profile {} has no agent_card_url",
45                profile.agent_id
46            ))
47        })?;
48
49        Ok(A2AClientConfig {
50            profile_id: profile.agent_id.clone(),
51            agent_card_url,
52            rpc_url_override: profile.rpc_url_override.clone(),
53            auth,
54            tenant: profile.tenant.clone(),
55            request_timeout: Duration::from_secs(120),
56            protocol_version: "1.0".to_string(),
57            extensions: Vec::new(),
58        })
59    }
60
61    /// Build a `SendMessageRequest` from the current child session state.
62    fn build_send_message_request(
63        &self,
64        session: &bamboo_agent_core::Session,
65    ) -> SendMessageRequest {
66        let mut message = build_a2a_message(session);
67
68        // Carry over context_id if we have one from a previous task in this session.
69        message.context_id = session.metadata.get("a2a.context_id").cloned();
70
71        // Build reference_task_ids from previous attempts.
72        message.reference_task_ids = session
73            .metadata
74            .get("a2a.reference_task_ids")
75            .and_then(|v| serde_json::from_str::<Vec<String>>(v).ok())
76            .unwrap_or_default();
77
78        let configuration = Some(SendMessageConfiguration {
79            accepted_output_modes: Some(vec!["text/plain".to_string()]),
80            history_length: Some(0),
81            blocking: Some(false),
82            extra: Default::default(),
83        });
84
85        let mut metadata = serde_json::json!({
86            "bamboo_session_id": session.id,
87            "bamboo_attempt": session.metadata.get("a2a.attempt").unwrap_or(&"1".to_string()),
88        });
89
90        if let Some(skill) = &self.profile.skill {
91            metadata["skill"] = serde_json::json!(skill);
92        }
93
94        SendMessageRequest {
95            tenant: self.profile.tenant.clone(),
96            message,
97            configuration,
98            metadata: Some(metadata),
99        }
100    }
101}
102
103#[async_trait]
104impl ExternalChildRunner for A2AExternalChildRunner {
105    async fn should_handle(&self, session: &bamboo_agent_core::Session) -> bool {
106        let kind = session.metadata.get("runtime.kind");
107        let protocol = session.metadata.get("external.protocol");
108        let agent_id = session.metadata.get("external.agent_id");
109
110        kind == Some(&"external".to_string())
111            && protocol == Some(&"a2a_jsonrpc".to_string())
112            && agent_id == Some(&self.profile.agent_id)
113    }
114
115    async fn execute_external_child(
116        &self,
117        session: &mut bamboo_agent_core::Session,
118        _job: &SpawnJob,
119        event_tx: mpsc::Sender<AgentEvent>,
120        cancel_token: CancellationToken,
121    ) -> crate::runtime::runner::Result<()> {
122        // Increment attempt counter.
123        let attempt: u64 = session
124            .metadata
125            .get("a2a.attempt")
126            .and_then(|v| v.parse().ok())
127            .unwrap_or(0)
128            + 1;
129        session
130            .metadata
131            .insert("a2a.attempt".to_string(), attempt.to_string());
132
133        // Validate agent card before sending. If non-streaming fallback is allowed,
134        // do not require streaming capability; use the card to select the execution path.
135        let card_validation = match self.client.fetch_agent_card().await {
136            Ok(card) => match validate_agent_card_for_jsonrpc_mvp(
137                &card,
138                !self.profile.allow_non_streaming_fallback,
139                self.profile.skill.as_deref(),
140            ) {
141                Ok(validation) => validation,
142                Err(e) => {
143                    return Err(AgentError::LLM(format!(
144                        "A2A agent card validation failed for profile {}: {}",
145                        self.profile.agent_id, e
146                    )));
147                }
148            },
149            Err(e) => {
150                return Err(AgentError::LLM(format!(
151                    "A2A agent card fetch failed for profile {}: {}",
152                    self.profile.agent_id, e
153                )));
154            }
155        };
156
157        let request = self.build_send_message_request(session);
158
159        let fallback_to_non_streaming = self.profile.allow_non_streaming_fallback;
160        if !card_validation.streaming_supported {
161            if fallback_to_non_streaming {
162                tracing::info!(
163                    "A2A profile {} does not advertise streaming; using non-streaming fallback.",
164                    self.profile.agent_id
165                );
166                return handle_non_streaming(
167                    &self.client,
168                    request,
169                    event_tx,
170                    session,
171                    self.profile.tenant.clone(),
172                )
173                .await;
174            }
175
176            return Err(AgentError::LLM(format!(
177                "A2A profile {} does not support streaming and non-streaming fallback is disabled",
178                self.profile.agent_id
179            )));
180        }
181
182        // Try streaming first.
183        let stream_result = self.client.send_streaming_message(request.clone()).await;
184
185        let stream = match stream_result {
186            Ok(stream) => stream,
187            Err(e) => {
188                if fallback_to_non_streaming {
189                    tracing::warn!(
190                        "A2A streaming failed for profile {}: {}. Falling back to non-streaming.",
191                        self.profile.agent_id,
192                        e
193                    );
194                    return handle_non_streaming(
195                        &self.client,
196                        request,
197                        event_tx,
198                        session,
199                        self.profile.tenant.clone(),
200                    )
201                    .await;
202                }
203                return Err(AgentError::LLM(format!(
204                    "A2A streaming failed and fallback disabled: {}",
205                    e
206                )));
207            }
208        };
209
210        handle_streaming(
211            &self.client,
212            stream,
213            event_tx,
214            cancel_token,
215            session,
216            self.profile.tenant.clone(),
217        )
218        .await
219    }
220}
221
222/// Consume a non-streaming A2A response and map to AgentEvents.
223async fn handle_non_streaming(
224    client: &A2AJsonRpcClient,
225    request: SendMessageRequest,
226    event_tx: mpsc::Sender<AgentEvent>,
227    session: &mut bamboo_agent_core::Session,
228    tenant: Option<String>,
229) -> crate::runtime::runner::Result<()> {
230    let response = client
231        .send_message(request)
232        .await
233        .map_err(|e| AgentError::LLM(format!("A2A send_message failed: {}", e)))?;
234
235    let mut mapper = A2AEventMapper::new();
236
237    // Build a synthetic StreamResponse from the non-streaming result.
238    let synthetic = bamboo_a2a::types::StreamResponse {
239        task: response.task.clone(),
240        message: response.message.clone(),
241        status_update: None,
242        artifact_update: None,
243    };
244
245    let mapped = mapper.map_stream_response(synthetic);
246    apply_mapped_events(&event_tx, session, mapped).await?;
247
248    // If task is present, emit its status as well.
249    if let Some(task) = &response.task {
250        let status_update = bamboo_a2a::types::StreamResponse {
251            task: None,
252            message: None,
253            status_update: Some(bamboo_a2a::types::TaskStatusUpdateEvent {
254                task_id: task.id.clone(),
255                context_id: task.context_id.clone().unwrap_or_default(),
256                status: task.status.clone(),
257                metadata: None,
258            }),
259            artifact_update: None,
260        };
261        let mapped = mapper.map_stream_response(status_update);
262        apply_mapped_events(&event_tx, session, mapped).await?;
263    }
264
265    if mapper.is_terminal() {
266        append_reference_task_id(session, &mapper);
267        add_final_assistant_message(session, &mapper);
268        return Ok(());
269    }
270
271    // Non-terminal: if there's a message but no task, it's a simple text response.
272    if response.message.is_some() && response.task.is_none() {
273        let _ = event_tx
274            .send(AgentEvent::Complete {
275                usage: TokenUsage::default(),
276            })
277            .await;
278        append_reference_task_id(session, &mapper);
279        add_final_assistant_message(session, &mapper);
280        return Ok(());
281    }
282
283    // Non-terminal with a task: attempt recovery via GetTask.
284    if let Some(task) = &response.task {
285        match recover_task_state(client, &task.id, tenant.clone()).await {
286            Ok(recovered) => {
287                let mapped = mapper.map_stream_response(recovered);
288                apply_mapped_events(&event_tx, session, mapped).await?;
289                if mapper.is_terminal() {
290                    append_reference_task_id(session, &mapper);
291                    add_final_assistant_message(session, &mapper);
292                    return Ok(());
293                }
294            }
295            Err(e) => {
296                tracing::warn!("A2A GetTask recovery after non-streaming failed: {}", e);
297            }
298        }
299    }
300
301    let msg = "A2A non-streaming response did not reach terminal state".to_string();
302    let _ = event_tx
303        .send(AgentEvent::Error {
304            message: msg.clone(),
305        })
306        .await;
307    append_reference_task_id(session, &mapper);
308    add_final_assistant_message(session, &mapper);
309    Err(AgentError::LLM(msg))
310}
311
312/// Consume an SSE stream, map to AgentEvents, handle cancellation and recovery.
313async fn handle_streaming(
314    client: &A2AJsonRpcClient,
315    mut stream: bamboo_a2a::A2AStream,
316    event_tx: mpsc::Sender<AgentEvent>,
317    cancel_token: CancellationToken,
318    session: &mut bamboo_agent_core::Session,
319    tenant: Option<String>,
320) -> crate::runtime::runner::Result<()> {
321    let mut mapper = A2AEventMapper::new();
322
323    loop {
324        tokio::select! {
325            _ = cancel_token.cancelled() => {
326                if let Some(task_id) = mapper.latest_task_id() {
327                    let _ = client.cancel_task(CancelTaskRequest {
328                        tenant: tenant.clone(),
329                        id: task_id.to_string(),
330                        metadata: Some(serde_json::json!({"cancelledBy": "bamboo"})),
331                    }).await;
332                }
333                return Err(AgentError::Cancelled);
334            }
335            item = stream.next() => {
336                match item {
337                    None => {
338                        // Stream closed.
339                        break;
340                    }
341                    Some(Err(e)) => {
342                        tracing::warn!("A2A stream error: {}", e);
343                        if !mapper.is_terminal() {
344                            // Try recovery via GetTask.
345                            if let Some(task_id) = mapper.latest_task_id() {
346                                match recover_task_state(client, task_id, tenant.clone()).await {
347                                    Ok(recovered) => {
348                                        let mapped = mapper.map_stream_response(recovered);
349                                        let _ = apply_mapped_events(&event_tx, session, mapped).await;
350                                    }
351                                    Err(recovery_err) => {
352                                        tracing::warn!("A2A GetTask recovery failed: {}", recovery_err);
353                                    }
354                                }
355                            }
356                        }
357                        return Err(AgentError::LLM(format!("A2A stream error: {}", e)));
358                    }
359                    Some(Ok(response)) => {
360                        let mapped = mapper.map_stream_response(response);
361                        apply_mapped_events(&event_tx, session, mapped).await?;
362
363                        if mapper.is_terminal() {
364                            append_reference_task_id(session, &mapper);
365                            add_final_assistant_message(session, &mapper);
366                            return Ok(());
367                        }
368                    }
369                }
370            }
371        }
372    }
373
374    // Stream ended without terminal state — attempt recovery.
375    if !mapper.is_terminal() {
376        if let Some(task_id) = mapper.latest_task_id() {
377            match recover_task_state(client, task_id, tenant.clone()).await {
378                Ok(recovered) => {
379                    let mapped = mapper.map_stream_response(recovered);
380                    apply_mapped_events(&event_tx, session, mapped).await?;
381                }
382                Err(e) => {
383                    tracing::warn!("A2A GetTask recovery after stream close failed: {}", e);
384                }
385            }
386        }
387
388        if !mapper.is_terminal() {
389            let msg = "A2A stream closed without terminal state".to_string();
390            let _ = event_tx
391                .send(AgentEvent::Error {
392                    message: msg.clone(),
393                })
394                .await;
395            append_reference_task_id(session, &mapper);
396            add_final_assistant_message(session, &mapper);
397            return Err(AgentError::LLM(msg));
398        }
399    }
400
401    append_reference_task_id(session, &mapper);
402    add_final_assistant_message(session, &mapper);
403
404    Ok(())
405}
406
407/// Call `GetTask` to recover the final state of a task after stream disruption.
408async fn recover_task_state(
409    client: &A2AJsonRpcClient,
410    task_id: &str,
411    tenant: Option<String>,
412) -> bamboo_a2a::A2AClientResult<bamboo_a2a::types::StreamResponse> {
413    let task = client
414        .get_task(GetTaskRequest {
415            tenant,
416            id: task_id.to_string(),
417            history_length: Some(0),
418        })
419        .await?;
420
421    Ok(bamboo_a2a::types::StreamResponse {
422        task: Some(task),
423        message: None,
424        status_update: None,
425        artifact_update: None,
426    })
427}
428
429/// Send mapped events through the event channel and apply metadata updates.
430async fn apply_mapped_events(
431    event_tx: &mpsc::Sender<AgentEvent>,
432    session: &mut bamboo_agent_core::Session,
433    mapped: A2AMappedEvents,
434) -> crate::runtime::runner::Result<()> {
435    for (k, v) in mapped.metadata_updates {
436        session.metadata.insert(k, v);
437    }
438
439    for event in mapped.events {
440        event_tx
441            .send(event)
442            .await
443            .map_err(|_| AgentError::Tool("event channel closed".to_string()))?;
444    }
445    Ok(())
446}
447
448/// Append latest task_id to reference_task_ids for future retries.
449fn append_reference_task_id(session: &mut bamboo_agent_core::Session, mapper: &A2AEventMapper) {
450    if let Some(task_id) = mapper.latest_task_id() {
451        let mut refs: Vec<String> = session
452            .metadata
453            .get("a2a.reference_task_ids")
454            .and_then(|v| serde_json::from_str(v).ok())
455            .unwrap_or_default();
456        if !refs.contains(&task_id.to_string()) {
457            refs.push(task_id.to_string());
458            session.metadata.insert(
459                "a2a.reference_task_ids".to_string(),
460                serde_json::to_string(&refs).unwrap_or_default(),
461            );
462        }
463    }
464}
465
466/// Append an assistant message with the final accumulated text.
467fn add_final_assistant_message(session: &mut bamboo_agent_core::Session, mapper: &A2AEventMapper) {
468    let text = mapper.final_text();
469    if text.is_empty() {
470        return;
471    }
472    session.messages.push(bamboo_agent_core::Message {
473        id: uuid::Uuid::new_v4().to_string(),
474        role: Role::Assistant,
475        content: text.to_string(),
476        reasoning: None,
477        content_parts: None,
478        image_ocr: None,
479        phase: None,
480        tool_calls: None,
481        tool_call_id: None,
482        tool_success: None,
483        compressed: false,
484        compressed_by_event_id: None,
485        never_compress: false,
486        compression_level: 0,
487        created_at: chrono::Utc::now(),
488        metadata: None,
489    });
490}
491
492/// Convert Bamboo session messages into a single A2A `Message`.
493///
494/// For MVP we send the latest user message as the primary message.  If the
495/// session has no user message we synthesise one from the session title.
496fn build_a2a_message(session: &bamboo_agent_core::Session) -> Message {
497    let content = session
498        .messages
499        .iter()
500        .rev()
501        .find(|m| matches!(m.role, Role::User))
502        .map(|m| m.content.clone())
503        .unwrap_or_else(|| {
504            session
505                .metadata
506                .get("title")
507                .cloned()
508                .unwrap_or_else(|| "Execute task".to_string())
509        });
510
511    Message {
512        message_id: uuid::Uuid::new_v4().to_string(),
513        context_id: session.metadata.get("a2a.context_id").cloned(),
514        task_id: None,
515        role: A2ARole::User,
516        parts: vec![Part {
517            content: PartContentWire::Text { text: content },
518            metadata: None,
519            filename: None,
520            media_type: Some("text/plain".to_string()),
521        }],
522        metadata: None,
523        extensions: Vec::new(),
524        reference_task_ids: Vec::new(),
525    }
526}