1use std::collections::HashMap;
2
3use bamboo_agent_core::{AgentEvent, TokenUsage};
4use bamboo_infrastructure::a2a::types::{
5 A2ARole, PartContentWire, StreamResponse, TaskState, TaskStatus,
6};
7
8pub struct A2AMappedEvents {
10 pub events: Vec<AgentEvent>,
11 pub metadata_updates: HashMap<String, String>,
12}
13
14#[derive(Default)]
16pub struct A2AEventMapper {
17 terminal_sent: bool,
18 latest_task_id: Option<String>,
19 context_id: Option<String>,
20 final_text: String,
21}
22
23impl A2AEventMapper {
24 pub fn new() -> Self {
25 Self::default()
26 }
27
28 pub fn latest_task_id(&self) -> Option<&str> {
29 self.latest_task_id.as_deref()
30 }
31
32 pub fn context_id(&self) -> Option<&str> {
33 self.context_id.as_deref()
34 }
35
36 pub fn is_terminal(&self) -> bool {
37 self.terminal_sent
38 }
39
40 pub fn final_text(&self) -> &str {
41 &self.final_text
42 }
43
44 pub fn map_stream_response(&mut self, response: StreamResponse) -> A2AMappedEvents {
46 let mut events = Vec::new();
47 let mut metadata = HashMap::new();
48
49 if let Some(task) = response.task {
50 self.latest_task_id = Some(task.id.clone());
51 if let Some(ctx) = task.context_id.clone() {
52 self.context_id = Some(ctx);
53 }
54 metadata.insert("a2a.latest_task_id".to_string(), task.id.clone());
55 if let Some(ctx) = &task.context_id {
56 metadata.insert("a2a.context_id".to_string(), ctx.clone());
57 }
58 metadata.insert(
59 "a2a.last_state".to_string(),
60 task.status.state.as_proto_str().to_string(),
61 );
62 events.extend(self.map_status(&task.id, task.context_id.as_deref(), task.status));
63 }
64
65 if let Some(message) = response.message {
66 if message.role == A2ARole::Agent {
67 let text = text_from_parts(&message.parts);
68 if !text.is_empty() {
69 self.final_text.push_str(&text);
70 events.push(AgentEvent::Token { content: text });
71 }
72 }
73 }
74
75 if let Some(update) = response.status_update {
76 self.latest_task_id = Some(update.task_id.clone());
77 self.context_id = Some(update.context_id.clone());
78 metadata.insert("a2a.latest_task_id".to_string(), update.task_id.clone());
79 metadata.insert("a2a.context_id".to_string(), update.context_id.clone());
80 metadata.insert(
81 "a2a.last_state".to_string(),
82 update.status.state.as_proto_str().to_string(),
83 );
84 events.extend(self.map_status(
85 &update.task_id,
86 Some(&update.context_id),
87 update.status,
88 ));
89 }
90
91 if let Some(update) = response.artifact_update {
92 let preview =
93 handle_artifact_update(&update.artifact, update.append, update.last_chunk);
94 if !preview.is_empty() {
95 events.push(AgentEvent::Token {
96 content: preview.clone(),
97 });
98 self.final_text.push_str(&preview);
99 }
100 metadata.insert(
101 "a2a.last_artifacts_summary".to_string(),
102 serde_json::json!({
103 "artifact_id": update.artifact.artifact_id,
104 "name": update.artifact.name,
105 "append": update.append,
106 "last_chunk": update.last_chunk,
107 })
108 .to_string(),
109 );
110 }
111
112 A2AMappedEvents {
113 events,
114 metadata_updates: metadata,
115 }
116 }
117
118 fn map_status(
119 &mut self,
120 _task_id: &str,
121 _context_id: Option<&str>,
122 status: TaskStatus,
123 ) -> Vec<AgentEvent> {
124 let mut events = Vec::new();
125
126 match &status.state {
127 TaskState::Submitted => {
128 }
130 TaskState::Working => {
131 if let Some(msg) = &status.message {
132 let text = text_from_parts(&msg.parts);
133 if !text.is_empty() {
134 self.final_text.push_str(&text);
135 events.push(AgentEvent::Token { content: text });
136 }
137 }
138 }
139 TaskState::InputRequired => {
140 let question = question_from_status(&status);
141 events.push(AgentEvent::NeedClarification {
142 question,
143 options: None,
144 tool_call_id: None,
145 allow_custom: true,
146 });
147 }
148 TaskState::AuthRequired => {
149 let question = question_from_status(&status);
150 events.push(AgentEvent::NeedClarification {
151 question,
152 options: None,
153 tool_call_id: None,
154 allow_custom: true,
155 });
156 }
157 TaskState::Completed => {
158 self.terminal_sent = true;
159 if let Some(msg) = &status.message {
160 let text = text_from_parts(&msg.parts);
161 if !text.is_empty() {
162 self.final_text.push_str(&text);
163 events.push(AgentEvent::Token { content: text });
164 }
165 }
166 events.push(AgentEvent::Complete {
167 usage: TokenUsage::default(),
168 });
169 }
170 TaskState::Failed => {
171 self.terminal_sent = true;
172 let error_msg = status
173 .message
174 .as_ref()
175 .map(|m| text_from_parts(&m.parts))
176 .filter(|s| !s.is_empty())
177 .unwrap_or_else(|| "External agent reported failure".to_string());
178 events.push(AgentEvent::Error { message: error_msg });
179 }
180 TaskState::Canceled => {
181 self.terminal_sent = true;
182 events.push(AgentEvent::Error {
183 message: "External agent task was cancelled".to_string(),
184 });
185 }
186 TaskState::Rejected => {
187 self.terminal_sent = true;
188 events.push(AgentEvent::Error {
189 message: "External agent rejected the task".to_string(),
190 });
191 }
192 TaskState::Unspecified => {}
193 }
194
195 events
196 }
197}
198
199pub fn text_from_parts(parts: &[bamboo_infrastructure::a2a::types::Part]) -> String {
201 parts
202 .iter()
203 .filter_map(|part| match &part.content {
204 PartContentWire::Text { text } => Some(text.as_str()),
205 PartContentWire::Data { data } => data.get("summary").and_then(|v| v.as_str()),
206 _ => None,
207 })
208 .collect::<Vec<_>>()
209 .join("\n")
210}
211
212fn question_from_status(status: &TaskStatus) -> String {
214 status
215 .message
216 .as_ref()
217 .map(|m| text_from_parts(&m.parts))
218 .filter(|s| !s.trim().is_empty())
219 .unwrap_or_else(|| match status.state {
220 TaskState::InputRequired => "External agent requires additional input.".to_string(),
221 TaskState::AuthRequired => {
222 "External agent requires authentication or authorization.".to_string()
223 }
224 _ => format!("External agent state: {:?}", status.state),
225 })
226}
227
228fn handle_artifact_update(
230 artifact: &bamboo_infrastructure::a2a::types::Artifact,
231 _append: bool,
232 _last_chunk: bool,
233) -> String {
234 let text = text_from_parts(&artifact.parts);
235 if text.is_empty() {
236 if let Some(name) = &artifact.name {
237 format!("[Artifact: {}]", name)
238 } else {
239 format!("[Artifact: {}]", artifact.artifact_id)
240 }
241 } else {
242 let header = artifact
243 .name
244 .as_ref()
245 .map(|n| format!("--- Artifact: {} ---\n", n))
246 .unwrap_or_default();
247 format!("{}{}", header, text)
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use bamboo_infrastructure::a2a::types::{
255 A2ARole, Message, Part, Task, TaskStatus, TaskStatusUpdateEvent,
256 };
257
258 #[test]
259 fn a2a_message_text_maps_to_token() {
260 let mut mapper = A2AEventMapper::new();
261 let response = StreamResponse {
262 task: None,
263 message: Some(Message {
264 message_id: "m1".to_string(),
265 context_id: None,
266 task_id: None,
267 role: A2ARole::Agent,
268 parts: vec![Part {
269 content: PartContentWire::text("hello world"),
270 metadata: None,
271 filename: None,
272 media_type: Some("text/plain".to_string()),
273 }],
274 metadata: None,
275 extensions: vec![],
276 reference_task_ids: vec![],
277 }),
278 status_update: None,
279 artifact_update: None,
280 };
281 let mapped = mapper.map_stream_response(response);
282 assert_eq!(mapped.events.len(), 1);
283 match &mapped.events[0] {
284 AgentEvent::Token { content } => assert_eq!(content, "hello world"),
285 other => panic!("expected Token, got {:?}", other),
286 }
287 }
288
289 #[test]
290 fn a2a_completed_status_maps_to_complete_and_metadata() {
291 let mut mapper = A2AEventMapper::new();
292 let response = StreamResponse {
293 task: Some(Task {
294 id: "task-1".to_string(),
295 context_id: Some("ctx-1".to_string()),
296 status: TaskStatus {
297 state: TaskState::Completed,
298 message: None,
299 timestamp: None,
300 },
301 artifacts: vec![],
302 history: vec![],
303 metadata: None,
304 }),
305 message: None,
306 status_update: None,
307 artifact_update: None,
308 };
309 let mapped = mapper.map_stream_response(response);
310 assert!(mapper.is_terminal());
311 assert_eq!(
312 mapped.metadata_updates.get("a2a.latest_task_id"),
313 Some(&"task-1".to_string())
314 );
315 assert_eq!(
316 mapped.metadata_updates.get("a2a.context_id"),
317 Some(&"ctx-1".to_string())
318 );
319 assert_eq!(
320 mapped.metadata_updates.get("a2a.last_state"),
321 Some(&"TASK_STATE_COMPLETED".to_string())
322 );
323 match &mapped.events[0] {
324 AgentEvent::Complete { .. } => {}
325 other => panic!("expected Complete, got {:?}", other),
326 }
327 }
328
329 #[test]
330 fn a2a_failed_status_maps_to_error() {
331 let mut mapper = A2AEventMapper::new();
332 let response = StreamResponse {
333 task: None,
334 message: None,
335 status_update: Some(TaskStatusUpdateEvent {
336 task_id: "task-1".to_string(),
337 context_id: "ctx-1".to_string(),
338 status: TaskStatus {
339 state: TaskState::Failed,
340 message: Some(Message {
341 message_id: "m1".to_string(),
342 context_id: None,
343 task_id: None,
344 role: A2ARole::Agent,
345 parts: vec![Part {
346 content: PartContentWire::text("Something went wrong"),
347 metadata: None,
348 filename: None,
349 media_type: None,
350 }],
351 metadata: None,
352 extensions: vec![],
353 reference_task_ids: vec![],
354 }),
355 timestamp: None,
356 },
357 metadata: None,
358 }),
359 artifact_update: None,
360 };
361 let mapped = mapper.map_stream_response(response);
362 assert!(mapper.is_terminal());
363 match &mapped.events[0] {
364 AgentEvent::Error { message } => assert_eq!(message, "Something went wrong"),
365 other => panic!("expected Error, got {:?}", other),
366 }
367 }
368
369 #[test]
370 fn a2a_input_required_maps_to_need_clarification() {
371 let mut mapper = A2AEventMapper::new();
372 let response = StreamResponse {
373 task: None,
374 message: None,
375 status_update: Some(TaskStatusUpdateEvent {
376 task_id: "task-1".to_string(),
377 context_id: "ctx-1".to_string(),
378 status: TaskStatus {
379 state: TaskState::InputRequired,
380 message: Some(Message {
381 message_id: "m1".to_string(),
382 context_id: None,
383 task_id: None,
384 role: A2ARole::Agent,
385 parts: vec![Part {
386 content: PartContentWire::text("What is your API key?"),
387 metadata: None,
388 filename: None,
389 media_type: None,
390 }],
391 metadata: None,
392 extensions: vec![],
393 reference_task_ids: vec![],
394 }),
395 timestamp: None,
396 },
397 metadata: None,
398 }),
399 artifact_update: None,
400 };
401 let mapped = mapper.map_stream_response(response);
402 assert!(!mapper.is_terminal());
403 match &mapped.events[0] {
404 AgentEvent::NeedClarification { question, .. } => {
405 assert_eq!(question, "What is your API key?");
406 }
407 other => panic!("expected NeedClarification, got {:?}", other),
408 }
409 }
410}