Skip to main content

openresponses_rust/
streaming.rs

1use eventsource_stream::Eventsource;
2use futures::{Stream, StreamExt};
3use reqwest::{Client as ReqwestClient, header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue}};
4use thiserror::Error;
5
6use crate::types::{CreateResponseBody, StreamingEvent};
7
8const DEFAULT_BASE_URL: &str = "https://api.openai.com";
9
10#[derive(Error, Debug)]
11pub enum StreamingError {
12    #[error("HTTP request failed: {0}")]
13    HttpError(#[from] reqwest::Error),
14
15    #[error("Event stream error: {0}")]
16    StreamError(String),
17
18    #[error("JSON parsing error: {0}")]
19    JsonError(#[from] serde_json::Error),
20
21    #[error("API error: {message}")]
22    ApiError { message: String },
23}
24
25pub struct StreamingClientBuilder {
26    api_key: String,
27    base_url: Option<String>,
28}
29
30impl StreamingClientBuilder {
31    pub fn new(api_key: impl Into<String>) -> Self {
32        Self {
33            api_key: api_key.into(),
34            base_url: None,
35        }
36    }
37
38    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
39        self.base_url = Some(base_url.into());
40        self
41    }
42
43    pub fn build(self) -> StreamingClient {
44        let mut base_url = self.base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
45        
46        // Remove trailing slash if present
47        if base_url.ends_with('/') {
48            base_url.pop();
49        }
50        
51        // Automatically append /v1 if it's not present in the path
52        if !base_url.ends_with("/v1") {
53            base_url.push_str("/v1");
54        }
55        
56        let mut headers = HeaderMap::new();
57        headers.insert(
58            CONTENT_TYPE,
59            HeaderValue::from_static("application/json"),
60        );
61        headers.insert(
62            ACCEPT,
63            HeaderValue::from_static("text/event-stream"),
64        );
65        
66        let inner = ReqwestClient::builder()
67            .default_headers(headers)
68            .build()
69            .expect("Failed to create HTTP client");
70        
71        StreamingClient {
72            inner,
73            base_url,
74            api_key: self.api_key,
75        }
76    }
77}
78
79#[derive(Clone)]
80pub struct StreamingClient {
81    inner: ReqwestClient,
82    base_url: String,
83    api_key: String,
84}
85
86impl StreamingClient {
87    pub fn new(api_key: impl Into<String>) -> Self {
88        StreamingClientBuilder::new(api_key).build()
89    }
90
91    pub fn builder(api_key: impl Into<String>) -> StreamingClientBuilder {
92        StreamingClientBuilder::new(api_key)
93    }
94    
95    pub fn with_base_url(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
96        StreamingClientBuilder::new(api_key).base_url(base_url).build()
97    }
98    
99    pub async fn stream_response(
100        &self,
101        mut request: CreateResponseBody,
102    ) -> Result<impl Stream<Item = Result<StreamingEvent, StreamingError>>, StreamingError> {
103        request.stream = Some(true);
104        
105        let url = format!("{}/responses", self.base_url);
106        
107        let response = self.inner
108            .post(&url)
109            .header(AUTHORIZATION, format!("Bearer {}", self.api_key))
110            .json(&request)
111            .send()
112            .await?;
113        
114        if !response.status().is_success() {
115            let error_text = response.text().await?;
116            return Err(StreamingError::ApiError { message: error_text });
117        }
118        
119        let stream = response.bytes_stream();
120        let eventsource = stream.eventsource();
121        
122        let event_stream = eventsource.map(|event| {
123            match event {
124                Ok(event) => {
125                    if event.data == "[DONE]" {
126                        Ok(StreamingEvent::Done)
127                    } else {
128                        serde_json::from_str::<StreamingEvent>(&event.data)
129                            .map_err(StreamingError::JsonError)
130                    }
131                }
132                Err(e) => Err(StreamingError::StreamError(e.to_string())),
133            }
134        });
135        
136        Ok(event_stream)
137    }
138    
139    pub async fn stream_response_lines(
140        &self,
141        mut request: CreateResponseBody,
142    ) -> Result<impl Stream<Item = Result<String, StreamingError>>, StreamingError> {
143        request.stream = Some(true);
144        
145        let url = format!("{}/responses", self.base_url);
146        
147        let response = self.inner
148            .post(&url)
149            .header(AUTHORIZATION, format!("Bearer {}", self.api_key))
150            .json(&request)
151            .send()
152            .await?;
153        
154        if !response.status().is_success() {
155            let error_text = response.text().await?;
156            return Err(StreamingError::ApiError { message: error_text });
157        }
158        
159        let stream = response.bytes_stream();
160        let eventsource = stream.eventsource();
161        
162        let line_stream = eventsource.map(|event| {
163            match event {
164                Ok(event) => Ok(event.data),
165                Err(e) => Err(StreamingError::StreamError(e.to_string())),
166            }
167        });
168        
169        Ok(line_stream)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    
177    #[test]
178    fn test_streaming_client_creation() {
179        let client = StreamingClient::new("test-api-key");
180        assert_eq!(client.api_key, "test-api-key");
181        assert_eq!(client.base_url, "https://api.openai.com/v1");
182    }
183    
184    #[test]
185    fn test_streaming_client_with_base_url_normalization() {
186        let client = StreamingClient::with_base_url("test-key", "https://openrouter.ai/api");
187        assert_eq!(client.base_url, "https://openrouter.ai/api/v1");
188    }
189}