1use std::collections::HashMap;
2
3use bamboo_a2a::types::{A2ARole, PartContentWire, StreamResponse, TaskState, TaskStatus};
4use bamboo_agent_core::{AgentEvent, TokenUsage};
5
6pub struct A2AMappedEvents {
8 pub events: Vec<AgentEvent>,
9 pub metadata_updates: HashMap<String, String>,
10}
11
12#[derive(Default)]
14pub struct A2AEventMapper {
15 terminal_sent: bool,
16 latest_task_id: Option<String>,
17 context_id: Option<String>,
18 final_text: String,
19}
20
21impl A2AEventMapper {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn latest_task_id(&self) -> Option<&str> {
27 self.latest_task_id.as_deref()
28 }
29
30 pub fn context_id(&self) -> Option<&str> {
31 self.context_id.as_deref()
32 }
33
34 pub fn is_terminal(&self) -> bool {
35 self.terminal_sent
36 }
37
38 pub fn final_text(&self) -> &str {
39 &self.final_text
40 }
41
42 pub fn map_stream_response(&mut self, response: StreamResponse) -> A2AMappedEvents {
44 let mut events = Vec::new();
45 let mut metadata = HashMap::new();
46
47 if let Some(task) = response.task {
48 self.latest_task_id = Some(task.id.clone());
49 if let Some(ctx) = task.context_id.clone() {
50 self.context_id = Some(ctx);
51 }
52 metadata.insert("a2a.latest_task_id".to_string(), task.id.clone());
53 if let Some(ctx) = &task.context_id {
54 metadata.insert("a2a.context_id".to_string(), ctx.clone());
55 }
56 metadata.insert(
57 "a2a.last_state".to_string(),
58 task.status.state.as_proto_str().to_string(),
59 );
60 events.extend(self.map_status(&task.id, task.context_id.as_deref(), task.status));
61 }
62
63 if let Some(message) = response.message {
64 if message.role == A2ARole::Agent {
65 let text = text_from_parts(&message.parts);
66 if !text.is_empty() {
67 self.final_text.push_str(&text);
68 events.push(AgentEvent::Token { content: text });
69 }
70 }
71 }
72
73 if let Some(update) = response.status_update {
74 self.latest_task_id = Some(update.task_id.clone());
75 self.context_id = Some(update.context_id.clone());
76 metadata.insert("a2a.latest_task_id".to_string(), update.task_id.clone());
77 metadata.insert("a2a.context_id".to_string(), update.context_id.clone());
78 metadata.insert(
79 "a2a.last_state".to_string(),
80 update.status.state.as_proto_str().to_string(),
81 );
82 events.extend(self.map_status(
83 &update.task_id,
84 Some(&update.context_id),
85 update.status,
86 ));
87 }
88
89 if let Some(update) = response.artifact_update {
90 let preview =
91 handle_artifact_update(&update.artifact, update.append, update.last_chunk);
92 if !preview.is_empty() {
93 events.push(AgentEvent::Token {
94 content: preview.clone(),
95 });
96 self.final_text.push_str(&preview);
97 }
98 metadata.insert(
99 "a2a.last_artifacts_summary".to_string(),
100 serde_json::json!({
101 "artifact_id": update.artifact.artifact_id,
102 "name": update.artifact.name,
103 "append": update.append,
104 "last_chunk": update.last_chunk,
105 })
106 .to_string(),
107 );
108 }
109
110 A2AMappedEvents {
111 events,
112 metadata_updates: metadata,
113 }
114 }
115
116 fn map_status(
117 &mut self,
118 _task_id: &str,
119 _context_id: Option<&str>,
120 status: TaskStatus,
121 ) -> Vec<AgentEvent> {
122 let mut events = Vec::new();
123
124 match &status.state {
125 TaskState::Submitted => {
126 }
128 TaskState::Working => {
129 if let Some(msg) = &status.message {
130 let text = text_from_parts(&msg.parts);
131 if !text.is_empty() {
132 self.final_text.push_str(&text);
133 events.push(AgentEvent::Token { content: text });
134 }
135 }
136 }
137 TaskState::InputRequired => {
138 let question = question_from_status(&status);
139 events.push(AgentEvent::NeedClarification {
140 question,
141 options: None,
142 tool_call_id: None,
143 tool_name: None,
144 allow_custom: true,
145 });
146 }
147 TaskState::AuthRequired => {
148 let question = question_from_status(&status);
149 events.push(AgentEvent::NeedClarification {
150 question,
151 options: None,
152 tool_call_id: None,
153 tool_name: None,
154 allow_custom: true,
155 });
156 }
157 TaskState::Completed => {
158 self.terminal_sent = true;
159 if let Some(msg) = &status.message {
160 let text = text_from_parts(&msg.parts);
161 if !text.is_empty() {
162 self.final_text.push_str(&text);
163 events.push(AgentEvent::Token { content: text });
164 }
165 }
166 events.push(AgentEvent::Complete {
167 usage: TokenUsage::default(),
168 });
169 }
170 TaskState::Failed => {
171 self.terminal_sent = true;
172 let error_msg = status
173 .message
174 .as_ref()
175 .map(|m| text_from_parts(&m.parts))
176 .filter(|s| !s.is_empty())
177 .unwrap_or_else(|| "External agent reported failure".to_string());
178 events.push(AgentEvent::Error { message: error_msg });
179 }
180 TaskState::Canceled => {
181 self.terminal_sent = true;
182 events.push(AgentEvent::Error {
183 message: "External agent task was cancelled".to_string(),
184 });
185 }
186 TaskState::Rejected => {
187 self.terminal_sent = true;
188 events.push(AgentEvent::Error {
189 message: "External agent rejected the task".to_string(),
190 });
191 }
192 TaskState::Unspecified => {}
193 }
194
195 events
196 }
197}
198
199pub fn text_from_parts(parts: &[bamboo_a2a::types::Part]) -> String {
201 parts
202 .iter()
203 .filter_map(|part| match &part.content {
204 PartContentWire::Text { text } => Some(text.as_str()),
205 PartContentWire::Data { data } => data.get("summary").and_then(|v| v.as_str()),
206 _ => None,
207 })
208 .collect::<Vec<_>>()
209 .join("\n")
210}
211
212fn question_from_status(status: &TaskStatus) -> String {
214 status
215 .message
216 .as_ref()
217 .map(|m| text_from_parts(&m.parts))
218 .filter(|s| !s.trim().is_empty())
219 .unwrap_or_else(|| match status.state {
220 TaskState::InputRequired => "External agent requires additional input.".to_string(),
221 TaskState::AuthRequired => {
222 "External agent requires authentication or authorization.".to_string()
223 }
224 _ => format!("External agent state: {:?}", status.state),
225 })
226}
227
228fn handle_artifact_update(
230 artifact: &bamboo_a2a::types::Artifact,
231 _append: bool,
232 _last_chunk: bool,
233) -> String {
234 let text = text_from_parts(&artifact.parts);
235 if text.is_empty() {
236 if let Some(name) = &artifact.name {
237 format!("[Artifact: {}]", name)
238 } else {
239 format!("[Artifact: {}]", artifact.artifact_id)
240 }
241 } else {
242 let header = artifact
243 .name
244 .as_ref()
245 .map(|n| format!("--- Artifact: {} ---\n", n))
246 .unwrap_or_default();
247 format!("{}{}", header, text)
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use bamboo_a2a::types::{A2ARole, Message, Part, Task, TaskStatus, TaskStatusUpdateEvent};
255
256 #[test]
257 fn a2a_message_text_maps_to_token() {
258 let mut mapper = A2AEventMapper::new();
259 let response = StreamResponse {
260 task: None,
261 message: Some(Message {
262 message_id: "m1".to_string(),
263 context_id: None,
264 task_id: None,
265 role: A2ARole::Agent,
266 parts: vec![Part {
267 content: PartContentWire::text("hello world"),
268 metadata: None,
269 filename: None,
270 media_type: Some("text/plain".to_string()),
271 }],
272 metadata: None,
273 extensions: vec![],
274 reference_task_ids: vec![],
275 }),
276 status_update: None,
277 artifact_update: None,
278 };
279 let mapped = mapper.map_stream_response(response);
280 assert_eq!(mapped.events.len(), 1);
281 match &mapped.events[0] {
282 AgentEvent::Token { content } => assert_eq!(content, "hello world"),
283 other => panic!("expected Token, got {:?}", other),
284 }
285 }
286
287 #[test]
288 fn a2a_completed_status_maps_to_complete_and_metadata() {
289 let mut mapper = A2AEventMapper::new();
290 let response = StreamResponse {
291 task: Some(Task {
292 id: "task-1".to_string(),
293 context_id: Some("ctx-1".to_string()),
294 status: TaskStatus {
295 state: TaskState::Completed,
296 message: None,
297 timestamp: None,
298 },
299 artifacts: vec![],
300 history: vec![],
301 metadata: None,
302 }),
303 message: None,
304 status_update: None,
305 artifact_update: None,
306 };
307 let mapped = mapper.map_stream_response(response);
308 assert!(mapper.is_terminal());
309 assert_eq!(
310 mapped.metadata_updates.get("a2a.latest_task_id"),
311 Some(&"task-1".to_string())
312 );
313 assert_eq!(
314 mapped.metadata_updates.get("a2a.context_id"),
315 Some(&"ctx-1".to_string())
316 );
317 assert_eq!(
318 mapped.metadata_updates.get("a2a.last_state"),
319 Some(&"TASK_STATE_COMPLETED".to_string())
320 );
321 match &mapped.events[0] {
322 AgentEvent::Complete { .. } => {}
323 other => panic!("expected Complete, got {:?}", other),
324 }
325 }
326
327 #[test]
328 fn a2a_failed_status_maps_to_error() {
329 let mut mapper = A2AEventMapper::new();
330 let response = StreamResponse {
331 task: None,
332 message: None,
333 status_update: Some(TaskStatusUpdateEvent {
334 task_id: "task-1".to_string(),
335 context_id: "ctx-1".to_string(),
336 status: TaskStatus {
337 state: TaskState::Failed,
338 message: Some(Message {
339 message_id: "m1".to_string(),
340 context_id: None,
341 task_id: None,
342 role: A2ARole::Agent,
343 parts: vec![Part {
344 content: PartContentWire::text("Something went wrong"),
345 metadata: None,
346 filename: None,
347 media_type: None,
348 }],
349 metadata: None,
350 extensions: vec![],
351 reference_task_ids: vec![],
352 }),
353 timestamp: None,
354 },
355 metadata: None,
356 }),
357 artifact_update: None,
358 };
359 let mapped = mapper.map_stream_response(response);
360 assert!(mapper.is_terminal());
361 match &mapped.events[0] {
362 AgentEvent::Error { message } => assert_eq!(message, "Something went wrong"),
363 other => panic!("expected Error, got {:?}", other),
364 }
365 }
366
367 #[test]
368 fn a2a_input_required_maps_to_need_clarification() {
369 let mut mapper = A2AEventMapper::new();
370 let response = StreamResponse {
371 task: None,
372 message: None,
373 status_update: Some(TaskStatusUpdateEvent {
374 task_id: "task-1".to_string(),
375 context_id: "ctx-1".to_string(),
376 status: TaskStatus {
377 state: TaskState::InputRequired,
378 message: Some(Message {
379 message_id: "m1".to_string(),
380 context_id: None,
381 task_id: None,
382 role: A2ARole::Agent,
383 parts: vec![Part {
384 content: PartContentWire::text("What is your API key?"),
385 metadata: None,
386 filename: None,
387 media_type: None,
388 }],
389 metadata: None,
390 extensions: vec![],
391 reference_task_ids: vec![],
392 }),
393 timestamp: None,
394 },
395 metadata: None,
396 }),
397 artifact_update: None,
398 };
399 let mapped = mapper.map_stream_response(response);
400 assert!(!mapper.is_terminal());
401 match &mapped.events[0] {
402 AgentEvent::NeedClarification { question, .. } => {
403 assert_eq!(question, "What is your API key?");
404 }
405 other => panic!("expected NeedClarification, got {:?}", other),
406 }
407 }
408}