openai_sdk_rs/
client.rs

1use std::time::Duration;
2
3use reqwest::{header, Client as HttpClient, StatusCode, Url};
4use serde::de::DeserializeOwned;
5use futures_util::stream::BoxStream;
6use async_stream::try_stream;
7use futures_util::{StreamExt, TryStreamExt};
8
9use crate::error::{ApiError, ApiErrorEnvelope, Error};
10use crate::types::chat::{ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk};
11use crate::types::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
12use crate::types::responses::{ResponsesRequest, ResponsesResponse, ResponseStreamEvent};
13use crate::types::images::{ImageGenerationRequest, ImageGenerationResponse};
14use crate::types::files::{FileListResponse, FileObject, FileDeleteResponse};
15
16const DEFAULT_BASE_URL: &str = "https://api.openai.com";
17
18#[derive(Clone)]
19pub struct OpenAI {
20    http: HttpClient,
21    base_url: Url,
22    api_key: String,
23    org: Option<String>,
24    project: Option<String>,
25    max_retries: u32,
26    retry_base_delay_ms: u64,
27}
28
29impl std::fmt::Debug for OpenAI {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("OpenAI")
32            .field("base_url", &self.base_url)
33            .field("org", &self.org)
34            .field("project", &self.project)
35            .finish_non_exhaustive()
36    }
37}
38
39impl OpenAI {
40    pub fn base_url(&self) -> String {self.base_url.as_str().to_string()}
41
42    pub fn new<S: Into<String>>(api_key: S) -> Result<Self, Error> {
43        Self::builder().api_key(api_key.into()).build()
44    }
45
46    pub fn with_http_client<S: Into<String>>(http: HttpClient, api_key: S) -> Result<Self, Error> {
47        Self::builder().http_client(http).api_key(api_key.into()).build()
48    }
49
50    pub fn from_env() -> Result<Self, Error> {
51        let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| Error::MissingApiKey)?;
52        let mut b = Self::builder().api_key(api_key);
53        if let Ok(o) = std::env::var("OPENAI_ORG_ID") { b = b.org(o); }
54        if let Ok(p) = std::env::var("OPENAI_PROJECT_ID") { b = b.project(p); }
55        if let Ok(u) = std::env::var("OPENAI_BASE_URL") { b = b.base_url(u); }
56        b.build()
57    }
58
59    pub fn builder() -> OpenAIBuilder { OpenAIBuilder::default() }
60
61    pub async fn chat_completion(&self, req: ChatCompletionRequest) -> Result<ChatCompletionResponse, Error> {
62        self.post_json("/v1/chat/completions", &req).await
63    }
64
65    pub async fn embeddings(&self, req: EmbeddingsRequest) -> Result<EmbeddingsResponse, Error> {
66        self.post_json("/v1/embeddings", &req).await
67    }
68
69    pub async fn chat_completion_stream(&self, mut req: ChatCompletionRequest) -> Result<BoxStream<'static, Result<ChatCompletionChunk, Error>>, Error> {
70        req.stream = Some(true);
71        self.post_sse("/v1/chat/completions", &req).await
72    }
73
74    pub async fn responses(&self, req: ResponsesRequest) -> Result<ResponsesResponse, Error> {
75        self.post_json("/v1/responses", &req).await
76    }
77
78    pub async fn responses_stream(&self, mut req: ResponsesRequest) -> Result<BoxStream<'static, Result<ResponseStreamEvent, Error>>, Error> {
79        req.stream = Some(true);
80        self.post_sse("/v1/responses", &req).await
81    }
82
83    pub async fn images_generate(&self, req: ImageGenerationRequest) -> Result<ImageGenerationResponse, Error> {
84        self.post_json("/v1/images/generations", &req).await
85    }
86
87    pub async fn files_list(&self) -> Result<FileListResponse, Error> {
88        self.get_json("/v1/files").await
89    }
90
91    pub async fn files_upload_bytes(&self, filename: &str, bytes: Vec<u8>, purpose: &str) -> Result<FileObject, Error> {
92        let url = self.base_url.join("/v1/files").expect("valid path");
93        let form = reqwest::multipart::Form::new()
94            .text("purpose", purpose.to_string())
95            .part("file", reqwest::multipart::Part::bytes(bytes).file_name(filename.to_string()));
96
97        let mut req = self.http.post(url)
98            .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key))
99            .multipart(form);
100        if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
101        if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
102        let resp = self.execute_with_retry(|| req.try_clone().expect("req clone"), false).await?;
103        let status = resp.status();
104        if status.is_success() { Ok(resp.json::<FileObject>().await?) } else { Self::map_api_error(status, resp).await }
105    }
106
107    pub async fn files_download(&self, file_id: &str) -> Result<Vec<u8>, Error> {
108        let mk = || {
109            let url = self.base_url.join(&format!("/v1/files/{}/content", file_id)).expect("valid path");
110            let mut req = self.http.get(url)
111                .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key));
112            if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
113            if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
114            req
115        };
116        let resp = self.execute_with_retry(mk, false).await?;
117        let status = resp.status();
118        if status.is_success() { Ok(resp.bytes().await?.to_vec()) } else { Self::map_api_error(status, resp).await }
119    }
120
121    pub async fn files_delete(&self, file_id: &str) -> Result<FileDeleteResponse, Error> {
122        let mk = || {
123            let url = self.base_url.join(&format!("/v1/files/{}", file_id)).expect("valid path");
124            let mut req = self.http.delete(url)
125                .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key));
126            if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
127            if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
128            req
129        };
130        let resp = self.execute_with_retry(mk, false).await?;
131        let status = resp.status();
132        if status.is_success() { Ok(resp.json::<FileDeleteResponse>().await?) } else { Self::map_api_error(status, resp).await }
133    }
134
135    pub async fn chat_completion_stream_text(&self, req: ChatCompletionRequest) -> Result<String, Error> {
136        let mut stream = self.chat_completion_stream(req).await?;
137        let mut out = String::new();
138        while let Some(chunk) = stream.try_next().await? {
139            if let Some(text) = chunk.choices.get(0).and_then(|c| c.delta.content.as_deref()) {
140                out.push_str(text);
141            }
142        }
143        Ok(out)
144    }
145
146    pub async fn responses_stream_text(&self, req: ResponsesRequest) -> Result<String, Error> {
147        let mut stream = self.responses_stream(req).await?;
148
149        let mut out = String::new();
150        while let Some(ev) = stream.next().await{
151            let ev = ev?;
152            if let Some(text) = ev.clone().output_text.as_deref() {
153                out.push_str(text);
154            } else if let Some(d) = ev.delta.as_ref().and_then(|v| v.get("output_text")).and_then(|v| v.as_str()) {
155                out.push_str(d);
156            }
157        }
158
159        Ok(out)
160    }
161
162    async fn post_json<TReq: serde::Serialize, TResp: DeserializeOwned>(&self, path: &str, body: &TReq) -> Result<TResp, Error> {
163        let mk = || {
164            let url = self.base_url.join(path).expect("valid path");
165            let mut req = self.http.post(url)
166                .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key))
167                .json(body);
168            if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
169            if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
170            req
171        };
172
173        let resp = self.execute_with_retry(mk, false).await?;
174        let status = resp.status();
175        if status.is_success() {
176            Ok(resp.json::<TResp>().await?)
177        } else {
178            Self::map_api_error(status, resp).await
179        }
180    }
181
182    async fn get_json<TResp: DeserializeOwned>(&self, path: &str) -> Result<TResp, Error> {
183        let mk = || {
184            let url = self.base_url.join(path).expect("valid path");
185            let mut req = self.http.get(url)
186                .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key));
187            if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
188            if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
189            req
190        };
191        let resp = self.execute_with_retry(mk, false).await?;
192        let status = resp.status();
193        if status.is_success() { Ok(resp.json::<TResp>().await?) } else { Self::map_api_error(status, resp).await }
194    }
195
196    async fn post_sse<TReq: serde::Serialize, TEvent: DeserializeOwned + Send + 'static>(&self, path: &str, body: &TReq) -> Result<BoxStream<'static, Result<TEvent, Error>>, Error> {
197        let mk = || {
198            let url = self.base_url.join(path).expect("valid path");
199            let mut req = self.http.post(url)
200                .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key))
201                .header(header::ACCEPT, "text/event-stream")
202                .json(body);
203            if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
204            if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
205            req
206        };
207
208        let resp = self.execute_with_retry(mk, true).await?;
209        let status = resp.status();
210        if !status.is_success() {
211            return Self::map_api_error(status, resp).await;
212        }
213        let res = Self::sse_json_stream::<TEvent>(resp);
214        Ok(res)
215    }
216
217    fn sse_json_stream<T: DeserializeOwned + Send + 'static>(resp: reqwest::Response) -> BoxStream<'static, Result<T, Error>> {
218        let stream = try_stream! {
219            let mut buf: Vec<u8> = Vec::new();
220            let mut byte_stream = resp.bytes_stream();
221            while let Some(chunk) = futures_util::StreamExt::next(&mut byte_stream).await {
222                let chunk = chunk?;
223                buf.extend_from_slice(&chunk);
224
225                let mut start = 0usize;
226                for i in 0..buf.len() {
227                    if buf[i] == b'\n' {
228                        let mut line = &buf[start..i];
229                        start = i + 1;
230                        if !line.is_empty() && line[line.len()-1] == b'\r' {
231                            line = &line[..line.len()-1];
232                        }
233                        if line.is_empty() { continue; }
234                        if line[0] == b':' { continue; }
235                        if let Some(rest) = line.strip_prefix(b"data: ") {
236                            if rest == b"[DONE]" { return; }
237                            let text = String::from_utf8(rest.to_vec()).unwrap_or_default();
238                            let val: T = serde_json::from_str(&text)?;
239                            yield val;
240                        }
241                    }
242                }
243                if start > 0 { buf.drain(0..start); }
244            }
245        };
246        Box::pin(stream)
247    }
248
249    #[cfg(test)]
250    fn sse_extract_data_lines(text: &str) -> Vec<String> {
251        text.lines()
252            .filter_map(|l| {
253                let l = l.trim_end_matches('\r');
254                if l.is_empty() || l.starts_with(':') { return None; }
255                if let Some(rest) = l.strip_prefix("data: ") {
256                    if rest == "[DONE]" { return None; }
257                    return Some(rest.to_string());
258                }
259                None
260            })
261            .collect()
262    }
263
264    async fn map_api_error<TResp>(status: StatusCode, resp: reqwest::Response) -> Result<TResp, Error> {
265        let text = resp.text().await.unwrap_or_default();
266        if let Ok(env) = serde_json::from_str::<ApiErrorEnvelope>(&text) {
267            let mut api: ApiError = env.into();
268            api.status = Some(status.as_u16());
269            Err(Error::Api(api))
270        } else {
271            Err(Error::UnexpectedStatus { status: status.as_u16(), body: text })
272        }
273    }
274}
275
276impl OpenAI {
277    async fn execute_with_retry<F>(&self, mk: F, _sse: bool) -> Result<reqwest::Response, Error>
278    where F: Fn() -> reqwest::RequestBuilder
279    {
280        let mut attempt = 0u32;
281        loop {
282
283            let req = mk();
284            let res = req.send().await;
285            return match res {
286                Ok(resp) => {
287                    let status = resp.status();
288                    if status.is_success() {
289                        return Ok(resp);
290                    }
291                    if self.should_retry_status(status) && attempt < self.max_retries {
292                        let delay = self.retry_delay(attempt, resp.headers().get(header::RETRY_AFTER));
293                        attempt += 1;
294                        tokio::time::sleep(delay).await;
295                        continue;
296                    }
297                    Ok(resp)
298                }
299                Err(e) => {
300                    println!("Request error: {}", e);
301                    if self.is_retryable_error(&e) && attempt < self.max_retries {
302                        let delay = self.retry_delay(attempt, None);
303                        attempt += 1;
304                        tokio::time::sleep(delay).await;
305                        continue;
306                    }
307                    Err(Error::Http(e))
308                }
309            }
310        }
311    }
312
313    fn should_retry_status(&self, status: StatusCode) -> bool {
314        status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()
315    }
316
317    fn retry_delay(&self, attempt: u32, retry_after: Option<&header::HeaderValue>) -> std::time::Duration {
318        if let Some(v) = retry_after {
319            if let Ok(s) = v.to_str() {
320                if let Ok(secs) = s.parse::<u64>() {
321                    return Duration::from_secs(secs);
322                }
323            }
324        }
325        let base = self.retry_base_delay_ms;
326        let backoff = base.saturating_mul(1u64 << attempt.min(8));
327        Duration::from_millis(backoff)
328    }
329
330    fn is_retryable_error(&self, e: &reqwest::Error) -> bool {
331        e.is_timeout() || e.is_connect() || e.is_request()
332    }
333}
334
335#[derive(Default)]
336pub struct OpenAIBuilder {
337    api_key: Option<String>,
338    base_url: Option<String>,
339    org: Option<String>,
340    project: Option<String>,
341    timeout: Option<Duration>,
342    user_agent: Option<String>,
343    max_retries: Option<u32>,
344    retry_base_delay_ms: Option<u64>,
345    http: Option<HttpClient>,
346    proxy: Option<String>,
347}
348
349impl OpenAIBuilder {
350    pub fn api_key(mut self, key: String) -> Self { self.api_key = Some(key); self }
351    pub fn base_url<S: Into<String>>(mut self, url: S) -> Self { self.base_url = Some(url.into()); self }
352    pub fn org<S: Into<String>>(mut self, org: S) -> Self { self.org = Some(org.into()); self }
353    pub fn project<S: Into<String>>(mut self, project: S) -> Self { self.project = Some(project.into()); self }
354    pub fn timeout(mut self, timeout: Duration) -> Self { self.timeout = Some(timeout); self }
355    pub fn user_agent<S: Into<String>>(mut self, ua: S) -> Self { self.user_agent = Some(ua.into()); self }
356    pub fn max_retries(mut self, n: u32) -> Self { self.max_retries = Some(n); self }
357    pub fn retry_base_delay(mut self, dur: Duration) -> Self { self.retry_base_delay_ms = Some(dur.as_millis() as u64); self }
358    pub fn http_client(mut self, client: HttpClient) -> Self { self.http = Some(client); self }
359    pub fn proxy<S: Into<String>>(mut self, url: S) -> Self { self.proxy = Some(url.into()); self }
360
361    pub fn build(self) -> Result<OpenAI, Error> {
362        let api_key = self.api_key.ok_or(Error::MissingApiKey)?;
363        let base_url_str = self.base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
364        let base_url = Url::parse(&base_url_str)?;
365
366        let http = if let Some(custom) = self.http {
367            custom
368        } else {
369            let mut headers = header::HeaderMap::new();
370            headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
371
372            let mut http = HttpClient::builder()
373                .default_headers(headers)
374                .gzip(true)
375                .brotli(true);
376
377            if let Some(t) = self.timeout { http = http.timeout(t); }
378            if let Some(ua) = self.user_agent {
379                http = http.user_agent(ua);
380            } else {
381                http = http.user_agent(format!("openai-sdk-rs/{} (+https://crates.io/crates/openai-sdk)", env!("CARGO_PKG_VERSION")));
382            }
383
384            if let Some(px) = self.proxy {
385                if let Ok(proxy) = reqwest::Proxy::all(px) {
386                    http = http.proxy(proxy);
387                }
388            }
389
390            http.build()?
391        };
392
393        Ok(OpenAI { 
394            http, 
395            base_url, 
396            api_key, 
397            org: self.org, 
398            project: self.project,
399            max_retries: self.max_retries.unwrap_or(3),
400            retry_base_delay_ms: self.retry_base_delay_ms.unwrap_or(200),
401        })
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::OpenAI;
408
409    #[test]
410    fn sse_extracts_data_lines() {
411        let input = "event: message\n:data line as comment\ndata: {\"a\":1}\n\nretry: 5000\ndata: [DONE]\n";
412        let lines = OpenAI::sse_extract_data_lines(input);
413        assert_eq!(lines, vec!["{\"a\":1}".to_string()]);
414    }
415}