1use crate::protocol::AgentEvent;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4
5#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(tag = "type", rename_all = "kebab-case")]
18pub enum UiStreamPart {
19 Start {
21 #[serde(rename = "messageId")]
22 message_id: String,
23 },
24 Finish {},
25 StartStep {},
26 FinishStep {},
27 Abort {
28 reason: String,
29 },
30
31 TextStart {
33 id: String,
34 },
35 TextDelta {
36 id: String,
37 delta: String,
38 },
39 TextEnd {
40 id: String,
41 },
42
43 ReasoningStart {
45 id: String,
46 },
47 ReasoningDelta {
48 id: String,
49 delta: String,
50 },
51 ReasoningEnd {
52 id: String,
53 },
54
55 ToolInputStart {
57 #[serde(rename = "toolCallId")]
58 tool_call_id: String,
59 #[serde(rename = "toolName")]
60 tool_name: String,
61 },
62 ToolInputDelta {
63 #[serde(rename = "toolCallId")]
64 tool_call_id: String,
65 #[serde(rename = "inputTextDelta")]
66 input_text_delta: String,
67 },
68 ToolInputAvailable {
69 #[serde(rename = "toolCallId")]
70 tool_call_id: String,
71 #[serde(rename = "toolName")]
72 tool_name: String,
73 input: Value,
74 },
75 ToolOutputAvailable {
76 #[serde(rename = "toolCallId")]
77 tool_call_id: String,
78 output: Value,
79 },
80
81 Error {
83 #[serde(rename = "errorText")]
84 error_text: String,
85 },
86
87 #[serde(rename = "data-state-patch")]
89 DataStatePatch {
90 data: Value,
91 },
92 #[serde(rename = "data-approval-request")]
93 DataApprovalRequest {
94 data: Value,
95 },
96}
97
98pub fn to_ui_stream_parts(event: &AgentEvent) -> Vec<UiStreamPart> {
104 match event {
105 AgentEvent::RunStarted { run_id, .. } => {
106 vec![UiStreamPart::Start {
107 message_id: run_id.clone(),
108 }]
109 }
110
111 AgentEvent::IterationStarted { .. } => {
112 vec![UiStreamPart::StartStep {}]
113 }
114
115 AgentEvent::ModelOutput { .. } => {
116 vec![UiStreamPart::FinishStep {}]
117 }
118
119 AgentEvent::TextDelta { run_id, delta, .. } => {
120 vec![UiStreamPart::TextDelta {
121 id: format!("{run_id}-text"),
122 delta: delta.clone(),
123 }]
124 }
125
126 AgentEvent::ToolCallRequested { call, .. } => {
127 let args_json = serde_json::to_string(&call.input).unwrap_or_default();
128 vec![
129 UiStreamPart::ToolInputStart {
130 tool_call_id: call.call_id.clone(),
131 tool_name: call.tool_name.clone(),
132 },
133 UiStreamPart::ToolInputDelta {
134 tool_call_id: call.call_id.clone(),
135 input_text_delta: args_json,
136 },
137 UiStreamPart::ToolInputAvailable {
138 tool_call_id: call.call_id.clone(),
139 tool_name: call.tool_name.clone(),
140 input: call.input.clone(),
141 },
142 ]
143 }
144
145 AgentEvent::ToolCallCompleted { result, .. } => {
146 vec![UiStreamPart::ToolOutputAvailable {
147 tool_call_id: result.call_id.clone(),
148 output: result.output.clone(),
149 }]
150 }
151
152 AgentEvent::ToolCallFailed { call_id, error, .. } => {
153 vec![UiStreamPart::ToolOutputAvailable {
154 tool_call_id: call_id.clone(),
155 output: serde_json::json!({ "error": error }),
156 }]
157 }
158
159 AgentEvent::StatePatched {
160 patch, revision, ..
161 } => {
162 vec![UiStreamPart::DataStatePatch {
163 data: serde_json::json!({
164 "patch": patch.patch,
165 "revision": revision,
166 }),
167 }]
168 }
169
170 AgentEvent::RunErrored { error, .. } => {
171 vec![UiStreamPart::Error {
172 error_text: error.clone(),
173 }]
174 }
175
176 AgentEvent::RunFinished {
177 run_id,
178 final_answer,
179 ..
180 } => {
181 let mut parts = Vec::new();
182 if let Some(answer) = final_answer
183 && !answer.is_empty()
184 {
185 let text_id = format!("{run_id}-text");
186 parts.push(UiStreamPart::TextDelta {
187 id: text_id.clone(),
188 delta: answer.clone(),
189 });
190 }
191 parts.push(UiStreamPart::Finish {});
192 parts
193 }
194
195 AgentEvent::ApprovalRequested {
196 approval_id,
197 call_id,
198 tool_name,
199 arguments,
200 risk,
201 ..
202 } => {
203 vec![UiStreamPart::DataApprovalRequest {
204 data: serde_json::json!({
205 "approvalId": approval_id,
206 "toolCallId": call_id,
207 "toolName": tool_name,
208 "arguments": arguments,
209 "risk": risk,
210 }),
211 }]
212 }
213
214 AgentEvent::ContextCompacted { .. } | AgentEvent::ApprovalResolved { .. } => {
216 vec![]
217 }
218 #[allow(unreachable_patterns)]
221 _ => vec![],
222 }
223}
224
225pub fn ui_stream_part_to_sse(part: &UiStreamPart) -> Result<String, serde_json::Error> {
227 let json = serde_json::to_string(part)?;
228 Ok(format!("data: {json}\n\n"))
229}
230
231pub type AiSdkPart = UiStreamPart;
235
236pub fn to_aisdk_parts(event: &AgentEvent) -> Vec<UiStreamPart> {
238 to_ui_stream_parts(event)
239}
240
241pub fn aisdk_part_to_sse(part: &UiStreamPart) -> Result<String, serde_json::Error> {
243 ui_stream_part_to_sse(part)
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::protocol::{
250 ModelStopReason, RunStopReason, StatePatch, StatePatchFormat, StatePatchSource, ToolCall,
251 ToolResultSummary,
252 };
253 use serde_json::json;
254
255 #[test]
256 fn run_started_maps_to_start() {
257 let event = AgentEvent::RunStarted {
258 run_id: "r1".to_string(),
259 session_id: "s1".to_string(),
260 provider: "anthropic".to_string(),
261 max_iterations: 10,
262 };
263 let parts = to_ui_stream_parts(&event);
264 assert_eq!(parts.len(), 1);
265 assert_eq!(
266 parts[0],
267 UiStreamPart::Start {
268 message_id: "r1".to_string()
269 }
270 );
271 }
272
273 #[test]
274 fn iteration_started_maps_to_start_step() {
275 let event = AgentEvent::IterationStarted {
276 run_id: "r1".to_string(),
277 session_id: "s1".to_string(),
278 iteration: 1,
279 };
280 let parts = to_ui_stream_parts(&event);
281 assert_eq!(parts.len(), 1);
282 assert_eq!(parts[0], UiStreamPart::StartStep {});
283 }
284
285 #[test]
286 fn model_output_maps_to_finish_step() {
287 let event = AgentEvent::ModelOutput {
288 run_id: "r1".to_string(),
289 session_id: "s1".to_string(),
290 iteration: 1,
291 stop_reason: ModelStopReason::EndTurn,
292 directive_count: 0,
293 usage: None,
294 };
295 let parts = to_ui_stream_parts(&event);
296 assert_eq!(parts.len(), 1);
297 assert_eq!(parts[0], UiStreamPart::FinishStep {});
298 }
299
300 #[test]
301 fn text_delta_includes_id() {
302 let event = AgentEvent::TextDelta {
303 run_id: "r1".to_string(),
304 session_id: "s1".to_string(),
305 iteration: 1,
306 delta: "Hello ".to_string(),
307 };
308 let parts = to_ui_stream_parts(&event);
309 assert_eq!(parts.len(), 1);
310 assert_eq!(
311 parts[0],
312 UiStreamPart::TextDelta {
313 id: "r1-text".to_string(),
314 delta: "Hello ".to_string(),
315 }
316 );
317 }
318
319 #[test]
320 fn tool_call_produces_input_start_delta_available() {
321 let event = AgentEvent::ToolCallRequested {
322 run_id: "r1".to_string(),
323 session_id: "s1".to_string(),
324 iteration: 1,
325 call: ToolCall {
326 call_id: "c1".to_string(),
327 tool_name: "read_file".to_string(),
328 input: json!({"path": "/tmp/test.rs"}),
329 },
330 };
331 let parts = to_ui_stream_parts(&event);
332 assert_eq!(parts.len(), 3);
333
334 assert_eq!(
335 parts[0],
336 UiStreamPart::ToolInputStart {
337 tool_call_id: "c1".to_string(),
338 tool_name: "read_file".to_string(),
339 }
340 );
341 match &parts[1] {
342 UiStreamPart::ToolInputDelta {
343 tool_call_id,
344 input_text_delta,
345 } => {
346 assert_eq!(tool_call_id, "c1");
347 assert!(input_text_delta.contains("path"));
348 }
349 other => panic!("Expected ToolInputDelta, got {:?}", other),
350 }
351 assert_eq!(
352 parts[2],
353 UiStreamPart::ToolInputAvailable {
354 tool_call_id: "c1".to_string(),
355 tool_name: "read_file".to_string(),
356 input: json!({"path": "/tmp/test.rs"}),
357 }
358 );
359 }
360
361 #[test]
362 fn tool_completed_maps_to_output_available() {
363 let event = AgentEvent::ToolCallCompleted {
364 run_id: "r1".to_string(),
365 session_id: "s1".to_string(),
366 iteration: 1,
367 result: ToolResultSummary {
368 call_id: "c1".to_string(),
369 tool_name: "read_file".to_string(),
370 output: json!({"content": "file contents here"}),
371 },
372 };
373 let parts = to_ui_stream_parts(&event);
374 assert_eq!(parts.len(), 1);
375 assert_eq!(
376 parts[0],
377 UiStreamPart::ToolOutputAvailable {
378 tool_call_id: "c1".to_string(),
379 output: json!({"content": "file contents here"}),
380 }
381 );
382 }
383
384 #[test]
385 fn tool_failed_maps_to_output_with_error() {
386 let event = AgentEvent::ToolCallFailed {
387 run_id: "r1".to_string(),
388 session_id: "s1".to_string(),
389 iteration: 1,
390 call_id: "c1".to_string(),
391 tool_name: "bash".to_string(),
392 error: "command not found".to_string(),
393 };
394 let parts = to_ui_stream_parts(&event);
395 assert_eq!(parts.len(), 1);
396 assert_eq!(
397 parts[0],
398 UiStreamPart::ToolOutputAvailable {
399 tool_call_id: "c1".to_string(),
400 output: json!({"error": "command not found"}),
401 }
402 );
403 }
404
405 #[test]
406 fn state_patched_maps_to_data_extension() {
407 let event = AgentEvent::StatePatched {
408 run_id: "r1".to_string(),
409 session_id: "s1".to_string(),
410 iteration: 1,
411 patch: StatePatch {
412 format: StatePatchFormat::MergePatch,
413 patch: json!({"cwd": "/new"}),
414 source: StatePatchSource::System,
415 },
416 revision: 5,
417 };
418 let parts = to_ui_stream_parts(&event);
419 assert_eq!(parts.len(), 1);
420 assert_eq!(
421 parts[0],
422 UiStreamPart::DataStatePatch {
423 data: json!({"patch": {"cwd": "/new"}, "revision": 5}),
424 }
425 );
426 }
427
428 #[test]
429 fn run_errored_maps_to_error() {
430 let event = AgentEvent::RunErrored {
431 run_id: "r1".to_string(),
432 session_id: "s1".to_string(),
433 error: "provider timeout".to_string(),
434 };
435 let parts = to_ui_stream_parts(&event);
436 assert_eq!(parts.len(), 1);
437 assert_eq!(
438 parts[0],
439 UiStreamPart::Error {
440 error_text: "provider timeout".to_string()
441 }
442 );
443 }
444
445 #[test]
446 fn run_finished_maps_to_finish() {
447 let event = AgentEvent::RunFinished {
448 run_id: "r1".to_string(),
449 session_id: "s1".to_string(),
450 reason: RunStopReason::Completed,
451 total_iterations: 3,
452 final_answer: None,
453 usage: None,
454 };
455 let parts = to_ui_stream_parts(&event);
456 assert_eq!(parts.len(), 1);
457 assert_eq!(parts[0], UiStreamPart::Finish {});
458 }
459
460 #[test]
461 fn run_finished_with_final_answer_emits_text_then_finish() {
462 let event = AgentEvent::RunFinished {
463 run_id: "r1".to_string(),
464 session_id: "s1".to_string(),
465 reason: RunStopReason::Completed,
466 total_iterations: 1,
467 final_answer: Some("Done!".to_string()),
468 usage: None,
469 };
470 let parts = to_ui_stream_parts(&event);
471 assert_eq!(parts.len(), 2);
472 assert_eq!(
473 parts[0],
474 UiStreamPart::TextDelta {
475 id: "r1-text".to_string(),
476 delta: "Done!".to_string(),
477 }
478 );
479 assert_eq!(parts[1], UiStreamPart::Finish {});
480 }
481
482 #[test]
483 fn context_compacted_produces_empty() {
484 let event = AgentEvent::ContextCompacted {
485 run_id: "r1".to_string(),
486 session_id: "s1".to_string(),
487 iteration: 1,
488 dropped_count: 5,
489 tokens_before: 1000,
490 tokens_after: 500,
491 };
492 assert!(to_ui_stream_parts(&event).is_empty());
493 }
494
495 #[test]
496 fn approval_requested_maps_to_data_approval() {
497 let event = AgentEvent::ApprovalRequested {
498 run_id: "r1".to_string(),
499 session_id: "s1".to_string(),
500 approval_id: "appr-1".to_string(),
501 call_id: "c1".to_string(),
502 tool_name: "bash".to_string(),
503 arguments: json!({"command": "rm -rf /"}),
504 risk: "high".to_string(),
505 };
506 let parts = to_ui_stream_parts(&event);
507 assert_eq!(parts.len(), 1);
508 match &parts[0] {
509 UiStreamPart::DataApprovalRequest { data } => {
510 assert_eq!(data["approvalId"], "appr-1");
511 assert_eq!(data["toolCallId"], "c1");
512 assert_eq!(data["toolName"], "bash");
513 assert_eq!(data["risk"], "high");
514 }
515 other => panic!("Expected DataApprovalRequest, got {:?}", other),
516 }
517 }
518
519 #[test]
520 fn v6_wire_format_serialization() {
521 let start = UiStreamPart::Start {
523 message_id: "m1".to_string(),
524 };
525 let json = serde_json::to_string(&start).unwrap();
526 assert!(json.contains(r#""type":"start""#));
527 assert!(json.contains(r#""messageId":"m1""#));
528
529 let text = UiStreamPart::TextDelta {
530 id: "t1".to_string(),
531 delta: "hi".to_string(),
532 };
533 let json = serde_json::to_string(&text).unwrap();
534 assert!(json.contains(r#""type":"text-delta""#));
535 assert!(json.contains(r#""delta":"hi""#));
536
537 let tool = UiStreamPart::ToolInputStart {
538 tool_call_id: "c1".to_string(),
539 tool_name: "bash".to_string(),
540 };
541 let json = serde_json::to_string(&tool).unwrap();
542 assert!(json.contains(r#""type":"tool-input-start""#));
543 assert!(json.contains(r#""toolCallId":"c1""#));
544 assert!(json.contains(r#""toolName":"bash""#));
545
546 let error = UiStreamPart::Error {
547 error_text: "boom".to_string(),
548 };
549 let json = serde_json::to_string(&error).unwrap();
550 assert!(json.contains(r#""type":"error""#));
551 assert!(json.contains(r#""errorText":"boom""#));
552
553 let ext = UiStreamPart::DataStatePatch {
554 data: json!({"patch": {}}),
555 };
556 let json = serde_json::to_string(&ext).unwrap();
557 assert!(json.contains(r#""type":"data-state-patch""#));
558 }
559
560 #[test]
561 fn sse_wire_format() {
562 let part = UiStreamPart::TextDelta {
563 id: "t1".to_string(),
564 delta: "hello".to_string(),
565 };
566 let sse = ui_stream_part_to_sse(&part).unwrap();
567 assert!(sse.starts_with("data: "));
568 assert!(sse.ends_with("\n\n"));
569 assert!(sse.contains("text-delta"));
570 assert!(sse.contains("hello"));
571 }
572
573 #[test]
574 fn round_trip_serialization() {
575 let parts = vec![
576 UiStreamPart::Start {
577 message_id: "m1".to_string(),
578 },
579 UiStreamPart::TextDelta {
580 id: "t1".to_string(),
581 delta: "hi".to_string(),
582 },
583 UiStreamPart::Finish {},
584 UiStreamPart::ToolInputAvailable {
585 tool_call_id: "c1".to_string(),
586 tool_name: "bash".to_string(),
587 input: json!({"cmd": "ls"}),
588 },
589 UiStreamPart::Error {
590 error_text: "oops".to_string(),
591 },
592 ];
593
594 for part in &parts {
595 let json = serde_json::to_string(part).unwrap();
596 let decoded: UiStreamPart = serde_json::from_str(&json).unwrap();
597 assert_eq!(*part, decoded);
598 }
599 }
600
601 #[test]
604 fn deprecated_to_aisdk_parts_still_works() {
605 let event = AgentEvent::TextDelta {
606 run_id: "r1".to_string(),
607 session_id: "s1".to_string(),
608 iteration: 1,
609 delta: "test".to_string(),
610 };
611 let parts = to_aisdk_parts(&event);
612 assert_eq!(parts.len(), 1);
613 }
614}