Skip to main content

arcan_core/
aisdk.rs

1use crate::protocol::AgentEvent;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4
5// ─── AI SDK v6 UI Message Stream Protocol ───────────────────────
6//
7// Spec: https://sdk.vercel.ai/docs/ai-sdk-ui/stream-protocol
8// Header: x-vercel-ai-ui-message-stream: v1
9// Transport: SSE, data: {json}\n\n
10// Termination: data: [DONE]
11
12/// Vercel AI SDK v6 "UI Message Stream Protocol" part.
13///
14/// Each variant maps to a v6 stream part type. Custom Arcan extensions
15/// use the `data-*` namespace per spec.
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(tag = "type", rename_all = "kebab-case")]
18pub enum UiStreamPart {
19    // ── Control ──
20    Start {
21        #[serde(rename = "messageId")]
22        message_id: String,
23    },
24    Finish {},
25    StartStep {},
26    FinishStep {},
27    Abort {
28        reason: String,
29    },
30
31    // ── Text ──
32    TextStart {
33        id: String,
34    },
35    TextDelta {
36        id: String,
37        delta: String,
38    },
39    TextEnd {
40        id: String,
41    },
42
43    // ── Reasoning ──
44    ReasoningStart {
45        id: String,
46    },
47    ReasoningDelta {
48        id: String,
49        delta: String,
50    },
51    ReasoningEnd {
52        id: String,
53    },
54
55    // ── Tool ──
56    ToolInputStart {
57        #[serde(rename = "toolCallId")]
58        tool_call_id: String,
59        #[serde(rename = "toolName")]
60        tool_name: String,
61    },
62    ToolInputDelta {
63        #[serde(rename = "toolCallId")]
64        tool_call_id: String,
65        #[serde(rename = "inputTextDelta")]
66        input_text_delta: String,
67    },
68    ToolInputAvailable {
69        #[serde(rename = "toolCallId")]
70        tool_call_id: String,
71        #[serde(rename = "toolName")]
72        tool_name: String,
73        input: Value,
74    },
75    ToolOutputAvailable {
76        #[serde(rename = "toolCallId")]
77        tool_call_id: String,
78        output: Value,
79    },
80
81    // ── Error ──
82    Error {
83        #[serde(rename = "errorText")]
84        error_text: String,
85    },
86
87    // ── Arcan Extensions (data-* namespace) ──
88    #[serde(rename = "data-state-patch")]
89    DataStatePatch {
90        data: Value,
91    },
92    #[serde(rename = "data-approval-request")]
93    DataApprovalRequest {
94        data: Value,
95    },
96}
97
98/// Convert an `AgentEvent` into zero or more `UiStreamPart`s (v6 protocol).
99///
100/// This is a stateless mapping. Text boundary tracking (TextStart/TextEnd)
101/// is handled by the SSE bridge in server.rs, which wraps consecutive
102/// TextDelta events with boundary signals.
103pub fn to_ui_stream_parts(event: &AgentEvent) -> Vec<UiStreamPart> {
104    match event {
105        AgentEvent::RunStarted { run_id, .. } => {
106            vec![UiStreamPart::Start {
107                message_id: run_id.clone(),
108            }]
109        }
110
111        AgentEvent::IterationStarted { .. } => {
112            vec![UiStreamPart::StartStep {}]
113        }
114
115        AgentEvent::ModelOutput { .. } => {
116            vec![UiStreamPart::FinishStep {}]
117        }
118
119        AgentEvent::TextDelta { run_id, delta, .. } => {
120            vec![UiStreamPart::TextDelta {
121                id: format!("{run_id}-text"),
122                delta: delta.clone(),
123            }]
124        }
125
126        AgentEvent::ToolCallRequested { call, .. } => {
127            let args_json = serde_json::to_string(&call.input).unwrap_or_default();
128            vec![
129                UiStreamPart::ToolInputStart {
130                    tool_call_id: call.call_id.clone(),
131                    tool_name: call.tool_name.clone(),
132                },
133                UiStreamPart::ToolInputDelta {
134                    tool_call_id: call.call_id.clone(),
135                    input_text_delta: args_json,
136                },
137                UiStreamPart::ToolInputAvailable {
138                    tool_call_id: call.call_id.clone(),
139                    tool_name: call.tool_name.clone(),
140                    input: call.input.clone(),
141                },
142            ]
143        }
144
145        AgentEvent::ToolCallCompleted { result, .. } => {
146            vec![UiStreamPart::ToolOutputAvailable {
147                tool_call_id: result.call_id.clone(),
148                output: result.output.clone(),
149            }]
150        }
151
152        AgentEvent::ToolCallFailed { call_id, error, .. } => {
153            vec![UiStreamPart::ToolOutputAvailable {
154                tool_call_id: call_id.clone(),
155                output: serde_json::json!({ "error": error }),
156            }]
157        }
158
159        AgentEvent::StatePatched {
160            patch, revision, ..
161        } => {
162            vec![UiStreamPart::DataStatePatch {
163                data: serde_json::json!({
164                    "patch": patch.patch,
165                    "revision": revision,
166                }),
167            }]
168        }
169
170        AgentEvent::RunErrored { error, .. } => {
171            vec![UiStreamPart::Error {
172                error_text: error.clone(),
173            }]
174        }
175
176        AgentEvent::RunFinished {
177            run_id,
178            final_answer,
179            ..
180        } => {
181            let mut parts = Vec::new();
182            if let Some(answer) = final_answer
183                && !answer.is_empty()
184            {
185                let text_id = format!("{run_id}-text");
186                parts.push(UiStreamPart::TextDelta {
187                    id: text_id.clone(),
188                    delta: answer.clone(),
189                });
190            }
191            parts.push(UiStreamPart::Finish {});
192            parts
193        }
194
195        AgentEvent::ApprovalRequested {
196            approval_id,
197            call_id,
198            tool_name,
199            arguments,
200            risk,
201            ..
202        } => {
203            vec![UiStreamPart::DataApprovalRequest {
204                data: serde_json::json!({
205                    "approvalId": approval_id,
206                    "toolCallId": call_id,
207                    "toolName": tool_name,
208                    "arguments": arguments,
209                    "risk": risk,
210                }),
211            }]
212        }
213
214        // Events with no UI representation
215        AgentEvent::ContextCompacted { .. } | AgentEvent::ApprovalResolved { .. } => {
216            vec![]
217        }
218        // Forward-compatible catch-all: new variants added to AgentEvent that
219        // don't yet have a UI mapping are silently dropped.
220        #[allow(unreachable_patterns)]
221        _ => vec![],
222    }
223}
224
225/// Serialize a `UiStreamPart` to the SSE wire format.
226pub fn ui_stream_part_to_sse(part: &UiStreamPart) -> Result<String, serde_json::Error> {
227    let json = serde_json::to_string(part)?;
228    Ok(format!("data: {json}\n\n"))
229}
230
231// ─── Deprecated v5 aliases (will be removed) ────────────────────
232
233/// Deprecated: use `UiStreamPart` instead.
234pub type AiSdkPart = UiStreamPart;
235
236/// Deprecated: use `to_ui_stream_parts` instead.
237pub fn to_aisdk_parts(event: &AgentEvent) -> Vec<UiStreamPart> {
238    to_ui_stream_parts(event)
239}
240
241/// Deprecated: use `ui_stream_part_to_sse` instead.
242pub fn aisdk_part_to_sse(part: &UiStreamPart) -> Result<String, serde_json::Error> {
243    ui_stream_part_to_sse(part)
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::protocol::{
250        ModelStopReason, RunStopReason, StatePatch, StatePatchFormat, StatePatchSource, ToolCall,
251        ToolResultSummary,
252    };
253    use serde_json::json;
254
255    #[test]
256    fn run_started_maps_to_start() {
257        let event = AgentEvent::RunStarted {
258            run_id: "r1".to_string(),
259            session_id: "s1".to_string(),
260            provider: "anthropic".to_string(),
261            max_iterations: 10,
262        };
263        let parts = to_ui_stream_parts(&event);
264        assert_eq!(parts.len(), 1);
265        assert_eq!(
266            parts[0],
267            UiStreamPart::Start {
268                message_id: "r1".to_string()
269            }
270        );
271    }
272
273    #[test]
274    fn iteration_started_maps_to_start_step() {
275        let event = AgentEvent::IterationStarted {
276            run_id: "r1".to_string(),
277            session_id: "s1".to_string(),
278            iteration: 1,
279        };
280        let parts = to_ui_stream_parts(&event);
281        assert_eq!(parts.len(), 1);
282        assert_eq!(parts[0], UiStreamPart::StartStep {});
283    }
284
285    #[test]
286    fn model_output_maps_to_finish_step() {
287        let event = AgentEvent::ModelOutput {
288            run_id: "r1".to_string(),
289            session_id: "s1".to_string(),
290            iteration: 1,
291            stop_reason: ModelStopReason::EndTurn,
292            directive_count: 0,
293            usage: None,
294        };
295        let parts = to_ui_stream_parts(&event);
296        assert_eq!(parts.len(), 1);
297        assert_eq!(parts[0], UiStreamPart::FinishStep {});
298    }
299
300    #[test]
301    fn text_delta_includes_id() {
302        let event = AgentEvent::TextDelta {
303            run_id: "r1".to_string(),
304            session_id: "s1".to_string(),
305            iteration: 1,
306            delta: "Hello ".to_string(),
307        };
308        let parts = to_ui_stream_parts(&event);
309        assert_eq!(parts.len(), 1);
310        assert_eq!(
311            parts[0],
312            UiStreamPart::TextDelta {
313                id: "r1-text".to_string(),
314                delta: "Hello ".to_string(),
315            }
316        );
317    }
318
319    #[test]
320    fn tool_call_produces_input_start_delta_available() {
321        let event = AgentEvent::ToolCallRequested {
322            run_id: "r1".to_string(),
323            session_id: "s1".to_string(),
324            iteration: 1,
325            call: ToolCall {
326                call_id: "c1".to_string(),
327                tool_name: "read_file".to_string(),
328                input: json!({"path": "/tmp/test.rs"}),
329            },
330        };
331        let parts = to_ui_stream_parts(&event);
332        assert_eq!(parts.len(), 3);
333
334        assert_eq!(
335            parts[0],
336            UiStreamPart::ToolInputStart {
337                tool_call_id: "c1".to_string(),
338                tool_name: "read_file".to_string(),
339            }
340        );
341        match &parts[1] {
342            UiStreamPart::ToolInputDelta {
343                tool_call_id,
344                input_text_delta,
345            } => {
346                assert_eq!(tool_call_id, "c1");
347                assert!(input_text_delta.contains("path"));
348            }
349            other => panic!("Expected ToolInputDelta, got {:?}", other),
350        }
351        assert_eq!(
352            parts[2],
353            UiStreamPart::ToolInputAvailable {
354                tool_call_id: "c1".to_string(),
355                tool_name: "read_file".to_string(),
356                input: json!({"path": "/tmp/test.rs"}),
357            }
358        );
359    }
360
361    #[test]
362    fn tool_completed_maps_to_output_available() {
363        let event = AgentEvent::ToolCallCompleted {
364            run_id: "r1".to_string(),
365            session_id: "s1".to_string(),
366            iteration: 1,
367            result: ToolResultSummary {
368                call_id: "c1".to_string(),
369                tool_name: "read_file".to_string(),
370                output: json!({"content": "file contents here"}),
371            },
372        };
373        let parts = to_ui_stream_parts(&event);
374        assert_eq!(parts.len(), 1);
375        assert_eq!(
376            parts[0],
377            UiStreamPart::ToolOutputAvailable {
378                tool_call_id: "c1".to_string(),
379                output: json!({"content": "file contents here"}),
380            }
381        );
382    }
383
384    #[test]
385    fn tool_failed_maps_to_output_with_error() {
386        let event = AgentEvent::ToolCallFailed {
387            run_id: "r1".to_string(),
388            session_id: "s1".to_string(),
389            iteration: 1,
390            call_id: "c1".to_string(),
391            tool_name: "bash".to_string(),
392            error: "command not found".to_string(),
393        };
394        let parts = to_ui_stream_parts(&event);
395        assert_eq!(parts.len(), 1);
396        assert_eq!(
397            parts[0],
398            UiStreamPart::ToolOutputAvailable {
399                tool_call_id: "c1".to_string(),
400                output: json!({"error": "command not found"}),
401            }
402        );
403    }
404
405    #[test]
406    fn state_patched_maps_to_data_extension() {
407        let event = AgentEvent::StatePatched {
408            run_id: "r1".to_string(),
409            session_id: "s1".to_string(),
410            iteration: 1,
411            patch: StatePatch {
412                format: StatePatchFormat::MergePatch,
413                patch: json!({"cwd": "/new"}),
414                source: StatePatchSource::System,
415            },
416            revision: 5,
417        };
418        let parts = to_ui_stream_parts(&event);
419        assert_eq!(parts.len(), 1);
420        assert_eq!(
421            parts[0],
422            UiStreamPart::DataStatePatch {
423                data: json!({"patch": {"cwd": "/new"}, "revision": 5}),
424            }
425        );
426    }
427
428    #[test]
429    fn run_errored_maps_to_error() {
430        let event = AgentEvent::RunErrored {
431            run_id: "r1".to_string(),
432            session_id: "s1".to_string(),
433            error: "provider timeout".to_string(),
434        };
435        let parts = to_ui_stream_parts(&event);
436        assert_eq!(parts.len(), 1);
437        assert_eq!(
438            parts[0],
439            UiStreamPart::Error {
440                error_text: "provider timeout".to_string()
441            }
442        );
443    }
444
445    #[test]
446    fn run_finished_maps_to_finish() {
447        let event = AgentEvent::RunFinished {
448            run_id: "r1".to_string(),
449            session_id: "s1".to_string(),
450            reason: RunStopReason::Completed,
451            total_iterations: 3,
452            final_answer: None,
453            usage: None,
454        };
455        let parts = to_ui_stream_parts(&event);
456        assert_eq!(parts.len(), 1);
457        assert_eq!(parts[0], UiStreamPart::Finish {});
458    }
459
460    #[test]
461    fn run_finished_with_final_answer_emits_text_then_finish() {
462        let event = AgentEvent::RunFinished {
463            run_id: "r1".to_string(),
464            session_id: "s1".to_string(),
465            reason: RunStopReason::Completed,
466            total_iterations: 1,
467            final_answer: Some("Done!".to_string()),
468            usage: None,
469        };
470        let parts = to_ui_stream_parts(&event);
471        assert_eq!(parts.len(), 2);
472        assert_eq!(
473            parts[0],
474            UiStreamPart::TextDelta {
475                id: "r1-text".to_string(),
476                delta: "Done!".to_string(),
477            }
478        );
479        assert_eq!(parts[1], UiStreamPart::Finish {});
480    }
481
482    #[test]
483    fn context_compacted_produces_empty() {
484        let event = AgentEvent::ContextCompacted {
485            run_id: "r1".to_string(),
486            session_id: "s1".to_string(),
487            iteration: 1,
488            dropped_count: 5,
489            tokens_before: 1000,
490            tokens_after: 500,
491        };
492        assert!(to_ui_stream_parts(&event).is_empty());
493    }
494
495    #[test]
496    fn approval_requested_maps_to_data_approval() {
497        let event = AgentEvent::ApprovalRequested {
498            run_id: "r1".to_string(),
499            session_id: "s1".to_string(),
500            approval_id: "appr-1".to_string(),
501            call_id: "c1".to_string(),
502            tool_name: "bash".to_string(),
503            arguments: json!({"command": "rm -rf /"}),
504            risk: "high".to_string(),
505        };
506        let parts = to_ui_stream_parts(&event);
507        assert_eq!(parts.len(), 1);
508        match &parts[0] {
509            UiStreamPart::DataApprovalRequest { data } => {
510                assert_eq!(data["approvalId"], "appr-1");
511                assert_eq!(data["toolCallId"], "c1");
512                assert_eq!(data["toolName"], "bash");
513                assert_eq!(data["risk"], "high");
514            }
515            other => panic!("Expected DataApprovalRequest, got {:?}", other),
516        }
517    }
518
519    #[test]
520    fn v6_wire_format_serialization() {
521        // Verify exact JSON shapes match v6 spec
522        let start = UiStreamPart::Start {
523            message_id: "m1".to_string(),
524        };
525        let json = serde_json::to_string(&start).unwrap();
526        assert!(json.contains(r#""type":"start""#));
527        assert!(json.contains(r#""messageId":"m1""#));
528
529        let text = UiStreamPart::TextDelta {
530            id: "t1".to_string(),
531            delta: "hi".to_string(),
532        };
533        let json = serde_json::to_string(&text).unwrap();
534        assert!(json.contains(r#""type":"text-delta""#));
535        assert!(json.contains(r#""delta":"hi""#));
536
537        let tool = UiStreamPart::ToolInputStart {
538            tool_call_id: "c1".to_string(),
539            tool_name: "bash".to_string(),
540        };
541        let json = serde_json::to_string(&tool).unwrap();
542        assert!(json.contains(r#""type":"tool-input-start""#));
543        assert!(json.contains(r#""toolCallId":"c1""#));
544        assert!(json.contains(r#""toolName":"bash""#));
545
546        let error = UiStreamPart::Error {
547            error_text: "boom".to_string(),
548        };
549        let json = serde_json::to_string(&error).unwrap();
550        assert!(json.contains(r#""type":"error""#));
551        assert!(json.contains(r#""errorText":"boom""#));
552
553        let ext = UiStreamPart::DataStatePatch {
554            data: json!({"patch": {}}),
555        };
556        let json = serde_json::to_string(&ext).unwrap();
557        assert!(json.contains(r#""type":"data-state-patch""#));
558    }
559
560    #[test]
561    fn sse_wire_format() {
562        let part = UiStreamPart::TextDelta {
563            id: "t1".to_string(),
564            delta: "hello".to_string(),
565        };
566        let sse = ui_stream_part_to_sse(&part).unwrap();
567        assert!(sse.starts_with("data: "));
568        assert!(sse.ends_with("\n\n"));
569        assert!(sse.contains("text-delta"));
570        assert!(sse.contains("hello"));
571    }
572
573    #[test]
574    fn round_trip_serialization() {
575        let parts = vec![
576            UiStreamPart::Start {
577                message_id: "m1".to_string(),
578            },
579            UiStreamPart::TextDelta {
580                id: "t1".to_string(),
581                delta: "hi".to_string(),
582            },
583            UiStreamPart::Finish {},
584            UiStreamPart::ToolInputAvailable {
585                tool_call_id: "c1".to_string(),
586                tool_name: "bash".to_string(),
587                input: json!({"cmd": "ls"}),
588            },
589            UiStreamPart::Error {
590                error_text: "oops".to_string(),
591            },
592        ];
593
594        for part in &parts {
595            let json = serde_json::to_string(part).unwrap();
596            let decoded: UiStreamPart = serde_json::from_str(&json).unwrap();
597            assert_eq!(*part, decoded);
598        }
599    }
600
601    // ── Deprecated alias tests ──
602
603    #[test]
604    fn deprecated_to_aisdk_parts_still_works() {
605        let event = AgentEvent::TextDelta {
606            run_id: "r1".to_string(),
607            session_id: "s1".to_string(),
608            iteration: 1,
609            delta: "test".to_string(),
610        };
611        let parts = to_aisdk_parts(&event);
612        assert_eq!(parts.len(), 1);
613    }
614}