Skip to main content

arcan_core/
aisdk.rs

1use crate::protocol::AgentEvent;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4
5// ─── AI SDK v6 UI Message Stream Protocol ───────────────────────
6//
7// Spec: https://sdk.vercel.ai/docs/ai-sdk-ui/stream-protocol
8// Header: x-vercel-ai-ui-message-stream: v1
9// Transport: SSE, data: {json}\n\n
10// Termination: data: [DONE]
11
12/// Vercel AI SDK v6 "UI Message Stream Protocol" part.
13///
14/// Each variant maps to a v6 stream part type. Custom Arcan extensions
15/// use the `data-*` namespace per spec.
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(tag = "type", rename_all = "kebab-case")]
18pub enum UiStreamPart {
19    // ── Control ──
20    Start {
21        #[serde(rename = "messageId")]
22        message_id: String,
23    },
24    Finish {},
25    StartStep {},
26    FinishStep {},
27    Abort {
28        reason: String,
29    },
30
31    // ── Text ──
32    TextStart {
33        id: String,
34    },
35    TextDelta {
36        id: String,
37        delta: String,
38    },
39    TextEnd {
40        id: String,
41    },
42
43    // ── Reasoning ──
44    ReasoningStart {
45        id: String,
46    },
47    ReasoningDelta {
48        id: String,
49        delta: String,
50    },
51    ReasoningEnd {
52        id: String,
53    },
54
55    // ── Tool ──
56    ToolInputStart {
57        #[serde(rename = "toolCallId")]
58        tool_call_id: String,
59        #[serde(rename = "toolName")]
60        tool_name: String,
61    },
62    ToolInputDelta {
63        #[serde(rename = "toolCallId")]
64        tool_call_id: String,
65        #[serde(rename = "inputTextDelta")]
66        input_text_delta: String,
67    },
68    ToolInputAvailable {
69        #[serde(rename = "toolCallId")]
70        tool_call_id: String,
71        #[serde(rename = "toolName")]
72        tool_name: String,
73        input: Value,
74    },
75    ToolOutputAvailable {
76        #[serde(rename = "toolCallId")]
77        tool_call_id: String,
78        output: Value,
79    },
80
81    // ── Error ──
82    Error {
83        #[serde(rename = "errorText")]
84        error_text: String,
85    },
86
87    // ── Arcan Extensions (data-* namespace) ──
88    #[serde(rename = "data-state-patch")]
89    DataStatePatch {
90        data: Value,
91    },
92    #[serde(rename = "data-approval-request")]
93    DataApprovalRequest {
94        data: Value,
95    },
96}
97
98/// Convert an `AgentEvent` into zero or more `UiStreamPart`s (v6 protocol).
99///
100/// This is a stateless mapping. Text boundary tracking (TextStart/TextEnd)
101/// is handled by the SSE bridge in server.rs, which wraps consecutive
102/// TextDelta events with boundary signals.
103pub fn to_ui_stream_parts(event: &AgentEvent) -> Vec<UiStreamPart> {
104    match event {
105        AgentEvent::RunStarted { run_id, .. } => {
106            vec![UiStreamPart::Start {
107                message_id: run_id.clone(),
108            }]
109        }
110
111        AgentEvent::IterationStarted { .. } => {
112            vec![UiStreamPart::StartStep {}]
113        }
114
115        AgentEvent::ModelOutput { .. } => {
116            vec![UiStreamPart::FinishStep {}]
117        }
118
119        AgentEvent::TextDelta { run_id, delta, .. } => {
120            vec![UiStreamPart::TextDelta {
121                id: format!("{run_id}-text"),
122                delta: delta.clone(),
123            }]
124        }
125
126        AgentEvent::ToolCallRequested { call, .. } => {
127            let args_json = serde_json::to_string(&call.input).unwrap_or_default();
128            vec![
129                UiStreamPart::ToolInputStart {
130                    tool_call_id: call.call_id.clone(),
131                    tool_name: call.tool_name.clone(),
132                },
133                UiStreamPart::ToolInputDelta {
134                    tool_call_id: call.call_id.clone(),
135                    input_text_delta: args_json,
136                },
137                UiStreamPart::ToolInputAvailable {
138                    tool_call_id: call.call_id.clone(),
139                    tool_name: call.tool_name.clone(),
140                    input: call.input.clone(),
141                },
142            ]
143        }
144
145        AgentEvent::ToolCallCompleted { result, .. } => {
146            vec![UiStreamPart::ToolOutputAvailable {
147                tool_call_id: result.call_id.clone(),
148                output: result.output.clone(),
149            }]
150        }
151
152        AgentEvent::ToolCallFailed { call_id, error, .. } => {
153            vec![UiStreamPart::ToolOutputAvailable {
154                tool_call_id: call_id.clone(),
155                output: serde_json::json!({ "error": error }),
156            }]
157        }
158
159        AgentEvent::StatePatched {
160            patch, revision, ..
161        } => {
162            vec![UiStreamPart::DataStatePatch {
163                data: serde_json::json!({
164                    "patch": patch.patch,
165                    "revision": revision,
166                }),
167            }]
168        }
169
170        AgentEvent::RunErrored { error, .. } => {
171            vec![UiStreamPart::Error {
172                error_text: error.clone(),
173            }]
174        }
175
176        AgentEvent::RunFinished {
177            run_id,
178            final_answer,
179            ..
180        } => {
181            let mut parts = Vec::new();
182            if let Some(answer) = final_answer {
183                if !answer.is_empty() {
184                    let text_id = format!("{run_id}-text");
185                    parts.push(UiStreamPart::TextDelta {
186                        id: text_id.clone(),
187                        delta: answer.clone(),
188                    });
189                }
190            }
191            parts.push(UiStreamPart::Finish {});
192            parts
193        }
194
195        AgentEvent::ApprovalRequested {
196            approval_id,
197            call_id,
198            tool_name,
199            arguments,
200            risk,
201            ..
202        } => {
203            vec![UiStreamPart::DataApprovalRequest {
204                data: serde_json::json!({
205                    "approvalId": approval_id,
206                    "toolCallId": call_id,
207                    "toolName": tool_name,
208                    "arguments": arguments,
209                    "risk": risk,
210                }),
211            }]
212        }
213
214        // Events with no UI representation
215        AgentEvent::ContextCompacted { .. } | AgentEvent::ApprovalResolved { .. } => {
216            vec![]
217        }
218    }
219}
220
221/// Serialize a `UiStreamPart` to the SSE wire format.
222pub fn ui_stream_part_to_sse(part: &UiStreamPart) -> Result<String, serde_json::Error> {
223    let json = serde_json::to_string(part)?;
224    Ok(format!("data: {json}\n\n"))
225}
226
227// ─── Deprecated v5 aliases (will be removed) ────────────────────
228
229/// Deprecated: use `UiStreamPart` instead.
230pub type AiSdkPart = UiStreamPart;
231
232/// Deprecated: use `to_ui_stream_parts` instead.
233pub fn to_aisdk_parts(event: &AgentEvent) -> Vec<UiStreamPart> {
234    to_ui_stream_parts(event)
235}
236
237/// Deprecated: use `ui_stream_part_to_sse` instead.
238pub fn aisdk_part_to_sse(part: &UiStreamPart) -> Result<String, serde_json::Error> {
239    ui_stream_part_to_sse(part)
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::protocol::{
246        ModelStopReason, RunStopReason, StatePatch, StatePatchFormat, StatePatchSource, ToolCall,
247        ToolResultSummary,
248    };
249    use serde_json::json;
250
251    #[test]
252    fn run_started_maps_to_start() {
253        let event = AgentEvent::RunStarted {
254            run_id: "r1".to_string(),
255            session_id: "s1".to_string(),
256            provider: "anthropic".to_string(),
257            max_iterations: 10,
258        };
259        let parts = to_ui_stream_parts(&event);
260        assert_eq!(parts.len(), 1);
261        assert_eq!(
262            parts[0],
263            UiStreamPart::Start {
264                message_id: "r1".to_string()
265            }
266        );
267    }
268
269    #[test]
270    fn iteration_started_maps_to_start_step() {
271        let event = AgentEvent::IterationStarted {
272            run_id: "r1".to_string(),
273            session_id: "s1".to_string(),
274            iteration: 1,
275        };
276        let parts = to_ui_stream_parts(&event);
277        assert_eq!(parts.len(), 1);
278        assert_eq!(parts[0], UiStreamPart::StartStep {});
279    }
280
281    #[test]
282    fn model_output_maps_to_finish_step() {
283        let event = AgentEvent::ModelOutput {
284            run_id: "r1".to_string(),
285            session_id: "s1".to_string(),
286            iteration: 1,
287            stop_reason: ModelStopReason::EndTurn,
288            directive_count: 0,
289            usage: None,
290        };
291        let parts = to_ui_stream_parts(&event);
292        assert_eq!(parts.len(), 1);
293        assert_eq!(parts[0], UiStreamPart::FinishStep {});
294    }
295
296    #[test]
297    fn text_delta_includes_id() {
298        let event = AgentEvent::TextDelta {
299            run_id: "r1".to_string(),
300            session_id: "s1".to_string(),
301            iteration: 1,
302            delta: "Hello ".to_string(),
303        };
304        let parts = to_ui_stream_parts(&event);
305        assert_eq!(parts.len(), 1);
306        assert_eq!(
307            parts[0],
308            UiStreamPart::TextDelta {
309                id: "r1-text".to_string(),
310                delta: "Hello ".to_string(),
311            }
312        );
313    }
314
315    #[test]
316    fn tool_call_produces_input_start_delta_available() {
317        let event = AgentEvent::ToolCallRequested {
318            run_id: "r1".to_string(),
319            session_id: "s1".to_string(),
320            iteration: 1,
321            call: ToolCall {
322                call_id: "c1".to_string(),
323                tool_name: "read_file".to_string(),
324                input: json!({"path": "/tmp/test.rs"}),
325            },
326        };
327        let parts = to_ui_stream_parts(&event);
328        assert_eq!(parts.len(), 3);
329
330        assert_eq!(
331            parts[0],
332            UiStreamPart::ToolInputStart {
333                tool_call_id: "c1".to_string(),
334                tool_name: "read_file".to_string(),
335            }
336        );
337        match &parts[1] {
338            UiStreamPart::ToolInputDelta {
339                tool_call_id,
340                input_text_delta,
341            } => {
342                assert_eq!(tool_call_id, "c1");
343                assert!(input_text_delta.contains("path"));
344            }
345            other => panic!("Expected ToolInputDelta, got {:?}", other),
346        }
347        assert_eq!(
348            parts[2],
349            UiStreamPart::ToolInputAvailable {
350                tool_call_id: "c1".to_string(),
351                tool_name: "read_file".to_string(),
352                input: json!({"path": "/tmp/test.rs"}),
353            }
354        );
355    }
356
357    #[test]
358    fn tool_completed_maps_to_output_available() {
359        let event = AgentEvent::ToolCallCompleted {
360            run_id: "r1".to_string(),
361            session_id: "s1".to_string(),
362            iteration: 1,
363            result: ToolResultSummary {
364                call_id: "c1".to_string(),
365                tool_name: "read_file".to_string(),
366                output: json!({"content": "file contents here"}),
367            },
368        };
369        let parts = to_ui_stream_parts(&event);
370        assert_eq!(parts.len(), 1);
371        assert_eq!(
372            parts[0],
373            UiStreamPart::ToolOutputAvailable {
374                tool_call_id: "c1".to_string(),
375                output: json!({"content": "file contents here"}),
376            }
377        );
378    }
379
380    #[test]
381    fn tool_failed_maps_to_output_with_error() {
382        let event = AgentEvent::ToolCallFailed {
383            run_id: "r1".to_string(),
384            session_id: "s1".to_string(),
385            iteration: 1,
386            call_id: "c1".to_string(),
387            tool_name: "bash".to_string(),
388            error: "command not found".to_string(),
389        };
390        let parts = to_ui_stream_parts(&event);
391        assert_eq!(parts.len(), 1);
392        assert_eq!(
393            parts[0],
394            UiStreamPart::ToolOutputAvailable {
395                tool_call_id: "c1".to_string(),
396                output: json!({"error": "command not found"}),
397            }
398        );
399    }
400
401    #[test]
402    fn state_patched_maps_to_data_extension() {
403        let event = AgentEvent::StatePatched {
404            run_id: "r1".to_string(),
405            session_id: "s1".to_string(),
406            iteration: 1,
407            patch: StatePatch {
408                format: StatePatchFormat::MergePatch,
409                patch: json!({"cwd": "/new"}),
410                source: StatePatchSource::System,
411            },
412            revision: 5,
413        };
414        let parts = to_ui_stream_parts(&event);
415        assert_eq!(parts.len(), 1);
416        assert_eq!(
417            parts[0],
418            UiStreamPart::DataStatePatch {
419                data: json!({"patch": {"cwd": "/new"}, "revision": 5}),
420            }
421        );
422    }
423
424    #[test]
425    fn run_errored_maps_to_error() {
426        let event = AgentEvent::RunErrored {
427            run_id: "r1".to_string(),
428            session_id: "s1".to_string(),
429            error: "provider timeout".to_string(),
430        };
431        let parts = to_ui_stream_parts(&event);
432        assert_eq!(parts.len(), 1);
433        assert_eq!(
434            parts[0],
435            UiStreamPart::Error {
436                error_text: "provider timeout".to_string()
437            }
438        );
439    }
440
441    #[test]
442    fn run_finished_maps_to_finish() {
443        let event = AgentEvent::RunFinished {
444            run_id: "r1".to_string(),
445            session_id: "s1".to_string(),
446            reason: RunStopReason::Completed,
447            total_iterations: 3,
448            final_answer: None,
449        };
450        let parts = to_ui_stream_parts(&event);
451        assert_eq!(parts.len(), 1);
452        assert_eq!(parts[0], UiStreamPart::Finish {});
453    }
454
455    #[test]
456    fn run_finished_with_final_answer_emits_text_then_finish() {
457        let event = AgentEvent::RunFinished {
458            run_id: "r1".to_string(),
459            session_id: "s1".to_string(),
460            reason: RunStopReason::Completed,
461            total_iterations: 1,
462            final_answer: Some("Done!".to_string()),
463        };
464        let parts = to_ui_stream_parts(&event);
465        assert_eq!(parts.len(), 2);
466        assert_eq!(
467            parts[0],
468            UiStreamPart::TextDelta {
469                id: "r1-text".to_string(),
470                delta: "Done!".to_string(),
471            }
472        );
473        assert_eq!(parts[1], UiStreamPart::Finish {});
474    }
475
476    #[test]
477    fn context_compacted_produces_empty() {
478        let event = AgentEvent::ContextCompacted {
479            run_id: "r1".to_string(),
480            session_id: "s1".to_string(),
481            iteration: 1,
482            dropped_count: 5,
483            tokens_before: 1000,
484            tokens_after: 500,
485        };
486        assert!(to_ui_stream_parts(&event).is_empty());
487    }
488
489    #[test]
490    fn approval_requested_maps_to_data_approval() {
491        let event = AgentEvent::ApprovalRequested {
492            run_id: "r1".to_string(),
493            session_id: "s1".to_string(),
494            approval_id: "appr-1".to_string(),
495            call_id: "c1".to_string(),
496            tool_name: "bash".to_string(),
497            arguments: json!({"command": "rm -rf /"}),
498            risk: "high".to_string(),
499        };
500        let parts = to_ui_stream_parts(&event);
501        assert_eq!(parts.len(), 1);
502        match &parts[0] {
503            UiStreamPart::DataApprovalRequest { data } => {
504                assert_eq!(data["approvalId"], "appr-1");
505                assert_eq!(data["toolCallId"], "c1");
506                assert_eq!(data["toolName"], "bash");
507                assert_eq!(data["risk"], "high");
508            }
509            other => panic!("Expected DataApprovalRequest, got {:?}", other),
510        }
511    }
512
513    #[test]
514    fn v6_wire_format_serialization() {
515        // Verify exact JSON shapes match v6 spec
516        let start = UiStreamPart::Start {
517            message_id: "m1".to_string(),
518        };
519        let json = serde_json::to_string(&start).unwrap();
520        assert!(json.contains(r#""type":"start""#));
521        assert!(json.contains(r#""messageId":"m1""#));
522
523        let text = UiStreamPart::TextDelta {
524            id: "t1".to_string(),
525            delta: "hi".to_string(),
526        };
527        let json = serde_json::to_string(&text).unwrap();
528        assert!(json.contains(r#""type":"text-delta""#));
529        assert!(json.contains(r#""delta":"hi""#));
530
531        let tool = UiStreamPart::ToolInputStart {
532            tool_call_id: "c1".to_string(),
533            tool_name: "bash".to_string(),
534        };
535        let json = serde_json::to_string(&tool).unwrap();
536        assert!(json.contains(r#""type":"tool-input-start""#));
537        assert!(json.contains(r#""toolCallId":"c1""#));
538        assert!(json.contains(r#""toolName":"bash""#));
539
540        let error = UiStreamPart::Error {
541            error_text: "boom".to_string(),
542        };
543        let json = serde_json::to_string(&error).unwrap();
544        assert!(json.contains(r#""type":"error""#));
545        assert!(json.contains(r#""errorText":"boom""#));
546
547        let ext = UiStreamPart::DataStatePatch {
548            data: json!({"patch": {}}),
549        };
550        let json = serde_json::to_string(&ext).unwrap();
551        assert!(json.contains(r#""type":"data-state-patch""#));
552    }
553
554    #[test]
555    fn sse_wire_format() {
556        let part = UiStreamPart::TextDelta {
557            id: "t1".to_string(),
558            delta: "hello".to_string(),
559        };
560        let sse = ui_stream_part_to_sse(&part).unwrap();
561        assert!(sse.starts_with("data: "));
562        assert!(sse.ends_with("\n\n"));
563        assert!(sse.contains("text-delta"));
564        assert!(sse.contains("hello"));
565    }
566
567    #[test]
568    fn round_trip_serialization() {
569        let parts = vec![
570            UiStreamPart::Start {
571                message_id: "m1".to_string(),
572            },
573            UiStreamPart::TextDelta {
574                id: "t1".to_string(),
575                delta: "hi".to_string(),
576            },
577            UiStreamPart::Finish {},
578            UiStreamPart::ToolInputAvailable {
579                tool_call_id: "c1".to_string(),
580                tool_name: "bash".to_string(),
581                input: json!({"cmd": "ls"}),
582            },
583            UiStreamPart::Error {
584                error_text: "oops".to_string(),
585            },
586        ];
587
588        for part in &parts {
589            let json = serde_json::to_string(part).unwrap();
590            let decoded: UiStreamPart = serde_json::from_str(&json).unwrap();
591            assert_eq!(*part, decoded);
592        }
593    }
594
595    // ── Deprecated alias tests ──
596
597    #[test]
598    fn deprecated_to_aisdk_parts_still_works() {
599        let event = AgentEvent::TextDelta {
600            run_id: "r1".to_string(),
601            session_id: "s1".to_string(),
602            iteration: 1,
603            delta: "test".to_string(),
604        };
605        let parts = to_aisdk_parts(&event);
606        assert_eq!(parts.len(), 1);
607    }
608}