Skip to main content

openresponses_rust/
streaming.rs

1use eventsource_stream::Eventsource;
2use futures::{Stream, StreamExt};
3use reqwest::{Client as ReqwestClient, header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue}};
4use thiserror::Error;
5
6use crate::types::{CreateResponseBody, StreamingEvent};
7
8const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
9
10#[derive(Error, Debug)]
11pub enum StreamingError {
12    #[error("HTTP request failed: {0}")]
13    HttpError(#[from] reqwest::Error),
14
15    #[error("Event stream error: {0}")]
16    StreamError(String),
17
18    #[error("JSON parsing error: {0}")]
19    JsonError(#[from] serde_json::Error),
20
21    #[error("API error: {message}")]
22    ApiError { message: String },
23}
24
25#[derive(Clone)]
26pub struct StreamingClient {
27    inner: ReqwestClient,
28    base_url: String,
29    api_key: String,
30}
31
32impl StreamingClient {
33    pub fn new(api_key: impl Into<String>) -> Self {
34        Self::with_base_url(api_key, DEFAULT_BASE_URL)
35    }
36    
37    pub fn with_base_url(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
38        let api_key = api_key.into();
39        let base_url = base_url.into();
40        
41        let mut headers = HeaderMap::new();
42        headers.insert(
43            CONTENT_TYPE,
44            HeaderValue::from_static("application/json"),
45        );
46        headers.insert(
47            ACCEPT,
48            HeaderValue::from_static("text/event-stream"),
49        );
50        
51        let inner = ReqwestClient::builder()
52            .default_headers(headers)
53            .build()
54            .expect("Failed to create HTTP client");
55        
56        Self { inner, base_url, api_key }
57    }
58    
59    pub async fn stream_response(
60        &self,
61        mut request: CreateResponseBody,
62    ) -> Result<impl Stream<Item = Result<StreamingEvent, StreamingError>>, StreamingError> {
63        request.stream = Some(true);
64        
65        let url = format!("{}/responses", self.base_url);
66        
67        let response = self.inner
68            .post(&url)
69            .header(AUTHORIZATION, format!("Bearer {}", self.api_key))
70            .json(&request)
71            .send()
72            .await?;
73        
74        if !response.status().is_success() {
75            let error_text = response.text().await?;
76            return Err(StreamingError::ApiError { message: error_text });
77        }
78        
79        let stream = response.bytes_stream();
80        let eventsource = stream.eventsource();
81        
82        let event_stream = eventsource.map(|event| {
83            match event {
84                Ok(event) => {
85                    if event.data == "[DONE]" {
86                        Ok(StreamingEvent::Done)
87                    } else {
88                        serde_json::from_str::<StreamingEvent>(&event.data)
89                            .map_err(StreamingError::JsonError)
90                    }
91                }
92                Err(e) => Err(StreamingError::StreamError(e.to_string())),
93            }
94        });
95        
96        Ok(event_stream)
97    }
98    
99    pub async fn stream_response_lines(
100        &self,
101        mut request: CreateResponseBody,
102    ) -> Result<impl Stream<Item = Result<String, StreamingError>>, StreamingError> {
103        request.stream = Some(true);
104        
105        let url = format!("{}/responses", self.base_url);
106        
107        let response = self.inner
108            .post(&url)
109            .header(AUTHORIZATION, format!("Bearer {}", self.api_key))
110            .json(&request)
111            .send()
112            .await?;
113        
114        if !response.status().is_success() {
115            let error_text = response.text().await?;
116            return Err(StreamingError::ApiError { message: error_text });
117        }
118        
119        let stream = response.bytes_stream();
120        let eventsource = stream.eventsource();
121        
122        let line_stream = eventsource.map(|event| {
123            match event {
124                Ok(event) => Ok(event.data),
125                Err(e) => Err(StreamingError::StreamError(e.to_string())),
126            }
127        });
128        
129        Ok(line_stream)
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    
137    #[test]
138    fn test_streaming_client_creation() {
139        let client = StreamingClient::new("test-api-key");
140        assert_eq!(client.api_key, "test-api-key");
141        assert_eq!(client.base_url, DEFAULT_BASE_URL);
142    }
143    
144    #[test]
145    fn test_streaming_client_with_base_url() {
146        let client = StreamingClient::with_base_url("test-key", "https://custom.api.com");
147        assert_eq!(client.api_key, "test-key");
148        assert_eq!(client.base_url, "https://custom.api.com");
149    }
150}