Skip to main content

a2a/
event.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7use crate::types::*;
8
9// ---------------------------------------------------------------------------
10// StreamResponse (field-presence union — 4 variants)
11// ---------------------------------------------------------------------------
12
13/// A streaming event. Uses field-presence serialization for wire compatibility.
14#[derive(Debug, Clone, PartialEq)]
15pub enum StreamResponse {
16    Task(Task),
17    Message(Message),
18    StatusUpdate(TaskStatusUpdateEvent),
19    ArtifactUpdate(TaskArtifactUpdateEvent),
20}
21
22impl Serialize for StreamResponse {
23    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
24        use serde::ser::SerializeMap;
25        let mut map = serializer.serialize_map(Some(1))?;
26        match self {
27            StreamResponse::Task(t) => map.serialize_entry("task", t)?,
28            StreamResponse::Message(m) => map.serialize_entry("message", m)?,
29            StreamResponse::StatusUpdate(s) => map.serialize_entry("statusUpdate", s)?,
30            StreamResponse::ArtifactUpdate(a) => map.serialize_entry("artifactUpdate", a)?,
31        }
32        map.end()
33    }
34}
35
36impl<'de> Deserialize<'de> for StreamResponse {
37    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
38        let raw: HashMap<String, Value> = HashMap::deserialize(deserializer)?;
39        if let Some(v) = raw.get("message") {
40            Ok(StreamResponse::Message(
41                serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)?,
42            ))
43        } else if let Some(v) = raw.get("task") {
44            Ok(StreamResponse::Task(
45                serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)?,
46            ))
47        } else if let Some(v) = raw.get("statusUpdate") {
48            Ok(StreamResponse::StatusUpdate(
49                serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)?,
50            ))
51        } else if let Some(v) = raw.get("artifactUpdate") {
52            Ok(StreamResponse::ArtifactUpdate(
53                serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)?,
54            ))
55        } else {
56            Err(serde::de::Error::custom("unknown StreamResponse variant"))
57        }
58    }
59}
60
61// ---------------------------------------------------------------------------
62// TaskStatusUpdateEvent
63// ---------------------------------------------------------------------------
64
65/// Event: a task's status has changed.
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67#[serde(rename_all = "camelCase")]
68pub struct TaskStatusUpdateEvent {
69    pub task_id: TaskId,
70    pub context_id: String,
71    pub status: TaskStatus,
72
73    #[serde(default, skip_serializing_if = "Option::is_none")]
74    pub metadata: Option<HashMap<String, Value>>,
75}
76
77// ---------------------------------------------------------------------------
78// TaskArtifactUpdateEvent
79// ---------------------------------------------------------------------------
80
81/// Event: an artifact has been generated or updated.
82#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
83#[serde(rename_all = "camelCase")]
84pub struct TaskArtifactUpdateEvent {
85    pub task_id: TaskId,
86    pub context_id: String,
87    pub artifact: Artifact,
88
89    #[serde(default, skip_serializing_if = "Option::is_none")]
90    pub append: Option<bool>,
91
92    #[serde(default, skip_serializing_if = "Option::is_none")]
93    pub last_chunk: Option<bool>,
94
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub metadata: Option<HashMap<String, Value>>,
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    #[test]
104    fn test_stream_response_status_update_serde() {
105        let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
106            task_id: "t1".into(),
107            context_id: "c1".into(),
108            status: TaskStatus {
109                state: TaskState::Working,
110                message: None,
111                timestamp: None,
112            },
113            metadata: None,
114        });
115        let json = serde_json::to_string(&event).unwrap();
116        let v: Value = serde_json::from_str(&json).unwrap();
117        assert!(v.get("statusUpdate").is_some());
118
119        let back: StreamResponse = serde_json::from_str(&json).unwrap();
120        assert!(matches!(back, StreamResponse::StatusUpdate(_)));
121    }
122
123    #[test]
124    fn test_stream_response_task_serde() {
125        let task = Task {
126            id: "t1".into(),
127            context_id: "c1".into(),
128            status: TaskStatus {
129                state: TaskState::Completed,
130                message: None,
131                timestamp: None,
132            },
133            artifacts: None,
134            history: None,
135            metadata: None,
136        };
137        let event = StreamResponse::Task(task.clone());
138        let json = serde_json::to_string(&event).unwrap();
139        let v: Value = serde_json::from_str(&json).unwrap();
140        assert!(v.get("task").is_some());
141        let back: StreamResponse = serde_json::from_str(&json).unwrap();
142        assert!(matches!(back, StreamResponse::Task(_)));
143    }
144
145    #[test]
146    fn test_stream_response_message_serde() {
147        let msg = Message::new(Role::Agent, vec![Part::text("hello")]);
148        let event = StreamResponse::Message(msg);
149        let json = serde_json::to_string(&event).unwrap();
150        let v: Value = serde_json::from_str(&json).unwrap();
151        assert!(v.get("message").is_some());
152        let back: StreamResponse = serde_json::from_str(&json).unwrap();
153        assert!(matches!(back, StreamResponse::Message(_)));
154    }
155
156    #[test]
157    fn test_stream_response_artifact_update_serde() {
158        let event = StreamResponse::ArtifactUpdate(TaskArtifactUpdateEvent {
159            task_id: "t1".into(),
160            context_id: "c1".into(),
161            artifact: Artifact {
162                artifact_id: "a1".into(),
163                name: None,
164                description: None,
165                parts: vec![],
166                metadata: None,
167                extensions: None,
168            },
169            append: Some(true),
170            last_chunk: Some(false),
171            metadata: None,
172        });
173        let json = serde_json::to_string(&event).unwrap();
174        let v: Value = serde_json::from_str(&json).unwrap();
175        assert!(v.get("artifactUpdate").is_some());
176        let back: StreamResponse = serde_json::from_str(&json).unwrap();
177        assert!(matches!(back, StreamResponse::ArtifactUpdate(_)));
178    }
179
180    #[test]
181    fn test_stream_response_unknown_variant() {
182        let json = r#"{"unknown": {}}"#;
183        let result = serde_json::from_str::<StreamResponse>(json);
184        assert!(result.is_err());
185    }
186
187    #[test]
188    fn test_task_status_update_event_with_metadata() {
189        let mut meta = HashMap::new();
190        meta.insert("key".to_string(), Value::String("val".to_string()));
191        let event = TaskStatusUpdateEvent {
192            task_id: "t1".into(),
193            context_id: "c1".into(),
194            status: TaskStatus {
195                state: TaskState::Working,
196                message: Some(Message::new(Role::Agent, vec![])),
197                timestamp: None,
198            },
199            metadata: Some(meta),
200        };
201        let json = serde_json::to_string(&event).unwrap();
202        let back: TaskStatusUpdateEvent = serde_json::from_str(&json).unwrap();
203        assert!(back.metadata.is_some());
204        assert_eq!(back.status.state, TaskState::Working);
205    }
206
207    #[test]
208    fn test_task_artifact_update_event_full() {
209        let event = TaskArtifactUpdateEvent {
210            task_id: "t1".into(),
211            context_id: "c1".into(),
212            artifact: Artifact {
213                artifact_id: "a1".into(),
214                name: Some("file.txt".into()),
215                description: Some("A file".into()),
216                parts: vec![Part::text("content")],
217                metadata: None,
218                extensions: None,
219            },
220            append: None,
221            last_chunk: Some(true),
222            metadata: None,
223        };
224        let json = serde_json::to_string(&event).unwrap();
225        let back: TaskArtifactUpdateEvent = serde_json::from_str(&json).unwrap();
226        assert_eq!(back.last_chunk, Some(true));
227        assert!(back.append.is_none());
228    }
229}