Skip to main content

bamboo_engine/external_agents/
mapping.rs

1use std::collections::HashMap;
2
3use bamboo_a2a::types::{A2ARole, PartContentWire, StreamResponse, TaskState, TaskStatus};
4use bamboo_agent_core::{AgentEvent, TokenUsage};
5
6/// Events and metadata updates produced by mapping a single A2A StreamResponse.
7pub struct A2AMappedEvents {
8    pub events: Vec<AgentEvent>,
9    pub metadata_updates: HashMap<String, String>,
10}
11
12/// Stateful mapper that tracks the latest A2A task state across a stream.
13#[derive(Default)]
14pub struct A2AEventMapper {
15    terminal_sent: bool,
16    latest_task_id: Option<String>,
17    context_id: Option<String>,
18    final_text: String,
19}
20
21impl A2AEventMapper {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn latest_task_id(&self) -> Option<&str> {
27        self.latest_task_id.as_deref()
28    }
29
30    pub fn context_id(&self) -> Option<&str> {
31        self.context_id.as_deref()
32    }
33
34    pub fn is_terminal(&self) -> bool {
35        self.terminal_sent
36    }
37
38    pub fn final_text(&self) -> &str {
39        &self.final_text
40    }
41
42    /// Map a single A2A StreamResponse to Bamboo AgentEvents and metadata.
43    pub fn map_stream_response(&mut self, response: StreamResponse) -> A2AMappedEvents {
44        let mut events = Vec::new();
45        let mut metadata = HashMap::new();
46
47        if let Some(task) = response.task {
48            self.latest_task_id = Some(task.id.clone());
49            if let Some(ctx) = task.context_id.clone() {
50                self.context_id = Some(ctx);
51            }
52            metadata.insert("a2a.latest_task_id".to_string(), task.id.clone());
53            if let Some(ctx) = &task.context_id {
54                metadata.insert("a2a.context_id".to_string(), ctx.clone());
55            }
56            metadata.insert(
57                "a2a.last_state".to_string(),
58                task.status.state.as_proto_str().to_string(),
59            );
60            events.extend(self.map_status(&task.id, task.context_id.as_deref(), task.status));
61        }
62
63        if let Some(message) = response.message {
64            if message.role == A2ARole::Agent {
65                let text = text_from_parts(&message.parts);
66                if !text.is_empty() {
67                    self.final_text.push_str(&text);
68                    events.push(AgentEvent::Token { content: text });
69                }
70            }
71        }
72
73        if let Some(update) = response.status_update {
74            self.latest_task_id = Some(update.task_id.clone());
75            self.context_id = Some(update.context_id.clone());
76            metadata.insert("a2a.latest_task_id".to_string(), update.task_id.clone());
77            metadata.insert("a2a.context_id".to_string(), update.context_id.clone());
78            metadata.insert(
79                "a2a.last_state".to_string(),
80                update.status.state.as_proto_str().to_string(),
81            );
82            events.extend(self.map_status(
83                &update.task_id,
84                Some(&update.context_id),
85                update.status,
86            ));
87        }
88
89        if let Some(update) = response.artifact_update {
90            let preview =
91                handle_artifact_update(&update.artifact, update.append, update.last_chunk);
92            if !preview.is_empty() {
93                events.push(AgentEvent::Token {
94                    content: preview.clone(),
95                });
96                self.final_text.push_str(&preview);
97            }
98            metadata.insert(
99                "a2a.last_artifacts_summary".to_string(),
100                serde_json::json!({
101                    "artifact_id": update.artifact.artifact_id,
102                    "name": update.artifact.name,
103                    "append": update.append,
104                    "last_chunk": update.last_chunk,
105                })
106                .to_string(),
107            );
108        }
109
110        A2AMappedEvents {
111            events,
112            metadata_updates: metadata,
113        }
114    }
115
116    fn map_status(
117        &mut self,
118        _task_id: &str,
119        _context_id: Option<&str>,
120        status: TaskStatus,
121    ) -> Vec<AgentEvent> {
122        let mut events = Vec::new();
123
124        match &status.state {
125            TaskState::Submitted => {
126                // Optionally emit a brief status token
127            }
128            TaskState::Working => {
129                if let Some(msg) = &status.message {
130                    let text = text_from_parts(&msg.parts);
131                    if !text.is_empty() {
132                        self.final_text.push_str(&text);
133                        events.push(AgentEvent::Token { content: text });
134                    }
135                }
136            }
137            TaskState::InputRequired => {
138                let question = question_from_status(&status);
139                events.push(AgentEvent::NeedClarification {
140                    question,
141                    options: None,
142                    tool_call_id: None,
143                    tool_name: None,
144                    allow_custom: true,
145                });
146            }
147            TaskState::AuthRequired => {
148                let question = question_from_status(&status);
149                events.push(AgentEvent::NeedClarification {
150                    question,
151                    options: None,
152                    tool_call_id: None,
153                    tool_name: None,
154                    allow_custom: true,
155                });
156            }
157            TaskState::Completed => {
158                self.terminal_sent = true;
159                if let Some(msg) = &status.message {
160                    let text = text_from_parts(&msg.parts);
161                    if !text.is_empty() {
162                        self.final_text.push_str(&text);
163                        events.push(AgentEvent::Token { content: text });
164                    }
165                }
166                events.push(AgentEvent::Complete {
167                    usage: TokenUsage::default(),
168                });
169            }
170            TaskState::Failed => {
171                self.terminal_sent = true;
172                let error_msg = status
173                    .message
174                    .as_ref()
175                    .map(|m| text_from_parts(&m.parts))
176                    .filter(|s| !s.is_empty())
177                    .unwrap_or_else(|| "External agent reported failure".to_string());
178                events.push(AgentEvent::Error { message: error_msg });
179            }
180            TaskState::Canceled => {
181                self.terminal_sent = true;
182                events.push(AgentEvent::Error {
183                    message: "External agent task was cancelled".to_string(),
184                });
185            }
186            TaskState::Rejected => {
187                self.terminal_sent = true;
188                events.push(AgentEvent::Error {
189                    message: "External agent rejected the task".to_string(),
190                });
191            }
192            TaskState::Unspecified => {}
193        }
194
195        events
196    }
197}
198
199/// Extract plain text from a slice of Parts.
200pub fn text_from_parts(parts: &[bamboo_a2a::types::Part]) -> String {
201    parts
202        .iter()
203        .filter_map(|part| match &part.content {
204            PartContentWire::Text { text } => Some(text.as_str()),
205            PartContentWire::Data { data } => data.get("summary").and_then(|v| v.as_str()),
206            _ => None,
207        })
208        .collect::<Vec<_>>()
209        .join("\n")
210}
211
212/// Build a human-readable question from a TaskStatus that requires input/auth.
213fn question_from_status(status: &TaskStatus) -> String {
214    status
215        .message
216        .as_ref()
217        .map(|m| text_from_parts(&m.parts))
218        .filter(|s| !s.trim().is_empty())
219        .unwrap_or_else(|| match status.state {
220            TaskState::InputRequired => "External agent requires additional input.".to_string(),
221            TaskState::AuthRequired => {
222                "External agent requires authentication or authorization.".to_string()
223            }
224            _ => format!("External agent state: {:?}", status.state),
225        })
226}
227
228/// Build a preview string from an artifact update.
229fn handle_artifact_update(
230    artifact: &bamboo_a2a::types::Artifact,
231    _append: bool,
232    _last_chunk: bool,
233) -> String {
234    let text = text_from_parts(&artifact.parts);
235    if text.is_empty() {
236        if let Some(name) = &artifact.name {
237            format!("[Artifact: {}]", name)
238        } else {
239            format!("[Artifact: {}]", artifact.artifact_id)
240        }
241    } else {
242        let header = artifact
243            .name
244            .as_ref()
245            .map(|n| format!("--- Artifact: {} ---\n", n))
246            .unwrap_or_default();
247        format!("{}{}", header, text)
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use bamboo_a2a::types::{A2ARole, Message, Part, Task, TaskStatus, TaskStatusUpdateEvent};
255
256    #[test]
257    fn a2a_message_text_maps_to_token() {
258        let mut mapper = A2AEventMapper::new();
259        let response = StreamResponse {
260            task: None,
261            message: Some(Message {
262                message_id: "m1".to_string(),
263                context_id: None,
264                task_id: None,
265                role: A2ARole::Agent,
266                parts: vec![Part {
267                    content: PartContentWire::text("hello world"),
268                    metadata: None,
269                    filename: None,
270                    media_type: Some("text/plain".to_string()),
271                }],
272                metadata: None,
273                extensions: vec![],
274                reference_task_ids: vec![],
275            }),
276            status_update: None,
277            artifact_update: None,
278        };
279        let mapped = mapper.map_stream_response(response);
280        assert_eq!(mapped.events.len(), 1);
281        match &mapped.events[0] {
282            AgentEvent::Token { content } => assert_eq!(content, "hello world"),
283            other => panic!("expected Token, got {:?}", other),
284        }
285    }
286
287    #[test]
288    fn a2a_completed_status_maps_to_complete_and_metadata() {
289        let mut mapper = A2AEventMapper::new();
290        let response = StreamResponse {
291            task: Some(Task {
292                id: "task-1".to_string(),
293                context_id: Some("ctx-1".to_string()),
294                status: TaskStatus {
295                    state: TaskState::Completed,
296                    message: None,
297                    timestamp: None,
298                },
299                artifacts: vec![],
300                history: vec![],
301                metadata: None,
302            }),
303            message: None,
304            status_update: None,
305            artifact_update: None,
306        };
307        let mapped = mapper.map_stream_response(response);
308        assert!(mapper.is_terminal());
309        assert_eq!(
310            mapped.metadata_updates.get("a2a.latest_task_id"),
311            Some(&"task-1".to_string())
312        );
313        assert_eq!(
314            mapped.metadata_updates.get("a2a.context_id"),
315            Some(&"ctx-1".to_string())
316        );
317        assert_eq!(
318            mapped.metadata_updates.get("a2a.last_state"),
319            Some(&"TASK_STATE_COMPLETED".to_string())
320        );
321        match &mapped.events[0] {
322            AgentEvent::Complete { .. } => {}
323            other => panic!("expected Complete, got {:?}", other),
324        }
325    }
326
327    #[test]
328    fn a2a_failed_status_maps_to_error() {
329        let mut mapper = A2AEventMapper::new();
330        let response = StreamResponse {
331            task: None,
332            message: None,
333            status_update: Some(TaskStatusUpdateEvent {
334                task_id: "task-1".to_string(),
335                context_id: "ctx-1".to_string(),
336                status: TaskStatus {
337                    state: TaskState::Failed,
338                    message: Some(Message {
339                        message_id: "m1".to_string(),
340                        context_id: None,
341                        task_id: None,
342                        role: A2ARole::Agent,
343                        parts: vec![Part {
344                            content: PartContentWire::text("Something went wrong"),
345                            metadata: None,
346                            filename: None,
347                            media_type: None,
348                        }],
349                        metadata: None,
350                        extensions: vec![],
351                        reference_task_ids: vec![],
352                    }),
353                    timestamp: None,
354                },
355                metadata: None,
356            }),
357            artifact_update: None,
358        };
359        let mapped = mapper.map_stream_response(response);
360        assert!(mapper.is_terminal());
361        match &mapped.events[0] {
362            AgentEvent::Error { message } => assert_eq!(message, "Something went wrong"),
363            other => panic!("expected Error, got {:?}", other),
364        }
365    }
366
367    #[test]
368    fn a2a_input_required_maps_to_need_clarification() {
369        let mut mapper = A2AEventMapper::new();
370        let response = StreamResponse {
371            task: None,
372            message: None,
373            status_update: Some(TaskStatusUpdateEvent {
374                task_id: "task-1".to_string(),
375                context_id: "ctx-1".to_string(),
376                status: TaskStatus {
377                    state: TaskState::InputRequired,
378                    message: Some(Message {
379                        message_id: "m1".to_string(),
380                        context_id: None,
381                        task_id: None,
382                        role: A2ARole::Agent,
383                        parts: vec![Part {
384                            content: PartContentWire::text("What is your API key?"),
385                            metadata: None,
386                            filename: None,
387                            media_type: None,
388                        }],
389                        metadata: None,
390                        extensions: vec![],
391                        reference_task_ids: vec![],
392                    }),
393                    timestamp: None,
394                },
395                metadata: None,
396            }),
397            artifact_update: None,
398        };
399        let mapped = mapper.map_stream_response(response);
400        assert!(!mapper.is_terminal());
401        match &mapped.events[0] {
402            AgentEvent::NeedClarification { question, .. } => {
403                assert_eq!(question, "What is your API key?");
404            }
405            other => panic!("expected NeedClarification, got {:?}", other),
406        }
407    }
408}