tower_a2a/codec/
sse.rs

1//! Server-Sent Events (SSE) codec for streaming A2A responses
2//!
3//! This codec handles parsing SSE event streams that contain JSON-RPC 2.0 responses.
4
5use eventsource_stream::Eventsource;
6use futures::stream::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::protocol::error::A2AError;
11
12/// SSE streaming event containing A2A protocol data
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SseEvent {
15    /// Event kind (e.g., "artifact-update", "status-update")
16    pub kind: String,
17
18    /// Event payload
19    pub payload: Value,
20
21    /// Whether this is the final event in the stream
22    #[serde(default)]
23    pub final_event: bool,
24}
25
26impl SseEvent {
27    /// Check if this event represents a terminal state
28    pub fn is_terminal(&self) -> bool {
29        if self.final_event {
30            return true;
31        }
32
33        // Check for terminal states in the payload
34        if let Some(state) = self.payload.get("state").and_then(|s| s.as_str()) {
35            matches!(state, "completed" | "failed" | "canceled" | "rejected")
36        } else {
37            false
38        }
39    }
40
41    /// Check if this event represents an error state
42    pub fn is_error(&self) -> bool {
43        if let Some(state) = self.payload.get("state").and_then(|s| s.as_str()) {
44            matches!(state, "failed" | "canceled" | "rejected")
45        } else {
46            false
47        }
48    }
49}
50
51/// SSE codec for parsing streaming responses
52#[derive(Debug, Clone)]
53pub struct SseCodec;
54
55impl SseCodec {
56    /// Create a new SSE codec
57    pub fn new() -> Self {
58        Self
59    }
60
61    /// Parse an SSE byte stream into a stream of events
62    ///
63    /// This method takes a byte stream (typically from reqwest) and parses it
64    /// into individual SSE events containing JSON-RPC responses.
65    pub fn parse_stream<S>(&self, byte_stream: S) -> impl Stream<Item = Result<SseEvent, A2AError>>
66    where
67        S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
68    {
69        byte_stream.eventsource().map(|result| {
70            match result {
71                Ok(event) => {
72                    // Parse the event data as JSON-RPC response
73                    let jsonrpc: Value = serde_json::from_str(&event.data).map_err(|e| {
74                        A2AError::Protocol(format!("Failed to parse SSE event data: {}", e))
75                    })?;
76
77                    // Check for JSON-RPC error
78                    if let Some(error) = jsonrpc.get("error") {
79                        let error_msg = error
80                            .get("message")
81                            .and_then(|m| m.as_str())
82                            .unwrap_or("Unknown error");
83                        return Err(A2AError::Protocol(format!(
84                            "SSE stream error: {}",
85                            error_msg
86                        )));
87                    }
88
89                    // Extract result from JSON-RPC response
90                    let result = jsonrpc.get("result").ok_or_else(|| {
91                        A2AError::Protocol("SSE event missing 'result' field".to_string())
92                    })?;
93
94                    // Determine if this is a final event
95                    let final_event = result
96                        .get("final")
97                        .and_then(|f| f.as_bool())
98                        .unwrap_or(false);
99
100                    // Extract event kind
101                    let kind = result
102                        .get("kind")
103                        .and_then(|k| k.as_str())
104                        .unwrap_or("event")
105                        .to_string();
106
107                    Ok(SseEvent {
108                        kind,
109                        payload: result.clone(),
110                        final_event,
111                    })
112                }
113                Err(e) => Err(A2AError::Transport(format!("SSE stream error: {}", e))),
114            }
115        })
116    }
117}
118
119impl Default for SseCodec {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use futures::StreamExt;
128    use serde_json::json;
129
130    use super::*;
131
132    #[test]
133    fn test_sse_event_is_terminal() {
134        let event = SseEvent {
135            kind: "status-update".to_string(),
136            payload: json!({
137                "state": "completed"
138            }),
139            final_event: false,
140        };
141        assert!(event.is_terminal());
142
143        let event = SseEvent {
144            kind: "artifact-update".to_string(),
145            payload: json!({}),
146            final_event: true,
147        };
148        assert!(event.is_terminal());
149
150        let event = SseEvent {
151            kind: "status-update".to_string(),
152            payload: json!({
153                "state": "running"
154            }),
155            final_event: false,
156        };
157        assert!(!event.is_terminal());
158    }
159
160    #[test]
161    fn test_sse_event_is_error() {
162        let event = SseEvent {
163            kind: "status-update".to_string(),
164            payload: json!({
165                "state": "failed"
166            }),
167            final_event: false,
168        };
169        assert!(event.is_error());
170
171        let event = SseEvent {
172            kind: "status-update".to_string(),
173            payload: json!({
174                "state": "completed"
175            }),
176            final_event: false,
177        };
178        assert!(!event.is_error());
179    }
180
181    #[tokio::test]
182    async fn test_parse_sse_stream() {
183        use futures::pin_mut;
184
185        let codec = SseCodec::new();
186
187        // Create a mock byte stream with SSE events
188        let sse_data = "data: {\"jsonrpc\":\"2.0\",\"result\":{\"kind\":\"status-update\",\"state\":\"running\"},\"id\":\"1\"}\n\n\
189                        data: {\"jsonrpc\":\"2.0\",\"result\":{\"kind\":\"artifact-update\",\"final\":true},\"id\":\"2\"}\n\n";
190
191        let byte_stream = futures::stream::once(async move {
192            Ok::<bytes::Bytes, reqwest::Error>(bytes::Bytes::from(sse_data))
193        });
194
195        let event_stream = codec.parse_stream(byte_stream);
196        pin_mut!(event_stream);
197
198        // First event
199        let event1 = event_stream.next().await.unwrap().unwrap();
200        assert_eq!(event1.kind, "status-update");
201        assert!(!event1.final_event);
202
203        // Second event
204        let event2 = event_stream.next().await.unwrap().unwrap();
205        assert_eq!(event2.kind, "artifact-update");
206        assert!(event2.final_event);
207    }
208
209    #[tokio::test]
210    async fn test_parse_sse_error() {
211        use futures::pin_mut;
212
213        let codec = SseCodec::new();
214
215        let sse_data = "data: {\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32600,\"message\":\"Invalid Request\"},\"id\":\"1\"}\n\n";
216
217        let byte_stream = futures::stream::once(async move {
218            Ok::<bytes::Bytes, reqwest::Error>(bytes::Bytes::from(sse_data))
219        });
220
221        let event_stream = codec.parse_stream(byte_stream);
222        pin_mut!(event_stream);
223
224        let result = event_stream.next().await.unwrap();
225        assert!(result.is_err());
226
227        match result {
228            Err(A2AError::Protocol(msg)) => {
229                assert!(msg.contains("Invalid Request"));
230            }
231            _ => panic!("Expected Protocol error"),
232        }
233    }
234}