Skip to main content

bamboo_server/external_agents/
mapping.rs

1use std::collections::HashMap;
2
3use bamboo_agent_core::{AgentEvent, TokenUsage};
4use bamboo_infrastructure::a2a::types::{
5    A2ARole, PartContentWire, StreamResponse, TaskState, TaskStatus,
6};
7
8/// Events and metadata updates produced by mapping a single A2A StreamResponse.
9pub struct A2AMappedEvents {
10    pub events: Vec<AgentEvent>,
11    pub metadata_updates: HashMap<String, String>,
12}
13
14/// Stateful mapper that tracks the latest A2A task state across a stream.
15#[derive(Default)]
16pub struct A2AEventMapper {
17    terminal_sent: bool,
18    latest_task_id: Option<String>,
19    context_id: Option<String>,
20    final_text: String,
21}
22
23impl A2AEventMapper {
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    pub fn latest_task_id(&self) -> Option<&str> {
29        self.latest_task_id.as_deref()
30    }
31
32    pub fn context_id(&self) -> Option<&str> {
33        self.context_id.as_deref()
34    }
35
36    pub fn is_terminal(&self) -> bool {
37        self.terminal_sent
38    }
39
40    pub fn final_text(&self) -> &str {
41        &self.final_text
42    }
43
44    /// Map a single A2A StreamResponse to Bamboo AgentEvents and metadata.
45    pub fn map_stream_response(&mut self, response: StreamResponse) -> A2AMappedEvents {
46        let mut events = Vec::new();
47        let mut metadata = HashMap::new();
48
49        if let Some(task) = response.task {
50            self.latest_task_id = Some(task.id.clone());
51            if let Some(ctx) = task.context_id.clone() {
52                self.context_id = Some(ctx);
53            }
54            metadata.insert("a2a.latest_task_id".to_string(), task.id.clone());
55            if let Some(ctx) = &task.context_id {
56                metadata.insert("a2a.context_id".to_string(), ctx.clone());
57            }
58            metadata.insert(
59                "a2a.last_state".to_string(),
60                task.status.state.as_proto_str().to_string(),
61            );
62            events.extend(self.map_status(&task.id, task.context_id.as_deref(), task.status));
63        }
64
65        if let Some(message) = response.message {
66            if message.role == A2ARole::Agent {
67                let text = text_from_parts(&message.parts);
68                if !text.is_empty() {
69                    self.final_text.push_str(&text);
70                    events.push(AgentEvent::Token { content: text });
71                }
72            }
73        }
74
75        if let Some(update) = response.status_update {
76            self.latest_task_id = Some(update.task_id.clone());
77            self.context_id = Some(update.context_id.clone());
78            metadata.insert("a2a.latest_task_id".to_string(), update.task_id.clone());
79            metadata.insert("a2a.context_id".to_string(), update.context_id.clone());
80            metadata.insert(
81                "a2a.last_state".to_string(),
82                update.status.state.as_proto_str().to_string(),
83            );
84            events.extend(self.map_status(
85                &update.task_id,
86                Some(&update.context_id),
87                update.status,
88            ));
89        }
90
91        if let Some(update) = response.artifact_update {
92            let preview =
93                handle_artifact_update(&update.artifact, update.append, update.last_chunk);
94            if !preview.is_empty() {
95                events.push(AgentEvent::Token {
96                    content: preview.clone(),
97                });
98                self.final_text.push_str(&preview);
99            }
100            metadata.insert(
101                "a2a.last_artifacts_summary".to_string(),
102                serde_json::json!({
103                    "artifact_id": update.artifact.artifact_id,
104                    "name": update.artifact.name,
105                    "append": update.append,
106                    "last_chunk": update.last_chunk,
107                })
108                .to_string(),
109            );
110        }
111
112        A2AMappedEvents {
113            events,
114            metadata_updates: metadata,
115        }
116    }
117
118    fn map_status(
119        &mut self,
120        _task_id: &str,
121        _context_id: Option<&str>,
122        status: TaskStatus,
123    ) -> Vec<AgentEvent> {
124        let mut events = Vec::new();
125
126        match &status.state {
127            TaskState::Submitted => {
128                // Optionally emit a brief status token
129            }
130            TaskState::Working => {
131                if let Some(msg) = &status.message {
132                    let text = text_from_parts(&msg.parts);
133                    if !text.is_empty() {
134                        self.final_text.push_str(&text);
135                        events.push(AgentEvent::Token { content: text });
136                    }
137                }
138            }
139            TaskState::InputRequired => {
140                let question = question_from_status(&status);
141                events.push(AgentEvent::NeedClarification {
142                    question,
143                    options: None,
144                    tool_call_id: None,
145                    allow_custom: true,
146                });
147            }
148            TaskState::AuthRequired => {
149                let question = question_from_status(&status);
150                events.push(AgentEvent::NeedClarification {
151                    question,
152                    options: None,
153                    tool_call_id: None,
154                    allow_custom: true,
155                });
156            }
157            TaskState::Completed => {
158                self.terminal_sent = true;
159                if let Some(msg) = &status.message {
160                    let text = text_from_parts(&msg.parts);
161                    if !text.is_empty() {
162                        self.final_text.push_str(&text);
163                        events.push(AgentEvent::Token { content: text });
164                    }
165                }
166                events.push(AgentEvent::Complete {
167                    usage: TokenUsage::default(),
168                });
169            }
170            TaskState::Failed => {
171                self.terminal_sent = true;
172                let error_msg = status
173                    .message
174                    .as_ref()
175                    .map(|m| text_from_parts(&m.parts))
176                    .filter(|s| !s.is_empty())
177                    .unwrap_or_else(|| "External agent reported failure".to_string());
178                events.push(AgentEvent::Error { message: error_msg });
179            }
180            TaskState::Canceled => {
181                self.terminal_sent = true;
182                events.push(AgentEvent::Error {
183                    message: "External agent task was cancelled".to_string(),
184                });
185            }
186            TaskState::Rejected => {
187                self.terminal_sent = true;
188                events.push(AgentEvent::Error {
189                    message: "External agent rejected the task".to_string(),
190                });
191            }
192            TaskState::Unspecified => {}
193        }
194
195        events
196    }
197}
198
199/// Extract plain text from a slice of Parts.
200pub fn text_from_parts(parts: &[bamboo_infrastructure::a2a::types::Part]) -> String {
201    parts
202        .iter()
203        .filter_map(|part| match &part.content {
204            PartContentWire::Text { text } => Some(text.as_str()),
205            PartContentWire::Data { data } => data.get("summary").and_then(|v| v.as_str()),
206            _ => None,
207        })
208        .collect::<Vec<_>>()
209        .join("\n")
210}
211
212/// Build a human-readable question from a TaskStatus that requires input/auth.
213fn question_from_status(status: &TaskStatus) -> String {
214    status
215        .message
216        .as_ref()
217        .map(|m| text_from_parts(&m.parts))
218        .filter(|s| !s.trim().is_empty())
219        .unwrap_or_else(|| match status.state {
220            TaskState::InputRequired => "External agent requires additional input.".to_string(),
221            TaskState::AuthRequired => {
222                "External agent requires authentication or authorization.".to_string()
223            }
224            _ => format!("External agent state: {:?}", status.state),
225        })
226}
227
228/// Build a preview string from an artifact update.
229fn handle_artifact_update(
230    artifact: &bamboo_infrastructure::a2a::types::Artifact,
231    _append: bool,
232    _last_chunk: bool,
233) -> String {
234    let text = text_from_parts(&artifact.parts);
235    if text.is_empty() {
236        if let Some(name) = &artifact.name {
237            format!("[Artifact: {}]", name)
238        } else {
239            format!("[Artifact: {}]", artifact.artifact_id)
240        }
241    } else {
242        let header = artifact
243            .name
244            .as_ref()
245            .map(|n| format!("--- Artifact: {} ---\n", n))
246            .unwrap_or_default();
247        format!("{}{}", header, text)
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use bamboo_infrastructure::a2a::types::{
255        A2ARole, Message, Part, Task, TaskStatus, TaskStatusUpdateEvent,
256    };
257
258    #[test]
259    fn a2a_message_text_maps_to_token() {
260        let mut mapper = A2AEventMapper::new();
261        let response = StreamResponse {
262            task: None,
263            message: Some(Message {
264                message_id: "m1".to_string(),
265                context_id: None,
266                task_id: None,
267                role: A2ARole::Agent,
268                parts: vec![Part {
269                    content: PartContentWire::text("hello world"),
270                    metadata: None,
271                    filename: None,
272                    media_type: Some("text/plain".to_string()),
273                }],
274                metadata: None,
275                extensions: vec![],
276                reference_task_ids: vec![],
277            }),
278            status_update: None,
279            artifact_update: None,
280        };
281        let mapped = mapper.map_stream_response(response);
282        assert_eq!(mapped.events.len(), 1);
283        match &mapped.events[0] {
284            AgentEvent::Token { content } => assert_eq!(content, "hello world"),
285            other => panic!("expected Token, got {:?}", other),
286        }
287    }
288
289    #[test]
290    fn a2a_completed_status_maps_to_complete_and_metadata() {
291        let mut mapper = A2AEventMapper::new();
292        let response = StreamResponse {
293            task: Some(Task {
294                id: "task-1".to_string(),
295                context_id: Some("ctx-1".to_string()),
296                status: TaskStatus {
297                    state: TaskState::Completed,
298                    message: None,
299                    timestamp: None,
300                },
301                artifacts: vec![],
302                history: vec![],
303                metadata: None,
304            }),
305            message: None,
306            status_update: None,
307            artifact_update: None,
308        };
309        let mapped = mapper.map_stream_response(response);
310        assert!(mapper.is_terminal());
311        assert_eq!(
312            mapped.metadata_updates.get("a2a.latest_task_id"),
313            Some(&"task-1".to_string())
314        );
315        assert_eq!(
316            mapped.metadata_updates.get("a2a.context_id"),
317            Some(&"ctx-1".to_string())
318        );
319        assert_eq!(
320            mapped.metadata_updates.get("a2a.last_state"),
321            Some(&"TASK_STATE_COMPLETED".to_string())
322        );
323        match &mapped.events[0] {
324            AgentEvent::Complete { .. } => {}
325            other => panic!("expected Complete, got {:?}", other),
326        }
327    }
328
329    #[test]
330    fn a2a_failed_status_maps_to_error() {
331        let mut mapper = A2AEventMapper::new();
332        let response = StreamResponse {
333            task: None,
334            message: None,
335            status_update: Some(TaskStatusUpdateEvent {
336                task_id: "task-1".to_string(),
337                context_id: "ctx-1".to_string(),
338                status: TaskStatus {
339                    state: TaskState::Failed,
340                    message: Some(Message {
341                        message_id: "m1".to_string(),
342                        context_id: None,
343                        task_id: None,
344                        role: A2ARole::Agent,
345                        parts: vec![Part {
346                            content: PartContentWire::text("Something went wrong"),
347                            metadata: None,
348                            filename: None,
349                            media_type: None,
350                        }],
351                        metadata: None,
352                        extensions: vec![],
353                        reference_task_ids: vec![],
354                    }),
355                    timestamp: None,
356                },
357                metadata: None,
358            }),
359            artifact_update: None,
360        };
361        let mapped = mapper.map_stream_response(response);
362        assert!(mapper.is_terminal());
363        match &mapped.events[0] {
364            AgentEvent::Error { message } => assert_eq!(message, "Something went wrong"),
365            other => panic!("expected Error, got {:?}", other),
366        }
367    }
368
369    #[test]
370    fn a2a_input_required_maps_to_need_clarification() {
371        let mut mapper = A2AEventMapper::new();
372        let response = StreamResponse {
373            task: None,
374            message: None,
375            status_update: Some(TaskStatusUpdateEvent {
376                task_id: "task-1".to_string(),
377                context_id: "ctx-1".to_string(),
378                status: TaskStatus {
379                    state: TaskState::InputRequired,
380                    message: Some(Message {
381                        message_id: "m1".to_string(),
382                        context_id: None,
383                        task_id: None,
384                        role: A2ARole::Agent,
385                        parts: vec![Part {
386                            content: PartContentWire::text("What is your API key?"),
387                            metadata: None,
388                            filename: None,
389                            media_type: None,
390                        }],
391                        metadata: None,
392                        extensions: vec![],
393                        reference_task_ids: vec![],
394                    }),
395                    timestamp: None,
396                },
397                metadata: None,
398            }),
399            artifact_update: None,
400        };
401        let mapped = mapper.map_stream_response(response);
402        assert!(!mapper.is_terminal());
403        match &mapped.events[0] {
404            AgentEvent::NeedClarification { question, .. } => {
405                assert_eq!(question, "What is your API key?");
406            }
407            other => panic!("expected NeedClarification, got {:?}", other),
408        }
409    }
410}