Skip to main content

bamboo_server/external_agents/
mapping.rs

1use std::collections::HashMap;
2
3use bamboo_agent_core::{AgentEvent, TokenUsage};
4use bamboo_infrastructure::a2a::types::{
5    A2ARole, PartContentWire, StreamResponse, TaskState, TaskStatus,
6};
7
8/// Events and metadata updates produced by mapping a single A2A StreamResponse.
9pub struct A2AMappedEvents {
10    pub events: Vec<AgentEvent>,
11    pub metadata_updates: HashMap<String, String>,
12}
13
14/// Stateful mapper that tracks the latest A2A task state across a stream.
15#[derive(Default)]
16pub struct A2AEventMapper {
17    terminal_sent: bool,
18    latest_task_id: Option<String>,
19    context_id: Option<String>,
20    final_text: String,
21}
22
23impl A2AEventMapper {
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    pub fn latest_task_id(&self) -> Option<&str> {
29        self.latest_task_id.as_deref()
30    }
31
32    pub fn context_id(&self) -> Option<&str> {
33        self.context_id.as_deref()
34    }
35
36    pub fn is_terminal(&self) -> bool {
37        self.terminal_sent
38    }
39
40    pub fn final_text(&self) -> &str {
41        &self.final_text
42    }
43
44    /// Map a single A2A StreamResponse to Bamboo AgentEvents and metadata.
45    pub fn map_stream_response(&mut self, response: StreamResponse) -> A2AMappedEvents {
46        let mut events = Vec::new();
47        let mut metadata = HashMap::new();
48
49        if let Some(task) = response.task {
50            self.latest_task_id = Some(task.id.clone());
51            if let Some(ctx) = task.context_id.clone() {
52                self.context_id = Some(ctx);
53            }
54            metadata.insert("a2a.latest_task_id".to_string(), task.id.clone());
55            if let Some(ctx) = &task.context_id {
56                metadata.insert("a2a.context_id".to_string(), ctx.clone());
57            }
58            metadata.insert(
59                "a2a.last_state".to_string(),
60                task.status.state.as_proto_str().to_string(),
61            );
62            events.extend(self.map_status(&task.id, task.context_id.as_deref(), task.status));
63        }
64
65        if let Some(message) = response.message {
66            if message.role == A2ARole::Agent {
67                let text = text_from_parts(&message.parts);
68                if !text.is_empty() {
69                    self.final_text.push_str(&text);
70                    events.push(AgentEvent::Token { content: text });
71                }
72            }
73        }
74
75        if let Some(update) = response.status_update {
76            self.latest_task_id = Some(update.task_id.clone());
77            self.context_id = Some(update.context_id.clone());
78            metadata.insert("a2a.latest_task_id".to_string(), update.task_id.clone());
79            metadata.insert("a2a.context_id".to_string(), update.context_id.clone());
80            metadata.insert(
81                "a2a.last_state".to_string(),
82                update.status.state.as_proto_str().to_string(),
83            );
84            events.extend(self.map_status(
85                &update.task_id,
86                Some(&update.context_id),
87                update.status,
88            ));
89        }
90
91        if let Some(update) = response.artifact_update {
92            let preview =
93                handle_artifact_update(&update.artifact, update.append, update.last_chunk);
94            if !preview.is_empty() {
95                events.push(AgentEvent::Token {
96                    content: preview.clone(),
97                });
98                self.final_text.push_str(&preview);
99            }
100            metadata.insert(
101                "a2a.last_artifacts_summary".to_string(),
102                serde_json::json!({
103                    "artifact_id": update.artifact.artifact_id,
104                    "name": update.artifact.name,
105                    "append": update.append,
106                    "last_chunk": update.last_chunk,
107                })
108                .to_string(),
109            );
110        }
111
112        A2AMappedEvents {
113            events,
114            metadata_updates: metadata,
115        }
116    }
117
118    fn map_status(
119        &mut self,
120        _task_id: &str,
121        _context_id: Option<&str>,
122        status: TaskStatus,
123    ) -> Vec<AgentEvent> {
124        let mut events = Vec::new();
125
126        match &status.state {
127            TaskState::Submitted => {
128                // Optionally emit a brief status token
129            }
130            TaskState::Working => {
131                if let Some(msg) = &status.message {
132                    let text = text_from_parts(&msg.parts);
133                    if !text.is_empty() {
134                        self.final_text.push_str(&text);
135                        events.push(AgentEvent::Token { content: text });
136                    }
137                }
138            }
139            TaskState::InputRequired => {
140                let question = question_from_status(&status);
141                events.push(AgentEvent::NeedClarification {
142                    question,
143                    options: None,
144                    tool_call_id: None,
145                    tool_name: None,
146                    allow_custom: true,
147                });
148            }
149            TaskState::AuthRequired => {
150                let question = question_from_status(&status);
151                events.push(AgentEvent::NeedClarification {
152                    question,
153                    options: None,
154                    tool_call_id: None,
155                    tool_name: None,
156                    allow_custom: true,
157                });
158            }
159            TaskState::Completed => {
160                self.terminal_sent = true;
161                if let Some(msg) = &status.message {
162                    let text = text_from_parts(&msg.parts);
163                    if !text.is_empty() {
164                        self.final_text.push_str(&text);
165                        events.push(AgentEvent::Token { content: text });
166                    }
167                }
168                events.push(AgentEvent::Complete {
169                    usage: TokenUsage::default(),
170                });
171            }
172            TaskState::Failed => {
173                self.terminal_sent = true;
174                let error_msg = status
175                    .message
176                    .as_ref()
177                    .map(|m| text_from_parts(&m.parts))
178                    .filter(|s| !s.is_empty())
179                    .unwrap_or_else(|| "External agent reported failure".to_string());
180                events.push(AgentEvent::Error { message: error_msg });
181            }
182            TaskState::Canceled => {
183                self.terminal_sent = true;
184                events.push(AgentEvent::Error {
185                    message: "External agent task was cancelled".to_string(),
186                });
187            }
188            TaskState::Rejected => {
189                self.terminal_sent = true;
190                events.push(AgentEvent::Error {
191                    message: "External agent rejected the task".to_string(),
192                });
193            }
194            TaskState::Unspecified => {}
195        }
196
197        events
198    }
199}
200
201/// Extract plain text from a slice of Parts.
202pub fn text_from_parts(parts: &[bamboo_infrastructure::a2a::types::Part]) -> String {
203    parts
204        .iter()
205        .filter_map(|part| match &part.content {
206            PartContentWire::Text { text } => Some(text.as_str()),
207            PartContentWire::Data { data } => data.get("summary").and_then(|v| v.as_str()),
208            _ => None,
209        })
210        .collect::<Vec<_>>()
211        .join("\n")
212}
213
214/// Build a human-readable question from a TaskStatus that requires input/auth.
215fn question_from_status(status: &TaskStatus) -> String {
216    status
217        .message
218        .as_ref()
219        .map(|m| text_from_parts(&m.parts))
220        .filter(|s| !s.trim().is_empty())
221        .unwrap_or_else(|| match status.state {
222            TaskState::InputRequired => "External agent requires additional input.".to_string(),
223            TaskState::AuthRequired => {
224                "External agent requires authentication or authorization.".to_string()
225            }
226            _ => format!("External agent state: {:?}", status.state),
227        })
228}
229
230/// Build a preview string from an artifact update.
231fn handle_artifact_update(
232    artifact: &bamboo_infrastructure::a2a::types::Artifact,
233    _append: bool,
234    _last_chunk: bool,
235) -> String {
236    let text = text_from_parts(&artifact.parts);
237    if text.is_empty() {
238        if let Some(name) = &artifact.name {
239            format!("[Artifact: {}]", name)
240        } else {
241            format!("[Artifact: {}]", artifact.artifact_id)
242        }
243    } else {
244        let header = artifact
245            .name
246            .as_ref()
247            .map(|n| format!("--- Artifact: {} ---\n", n))
248            .unwrap_or_default();
249        format!("{}{}", header, text)
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use bamboo_infrastructure::a2a::types::{
257        A2ARole, Message, Part, Task, TaskStatus, TaskStatusUpdateEvent,
258    };
259
260    #[test]
261    fn a2a_message_text_maps_to_token() {
262        let mut mapper = A2AEventMapper::new();
263        let response = StreamResponse {
264            task: None,
265            message: Some(Message {
266                message_id: "m1".to_string(),
267                context_id: None,
268                task_id: None,
269                role: A2ARole::Agent,
270                parts: vec![Part {
271                    content: PartContentWire::text("hello world"),
272                    metadata: None,
273                    filename: None,
274                    media_type: Some("text/plain".to_string()),
275                }],
276                metadata: None,
277                extensions: vec![],
278                reference_task_ids: vec![],
279            }),
280            status_update: None,
281            artifact_update: None,
282        };
283        let mapped = mapper.map_stream_response(response);
284        assert_eq!(mapped.events.len(), 1);
285        match &mapped.events[0] {
286            AgentEvent::Token { content } => assert_eq!(content, "hello world"),
287            other => panic!("expected Token, got {:?}", other),
288        }
289    }
290
291    #[test]
292    fn a2a_completed_status_maps_to_complete_and_metadata() {
293        let mut mapper = A2AEventMapper::new();
294        let response = StreamResponse {
295            task: Some(Task {
296                id: "task-1".to_string(),
297                context_id: Some("ctx-1".to_string()),
298                status: TaskStatus {
299                    state: TaskState::Completed,
300                    message: None,
301                    timestamp: None,
302                },
303                artifacts: vec![],
304                history: vec![],
305                metadata: None,
306            }),
307            message: None,
308            status_update: None,
309            artifact_update: None,
310        };
311        let mapped = mapper.map_stream_response(response);
312        assert!(mapper.is_terminal());
313        assert_eq!(
314            mapped.metadata_updates.get("a2a.latest_task_id"),
315            Some(&"task-1".to_string())
316        );
317        assert_eq!(
318            mapped.metadata_updates.get("a2a.context_id"),
319            Some(&"ctx-1".to_string())
320        );
321        assert_eq!(
322            mapped.metadata_updates.get("a2a.last_state"),
323            Some(&"TASK_STATE_COMPLETED".to_string())
324        );
325        match &mapped.events[0] {
326            AgentEvent::Complete { .. } => {}
327            other => panic!("expected Complete, got {:?}", other),
328        }
329    }
330
331    #[test]
332    fn a2a_failed_status_maps_to_error() {
333        let mut mapper = A2AEventMapper::new();
334        let response = StreamResponse {
335            task: None,
336            message: None,
337            status_update: Some(TaskStatusUpdateEvent {
338                task_id: "task-1".to_string(),
339                context_id: "ctx-1".to_string(),
340                status: TaskStatus {
341                    state: TaskState::Failed,
342                    message: Some(Message {
343                        message_id: "m1".to_string(),
344                        context_id: None,
345                        task_id: None,
346                        role: A2ARole::Agent,
347                        parts: vec![Part {
348                            content: PartContentWire::text("Something went wrong"),
349                            metadata: None,
350                            filename: None,
351                            media_type: None,
352                        }],
353                        metadata: None,
354                        extensions: vec![],
355                        reference_task_ids: vec![],
356                    }),
357                    timestamp: None,
358                },
359                metadata: None,
360            }),
361            artifact_update: None,
362        };
363        let mapped = mapper.map_stream_response(response);
364        assert!(mapper.is_terminal());
365        match &mapped.events[0] {
366            AgentEvent::Error { message } => assert_eq!(message, "Something went wrong"),
367            other => panic!("expected Error, got {:?}", other),
368        }
369    }
370
371    #[test]
372    fn a2a_input_required_maps_to_need_clarification() {
373        let mut mapper = A2AEventMapper::new();
374        let response = StreamResponse {
375            task: None,
376            message: None,
377            status_update: Some(TaskStatusUpdateEvent {
378                task_id: "task-1".to_string(),
379                context_id: "ctx-1".to_string(),
380                status: TaskStatus {
381                    state: TaskState::InputRequired,
382                    message: Some(Message {
383                        message_id: "m1".to_string(),
384                        context_id: None,
385                        task_id: None,
386                        role: A2ARole::Agent,
387                        parts: vec![Part {
388                            content: PartContentWire::text("What is your API key?"),
389                            metadata: None,
390                            filename: None,
391                            media_type: None,
392                        }],
393                        metadata: None,
394                        extensions: vec![],
395                        reference_task_ids: vec![],
396                    }),
397                    timestamp: None,
398                },
399                metadata: None,
400            }),
401            artifact_update: None,
402        };
403        let mapped = mapper.map_stream_response(response);
404        assert!(!mapper.is_terminal());
405        match &mapped.events[0] {
406            AgentEvent::NeedClarification { question, .. } => {
407                assert_eq!(question, "What is your API key?");
408            }
409            other => panic!("expected NeedClarification, got {:?}", other),
410        }
411    }
412}