Skip to main content

adk_ui/interop/
ag_ui.rs

1use super::surface::UiSurface;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4
5/// Event name used for surface payload transport via AG-UI custom events.
6pub const ADK_UI_SURFACE_EVENT_NAME: &str = "adk.ui.surface";
7
8/// AG-UI event types from the protocol event model.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
11pub enum AgUiEventType {
12    RunStarted,
13    RunFinished,
14    StepStarted,
15    StepFinished,
16    TextMessageStart,
17    TextMessageDelta,
18    TextMessageEnd,
19    ToolCallStart,
20    ToolCallArgs,
21    ToolCallEnd,
22    ToolCallResult,
23    StateSnapshot,
24    StateDelta,
25    Error,
26    Custom,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(rename_all = "camelCase")]
31pub struct AgUiRunStartedEvent {
32    #[serde(rename = "type")]
33    pub event_type: AgUiEventType,
34    pub thread_id: String,
35    pub run_id: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(rename_all = "camelCase")]
40pub struct AgUiRunFinishedEvent {
41    #[serde(rename = "type")]
42    pub event_type: AgUiEventType,
43    pub thread_id: String,
44    pub run_id: String,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(rename_all = "camelCase")]
49pub struct AgUiCustomEvent {
50    #[serde(rename = "type")]
51    pub event_type: AgUiEventType,
52    pub name: String,
53    pub value: Value,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub timestamp: Option<u64>,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub raw_event: Option<Value>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61#[serde(rename_all = "camelCase")]
62pub struct AgUiStepEvent {
63    #[serde(rename = "type")]
64    pub event_type: AgUiEventType,
65    pub thread_id: String,
66    pub run_id: String,
67    pub step_id: String,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub name: Option<String>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74pub struct AgUiTextMessageStartEvent {
75    #[serde(rename = "type")]
76    pub event_type: AgUiEventType,
77    pub thread_id: String,
78    pub run_id: String,
79    pub message_id: String,
80    pub role: String,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(rename_all = "camelCase")]
85pub struct AgUiTextMessageDeltaEvent {
86    #[serde(rename = "type")]
87    pub event_type: AgUiEventType,
88    pub thread_id: String,
89    pub run_id: String,
90    pub message_id: String,
91    pub delta: String,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95#[serde(rename_all = "camelCase")]
96pub struct AgUiTextMessageEndEvent {
97    #[serde(rename = "type")]
98    pub event_type: AgUiEventType,
99    pub thread_id: String,
100    pub run_id: String,
101    pub message_id: String,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
105#[serde(rename_all = "camelCase")]
106pub struct AgUiToolCallStartEvent {
107    #[serde(rename = "type")]
108    pub event_type: AgUiEventType,
109    pub thread_id: String,
110    pub run_id: String,
111    pub tool_call_id: String,
112    pub name: String,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
116#[serde(rename_all = "camelCase")]
117pub struct AgUiToolCallArgsEvent {
118    #[serde(rename = "type")]
119    pub event_type: AgUiEventType,
120    pub thread_id: String,
121    pub run_id: String,
122    pub tool_call_id: String,
123    pub args: Value,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127#[serde(rename_all = "camelCase")]
128pub struct AgUiToolCallEndEvent {
129    #[serde(rename = "type")]
130    pub event_type: AgUiEventType,
131    pub thread_id: String,
132    pub run_id: String,
133    pub tool_call_id: String,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137#[serde(rename_all = "camelCase")]
138pub struct AgUiToolCallResultEvent {
139    #[serde(rename = "type")]
140    pub event_type: AgUiEventType,
141    pub thread_id: String,
142    pub run_id: String,
143    pub tool_call_id: String,
144    pub result: Value,
145    pub is_error: bool,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
149#[serde(rename_all = "camelCase")]
150pub struct AgUiStateSnapshotEvent {
151    #[serde(rename = "type")]
152    pub event_type: AgUiEventType,
153    pub thread_id: String,
154    pub run_id: String,
155    pub state: Value,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
159#[serde(rename_all = "camelCase")]
160pub struct AgUiStateDeltaEvent {
161    #[serde(rename = "type")]
162    pub event_type: AgUiEventType,
163    pub thread_id: String,
164    pub run_id: String,
165    pub delta: Value,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct AgUiErrorEvent {
171    #[serde(rename = "type")]
172    pub event_type: AgUiEventType,
173    pub thread_id: String,
174    pub run_id: String,
175    pub message: String,
176    pub recoverable: bool,
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub code: Option<String>,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182#[serde(untagged)]
183pub enum AgUiEvent {
184    RunStarted(AgUiRunStartedEvent),
185    StepStarted(AgUiStepEvent),
186    StepFinished(AgUiStepEvent),
187    TextMessageStart(AgUiTextMessageStartEvent),
188    TextMessageDelta(AgUiTextMessageDeltaEvent),
189    TextMessageEnd(AgUiTextMessageEndEvent),
190    ToolCallStart(AgUiToolCallStartEvent),
191    ToolCallArgs(AgUiToolCallArgsEvent),
192    ToolCallEnd(AgUiToolCallEndEvent),
193    ToolCallResult(AgUiToolCallResultEvent),
194    StateSnapshot(AgUiStateSnapshotEvent),
195    StateDelta(AgUiStateDeltaEvent),
196    Error(AgUiErrorEvent),
197    Custom(AgUiCustomEvent),
198    RunFinished(AgUiRunFinishedEvent),
199}
200
201pub fn step_started_event(
202    thread_id: impl Into<String>,
203    run_id: impl Into<String>,
204    step_id: impl Into<String>,
205    name: Option<String>,
206) -> AgUiEvent {
207    AgUiEvent::StepStarted(AgUiStepEvent {
208        event_type: AgUiEventType::StepStarted,
209        thread_id: thread_id.into(),
210        run_id: run_id.into(),
211        step_id: step_id.into(),
212        name,
213    })
214}
215
216pub fn step_finished_event(
217    thread_id: impl Into<String>,
218    run_id: impl Into<String>,
219    step_id: impl Into<String>,
220    name: Option<String>,
221) -> AgUiEvent {
222    AgUiEvent::StepFinished(AgUiStepEvent {
223        event_type: AgUiEventType::StepFinished,
224        thread_id: thread_id.into(),
225        run_id: run_id.into(),
226        step_id: step_id.into(),
227        name,
228    })
229}
230
231pub fn text_message_events(
232    thread_id: impl Into<String>,
233    run_id: impl Into<String>,
234    message_id: impl Into<String>,
235    role: impl Into<String>,
236    delta: impl Into<String>,
237) -> Vec<AgUiEvent> {
238    let thread_id = thread_id.into();
239    let run_id = run_id.into();
240    let message_id = message_id.into();
241    let role = role.into();
242    let delta = delta.into();
243
244    vec![
245        AgUiEvent::TextMessageStart(AgUiTextMessageStartEvent {
246            event_type: AgUiEventType::TextMessageStart,
247            thread_id: thread_id.clone(),
248            run_id: run_id.clone(),
249            message_id: message_id.clone(),
250            role,
251        }),
252        AgUiEvent::TextMessageDelta(AgUiTextMessageDeltaEvent {
253            event_type: AgUiEventType::TextMessageDelta,
254            thread_id: thread_id.clone(),
255            run_id: run_id.clone(),
256            message_id: message_id.clone(),
257            delta,
258        }),
259        AgUiEvent::TextMessageEnd(AgUiTextMessageEndEvent {
260            event_type: AgUiEventType::TextMessageEnd,
261            thread_id,
262            run_id,
263            message_id,
264        }),
265    ]
266}
267
268pub fn tool_call_events(
269    thread_id: impl Into<String>,
270    run_id: impl Into<String>,
271    tool_call_id: impl Into<String>,
272    name: impl Into<String>,
273    args: Value,
274    result: Value,
275    is_error: bool,
276) -> Vec<AgUiEvent> {
277    let thread_id = thread_id.into();
278    let run_id = run_id.into();
279    let tool_call_id = tool_call_id.into();
280    let name = name.into();
281
282    vec![
283        AgUiEvent::ToolCallStart(AgUiToolCallStartEvent {
284            event_type: AgUiEventType::ToolCallStart,
285            thread_id: thread_id.clone(),
286            run_id: run_id.clone(),
287            tool_call_id: tool_call_id.clone(),
288            name,
289        }),
290        AgUiEvent::ToolCallArgs(AgUiToolCallArgsEvent {
291            event_type: AgUiEventType::ToolCallArgs,
292            thread_id: thread_id.clone(),
293            run_id: run_id.clone(),
294            tool_call_id: tool_call_id.clone(),
295            args,
296        }),
297        AgUiEvent::ToolCallEnd(AgUiToolCallEndEvent {
298            event_type: AgUiEventType::ToolCallEnd,
299            thread_id: thread_id.clone(),
300            run_id: run_id.clone(),
301            tool_call_id: tool_call_id.clone(),
302        }),
303        AgUiEvent::ToolCallResult(AgUiToolCallResultEvent {
304            event_type: AgUiEventType::ToolCallResult,
305            thread_id,
306            run_id,
307            tool_call_id,
308            result,
309            is_error,
310        }),
311    ]
312}
313
314pub fn state_snapshot_event(
315    thread_id: impl Into<String>,
316    run_id: impl Into<String>,
317    state: Value,
318) -> AgUiEvent {
319    AgUiEvent::StateSnapshot(AgUiStateSnapshotEvent {
320        event_type: AgUiEventType::StateSnapshot,
321        thread_id: thread_id.into(),
322        run_id: run_id.into(),
323        state,
324    })
325}
326
327pub fn state_delta_event(
328    thread_id: impl Into<String>,
329    run_id: impl Into<String>,
330    delta: Value,
331) -> AgUiEvent {
332    AgUiEvent::StateDelta(AgUiStateDeltaEvent {
333        event_type: AgUiEventType::StateDelta,
334        thread_id: thread_id.into(),
335        run_id: run_id.into(),
336        delta,
337    })
338}
339
340pub fn error_event(
341    thread_id: impl Into<String>,
342    run_id: impl Into<String>,
343    message: impl Into<String>,
344    code: Option<String>,
345    recoverable: bool,
346) -> AgUiEvent {
347    AgUiEvent::Error(AgUiErrorEvent {
348        event_type: AgUiEventType::Error,
349        thread_id: thread_id.into(),
350        run_id: run_id.into(),
351        message: message.into(),
352        recoverable,
353        code,
354    })
355}
356
357pub fn surface_to_custom_event(surface: &UiSurface) -> AgUiCustomEvent {
358    AgUiCustomEvent {
359        event_type: AgUiEventType::Custom,
360        name: ADK_UI_SURFACE_EVENT_NAME.to_string(),
361        value: json!({
362            "format": "adk-ui-surface-v1",
363            "surface": surface
364        }),
365        timestamp: None,
366        raw_event: None,
367    }
368}
369
370pub fn surface_to_event_stream(
371    surface: &UiSurface,
372    thread_id: impl Into<String>,
373    run_id: impl Into<String>,
374) -> Vec<AgUiEvent> {
375    let thread_id = thread_id.into();
376    let run_id = run_id.into();
377
378    vec![
379        AgUiEvent::RunStarted(AgUiRunStartedEvent {
380            event_type: AgUiEventType::RunStarted,
381            thread_id: thread_id.clone(),
382            run_id: run_id.clone(),
383        }),
384        AgUiEvent::Custom(surface_to_custom_event(surface)),
385        AgUiEvent::RunFinished(AgUiRunFinishedEvent {
386            event_type: AgUiEventType::RunFinished,
387            thread_id,
388            run_id,
389        }),
390    ]
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use serde_json::json;
397
398    #[test]
399    fn surface_custom_event_is_well_formed() {
400        let surface = UiSurface::new(
401            "main",
402            "catalog",
403            vec![json!({"id":"root","component":{"Column":{"children":[]}}})],
404        );
405        let event = surface_to_custom_event(&surface);
406        assert_eq!(event.event_type, AgUiEventType::Custom);
407        assert_eq!(event.name, ADK_UI_SURFACE_EVENT_NAME);
408        assert!(event.value.get("surface").is_some());
409    }
410
411    #[test]
412    fn event_stream_wraps_custom_event_with_lifecycle() {
413        let surface = UiSurface::new(
414            "main",
415            "catalog",
416            vec![json!({"id":"root","component":{"Column":{"children":[]}}})],
417        );
418        let stream = surface_to_event_stream(&surface, "thread-1", "run-1");
419        assert_eq!(stream.len(), 3);
420
421        let first = serde_json::to_value(&stream[0]).unwrap();
422        let second = serde_json::to_value(&stream[1]).unwrap();
423        let third = serde_json::to_value(&stream[2]).unwrap();
424
425        assert_eq!(first["type"], "RUN_STARTED");
426        assert_eq!(second["type"], "CUSTOM");
427        assert_eq!(third["type"], "RUN_FINISHED");
428    }
429
430    #[test]
431    fn text_message_helpers_emit_start_delta_end() {
432        let events = text_message_events("thread-1", "run-1", "msg-1", "assistant", "hello");
433        assert_eq!(events.len(), 3);
434
435        let start = serde_json::to_value(&events[0]).unwrap();
436        let delta = serde_json::to_value(&events[1]).unwrap();
437        let end = serde_json::to_value(&events[2]).unwrap();
438
439        assert_eq!(start["type"], "TEXT_MESSAGE_START");
440        assert_eq!(delta["type"], "TEXT_MESSAGE_DELTA");
441        assert_eq!(delta["delta"], "hello");
442        assert_eq!(end["type"], "TEXT_MESSAGE_END");
443    }
444
445    #[test]
446    fn tool_call_helpers_emit_lifecycle_and_result() {
447        let events = tool_call_events(
448            "thread-1",
449            "run-1",
450            "tool-1",
451            "lookup_weather",
452            json!({"city": "Nairobi"}),
453            json!({"temp": 23}),
454            false,
455        );
456
457        assert_eq!(events.len(), 4);
458        let start = serde_json::to_value(&events[0]).unwrap();
459        let args = serde_json::to_value(&events[1]).unwrap();
460        let end = serde_json::to_value(&events[2]).unwrap();
461        let result = serde_json::to_value(&events[3]).unwrap();
462
463        assert_eq!(start["type"], "TOOL_CALL_START");
464        assert_eq!(args["type"], "TOOL_CALL_ARGS");
465        assert_eq!(end["type"], "TOOL_CALL_END");
466        assert_eq!(result["type"], "TOOL_CALL_RESULT");
467        assert_eq!(result["isError"], false);
468    }
469
470    #[test]
471    fn state_and_error_helpers_emit_expected_shapes() {
472        let snapshot = state_snapshot_event("thread-1", "run-1", json!({"phase": "planning"}));
473        let delta = state_delta_event("thread-1", "run-1", json!({"phase": "acting"}));
474        let error =
475            error_event("thread-1", "run-1", "tool timeout", Some("TIMEOUT".to_string()), true);
476
477        let snapshot_json = serde_json::to_value(snapshot).unwrap();
478        let delta_json = serde_json::to_value(delta).unwrap();
479        let error_json = serde_json::to_value(error).unwrap();
480
481        assert_eq!(snapshot_json["type"], "STATE_SNAPSHOT");
482        assert_eq!(snapshot_json["state"]["phase"], "planning");
483        assert_eq!(delta_json["type"], "STATE_DELTA");
484        assert_eq!(delta_json["delta"]["phase"], "acting");
485        assert_eq!(error_json["type"], "ERROR");
486        assert_eq!(error_json["code"], "TIMEOUT");
487        assert_eq!(error_json["recoverable"], true);
488    }
489}