Skip to main content

agentik_sdk/http/
streaming.rs

1//! HTTP streaming client for Server-Sent Events (SSE) processing.
2//!
3//! This module handles the HTTP layer for streaming responses from the Anthropic API,
4//! parsing SSE events and converting them into MessageStreamEvent objects.
5
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use futures::Stream;
9use pin_project::pin_project;
10use reqwest::Response;
11use serde_json;
12use tokio::sync::broadcast;
13use tokio_stream::StreamExt;
14
15use crate::types::{MessageStreamEvent, AnthropicError, Result};
16
17/// Configuration for SSE streaming requests.
18#[derive(Debug, Clone)]
19pub struct StreamConfig {
20    /// Buffer size for the event channel
21    pub buffer_size: usize,
22    /// Timeout for individual events (in seconds)
23    pub event_timeout: Option<u64>,
24    /// Whether to retry on connection errors
25    pub retry_on_error: bool,
26    /// Maximum retry attempts
27    pub max_retries: Option<u32>,
28}
29
30impl Default for StreamConfig {
31    fn default() -> Self {
32        Self {
33            buffer_size: 1000,
34            event_timeout: Some(30),
35            retry_on_error: true,
36            max_retries: Some(3),
37        }
38    }
39}
40
41/// HTTP streaming client for processing Server-Sent Events.
42///
43/// This client handles the low-level HTTP streaming and SSE parsing,
44/// converting raw SSE data into structured MessageStreamEvent objects.
45#[pin_project]
46pub struct HttpStreamClient {
47    /// The underlying SSE event stream
48    #[pin]
49    event_stream: Pin<Box<dyn Stream<Item = Result<MessageStreamEvent>> + Send>>,
50    
51    /// Broadcast sender for distributing events
52    event_sender: broadcast::Sender<MessageStreamEvent>,
53    
54    /// Configuration for the stream
55    config: StreamConfig,
56    
57    /// Whether the stream has ended
58    ended: bool,
59    
60    /// Request ID from response headers
61    request_id: Option<String>,
62}
63
64impl HttpStreamClient {
65    /// Create a new HTTP stream client from a response.
66    ///
67    /// This method takes a reqwest Response (which should be from a streaming endpoint)
68    /// and converts it into a stream of MessageStreamEvent objects.
69    pub async fn from_response(response: Response, config: StreamConfig) -> Result<Self> {
70        let request_id = response.headers()
71            .get("request-id")
72            .and_then(|v| v.to_str().ok())
73            .map(|s| s.to_string());
74
75        // Create the event channel
76        let (event_sender, _) = broadcast::channel(config.buffer_size);
77
78        // Convert the HTTP response into an SSE stream
79        let event_stream = Self::create_event_stream(response).await?;
80
81        Ok(Self {
82            event_stream: Box::pin(event_stream),
83            event_sender,
84            config,
85            ended: false,
86            request_id,
87        })
88    }
89
90    /// Create a stream of MessageStreamEvent from an HTTP response.
91    async fn create_event_stream(
92        response: Response,
93    ) -> Result<impl Stream<Item = Result<MessageStreamEvent>>> {
94        // Check that we got a successful response
95        if !response.status().is_success() {
96            let status = response.status();
97            let text = response.text().await.unwrap_or_default();
98            return Err(AnthropicError::from_status(status.as_u16(), text));
99        }
100
101        // Convert the response into a byte stream
102        let byte_stream = response.bytes_stream();
103
104        // Use eventsource-stream to parse SSE events
105        use eventsource_stream::Eventsource;
106        
107        let sse_stream = byte_stream
108            .eventsource()
109            .map(|result| {
110                match result {
111                    Ok(event) => {
112                        // Parse the SSE event data based on event type
113                        match event.event.as_str() {
114                            // Handle Anthropic API format (event type is "message", data contains the event)
115                            "message" | "" => {
116                                match serde_json::from_str::<MessageStreamEvent>(&event.data) {
117                                    Ok(stream_event) => Ok(stream_event),
118                                    Err(e) => Err(AnthropicError::StreamError(
119                                        format!("Failed to parse SSE event: {}", e)
120                                    )),
121                                }
122                            }
123                            // Handle custom gateway format (event type IS the message event type)
124                            "message_start" => {
125                                // Parse the message data - handle both direct and nested formats
126                                match serde_json::from_str::<crate::types::Message>(&event.data) {
127                                    Ok(message) => Ok(MessageStreamEvent::MessageStart { message }),
128                                    Err(_) => {
129                                        // Try parsing as a wrapped message (custom gateway format)
130                                        match serde_json::from_str::<serde_json::Value>(&event.data) {
131                                            Ok(value) => {
132                                                if let Some(message_value) = value.get("message") {
133                                                    match serde_json::from_value::<crate::types::Message>(message_value.clone()) {
134                                                        Ok(message) => Ok(MessageStreamEvent::MessageStart { message }),
135                                                        Err(e) => Err(AnthropicError::StreamError(
136                                                            format!("Failed to parse nested message: {}", e)
137                                                        )),
138                                                    }
139                                                } else {
140                                                    Err(AnthropicError::StreamError(
141                                                        "message_start event missing message field".to_string()
142                                                    ))
143                                                }
144                                            }
145                                            Err(e) => Err(AnthropicError::StreamError(
146                                                format!("Failed to parse message_start as JSON: {}", e)
147                                            )),
148                                        }
149                                    }
150                                }
151                            }
152                            "content_block_start" => {
153                                // Parse as a generic JSON value first to extract index and content_block
154                                match serde_json::from_str::<serde_json::Value>(&event.data) {
155                                    Ok(value) => {
156                                        let index = value["index"].as_u64().unwrap_or(0) as usize;
157                                        match serde_json::from_value::<crate::types::ContentBlock>(value["content_block"].clone()) {
158                                            Ok(content_block) => Ok(MessageStreamEvent::ContentBlockStart { content_block, index }),
159                                            Err(e) => Err(AnthropicError::StreamError(
160                                                format!("Failed to parse content_block in content_block_start: {}", e)
161                                            )),
162                                        }
163                                    }
164                                    Err(e) => Err(AnthropicError::StreamError(
165                                        format!("Failed to parse content_block_start event: {}", e)
166                                    )),
167                                }
168                            }
169                            "content_block_delta" => {
170                                // Parse as a generic JSON value first to extract index and delta
171                                match serde_json::from_str::<serde_json::Value>(&event.data) {
172                                    Ok(value) => {
173                                        let index = value["index"].as_u64().unwrap_or(0) as usize;
174                                        match serde_json::from_value::<crate::types::ContentBlockDelta>(value["delta"].clone()) {
175                                            Ok(delta) => Ok(MessageStreamEvent::ContentBlockDelta { delta, index }),
176                                            Err(e) => Err(AnthropicError::StreamError(
177                                                format!("Failed to parse delta in content_block_delta: {}", e)
178                                            )),
179                                        }
180                                    }
181                                    Err(e) => Err(AnthropicError::StreamError(
182                                        format!("Failed to parse content_block_delta event: {}", e)
183                                    )),
184                                }
185                            }
186                            "content_block_stop" => {
187                                // Parse as a generic JSON value to extract index
188                                match serde_json::from_str::<serde_json::Value>(&event.data) {
189                                    Ok(value) => {
190                                        let index = value["index"].as_u64().unwrap_or(0) as usize;
191                                        Ok(MessageStreamEvent::ContentBlockStop { index })
192                                    }
193                                    Err(e) => Err(AnthropicError::StreamError(
194                                        format!("Failed to parse content_block_stop event: {}", e)
195                                    )),
196                                }
197                            }
198                            "message_delta" => {
199                                // Parse as a generic JSON value to extract delta and usage
200                                match serde_json::from_str::<serde_json::Value>(&event.data) {
201                                    Ok(value) => {
202                                        let delta = serde_json::from_value::<crate::types::MessageDelta>(value["delta"].clone())
203                                            .map_err(|e| AnthropicError::StreamError(format!("Failed to parse delta: {}", e)))?;
204                                        let usage = serde_json::from_value::<crate::types::MessageDeltaUsage>(value["usage"].clone())
205                                            .map_err(|e| AnthropicError::StreamError(format!("Failed to parse usage: {}", e)))?;
206                                        Ok(MessageStreamEvent::MessageDelta { delta, usage })
207                                    }
208                                    Err(e) => Err(AnthropicError::StreamError(
209                                        format!("Failed to parse message_delta event: {}", e)
210                                    )),
211                                }
212                            }
213                            "message_stop" => {
214                                // Message stop doesn't need data parsing
215                                Ok(MessageStreamEvent::MessageStop)
216                            }
217                            // Handle other event types
218                            "ping" => {
219                                // Ignore ping events
220                                Err(AnthropicError::StreamError("ping".to_string()))
221                            }
222                            event_type => {
223                                // Log unknown event types but don't fail
224                                tracing::debug!("Unknown SSE event type: {}", event_type);
225                                Err(AnthropicError::StreamError(
226                                    format!("Unknown event type: {}", event_type)
227                                ))
228                            }
229                        }
230                    }
231                    Err(e) => Err(AnthropicError::StreamError(
232                        format!("SSE stream error: {}", e)
233                    )),
234                }
235            })
236            .filter_map(|result| {
237                match result {
238                    Ok(event) => Some(Ok(event)),
239                    Err(e) if e.to_string().contains("ping") => None, // Filter out ping errors
240                    Err(e) => Some(Err(e)),
241                }
242            });
243
244        Ok(sse_stream)
245    }
246
247    /// Get the request ID from the response headers.
248    pub fn request_id(&self) -> Option<&str> {
249        self.request_id.as_deref()
250    }
251
252    /// Get the stream configuration.
253    pub fn config(&self) -> &StreamConfig {
254        &self.config
255    }
256
257    /// Check if the stream has ended.
258    pub fn ended(&self) -> bool {
259        self.ended
260    }
261
262    /// Get a receiver for the broadcast channel.
263    ///
264    /// This allows multiple consumers to receive the same stream events.
265    pub fn subscribe(&self) -> broadcast::Receiver<MessageStreamEvent> {
266        self.event_sender.subscribe()
267    }
268}
269
270impl Stream for HttpStreamClient {
271    type Item = Result<MessageStreamEvent>;
272
273    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
274        let this = self.project();
275
276        match this.event_stream.poll_next(cx) {
277            Poll::Ready(Some(Ok(event))) => {
278                // Broadcast the event to all subscribers
279                let _ = this.event_sender.send(event.clone());
280                
281                // Check if this is a terminal event
282                if matches!(event, MessageStreamEvent::MessageStop) {
283                    *this.ended = true;
284                }
285                
286                Poll::Ready(Some(Ok(event)))
287            }
288            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
289            Poll::Ready(None) => {
290                *this.ended = true;
291                Poll::Ready(None)
292            }
293            Poll::Pending => Poll::Pending,
294        }
295    }
296}
297
298/// Builder for creating HTTP streaming requests.
299#[derive(Debug, Clone)]
300pub struct StreamRequestBuilder {
301    /// HTTP client for making requests
302    client: reqwest::Client,
303    /// Base URL for the API
304    base_url: String,
305    /// Request headers
306    headers: reqwest::header::HeaderMap,
307    /// Stream configuration
308    config: StreamConfig,
309}
310
311impl StreamRequestBuilder {
312    /// Create a new stream request builder.
313    pub fn new(client: reqwest::Client, base_url: String) -> Self {
314        Self {
315            client,
316            base_url,
317            headers: reqwest::header::HeaderMap::new(),
318            config: StreamConfig::default(),
319        }
320    }
321
322    /// Add a header to the request.
323    pub fn header(mut self, key: &str, value: &str) -> Self {
324        if let (Ok(key), Ok(value)) = (
325            reqwest::header::HeaderName::from_bytes(key.as_bytes()),
326            reqwest::header::HeaderValue::from_str(value),
327        ) {
328            self.headers.insert(key, value);
329        }
330        self
331    }
332
333    /// Set the stream configuration.
334    pub fn config(mut self, config: StreamConfig) -> Self {
335        self.config = config;
336        self
337    }
338
339    /// Make a streaming POST request.
340    pub async fn post_stream<T: serde::Serialize>(
341        self,
342        endpoint: &str,
343        body: &T,
344    ) -> Result<HttpStreamClient> {
345        let url = format!("{}/{}", self.base_url.trim_end_matches('/'), endpoint.trim_start_matches('/'));
346        
347        let mut headers = self.headers;
348        headers.insert(
349            reqwest::header::ACCEPT,
350            reqwest::header::HeaderValue::from_static("text/event-stream"),
351        );
352        headers.insert(
353            reqwest::header::CACHE_CONTROL,
354            reqwest::header::HeaderValue::from_static("no-cache"),
355        );
356
357        let response = self
358            .client
359            .post(&url)
360            .headers(headers)
361            .json(body)
362            .send()
363            .await
364            .map_err(|e| AnthropicError::Connection { message: e.to_string() })?;
365
366        HttpStreamClient::from_response(response, self.config).await
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn test_stream_config_default() {
376        let config = StreamConfig::default();
377        assert_eq!(config.buffer_size, 1000);
378        assert_eq!(config.event_timeout, Some(30));
379        assert_eq!(config.retry_on_error, true);
380        assert_eq!(config.max_retries, Some(3));
381    }
382
383    #[test]
384    fn test_stream_request_builder() {
385        let client = reqwest::Client::new();
386        let builder = StreamRequestBuilder::new(client, "https://api.anthropic.com".to_string())
387            .header("Authorization", "Bearer test-key")
388            .config(StreamConfig {
389                buffer_size: 500,
390                ..Default::default()
391            });
392
393        assert_eq!(builder.base_url, "https://api.anthropic.com");
394        assert_eq!(builder.config.buffer_size, 500);
395        assert!(builder.headers.contains_key("authorization"));
396    }
397
398    #[tokio::test]
399    async fn test_sse_event_parsing() {
400        // Test that we can parse a sample SSE event
401        let event_data = r#"{"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","content":[],"model":"claude-3-5-sonnet-latest","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0,"cache_creation_input_tokens":null,"cache_read_input_tokens":null,"server_tool_use":null,"service_tier":null}}}"#;
402        
403        let parsed: std::result::Result<MessageStreamEvent, _> = serde_json::from_str(event_data);
404        assert!(parsed.is_ok());
405        
406        if let Ok(MessageStreamEvent::MessageStart { message }) = parsed {
407            assert_eq!(message.id, "msg_123");
408            assert_eq!(message.usage.input_tokens, 10);
409        } else {
410            panic!("Expected MessageStart event");
411        }
412    }
413}