Skip to main content

mika_a2a/
streaming.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::types::{Artifact, Message, Task, TaskStatus};
5
6/// Server-Sent Event from an A2A streaming response.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[serde(tag = "kind")]
9pub enum StreamEvent {
10    #[serde(rename = "task")]
11    Task(Task),
12    #[serde(rename = "message")]
13    Message(Message),
14    #[serde(rename = "status-update")]
15    StatusUpdate(TaskStatusUpdateEvent),
16    #[serde(rename = "artifact-update")]
17    ArtifactUpdate(TaskArtifactUpdateEvent),
18}
19
20/// A task status update event sent via SSE.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct TaskStatusUpdateEvent {
24    pub task_id: String,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub context_id: Option<String>,
27    pub status: TaskStatus,
28    #[serde(default)]
29    #[serde(rename = "final")]
30    pub is_final: bool,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub metadata: Option<HashMap<String, serde_json::Value>>,
33}
34
35/// An artifact update event sent via SSE.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(rename_all = "camelCase")]
38pub struct TaskArtifactUpdateEvent {
39    pub task_id: String,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub context_id: Option<String>,
42    pub artifact: Artifact,
43    #[serde(default)]
44    pub append: bool,
45    #[serde(default)]
46    pub last_chunk: bool,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub metadata: Option<HashMap<String, serde_json::Value>>,
49}
50
51#[cfg(test)]
52mod tests {
53    use super::*;
54    use crate::types::{Part, Role, TaskState};
55
56    fn make_status_update() -> StreamEvent {
57        StreamEvent::StatusUpdate(TaskStatusUpdateEvent {
58            task_id: "task-1".to_string(),
59            context_id: Some("ctx-1".to_string()),
60            status: TaskStatus {
61                state: TaskState::Working,
62                message: None,
63                timestamp: Some("2025-01-01T00:00:00Z".to_string()),
64            },
65            is_final: false,
66            metadata: None,
67        })
68    }
69
70    fn make_artifact_update() -> StreamEvent {
71        StreamEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
72            task_id: "task-1".to_string(),
73            context_id: None,
74            artifact: Artifact {
75                artifact_id: "art-1".to_string(),
76                name: Some("output.txt".to_string()),
77                description: None,
78                parts: vec![Part::Text {
79                    text: "result".to_string(),
80                    metadata: None,
81                }],
82                metadata: None,
83                extensions: None,
84            },
85            append: false,
86            last_chunk: true,
87            metadata: None,
88        })
89    }
90
91    fn make_task_event() -> StreamEvent {
92        StreamEvent::Task(Task {
93            id: "task-1".to_string(),
94            context_id: None,
95            status: TaskStatus {
96                state: TaskState::Completed,
97                message: None,
98                timestamp: None,
99            },
100            artifacts: None,
101            history: None,
102            metadata: None,
103            kind: "task".to_string(),
104        })
105    }
106
107    fn make_message_event() -> StreamEvent {
108        StreamEvent::Message(Message {
109            message_id: "msg-1".to_string(),
110            role: Role::Agent,
111            parts: vec![Part::Text {
112                text: "response".to_string(),
113                metadata: None,
114            }],
115            context_id: None,
116            task_id: Some("task-1".to_string()),
117            metadata: None,
118            reference_task_ids: None,
119            extensions: None,
120            kind: "message".to_string(),
121        })
122    }
123
124    #[test]
125    fn status_update_round_trip() {
126        let event = make_status_update();
127        let json = serde_json::to_string(&event).unwrap();
128        let parsed: StreamEvent = serde_json::from_str(&json).unwrap();
129        let value: serde_json::Value = serde_json::from_str(&json).unwrap();
130        assert_eq!(value["kind"], "status-update");
131        if let StreamEvent::StatusUpdate(ref e) = parsed {
132            assert_eq!(e.task_id, "task-1");
133            assert_eq!(e.status.state, TaskState::Working);
134            assert!(!e.is_final);
135        } else {
136            panic!("expected StatusUpdate");
137        }
138    }
139
140    #[test]
141    fn status_update_final_flag() {
142        let event = StreamEvent::StatusUpdate(TaskStatusUpdateEvent {
143            task_id: "t".to_string(),
144            context_id: None,
145            status: TaskStatus {
146                state: TaskState::Completed,
147                message: None,
148                timestamp: None,
149            },
150            is_final: true,
151            metadata: None,
152        });
153        let json = serde_json::to_value(&event).unwrap();
154        assert_eq!(json["final"], true);
155    }
156
157    #[test]
158    fn artifact_update_round_trip() {
159        let event = make_artifact_update();
160        let json = serde_json::to_string(&event).unwrap();
161        let parsed: StreamEvent = serde_json::from_str(&json).unwrap();
162        let value: serde_json::Value = serde_json::from_str(&json).unwrap();
163        assert_eq!(value["kind"], "artifact-update");
164        if let StreamEvent::ArtifactUpdate(ref e) = parsed {
165            assert_eq!(e.task_id, "task-1");
166            assert_eq!(e.artifact.artifact_id, "art-1");
167            assert!(e.last_chunk);
168            assert!(!e.append);
169        } else {
170            panic!("expected ArtifactUpdate");
171        }
172    }
173
174    #[test]
175    fn task_event_serializes_with_kind_tag() {
176        // Task and Message already have a `kind` field in their struct,
177        // which conflicts with StreamEvent's serde tag on deserialization.
178        // We test serialization produces the correct kind tag.
179        let event = make_task_event();
180        let value = serde_json::to_value(&event).unwrap();
181        assert_eq!(value["kind"], "task");
182        assert_eq!(value["id"], "task-1");
183        assert_eq!(value["status"]["state"], "completed");
184    }
185
186    #[test]
187    fn message_event_serializes_with_kind_tag() {
188        let event = make_message_event();
189        let value = serde_json::to_value(&event).unwrap();
190        assert_eq!(value["kind"], "message");
191        assert_eq!(value["messageId"], "msg-1");
192        assert_eq!(value["role"], "agent");
193    }
194
195    #[test]
196    fn all_kind_discriminators() {
197        let events = vec![
198            (make_status_update(), "status-update"),
199            (make_artifact_update(), "artifact-update"),
200            (make_task_event(), "task"),
201            (make_message_event(), "message"),
202        ];
203        for (event, expected_kind) in events {
204            let value = serde_json::to_value(&event).unwrap();
205            assert_eq!(
206                value["kind"].as_str().unwrap(),
207                expected_kind,
208                "kind discriminator for {expected_kind}"
209            );
210        }
211    }
212}