1use crate::protocol::AgentEvent;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4
5#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(tag = "type", rename_all = "kebab-case")]
18pub enum UiStreamPart {
19 Start {
21 #[serde(rename = "messageId")]
22 message_id: String,
23 },
24 Finish {},
25 StartStep {},
26 FinishStep {},
27 Abort {
28 reason: String,
29 },
30
31 TextStart {
33 id: String,
34 },
35 TextDelta {
36 id: String,
37 delta: String,
38 },
39 TextEnd {
40 id: String,
41 },
42
43 ReasoningStart {
45 id: String,
46 },
47 ReasoningDelta {
48 id: String,
49 delta: String,
50 },
51 ReasoningEnd {
52 id: String,
53 },
54
55 ToolInputStart {
57 #[serde(rename = "toolCallId")]
58 tool_call_id: String,
59 #[serde(rename = "toolName")]
60 tool_name: String,
61 },
62 ToolInputDelta {
63 #[serde(rename = "toolCallId")]
64 tool_call_id: String,
65 #[serde(rename = "inputTextDelta")]
66 input_text_delta: String,
67 },
68 ToolInputAvailable {
69 #[serde(rename = "toolCallId")]
70 tool_call_id: String,
71 #[serde(rename = "toolName")]
72 tool_name: String,
73 input: Value,
74 },
75 ToolOutputAvailable {
76 #[serde(rename = "toolCallId")]
77 tool_call_id: String,
78 output: Value,
79 },
80
81 Error {
83 #[serde(rename = "errorText")]
84 error_text: String,
85 },
86
87 #[serde(rename = "data-state-patch")]
89 DataStatePatch {
90 data: Value,
91 },
92 #[serde(rename = "data-approval-request")]
93 DataApprovalRequest {
94 data: Value,
95 },
96}
97
98pub fn to_ui_stream_parts(event: &AgentEvent) -> Vec<UiStreamPart> {
104 match event {
105 AgentEvent::RunStarted { run_id, .. } => {
106 vec![UiStreamPart::Start {
107 message_id: run_id.clone(),
108 }]
109 }
110
111 AgentEvent::IterationStarted { .. } => {
112 vec![UiStreamPart::StartStep {}]
113 }
114
115 AgentEvent::ModelOutput { .. } => {
116 vec![UiStreamPart::FinishStep {}]
117 }
118
119 AgentEvent::TextDelta { run_id, delta, .. } => {
120 vec![UiStreamPart::TextDelta {
121 id: format!("{run_id}-text"),
122 delta: delta.clone(),
123 }]
124 }
125
126 AgentEvent::ToolCallRequested { call, .. } => {
127 let args_json = serde_json::to_string(&call.input).unwrap_or_default();
128 vec![
129 UiStreamPart::ToolInputStart {
130 tool_call_id: call.call_id.clone(),
131 tool_name: call.tool_name.clone(),
132 },
133 UiStreamPart::ToolInputDelta {
134 tool_call_id: call.call_id.clone(),
135 input_text_delta: args_json,
136 },
137 UiStreamPart::ToolInputAvailable {
138 tool_call_id: call.call_id.clone(),
139 tool_name: call.tool_name.clone(),
140 input: call.input.clone(),
141 },
142 ]
143 }
144
145 AgentEvent::ToolCallCompleted { result, .. } => {
146 vec![UiStreamPart::ToolOutputAvailable {
147 tool_call_id: result.call_id.clone(),
148 output: result.output.clone(),
149 }]
150 }
151
152 AgentEvent::ToolCallFailed { call_id, error, .. } => {
153 vec![UiStreamPart::ToolOutputAvailable {
154 tool_call_id: call_id.clone(),
155 output: serde_json::json!({ "error": error }),
156 }]
157 }
158
159 AgentEvent::StatePatched {
160 patch, revision, ..
161 } => {
162 vec![UiStreamPart::DataStatePatch {
163 data: serde_json::json!({
164 "patch": patch.patch,
165 "revision": revision,
166 }),
167 }]
168 }
169
170 AgentEvent::RunErrored { error, .. } => {
171 vec![UiStreamPart::Error {
172 error_text: error.clone(),
173 }]
174 }
175
176 AgentEvent::RunFinished {
177 run_id,
178 final_answer,
179 ..
180 } => {
181 let mut parts = Vec::new();
182 if let Some(answer) = final_answer {
183 if !answer.is_empty() {
184 let text_id = format!("{run_id}-text");
185 parts.push(UiStreamPart::TextDelta {
186 id: text_id.clone(),
187 delta: answer.clone(),
188 });
189 }
190 }
191 parts.push(UiStreamPart::Finish {});
192 parts
193 }
194
195 AgentEvent::ApprovalRequested {
196 approval_id,
197 call_id,
198 tool_name,
199 arguments,
200 risk,
201 ..
202 } => {
203 vec![UiStreamPart::DataApprovalRequest {
204 data: serde_json::json!({
205 "approvalId": approval_id,
206 "toolCallId": call_id,
207 "toolName": tool_name,
208 "arguments": arguments,
209 "risk": risk,
210 }),
211 }]
212 }
213
214 AgentEvent::ContextCompacted { .. } | AgentEvent::ApprovalResolved { .. } => {
216 vec![]
217 }
218 }
219}
220
221pub fn ui_stream_part_to_sse(part: &UiStreamPart) -> Result<String, serde_json::Error> {
223 let json = serde_json::to_string(part)?;
224 Ok(format!("data: {json}\n\n"))
225}
226
227pub type AiSdkPart = UiStreamPart;
231
232pub fn to_aisdk_parts(event: &AgentEvent) -> Vec<UiStreamPart> {
234 to_ui_stream_parts(event)
235}
236
237pub fn aisdk_part_to_sse(part: &UiStreamPart) -> Result<String, serde_json::Error> {
239 ui_stream_part_to_sse(part)
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::protocol::{
246 ModelStopReason, RunStopReason, StatePatch, StatePatchFormat, StatePatchSource, ToolCall,
247 ToolResultSummary,
248 };
249 use serde_json::json;
250
251 #[test]
252 fn run_started_maps_to_start() {
253 let event = AgentEvent::RunStarted {
254 run_id: "r1".to_string(),
255 session_id: "s1".to_string(),
256 provider: "anthropic".to_string(),
257 max_iterations: 10,
258 };
259 let parts = to_ui_stream_parts(&event);
260 assert_eq!(parts.len(), 1);
261 assert_eq!(
262 parts[0],
263 UiStreamPart::Start {
264 message_id: "r1".to_string()
265 }
266 );
267 }
268
269 #[test]
270 fn iteration_started_maps_to_start_step() {
271 let event = AgentEvent::IterationStarted {
272 run_id: "r1".to_string(),
273 session_id: "s1".to_string(),
274 iteration: 1,
275 };
276 let parts = to_ui_stream_parts(&event);
277 assert_eq!(parts.len(), 1);
278 assert_eq!(parts[0], UiStreamPart::StartStep {});
279 }
280
281 #[test]
282 fn model_output_maps_to_finish_step() {
283 let event = AgentEvent::ModelOutput {
284 run_id: "r1".to_string(),
285 session_id: "s1".to_string(),
286 iteration: 1,
287 stop_reason: ModelStopReason::EndTurn,
288 directive_count: 0,
289 usage: None,
290 };
291 let parts = to_ui_stream_parts(&event);
292 assert_eq!(parts.len(), 1);
293 assert_eq!(parts[0], UiStreamPart::FinishStep {});
294 }
295
296 #[test]
297 fn text_delta_includes_id() {
298 let event = AgentEvent::TextDelta {
299 run_id: "r1".to_string(),
300 session_id: "s1".to_string(),
301 iteration: 1,
302 delta: "Hello ".to_string(),
303 };
304 let parts = to_ui_stream_parts(&event);
305 assert_eq!(parts.len(), 1);
306 assert_eq!(
307 parts[0],
308 UiStreamPart::TextDelta {
309 id: "r1-text".to_string(),
310 delta: "Hello ".to_string(),
311 }
312 );
313 }
314
315 #[test]
316 fn tool_call_produces_input_start_delta_available() {
317 let event = AgentEvent::ToolCallRequested {
318 run_id: "r1".to_string(),
319 session_id: "s1".to_string(),
320 iteration: 1,
321 call: ToolCall {
322 call_id: "c1".to_string(),
323 tool_name: "read_file".to_string(),
324 input: json!({"path": "/tmp/test.rs"}),
325 },
326 };
327 let parts = to_ui_stream_parts(&event);
328 assert_eq!(parts.len(), 3);
329
330 assert_eq!(
331 parts[0],
332 UiStreamPart::ToolInputStart {
333 tool_call_id: "c1".to_string(),
334 tool_name: "read_file".to_string(),
335 }
336 );
337 match &parts[1] {
338 UiStreamPart::ToolInputDelta {
339 tool_call_id,
340 input_text_delta,
341 } => {
342 assert_eq!(tool_call_id, "c1");
343 assert!(input_text_delta.contains("path"));
344 }
345 other => panic!("Expected ToolInputDelta, got {:?}", other),
346 }
347 assert_eq!(
348 parts[2],
349 UiStreamPart::ToolInputAvailable {
350 tool_call_id: "c1".to_string(),
351 tool_name: "read_file".to_string(),
352 input: json!({"path": "/tmp/test.rs"}),
353 }
354 );
355 }
356
357 #[test]
358 fn tool_completed_maps_to_output_available() {
359 let event = AgentEvent::ToolCallCompleted {
360 run_id: "r1".to_string(),
361 session_id: "s1".to_string(),
362 iteration: 1,
363 result: ToolResultSummary {
364 call_id: "c1".to_string(),
365 tool_name: "read_file".to_string(),
366 output: json!({"content": "file contents here"}),
367 },
368 };
369 let parts = to_ui_stream_parts(&event);
370 assert_eq!(parts.len(), 1);
371 assert_eq!(
372 parts[0],
373 UiStreamPart::ToolOutputAvailable {
374 tool_call_id: "c1".to_string(),
375 output: json!({"content": "file contents here"}),
376 }
377 );
378 }
379
380 #[test]
381 fn tool_failed_maps_to_output_with_error() {
382 let event = AgentEvent::ToolCallFailed {
383 run_id: "r1".to_string(),
384 session_id: "s1".to_string(),
385 iteration: 1,
386 call_id: "c1".to_string(),
387 tool_name: "bash".to_string(),
388 error: "command not found".to_string(),
389 };
390 let parts = to_ui_stream_parts(&event);
391 assert_eq!(parts.len(), 1);
392 assert_eq!(
393 parts[0],
394 UiStreamPart::ToolOutputAvailable {
395 tool_call_id: "c1".to_string(),
396 output: json!({"error": "command not found"}),
397 }
398 );
399 }
400
401 #[test]
402 fn state_patched_maps_to_data_extension() {
403 let event = AgentEvent::StatePatched {
404 run_id: "r1".to_string(),
405 session_id: "s1".to_string(),
406 iteration: 1,
407 patch: StatePatch {
408 format: StatePatchFormat::MergePatch,
409 patch: json!({"cwd": "/new"}),
410 source: StatePatchSource::System,
411 },
412 revision: 5,
413 };
414 let parts = to_ui_stream_parts(&event);
415 assert_eq!(parts.len(), 1);
416 assert_eq!(
417 parts[0],
418 UiStreamPart::DataStatePatch {
419 data: json!({"patch": {"cwd": "/new"}, "revision": 5}),
420 }
421 );
422 }
423
424 #[test]
425 fn run_errored_maps_to_error() {
426 let event = AgentEvent::RunErrored {
427 run_id: "r1".to_string(),
428 session_id: "s1".to_string(),
429 error: "provider timeout".to_string(),
430 };
431 let parts = to_ui_stream_parts(&event);
432 assert_eq!(parts.len(), 1);
433 assert_eq!(
434 parts[0],
435 UiStreamPart::Error {
436 error_text: "provider timeout".to_string()
437 }
438 );
439 }
440
441 #[test]
442 fn run_finished_maps_to_finish() {
443 let event = AgentEvent::RunFinished {
444 run_id: "r1".to_string(),
445 session_id: "s1".to_string(),
446 reason: RunStopReason::Completed,
447 total_iterations: 3,
448 final_answer: None,
449 };
450 let parts = to_ui_stream_parts(&event);
451 assert_eq!(parts.len(), 1);
452 assert_eq!(parts[0], UiStreamPart::Finish {});
453 }
454
455 #[test]
456 fn run_finished_with_final_answer_emits_text_then_finish() {
457 let event = AgentEvent::RunFinished {
458 run_id: "r1".to_string(),
459 session_id: "s1".to_string(),
460 reason: RunStopReason::Completed,
461 total_iterations: 1,
462 final_answer: Some("Done!".to_string()),
463 };
464 let parts = to_ui_stream_parts(&event);
465 assert_eq!(parts.len(), 2);
466 assert_eq!(
467 parts[0],
468 UiStreamPart::TextDelta {
469 id: "r1-text".to_string(),
470 delta: "Done!".to_string(),
471 }
472 );
473 assert_eq!(parts[1], UiStreamPart::Finish {});
474 }
475
476 #[test]
477 fn context_compacted_produces_empty() {
478 let event = AgentEvent::ContextCompacted {
479 run_id: "r1".to_string(),
480 session_id: "s1".to_string(),
481 iteration: 1,
482 dropped_count: 5,
483 tokens_before: 1000,
484 tokens_after: 500,
485 };
486 assert!(to_ui_stream_parts(&event).is_empty());
487 }
488
489 #[test]
490 fn approval_requested_maps_to_data_approval() {
491 let event = AgentEvent::ApprovalRequested {
492 run_id: "r1".to_string(),
493 session_id: "s1".to_string(),
494 approval_id: "appr-1".to_string(),
495 call_id: "c1".to_string(),
496 tool_name: "bash".to_string(),
497 arguments: json!({"command": "rm -rf /"}),
498 risk: "high".to_string(),
499 };
500 let parts = to_ui_stream_parts(&event);
501 assert_eq!(parts.len(), 1);
502 match &parts[0] {
503 UiStreamPart::DataApprovalRequest { data } => {
504 assert_eq!(data["approvalId"], "appr-1");
505 assert_eq!(data["toolCallId"], "c1");
506 assert_eq!(data["toolName"], "bash");
507 assert_eq!(data["risk"], "high");
508 }
509 other => panic!("Expected DataApprovalRequest, got {:?}", other),
510 }
511 }
512
513 #[test]
514 fn v6_wire_format_serialization() {
515 let start = UiStreamPart::Start {
517 message_id: "m1".to_string(),
518 };
519 let json = serde_json::to_string(&start).unwrap();
520 assert!(json.contains(r#""type":"start""#));
521 assert!(json.contains(r#""messageId":"m1""#));
522
523 let text = UiStreamPart::TextDelta {
524 id: "t1".to_string(),
525 delta: "hi".to_string(),
526 };
527 let json = serde_json::to_string(&text).unwrap();
528 assert!(json.contains(r#""type":"text-delta""#));
529 assert!(json.contains(r#""delta":"hi""#));
530
531 let tool = UiStreamPart::ToolInputStart {
532 tool_call_id: "c1".to_string(),
533 tool_name: "bash".to_string(),
534 };
535 let json = serde_json::to_string(&tool).unwrap();
536 assert!(json.contains(r#""type":"tool-input-start""#));
537 assert!(json.contains(r#""toolCallId":"c1""#));
538 assert!(json.contains(r#""toolName":"bash""#));
539
540 let error = UiStreamPart::Error {
541 error_text: "boom".to_string(),
542 };
543 let json = serde_json::to_string(&error).unwrap();
544 assert!(json.contains(r#""type":"error""#));
545 assert!(json.contains(r#""errorText":"boom""#));
546
547 let ext = UiStreamPart::DataStatePatch {
548 data: json!({"patch": {}}),
549 };
550 let json = serde_json::to_string(&ext).unwrap();
551 assert!(json.contains(r#""type":"data-state-patch""#));
552 }
553
554 #[test]
555 fn sse_wire_format() {
556 let part = UiStreamPart::TextDelta {
557 id: "t1".to_string(),
558 delta: "hello".to_string(),
559 };
560 let sse = ui_stream_part_to_sse(&part).unwrap();
561 assert!(sse.starts_with("data: "));
562 assert!(sse.ends_with("\n\n"));
563 assert!(sse.contains("text-delta"));
564 assert!(sse.contains("hello"));
565 }
566
567 #[test]
568 fn round_trip_serialization() {
569 let parts = vec![
570 UiStreamPart::Start {
571 message_id: "m1".to_string(),
572 },
573 UiStreamPart::TextDelta {
574 id: "t1".to_string(),
575 delta: "hi".to_string(),
576 },
577 UiStreamPart::Finish {},
578 UiStreamPart::ToolInputAvailable {
579 tool_call_id: "c1".to_string(),
580 tool_name: "bash".to_string(),
581 input: json!({"cmd": "ls"}),
582 },
583 UiStreamPart::Error {
584 error_text: "oops".to_string(),
585 },
586 ];
587
588 for part in &parts {
589 let json = serde_json::to_string(part).unwrap();
590 let decoded: UiStreamPart = serde_json::from_str(&json).unwrap();
591 assert_eq!(*part, decoded);
592 }
593 }
594
595 #[test]
598 fn deprecated_to_aisdk_parts_still_works() {
599 let event = AgentEvent::TextDelta {
600 run_id: "r1".to_string(),
601 session_id: "s1".to_string(),
602 iteration: 1,
603 delta: "test".to_string(),
604 };
605 let parts = to_aisdk_parts(&event);
606 assert_eq!(parts.len(), 1);
607 }
608}