1use std::time::Duration;
2
3use reqwest::{header, Client as HttpClient, StatusCode, Url};
4use serde::de::DeserializeOwned;
5use futures_util::stream::BoxStream;
6use async_stream::try_stream;
7use futures_util::{StreamExt, TryStreamExt};
8
9use crate::error::{ApiError, ApiErrorEnvelope, Error};
10use crate::types::chat::{ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk};
11use crate::types::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
12use crate::types::responses::{ResponsesRequest, ResponsesResponse, ResponseStreamEvent};
13use crate::types::images::{ImageGenerationRequest, ImageGenerationResponse};
14use crate::types::files::{FileListResponse, FileObject, FileDeleteResponse};
15
16const DEFAULT_BASE_URL: &str = "https://api.openai.com";
17
18#[derive(Clone)]
19pub struct OpenAI {
20 http: HttpClient,
21 base_url: Url,
22 api_key: String,
23 org: Option<String>,
24 project: Option<String>,
25 max_retries: u32,
26 retry_base_delay_ms: u64,
27}
28
29impl std::fmt::Debug for OpenAI {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("OpenAI")
32 .field("base_url", &self.base_url)
33 .field("org", &self.org)
34 .field("project", &self.project)
35 .finish_non_exhaustive()
36 }
37}
38
39impl OpenAI {
40 pub fn base_url(&self) -> String {self.base_url.as_str().to_string()}
41
42 pub fn new<S: Into<String>>(api_key: S) -> Result<Self, Error> {
43 Self::builder().api_key(api_key.into()).build()
44 }
45
46 pub fn with_http_client<S: Into<String>>(http: HttpClient, api_key: S) -> Result<Self, Error> {
47 Self::builder().http_client(http).api_key(api_key.into()).build()
48 }
49
50 pub fn from_env() -> Result<Self, Error> {
51 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| Error::MissingApiKey)?;
52 let mut b = Self::builder().api_key(api_key);
53 if let Ok(o) = std::env::var("OPENAI_ORG_ID") { b = b.org(o); }
54 if let Ok(p) = std::env::var("OPENAI_PROJECT_ID") { b = b.project(p); }
55 if let Ok(u) = std::env::var("OPENAI_BASE_URL") { b = b.base_url(u); }
56 b.build()
57 }
58
59 pub fn builder() -> OpenAIBuilder { OpenAIBuilder::default() }
60
61 pub async fn chat_completion(&self, req: ChatCompletionRequest) -> Result<ChatCompletionResponse, Error> {
62 self.post_json("/v1/chat/completions", &req).await
63 }
64
65 pub async fn embeddings(&self, req: EmbeddingsRequest) -> Result<EmbeddingsResponse, Error> {
66 self.post_json("/v1/embeddings", &req).await
67 }
68
69 pub async fn chat_completion_stream(&self, mut req: ChatCompletionRequest) -> Result<BoxStream<'static, Result<ChatCompletionChunk, Error>>, Error> {
70 req.stream = Some(true);
71 self.post_sse("/v1/chat/completions", &req).await
72 }
73
74 pub async fn responses(&self, req: ResponsesRequest) -> Result<ResponsesResponse, Error> {
75 self.post_json("/v1/responses", &req).await
76 }
77
78 pub async fn responses_stream(&self, mut req: ResponsesRequest) -> Result<BoxStream<'static, Result<ResponseStreamEvent, Error>>, Error> {
79 req.stream = Some(true);
80 self.post_sse("/v1/responses", &req).await
81 }
82
83 pub async fn images_generate(&self, req: ImageGenerationRequest) -> Result<ImageGenerationResponse, Error> {
84 self.post_json("/v1/images/generations", &req).await
85 }
86
87 pub async fn files_list(&self) -> Result<FileListResponse, Error> {
88 self.get_json("/v1/files").await
89 }
90
91 pub async fn files_upload_bytes(&self, filename: &str, bytes: Vec<u8>, purpose: &str) -> Result<FileObject, Error> {
92 let url = self.base_url.join("/v1/files").expect("valid path");
93 let form = reqwest::multipart::Form::new()
94 .text("purpose", purpose.to_string())
95 .part("file", reqwest::multipart::Part::bytes(bytes).file_name(filename.to_string()));
96
97 let mut req = self.http.post(url)
98 .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key))
99 .multipart(form);
100 if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
101 if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
102 let resp = self.execute_with_retry(|| req.try_clone().expect("req clone"), false).await?;
103 let status = resp.status();
104 if status.is_success() { Ok(resp.json::<FileObject>().await?) } else { Self::map_api_error(status, resp).await }
105 }
106
107 pub async fn files_download(&self, file_id: &str) -> Result<Vec<u8>, Error> {
108 let mk = || {
109 let url = self.base_url.join(&format!("/v1/files/{}/content", file_id)).expect("valid path");
110 let mut req = self.http.get(url)
111 .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key));
112 if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
113 if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
114 req
115 };
116 let resp = self.execute_with_retry(mk, false).await?;
117 let status = resp.status();
118 if status.is_success() { Ok(resp.bytes().await?.to_vec()) } else { Self::map_api_error(status, resp).await }
119 }
120
121 pub async fn files_delete(&self, file_id: &str) -> Result<FileDeleteResponse, Error> {
122 let mk = || {
123 let url = self.base_url.join(&format!("/v1/files/{}", file_id)).expect("valid path");
124 let mut req = self.http.delete(url)
125 .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key));
126 if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
127 if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
128 req
129 };
130 let resp = self.execute_with_retry(mk, false).await?;
131 let status = resp.status();
132 if status.is_success() { Ok(resp.json::<FileDeleteResponse>().await?) } else { Self::map_api_error(status, resp).await }
133 }
134
135 pub async fn chat_completion_stream_text(&self, req: ChatCompletionRequest) -> Result<String, Error> {
136 let mut stream = self.chat_completion_stream(req).await?;
137 let mut out = String::new();
138 while let Some(chunk) = stream.try_next().await? {
139 if let Some(text) = chunk.choices.get(0).and_then(|c| c.delta.content.as_deref()) {
140 out.push_str(text);
141 }
142 }
143 Ok(out)
144 }
145
146 pub async fn responses_stream_text(&self, req: ResponsesRequest) -> Result<String, Error> {
147 let mut stream = self.responses_stream(req).await?;
148
149 let mut out = String::new();
150 while let Some(ev) = stream.next().await{
151 let ev = ev?;
152 if let Some(text) = ev.clone().output_text.as_deref() {
153 out.push_str(text);
154 } else if let Some(d) = ev.delta.as_ref().and_then(|v| v.get("output_text")).and_then(|v| v.as_str()) {
155 out.push_str(d);
156 }
157 }
158
159 Ok(out)
160 }
161
162 async fn post_json<TReq: serde::Serialize, TResp: DeserializeOwned>(&self, path: &str, body: &TReq) -> Result<TResp, Error> {
163 let mk = || {
164 let url = self.base_url.join(path).expect("valid path");
165 let mut req = self.http.post(url)
166 .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key))
167 .json(body);
168 if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
169 if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
170 req
171 };
172
173 let resp = self.execute_with_retry(mk, false).await?;
174 let status = resp.status();
175 if status.is_success() {
176 Ok(resp.json::<TResp>().await?)
177 } else {
178 Self::map_api_error(status, resp).await
179 }
180 }
181
182 async fn get_json<TResp: DeserializeOwned>(&self, path: &str) -> Result<TResp, Error> {
183 let mk = || {
184 let url = self.base_url.join(path).expect("valid path");
185 let mut req = self.http.get(url)
186 .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key));
187 if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
188 if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
189 req
190 };
191 let resp = self.execute_with_retry(mk, false).await?;
192 let status = resp.status();
193 if status.is_success() { Ok(resp.json::<TResp>().await?) } else { Self::map_api_error(status, resp).await }
194 }
195
196 async fn post_sse<TReq: serde::Serialize, TEvent: DeserializeOwned + Send + 'static>(&self, path: &str, body: &TReq) -> Result<BoxStream<'static, Result<TEvent, Error>>, Error> {
197 let mk = || {
198 let url = self.base_url.join(path).expect("valid path");
199 let mut req = self.http.post(url)
200 .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key))
201 .header(header::ACCEPT, "text/event-stream")
202 .json(body);
203 if let Some(org) = &self.org { req = req.header("OpenAI-Organization", org); }
204 if let Some(project) = &self.project { req = req.header("OpenAI-Project", project); }
205 req
206 };
207
208 let resp = self.execute_with_retry(mk, true).await?;
209 let status = resp.status();
210 if !status.is_success() {
211 return Self::map_api_error(status, resp).await;
212 }
213 let res = Self::sse_json_stream::<TEvent>(resp);
214 Ok(res)
215 }
216
217 fn sse_json_stream<T: DeserializeOwned + Send + 'static>(resp: reqwest::Response) -> BoxStream<'static, Result<T, Error>> {
218 let stream = try_stream! {
219 let mut buf: Vec<u8> = Vec::new();
220 let mut byte_stream = resp.bytes_stream();
221 while let Some(chunk) = futures_util::StreamExt::next(&mut byte_stream).await {
222 let chunk = chunk?;
223 buf.extend_from_slice(&chunk);
224
225 let mut start = 0usize;
226 for i in 0..buf.len() {
227 if buf[i] == b'\n' {
228 let mut line = &buf[start..i];
229 start = i + 1;
230 if !line.is_empty() && line[line.len()-1] == b'\r' {
231 line = &line[..line.len()-1];
232 }
233 if line.is_empty() { continue; }
234 if line[0] == b':' { continue; }
235 if let Some(rest) = line.strip_prefix(b"data: ") {
236 if rest == b"[DONE]" { return; }
237 let text = String::from_utf8(rest.to_vec()).unwrap_or_default();
238 let val: T = serde_json::from_str(&text)?;
239 yield val;
240 }
241 }
242 }
243 if start > 0 { buf.drain(0..start); }
244 }
245 };
246 Box::pin(stream)
247 }
248
249 #[cfg(test)]
250 fn sse_extract_data_lines(text: &str) -> Vec<String> {
251 text.lines()
252 .filter_map(|l| {
253 let l = l.trim_end_matches('\r');
254 if l.is_empty() || l.starts_with(':') { return None; }
255 if let Some(rest) = l.strip_prefix("data: ") {
256 if rest == "[DONE]" { return None; }
257 return Some(rest.to_string());
258 }
259 None
260 })
261 .collect()
262 }
263
264 async fn map_api_error<TResp>(status: StatusCode, resp: reqwest::Response) -> Result<TResp, Error> {
265 let text = resp.text().await.unwrap_or_default();
266 if let Ok(env) = serde_json::from_str::<ApiErrorEnvelope>(&text) {
267 let mut api: ApiError = env.into();
268 api.status = Some(status.as_u16());
269 Err(Error::Api(api))
270 } else {
271 Err(Error::UnexpectedStatus { status: status.as_u16(), body: text })
272 }
273 }
274}
275
276impl OpenAI {
277 async fn execute_with_retry<F>(&self, mk: F, _sse: bool) -> Result<reqwest::Response, Error>
278 where F: Fn() -> reqwest::RequestBuilder
279 {
280 let mut attempt = 0u32;
281 loop {
282
283 let req = mk();
284 let res = req.send().await;
285 return match res {
286 Ok(resp) => {
287 let status = resp.status();
288 if status.is_success() {
289 return Ok(resp);
290 }
291 if self.should_retry_status(status) && attempt < self.max_retries {
292 let delay = self.retry_delay(attempt, resp.headers().get(header::RETRY_AFTER));
293 attempt += 1;
294 tokio::time::sleep(delay).await;
295 continue;
296 }
297 Ok(resp)
298 }
299 Err(e) => {
300 println!("Request error: {}", e);
301 if self.is_retryable_error(&e) && attempt < self.max_retries {
302 let delay = self.retry_delay(attempt, None);
303 attempt += 1;
304 tokio::time::sleep(delay).await;
305 continue;
306 }
307 Err(Error::Http(e))
308 }
309 }
310 }
311 }
312
313 fn should_retry_status(&self, status: StatusCode) -> bool {
314 status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()
315 }
316
317 fn retry_delay(&self, attempt: u32, retry_after: Option<&header::HeaderValue>) -> std::time::Duration {
318 if let Some(v) = retry_after {
319 if let Ok(s) = v.to_str() {
320 if let Ok(secs) = s.parse::<u64>() {
321 return Duration::from_secs(secs);
322 }
323 }
324 }
325 let base = self.retry_base_delay_ms;
326 let backoff = base.saturating_mul(1u64 << attempt.min(8));
327 Duration::from_millis(backoff)
328 }
329
330 fn is_retryable_error(&self, e: &reqwest::Error) -> bool {
331 e.is_timeout() || e.is_connect() || e.is_request()
332 }
333}
334
335#[derive(Default)]
336pub struct OpenAIBuilder {
337 api_key: Option<String>,
338 base_url: Option<String>,
339 org: Option<String>,
340 project: Option<String>,
341 timeout: Option<Duration>,
342 user_agent: Option<String>,
343 max_retries: Option<u32>,
344 retry_base_delay_ms: Option<u64>,
345 http: Option<HttpClient>,
346 proxy: Option<String>,
347}
348
349impl OpenAIBuilder {
350 pub fn api_key(mut self, key: String) -> Self { self.api_key = Some(key); self }
351 pub fn base_url<S: Into<String>>(mut self, url: S) -> Self { self.base_url = Some(url.into()); self }
352 pub fn org<S: Into<String>>(mut self, org: S) -> Self { self.org = Some(org.into()); self }
353 pub fn project<S: Into<String>>(mut self, project: S) -> Self { self.project = Some(project.into()); self }
354 pub fn timeout(mut self, timeout: Duration) -> Self { self.timeout = Some(timeout); self }
355 pub fn user_agent<S: Into<String>>(mut self, ua: S) -> Self { self.user_agent = Some(ua.into()); self }
356 pub fn max_retries(mut self, n: u32) -> Self { self.max_retries = Some(n); self }
357 pub fn retry_base_delay(mut self, dur: Duration) -> Self { self.retry_base_delay_ms = Some(dur.as_millis() as u64); self }
358 pub fn http_client(mut self, client: HttpClient) -> Self { self.http = Some(client); self }
359 pub fn proxy<S: Into<String>>(mut self, url: S) -> Self { self.proxy = Some(url.into()); self }
360
361 pub fn build(self) -> Result<OpenAI, Error> {
362 let api_key = self.api_key.ok_or(Error::MissingApiKey)?;
363 let base_url_str = self.base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
364 let base_url = Url::parse(&base_url_str)?;
365
366 let http = if let Some(custom) = self.http {
367 custom
368 } else {
369 let mut headers = header::HeaderMap::new();
370 headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
371
372 let mut http = HttpClient::builder()
373 .default_headers(headers)
374 .gzip(true)
375 .brotli(true);
376
377 if let Some(t) = self.timeout { http = http.timeout(t); }
378 if let Some(ua) = self.user_agent {
379 http = http.user_agent(ua);
380 } else {
381 http = http.user_agent(format!("openai-sdk-rs/{} (+https://crates.io/crates/openai-sdk)", env!("CARGO_PKG_VERSION")));
382 }
383
384 if let Some(px) = self.proxy {
385 if let Ok(proxy) = reqwest::Proxy::all(px) {
386 http = http.proxy(proxy);
387 }
388 }
389
390 http.build()?
391 };
392
393 Ok(OpenAI {
394 http,
395 base_url,
396 api_key,
397 org: self.org,
398 project: self.project,
399 max_retries: self.max_retries.unwrap_or(3),
400 retry_base_delay_ms: self.retry_base_delay_ms.unwrap_or(200),
401 })
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::OpenAI;
408
409 #[test]
410 fn sse_extracts_data_lines() {
411 let input = "event: message\n:data line as comment\ndata: {\"a\":1}\n\nretry: 5000\ndata: [DONE]\n";
412 let lines = OpenAI::sse_extract_data_lines(input);
413 assert_eq!(lines, vec!["{\"a\":1}".to_string()]);
414 }
415}