1use std::pin::Pin;
7use std::task::{Context, Poll};
8use futures::Stream;
9use pin_project::pin_project;
10use reqwest::Response;
11use serde_json;
12use tokio::sync::broadcast;
13use tokio_stream::StreamExt;
14
15use crate::types::{MessageStreamEvent, AnthropicError, Result};
16
17#[derive(Debug, Clone)]
19pub struct StreamConfig {
20 pub buffer_size: usize,
22 pub event_timeout: Option<u64>,
24 pub retry_on_error: bool,
26 pub max_retries: Option<u32>,
28}
29
30impl Default for StreamConfig {
31 fn default() -> Self {
32 Self {
33 buffer_size: 1000,
34 event_timeout: Some(30),
35 retry_on_error: true,
36 max_retries: Some(3),
37 }
38 }
39}
40
41#[pin_project]
46pub struct HttpStreamClient {
47 #[pin]
49 event_stream: Pin<Box<dyn Stream<Item = Result<MessageStreamEvent>> + Send>>,
50
51 event_sender: broadcast::Sender<MessageStreamEvent>,
53
54 config: StreamConfig,
56
57 ended: bool,
59
60 request_id: Option<String>,
62}
63
64impl HttpStreamClient {
65 pub async fn from_response(response: Response, config: StreamConfig) -> Result<Self> {
70 let request_id = response.headers()
71 .get("request-id")
72 .and_then(|v| v.to_str().ok())
73 .map(|s| s.to_string());
74
75 let (event_sender, _) = broadcast::channel(config.buffer_size);
77
78 let event_stream = Self::create_event_stream(response).await?;
80
81 Ok(Self {
82 event_stream: Box::pin(event_stream),
83 event_sender,
84 config,
85 ended: false,
86 request_id,
87 })
88 }
89
90 async fn create_event_stream(
92 response: Response,
93 ) -> Result<impl Stream<Item = Result<MessageStreamEvent>>> {
94 if !response.status().is_success() {
96 let status = response.status();
97 let text = response.text().await.unwrap_or_default();
98 return Err(AnthropicError::from_status(status.as_u16(), text));
99 }
100
101 let byte_stream = response.bytes_stream();
103
104 use eventsource_stream::Eventsource;
106
107 let sse_stream = byte_stream
108 .eventsource()
109 .map(|result| {
110 match result {
111 Ok(event) => {
112 match event.event.as_str() {
114 "message" | "" => {
116 match serde_json::from_str::<MessageStreamEvent>(&event.data) {
117 Ok(stream_event) => Ok(stream_event),
118 Err(e) => Err(AnthropicError::StreamError(
119 format!("Failed to parse SSE event: {}", e)
120 )),
121 }
122 }
123 "message_start" => {
125 match serde_json::from_str::<crate::types::Message>(&event.data) {
127 Ok(message) => Ok(MessageStreamEvent::MessageStart { message }),
128 Err(_) => {
129 match serde_json::from_str::<serde_json::Value>(&event.data) {
131 Ok(value) => {
132 if let Some(message_value) = value.get("message") {
133 match serde_json::from_value::<crate::types::Message>(message_value.clone()) {
134 Ok(message) => Ok(MessageStreamEvent::MessageStart { message }),
135 Err(e) => Err(AnthropicError::StreamError(
136 format!("Failed to parse nested message: {}", e)
137 )),
138 }
139 } else {
140 Err(AnthropicError::StreamError(
141 "message_start event missing message field".to_string()
142 ))
143 }
144 }
145 Err(e) => Err(AnthropicError::StreamError(
146 format!("Failed to parse message_start as JSON: {}", e)
147 )),
148 }
149 }
150 }
151 }
152 "content_block_start" => {
153 match serde_json::from_str::<serde_json::Value>(&event.data) {
155 Ok(value) => {
156 let index = value["index"].as_u64().unwrap_or(0) as usize;
157 match serde_json::from_value::<crate::types::ContentBlock>(value["content_block"].clone()) {
158 Ok(content_block) => Ok(MessageStreamEvent::ContentBlockStart { content_block, index }),
159 Err(e) => Err(AnthropicError::StreamError(
160 format!("Failed to parse content_block in content_block_start: {}", e)
161 )),
162 }
163 }
164 Err(e) => Err(AnthropicError::StreamError(
165 format!("Failed to parse content_block_start event: {}", e)
166 )),
167 }
168 }
169 "content_block_delta" => {
170 match serde_json::from_str::<serde_json::Value>(&event.data) {
172 Ok(value) => {
173 let index = value["index"].as_u64().unwrap_or(0) as usize;
174 match serde_json::from_value::<crate::types::ContentBlockDelta>(value["delta"].clone()) {
175 Ok(delta) => Ok(MessageStreamEvent::ContentBlockDelta { delta, index }),
176 Err(e) => Err(AnthropicError::StreamError(
177 format!("Failed to parse delta in content_block_delta: {}", e)
178 )),
179 }
180 }
181 Err(e) => Err(AnthropicError::StreamError(
182 format!("Failed to parse content_block_delta event: {}", e)
183 )),
184 }
185 }
186 "content_block_stop" => {
187 match serde_json::from_str::<serde_json::Value>(&event.data) {
189 Ok(value) => {
190 let index = value["index"].as_u64().unwrap_or(0) as usize;
191 Ok(MessageStreamEvent::ContentBlockStop { index })
192 }
193 Err(e) => Err(AnthropicError::StreamError(
194 format!("Failed to parse content_block_stop event: {}", e)
195 )),
196 }
197 }
198 "message_delta" => {
199 match serde_json::from_str::<serde_json::Value>(&event.data) {
201 Ok(value) => {
202 let delta = serde_json::from_value::<crate::types::MessageDelta>(value["delta"].clone())
203 .map_err(|e| AnthropicError::StreamError(format!("Failed to parse delta: {}", e)))?;
204 let usage = serde_json::from_value::<crate::types::MessageDeltaUsage>(value["usage"].clone())
205 .map_err(|e| AnthropicError::StreamError(format!("Failed to parse usage: {}", e)))?;
206 Ok(MessageStreamEvent::MessageDelta { delta, usage })
207 }
208 Err(e) => Err(AnthropicError::StreamError(
209 format!("Failed to parse message_delta event: {}", e)
210 )),
211 }
212 }
213 "message_stop" => {
214 Ok(MessageStreamEvent::MessageStop)
216 }
217 "ping" => {
219 Err(AnthropicError::StreamError("ping".to_string()))
221 }
222 event_type => {
223 tracing::debug!("Unknown SSE event type: {}", event_type);
225 Err(AnthropicError::StreamError(
226 format!("Unknown event type: {}", event_type)
227 ))
228 }
229 }
230 }
231 Err(e) => Err(AnthropicError::StreamError(
232 format!("SSE stream error: {}", e)
233 )),
234 }
235 })
236 .filter_map(|result| {
237 match result {
238 Ok(event) => Some(Ok(event)),
239 Err(e) if e.to_string().contains("ping") => None, Err(e) => Some(Err(e)),
241 }
242 });
243
244 Ok(sse_stream)
245 }
246
247 pub fn request_id(&self) -> Option<&str> {
249 self.request_id.as_deref()
250 }
251
252 pub fn config(&self) -> &StreamConfig {
254 &self.config
255 }
256
257 pub fn ended(&self) -> bool {
259 self.ended
260 }
261
262 pub fn subscribe(&self) -> broadcast::Receiver<MessageStreamEvent> {
266 self.event_sender.subscribe()
267 }
268}
269
270impl Stream for HttpStreamClient {
271 type Item = Result<MessageStreamEvent>;
272
273 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
274 let this = self.project();
275
276 match this.event_stream.poll_next(cx) {
277 Poll::Ready(Some(Ok(event))) => {
278 let _ = this.event_sender.send(event.clone());
280
281 if matches!(event, MessageStreamEvent::MessageStop) {
283 *this.ended = true;
284 }
285
286 Poll::Ready(Some(Ok(event)))
287 }
288 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
289 Poll::Ready(None) => {
290 *this.ended = true;
291 Poll::Ready(None)
292 }
293 Poll::Pending => Poll::Pending,
294 }
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct StreamRequestBuilder {
301 client: reqwest::Client,
303 base_url: String,
305 headers: reqwest::header::HeaderMap,
307 config: StreamConfig,
309}
310
311impl StreamRequestBuilder {
312 pub fn new(client: reqwest::Client, base_url: String) -> Self {
314 Self {
315 client,
316 base_url,
317 headers: reqwest::header::HeaderMap::new(),
318 config: StreamConfig::default(),
319 }
320 }
321
322 pub fn header(mut self, key: &str, value: &str) -> Self {
324 if let (Ok(key), Ok(value)) = (
325 reqwest::header::HeaderName::from_bytes(key.as_bytes()),
326 reqwest::header::HeaderValue::from_str(value),
327 ) {
328 self.headers.insert(key, value);
329 }
330 self
331 }
332
333 pub fn config(mut self, config: StreamConfig) -> Self {
335 self.config = config;
336 self
337 }
338
339 pub async fn post_stream<T: serde::Serialize>(
341 self,
342 endpoint: &str,
343 body: &T,
344 ) -> Result<HttpStreamClient> {
345 let url = format!("{}/{}", self.base_url.trim_end_matches('/'), endpoint.trim_start_matches('/'));
346
347 let mut headers = self.headers;
348 headers.insert(
349 reqwest::header::ACCEPT,
350 reqwest::header::HeaderValue::from_static("text/event-stream"),
351 );
352 headers.insert(
353 reqwest::header::CACHE_CONTROL,
354 reqwest::header::HeaderValue::from_static("no-cache"),
355 );
356
357 let response = self
358 .client
359 .post(&url)
360 .headers(headers)
361 .json(body)
362 .send()
363 .await
364 .map_err(|e| AnthropicError::Connection { message: e.to_string() })?;
365
366 HttpStreamClient::from_response(response, self.config).await
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn test_stream_config_default() {
376 let config = StreamConfig::default();
377 assert_eq!(config.buffer_size, 1000);
378 assert_eq!(config.event_timeout, Some(30));
379 assert_eq!(config.retry_on_error, true);
380 assert_eq!(config.max_retries, Some(3));
381 }
382
383 #[test]
384 fn test_stream_request_builder() {
385 let client = reqwest::Client::new();
386 let builder = StreamRequestBuilder::new(client, "https://api.anthropic.com".to_string())
387 .header("Authorization", "Bearer test-key")
388 .config(StreamConfig {
389 buffer_size: 500,
390 ..Default::default()
391 });
392
393 assert_eq!(builder.base_url, "https://api.anthropic.com");
394 assert_eq!(builder.config.buffer_size, 500);
395 assert!(builder.headers.contains_key("authorization"));
396 }
397
398 #[tokio::test]
399 async fn test_sse_event_parsing() {
400 let event_data = r#"{"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","content":[],"model":"claude-3-5-sonnet-latest","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0,"cache_creation_input_tokens":null,"cache_read_input_tokens":null,"server_tool_use":null,"service_tier":null}}}"#;
402
403 let parsed: std::result::Result<MessageStreamEvent, _> = serde_json::from_str(event_data);
404 assert!(parsed.is_ok());
405
406 if let Ok(MessageStreamEvent::MessageStart { message }) = parsed {
407 assert_eq!(message.id, "msg_123");
408 assert_eq!(message.usage.input_tokens, 10);
409 } else {
410 panic!("Expected MessageStart event");
411 }
412 }
413}