1use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7use crate::types::*;
8
9#[derive(Debug, Clone, PartialEq)]
15pub enum StreamResponse {
16 Task(Task),
17 Message(Message),
18 StatusUpdate(TaskStatusUpdateEvent),
19 ArtifactUpdate(TaskArtifactUpdateEvent),
20}
21
22impl Serialize for StreamResponse {
23 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
24 use serde::ser::SerializeMap;
25 let mut map = serializer.serialize_map(Some(1))?;
26 match self {
27 StreamResponse::Task(t) => map.serialize_entry("task", t)?,
28 StreamResponse::Message(m) => map.serialize_entry("message", m)?,
29 StreamResponse::StatusUpdate(s) => map.serialize_entry("statusUpdate", s)?,
30 StreamResponse::ArtifactUpdate(a) => map.serialize_entry("artifactUpdate", a)?,
31 }
32 map.end()
33 }
34}
35
36impl<'de> Deserialize<'de> for StreamResponse {
37 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
38 let raw: HashMap<String, Value> = HashMap::deserialize(deserializer)?;
39 if let Some(v) = raw.get("message") {
40 Ok(StreamResponse::Message(
41 serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)?,
42 ))
43 } else if let Some(v) = raw.get("task") {
44 Ok(StreamResponse::Task(
45 serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)?,
46 ))
47 } else if let Some(v) = raw.get("statusUpdate") {
48 Ok(StreamResponse::StatusUpdate(
49 serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)?,
50 ))
51 } else if let Some(v) = raw.get("artifactUpdate") {
52 Ok(StreamResponse::ArtifactUpdate(
53 serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)?,
54 ))
55 } else {
56 Err(serde::de::Error::custom("unknown StreamResponse variant"))
57 }
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67#[serde(rename_all = "camelCase")]
68pub struct TaskStatusUpdateEvent {
69 pub task_id: TaskId,
70 pub context_id: String,
71 pub status: TaskStatus,
72
73 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub metadata: Option<HashMap<String, Value>>,
75}
76
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83#[serde(rename_all = "camelCase")]
84pub struct TaskArtifactUpdateEvent {
85 pub task_id: TaskId,
86 pub context_id: String,
87 pub artifact: Artifact,
88
89 #[serde(default, skip_serializing_if = "Option::is_none")]
90 pub append: Option<bool>,
91
92 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub last_chunk: Option<bool>,
94
95 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub metadata: Option<HashMap<String, Value>>,
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 #[test]
104 fn test_stream_response_status_update_serde() {
105 let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
106 task_id: "t1".into(),
107 context_id: "c1".into(),
108 status: TaskStatus {
109 state: TaskState::Working,
110 message: None,
111 timestamp: None,
112 },
113 metadata: None,
114 });
115 let json = serde_json::to_string(&event).unwrap();
116 let v: Value = serde_json::from_str(&json).unwrap();
117 assert!(v.get("statusUpdate").is_some());
118
119 let back: StreamResponse = serde_json::from_str(&json).unwrap();
120 assert!(matches!(back, StreamResponse::StatusUpdate(_)));
121 }
122
123 #[test]
124 fn test_stream_response_task_serde() {
125 let task = Task {
126 id: "t1".into(),
127 context_id: "c1".into(),
128 status: TaskStatus {
129 state: TaskState::Completed,
130 message: None,
131 timestamp: None,
132 },
133 artifacts: None,
134 history: None,
135 metadata: None,
136 };
137 let event = StreamResponse::Task(task.clone());
138 let json = serde_json::to_string(&event).unwrap();
139 let v: Value = serde_json::from_str(&json).unwrap();
140 assert!(v.get("task").is_some());
141 let back: StreamResponse = serde_json::from_str(&json).unwrap();
142 assert!(matches!(back, StreamResponse::Task(_)));
143 }
144
145 #[test]
146 fn test_stream_response_message_serde() {
147 let msg = Message::new(Role::Agent, vec![Part::text("hello")]);
148 let event = StreamResponse::Message(msg);
149 let json = serde_json::to_string(&event).unwrap();
150 let v: Value = serde_json::from_str(&json).unwrap();
151 assert!(v.get("message").is_some());
152 let back: StreamResponse = serde_json::from_str(&json).unwrap();
153 assert!(matches!(back, StreamResponse::Message(_)));
154 }
155
156 #[test]
157 fn test_stream_response_artifact_update_serde() {
158 let event = StreamResponse::ArtifactUpdate(TaskArtifactUpdateEvent {
159 task_id: "t1".into(),
160 context_id: "c1".into(),
161 artifact: Artifact {
162 artifact_id: "a1".into(),
163 name: None,
164 description: None,
165 parts: vec![],
166 metadata: None,
167 extensions: None,
168 },
169 append: Some(true),
170 last_chunk: Some(false),
171 metadata: None,
172 });
173 let json = serde_json::to_string(&event).unwrap();
174 let v: Value = serde_json::from_str(&json).unwrap();
175 assert!(v.get("artifactUpdate").is_some());
176 let back: StreamResponse = serde_json::from_str(&json).unwrap();
177 assert!(matches!(back, StreamResponse::ArtifactUpdate(_)));
178 }
179
180 #[test]
181 fn test_stream_response_unknown_variant() {
182 let json = r#"{"unknown": {}}"#;
183 let result = serde_json::from_str::<StreamResponse>(json);
184 assert!(result.is_err());
185 }
186
187 #[test]
188 fn test_task_status_update_event_with_metadata() {
189 let mut meta = HashMap::new();
190 meta.insert("key".to_string(), Value::String("val".to_string()));
191 let event = TaskStatusUpdateEvent {
192 task_id: "t1".into(),
193 context_id: "c1".into(),
194 status: TaskStatus {
195 state: TaskState::Working,
196 message: Some(Message::new(Role::Agent, vec![])),
197 timestamp: None,
198 },
199 metadata: Some(meta),
200 };
201 let json = serde_json::to_string(&event).unwrap();
202 let back: TaskStatusUpdateEvent = serde_json::from_str(&json).unwrap();
203 assert!(back.metadata.is_some());
204 assert_eq!(back.status.state, TaskState::Working);
205 }
206
207 #[test]
208 fn test_task_artifact_update_event_full() {
209 let event = TaskArtifactUpdateEvent {
210 task_id: "t1".into(),
211 context_id: "c1".into(),
212 artifact: Artifact {
213 artifact_id: "a1".into(),
214 name: Some("file.txt".into()),
215 description: Some("A file".into()),
216 parts: vec![Part::text("content")],
217 metadata: None,
218 extensions: None,
219 },
220 append: None,
221 last_chunk: Some(true),
222 metadata: None,
223 };
224 let json = serde_json::to_string(&event).unwrap();
225 let back: TaskArtifactUpdateEvent = serde_json::from_str(&json).unwrap();
226 assert_eq!(back.last_chunk, Some(true));
227 assert!(back.append.is_none());
228 }
229}