1use std::collections::HashMap;
2
3use bamboo_agent_core::{AgentEvent, TokenUsage};
4use bamboo_infrastructure::a2a::types::{
5 A2ARole, PartContentWire, StreamResponse, TaskState, TaskStatus,
6};
7
8pub struct A2AMappedEvents {
10 pub events: Vec<AgentEvent>,
11 pub metadata_updates: HashMap<String, String>,
12}
13
14#[derive(Default)]
16pub struct A2AEventMapper {
17 terminal_sent: bool,
18 latest_task_id: Option<String>,
19 context_id: Option<String>,
20 final_text: String,
21}
22
23impl A2AEventMapper {
24 pub fn new() -> Self {
25 Self::default()
26 }
27
28 pub fn latest_task_id(&self) -> Option<&str> {
29 self.latest_task_id.as_deref()
30 }
31
32 pub fn context_id(&self) -> Option<&str> {
33 self.context_id.as_deref()
34 }
35
36 pub fn is_terminal(&self) -> bool {
37 self.terminal_sent
38 }
39
40 pub fn final_text(&self) -> &str {
41 &self.final_text
42 }
43
44 pub fn map_stream_response(&mut self, response: StreamResponse) -> A2AMappedEvents {
46 let mut events = Vec::new();
47 let mut metadata = HashMap::new();
48
49 if let Some(task) = response.task {
50 self.latest_task_id = Some(task.id.clone());
51 if let Some(ctx) = task.context_id.clone() {
52 self.context_id = Some(ctx);
53 }
54 metadata.insert("a2a.latest_task_id".to_string(), task.id.clone());
55 if let Some(ctx) = &task.context_id {
56 metadata.insert("a2a.context_id".to_string(), ctx.clone());
57 }
58 metadata.insert(
59 "a2a.last_state".to_string(),
60 task.status.state.as_proto_str().to_string(),
61 );
62 events.extend(self.map_status(&task.id, task.context_id.as_deref(), task.status));
63 }
64
65 if let Some(message) = response.message {
66 if message.role == A2ARole::Agent {
67 let text = text_from_parts(&message.parts);
68 if !text.is_empty() {
69 self.final_text.push_str(&text);
70 events.push(AgentEvent::Token { content: text });
71 }
72 }
73 }
74
75 if let Some(update) = response.status_update {
76 self.latest_task_id = Some(update.task_id.clone());
77 self.context_id = Some(update.context_id.clone());
78 metadata.insert("a2a.latest_task_id".to_string(), update.task_id.clone());
79 metadata.insert("a2a.context_id".to_string(), update.context_id.clone());
80 metadata.insert(
81 "a2a.last_state".to_string(),
82 update.status.state.as_proto_str().to_string(),
83 );
84 events.extend(self.map_status(
85 &update.task_id,
86 Some(&update.context_id),
87 update.status,
88 ));
89 }
90
91 if let Some(update) = response.artifact_update {
92 let preview =
93 handle_artifact_update(&update.artifact, update.append, update.last_chunk);
94 if !preview.is_empty() {
95 events.push(AgentEvent::Token {
96 content: preview.clone(),
97 });
98 self.final_text.push_str(&preview);
99 }
100 metadata.insert(
101 "a2a.last_artifacts_summary".to_string(),
102 serde_json::json!({
103 "artifact_id": update.artifact.artifact_id,
104 "name": update.artifact.name,
105 "append": update.append,
106 "last_chunk": update.last_chunk,
107 })
108 .to_string(),
109 );
110 }
111
112 A2AMappedEvents {
113 events,
114 metadata_updates: metadata,
115 }
116 }
117
118 fn map_status(
119 &mut self,
120 _task_id: &str,
121 _context_id: Option<&str>,
122 status: TaskStatus,
123 ) -> Vec<AgentEvent> {
124 let mut events = Vec::new();
125
126 match &status.state {
127 TaskState::Submitted => {
128 }
130 TaskState::Working => {
131 if let Some(msg) = &status.message {
132 let text = text_from_parts(&msg.parts);
133 if !text.is_empty() {
134 self.final_text.push_str(&text);
135 events.push(AgentEvent::Token { content: text });
136 }
137 }
138 }
139 TaskState::InputRequired => {
140 let question = question_from_status(&status);
141 events.push(AgentEvent::NeedClarification {
142 question,
143 options: None,
144 tool_call_id: None,
145 tool_name: None,
146 allow_custom: true,
147 });
148 }
149 TaskState::AuthRequired => {
150 let question = question_from_status(&status);
151 events.push(AgentEvent::NeedClarification {
152 question,
153 options: None,
154 tool_call_id: None,
155 tool_name: None,
156 allow_custom: true,
157 });
158 }
159 TaskState::Completed => {
160 self.terminal_sent = true;
161 if let Some(msg) = &status.message {
162 let text = text_from_parts(&msg.parts);
163 if !text.is_empty() {
164 self.final_text.push_str(&text);
165 events.push(AgentEvent::Token { content: text });
166 }
167 }
168 events.push(AgentEvent::Complete {
169 usage: TokenUsage::default(),
170 });
171 }
172 TaskState::Failed => {
173 self.terminal_sent = true;
174 let error_msg = status
175 .message
176 .as_ref()
177 .map(|m| text_from_parts(&m.parts))
178 .filter(|s| !s.is_empty())
179 .unwrap_or_else(|| "External agent reported failure".to_string());
180 events.push(AgentEvent::Error { message: error_msg });
181 }
182 TaskState::Canceled => {
183 self.terminal_sent = true;
184 events.push(AgentEvent::Error {
185 message: "External agent task was cancelled".to_string(),
186 });
187 }
188 TaskState::Rejected => {
189 self.terminal_sent = true;
190 events.push(AgentEvent::Error {
191 message: "External agent rejected the task".to_string(),
192 });
193 }
194 TaskState::Unspecified => {}
195 }
196
197 events
198 }
199}
200
201pub fn text_from_parts(parts: &[bamboo_infrastructure::a2a::types::Part]) -> String {
203 parts
204 .iter()
205 .filter_map(|part| match &part.content {
206 PartContentWire::Text { text } => Some(text.as_str()),
207 PartContentWire::Data { data } => data.get("summary").and_then(|v| v.as_str()),
208 _ => None,
209 })
210 .collect::<Vec<_>>()
211 .join("\n")
212}
213
214fn question_from_status(status: &TaskStatus) -> String {
216 status
217 .message
218 .as_ref()
219 .map(|m| text_from_parts(&m.parts))
220 .filter(|s| !s.trim().is_empty())
221 .unwrap_or_else(|| match status.state {
222 TaskState::InputRequired => "External agent requires additional input.".to_string(),
223 TaskState::AuthRequired => {
224 "External agent requires authentication or authorization.".to_string()
225 }
226 _ => format!("External agent state: {:?}", status.state),
227 })
228}
229
230fn handle_artifact_update(
232 artifact: &bamboo_infrastructure::a2a::types::Artifact,
233 _append: bool,
234 _last_chunk: bool,
235) -> String {
236 let text = text_from_parts(&artifact.parts);
237 if text.is_empty() {
238 if let Some(name) = &artifact.name {
239 format!("[Artifact: {}]", name)
240 } else {
241 format!("[Artifact: {}]", artifact.artifact_id)
242 }
243 } else {
244 let header = artifact
245 .name
246 .as_ref()
247 .map(|n| format!("--- Artifact: {} ---\n", n))
248 .unwrap_or_default();
249 format!("{}{}", header, text)
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use bamboo_infrastructure::a2a::types::{
257 A2ARole, Message, Part, Task, TaskStatus, TaskStatusUpdateEvent,
258 };
259
260 #[test]
261 fn a2a_message_text_maps_to_token() {
262 let mut mapper = A2AEventMapper::new();
263 let response = StreamResponse {
264 task: None,
265 message: Some(Message {
266 message_id: "m1".to_string(),
267 context_id: None,
268 task_id: None,
269 role: A2ARole::Agent,
270 parts: vec![Part {
271 content: PartContentWire::text("hello world"),
272 metadata: None,
273 filename: None,
274 media_type: Some("text/plain".to_string()),
275 }],
276 metadata: None,
277 extensions: vec![],
278 reference_task_ids: vec![],
279 }),
280 status_update: None,
281 artifact_update: None,
282 };
283 let mapped = mapper.map_stream_response(response);
284 assert_eq!(mapped.events.len(), 1);
285 match &mapped.events[0] {
286 AgentEvent::Token { content } => assert_eq!(content, "hello world"),
287 other => panic!("expected Token, got {:?}", other),
288 }
289 }
290
291 #[test]
292 fn a2a_completed_status_maps_to_complete_and_metadata() {
293 let mut mapper = A2AEventMapper::new();
294 let response = StreamResponse {
295 task: Some(Task {
296 id: "task-1".to_string(),
297 context_id: Some("ctx-1".to_string()),
298 status: TaskStatus {
299 state: TaskState::Completed,
300 message: None,
301 timestamp: None,
302 },
303 artifacts: vec![],
304 history: vec![],
305 metadata: None,
306 }),
307 message: None,
308 status_update: None,
309 artifact_update: None,
310 };
311 let mapped = mapper.map_stream_response(response);
312 assert!(mapper.is_terminal());
313 assert_eq!(
314 mapped.metadata_updates.get("a2a.latest_task_id"),
315 Some(&"task-1".to_string())
316 );
317 assert_eq!(
318 mapped.metadata_updates.get("a2a.context_id"),
319 Some(&"ctx-1".to_string())
320 );
321 assert_eq!(
322 mapped.metadata_updates.get("a2a.last_state"),
323 Some(&"TASK_STATE_COMPLETED".to_string())
324 );
325 match &mapped.events[0] {
326 AgentEvent::Complete { .. } => {}
327 other => panic!("expected Complete, got {:?}", other),
328 }
329 }
330
331 #[test]
332 fn a2a_failed_status_maps_to_error() {
333 let mut mapper = A2AEventMapper::new();
334 let response = StreamResponse {
335 task: None,
336 message: None,
337 status_update: Some(TaskStatusUpdateEvent {
338 task_id: "task-1".to_string(),
339 context_id: "ctx-1".to_string(),
340 status: TaskStatus {
341 state: TaskState::Failed,
342 message: Some(Message {
343 message_id: "m1".to_string(),
344 context_id: None,
345 task_id: None,
346 role: A2ARole::Agent,
347 parts: vec![Part {
348 content: PartContentWire::text("Something went wrong"),
349 metadata: None,
350 filename: None,
351 media_type: None,
352 }],
353 metadata: None,
354 extensions: vec![],
355 reference_task_ids: vec![],
356 }),
357 timestamp: None,
358 },
359 metadata: None,
360 }),
361 artifact_update: None,
362 };
363 let mapped = mapper.map_stream_response(response);
364 assert!(mapper.is_terminal());
365 match &mapped.events[0] {
366 AgentEvent::Error { message } => assert_eq!(message, "Something went wrong"),
367 other => panic!("expected Error, got {:?}", other),
368 }
369 }
370
371 #[test]
372 fn a2a_input_required_maps_to_need_clarification() {
373 let mut mapper = A2AEventMapper::new();
374 let response = StreamResponse {
375 task: None,
376 message: None,
377 status_update: Some(TaskStatusUpdateEvent {
378 task_id: "task-1".to_string(),
379 context_id: "ctx-1".to_string(),
380 status: TaskStatus {
381 state: TaskState::InputRequired,
382 message: Some(Message {
383 message_id: "m1".to_string(),
384 context_id: None,
385 task_id: None,
386 role: A2ARole::Agent,
387 parts: vec![Part {
388 content: PartContentWire::text("What is your API key?"),
389 metadata: None,
390 filename: None,
391 media_type: None,
392 }],
393 metadata: None,
394 extensions: vec![],
395 reference_task_ids: vec![],
396 }),
397 timestamp: None,
398 },
399 metadata: None,
400 }),
401 artifact_update: None,
402 };
403 let mapped = mapper.map_stream_response(response);
404 assert!(!mapper.is_terminal());
405 match &mapped.events[0] {
406 AgentEvent::NeedClarification { question, .. } => {
407 assert_eq!(question, "What is your API key?");
408 }
409 other => panic!("expected NeedClarification, got {:?}", other),
410 }
411 }
412}