Skip to main content

openai_oxide/
client.rs

1// OpenAI client
2
3use std::time::Duration;
4
5use crate::config::ClientConfig;
6use crate::error::{ErrorResponse, OpenAIError};
7use crate::resources::audio::Audio;
8use crate::resources::batches::Batches;
9use crate::resources::beta::assistants::Assistants;
10use crate::resources::beta::runs::Runs;
11use crate::resources::beta::threads::Threads;
12use crate::resources::beta::vector_stores::VectorStores;
13use crate::resources::chat::Chat;
14use crate::resources::embeddings::Embeddings;
15use crate::resources::files::Files;
16use crate::resources::fine_tuning::FineTuning;
17use crate::resources::images::Images;
18use crate::resources::models::Models;
19use crate::resources::moderations::Moderations;
20use crate::resources::responses::Responses;
21use crate::resources::uploads::Uploads;
22
23/// Status codes that trigger a retry.
24const RETRYABLE_STATUS_CODES: [u16; 4] = [429, 500, 502, 503];
25
26/// The main OpenAI client.
27#[derive(Debug, Clone)]
28pub struct OpenAI {
29    pub(crate) http: reqwest::Client,
30    pub(crate) config: ClientConfig,
31}
32
33impl OpenAI {
34    /// Create a new client with the given API key.
35    pub fn new(api_key: impl Into<String>) -> Self {
36        Self::with_config(ClientConfig::new(api_key))
37    }
38
39    /// Create a client from a full config.
40    pub fn with_config(config: ClientConfig) -> Self {
41        let http = reqwest::Client::builder()
42            .timeout(Duration::from_secs(config.timeout_secs))
43            .build()
44            .expect("failed to build HTTP client");
45        Self { http, config }
46    }
47
48    /// Create a client using the `OPENAI_API_KEY` environment variable.
49    pub fn from_env() -> Result<Self, OpenAIError> {
50        Ok(Self::with_config(ClientConfig::from_env()?))
51    }
52
53    /// Access the Batches resource.
54    pub fn batches(&self) -> Batches<'_> {
55        Batches::new(self)
56    }
57
58    /// Access the Uploads resource.
59    pub fn uploads(&self) -> Uploads<'_> {
60        Uploads::new(self)
61    }
62
63    /// Access the Beta resources (Assistants, Threads, Runs, Vector Stores).
64    pub fn beta(&self) -> Beta<'_> {
65        Beta { client: self }
66    }
67
68    /// Access the Audio resource.
69    pub fn audio(&self) -> Audio<'_> {
70        Audio::new(self)
71    }
72
73    /// Access the Chat resource.
74    pub fn chat(&self) -> Chat<'_> {
75        Chat::new(self)
76    }
77
78    /// Access the Models resource.
79    pub fn models(&self) -> Models<'_> {
80        Models::new(self)
81    }
82
83    /// Access the Fine-tuning resource.
84    pub fn fine_tuning(&self) -> FineTuning<'_> {
85        FineTuning::new(self)
86    }
87
88    /// Access the Files resource.
89    pub fn files(&self) -> Files<'_> {
90        Files::new(self)
91    }
92
93    /// Access the Images resource.
94    pub fn images(&self) -> Images<'_> {
95        Images::new(self)
96    }
97
98    /// Access the Moderations resource.
99    pub fn moderations(&self) -> Moderations<'_> {
100        Moderations::new(self)
101    }
102
103    /// Access the Responses resource.
104    pub fn responses(&self) -> Responses<'_> {
105        Responses::new(self)
106    }
107
108    /// Access the Embeddings resource.
109    pub fn embeddings(&self) -> Embeddings<'_> {
110        Embeddings::new(self)
111    }
112
113    /// Build a request with auth headers.
114    pub(crate) fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
115        let url = format!("{}{}", self.config.base_url, path);
116        let mut req = self
117            .http
118            .request(method, &url)
119            .bearer_auth(&self.config.api_key);
120
121        if let Some(ref org) = self.config.organization {
122            req = req.header("OpenAI-Organization", org);
123        }
124        if let Some(ref project) = self.config.project {
125            req = req.header("OpenAI-Project", project);
126        }
127
128        req
129    }
130
131    /// Send a GET request and deserialize the response.
132    #[allow(dead_code)]
133    pub(crate) async fn get<T: serde::de::DeserializeOwned>(
134        &self,
135        path: &str,
136    ) -> Result<T, OpenAIError> {
137        self.send_with_retry(reqwest::Method::GET, path, None::<&()>)
138            .await
139    }
140
141    /// Send a POST request with a JSON body and deserialize the response.
142    pub(crate) async fn post<B: serde::Serialize, T: serde::de::DeserializeOwned>(
143        &self,
144        path: &str,
145        body: &B,
146    ) -> Result<T, OpenAIError> {
147        self.send_with_retry(reqwest::Method::POST, path, Some(body))
148            .await
149    }
150
151    /// Send a POST request with a multipart form body and deserialize the response.
152    pub(crate) async fn post_multipart<T: serde::de::DeserializeOwned>(
153        &self,
154        path: &str,
155        form: reqwest::multipart::Form,
156    ) -> Result<T, OpenAIError> {
157        let response = self
158            .request(reqwest::Method::POST, path)
159            .multipart(form)
160            .send()
161            .await?;
162        Self::handle_response(response).await
163    }
164
165    /// Send a GET request and return raw bytes.
166    pub(crate) async fn get_raw(&self, path: &str) -> Result<bytes::Bytes, OpenAIError> {
167        let response = self.request(reqwest::Method::GET, path).send().await?;
168
169        let status = response.status();
170        if status.is_success() {
171            Ok(response.bytes().await?)
172        } else {
173            Err(Self::extract_error(status.as_u16(), response).await)
174        }
175    }
176
177    /// Send a POST request with JSON body and return raw bytes (for binary responses like audio).
178    pub(crate) async fn post_raw<B: serde::Serialize>(
179        &self,
180        path: &str,
181        body: &B,
182    ) -> Result<bytes::Bytes, OpenAIError> {
183        let response = self
184            .request(reqwest::Method::POST, path)
185            .json(body)
186            .send()
187            .await?;
188
189        let status = response.status();
190        if status.is_success() {
191            Ok(response.bytes().await?)
192        } else {
193            Err(Self::extract_error(status.as_u16(), response).await)
194        }
195    }
196
197    /// Send a DELETE request and deserialize the response.
198    #[allow(dead_code)]
199    pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
200        &self,
201        path: &str,
202    ) -> Result<T, OpenAIError> {
203        self.send_with_retry(reqwest::Method::DELETE, path, None::<&()>)
204            .await
205    }
206
207    /// Send a request with retry logic for transient errors.
208    async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
209        &self,
210        method: reqwest::Method,
211        path: &str,
212        body: Option<&B>,
213    ) -> Result<T, OpenAIError> {
214        let max_retries = self.config.max_retries;
215        let mut last_error: Option<OpenAIError> = None;
216
217        for attempt in 0..=max_retries {
218            let mut req = self.request(method.clone(), path);
219            if let Some(b) = body {
220                req = req.json(b);
221            }
222
223            let response = match req.send().await {
224                Ok(resp) => resp,
225                Err(e) => {
226                    last_error = Some(OpenAIError::RequestError(e));
227                    if attempt < max_retries {
228                        tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
229                        continue;
230                    }
231                    break;
232                }
233            };
234
235            let status = response.status().as_u16();
236
237            if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
238                return Self::handle_response(response).await;
239            }
240
241            // Retryable status — parse Retry-After and sleep
242            let retry_after = response
243                .headers()
244                .get("retry-after")
245                .and_then(|v| v.to_str().ok())
246                .and_then(|v| v.parse::<f64>().ok());
247
248            last_error = Some(Self::extract_error(status, response).await);
249            tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
250        }
251
252        Err(last_error.unwrap_or_else(|| {
253            OpenAIError::InvalidArgument("retry loop exhausted without error".to_string())
254        }))
255    }
256
257    /// Calculate backoff delay: max(retry_after, 0.5 * 2^attempt) seconds.
258    fn backoff_delay(attempt: u32, retry_after_secs: Option<f64>) -> Duration {
259        let exponential = 0.5 * 2.0_f64.powi(attempt as i32);
260        let secs = match retry_after_secs {
261            Some(ra) => ra.max(exponential),
262            None => exponential,
263        };
264        Duration::from_secs_f64(secs.min(60.0))
265    }
266
267    /// Handle API response: check status, parse errors or deserialize body.
268    pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
269        response: reqwest::Response,
270    ) -> Result<T, OpenAIError> {
271        let status = response.status();
272        if status.is_success() {
273            let body = response.text().await?;
274            let value: T = serde_json::from_str(&body)?;
275            Ok(value)
276        } else {
277            Err(Self::extract_error(status.as_u16(), response).await)
278        }
279    }
280
281    /// Extract an OpenAIError from a failed response.
282    async fn extract_error(status: u16, response: reqwest::Response) -> OpenAIError {
283        let body = response.text().await.unwrap_or_default();
284        if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&body) {
285            OpenAIError::ApiError {
286                status,
287                message: error_resp.error.message,
288                type_: error_resp.error.type_,
289                code: error_resp.error.code,
290            }
291        } else {
292            OpenAIError::ApiError {
293                status,
294                message: body,
295                type_: None,
296                code: None,
297            }
298        }
299    }
300}
301
302/// Access beta endpoints (Assistants v2, Threads, Runs, Vector Stores).
303pub struct Beta<'a> {
304    client: &'a OpenAI,
305}
306
307impl<'a> Beta<'a> {
308    /// Access the Assistants resource.
309    pub fn assistants(&self) -> Assistants<'_> {
310        Assistants::new(self.client)
311    }
312
313    /// Access the Threads resource.
314    pub fn threads(&self) -> Threads<'_> {
315        Threads::new(self.client)
316    }
317
318    /// Access runs for a specific thread.
319    pub fn runs(&self, thread_id: &str) -> Runs<'_> {
320        Runs::new(self.client, thread_id.to_string())
321    }
322
323    /// Access the Vector Stores resource.
324    pub fn vector_stores(&self) -> VectorStores<'_> {
325        VectorStores::new(self.client)
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_new_client() {
335        let client = OpenAI::new("sk-test-key");
336        assert_eq!(client.config.api_key, "sk-test-key");
337        assert_eq!(client.config.base_url, "https://api.openai.com/v1");
338    }
339
340    #[test]
341    fn test_with_config() {
342        let config = ClientConfig::new("sk-test")
343            .base_url("https://custom.api.com")
344            .organization("org-123")
345            .timeout_secs(30);
346        let client = OpenAI::with_config(config);
347        assert_eq!(client.config.base_url, "https://custom.api.com");
348        assert_eq!(client.config.organization.as_deref(), Some("org-123"));
349        assert_eq!(client.config.timeout_secs, 30);
350    }
351
352    #[test]
353    fn test_backoff_delay() {
354        // Attempt 0: 0.5s
355        let d = OpenAI::backoff_delay(0, None);
356        assert_eq!(d, Duration::from_millis(500));
357
358        // Attempt 1: 1.0s
359        let d = OpenAI::backoff_delay(1, None);
360        assert_eq!(d, Duration::from_secs(1));
361
362        // Attempt 2: 2.0s
363        let d = OpenAI::backoff_delay(2, None);
364        assert_eq!(d, Duration::from_secs(2));
365
366        // Retry-After takes precedence when larger
367        let d = OpenAI::backoff_delay(0, Some(5.0));
368        assert_eq!(d, Duration::from_secs(5));
369
370        // Exponential wins when larger than Retry-After
371        let d = OpenAI::backoff_delay(3, Some(0.1));
372        assert_eq!(d, Duration::from_secs(4));
373
374        // Capped at 60s
375        let d = OpenAI::backoff_delay(10, None);
376        assert_eq!(d, Duration::from_secs(60));
377    }
378
379    #[tokio::test]
380    async fn test_get_success() {
381        let mut server = mockito::Server::new_async().await;
382        let mock = server
383            .mock("GET", "/models/gpt-4")
384            .with_status(200)
385            .with_header("content-type", "application/json")
386            .with_body(
387                r#"{"id":"gpt-4","object":"model","created":1687882411,"owned_by":"openai"}"#,
388            )
389            .create_async()
390            .await;
391
392        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
393
394        #[derive(serde::Deserialize)]
395        struct Model {
396            id: String,
397            object: String,
398        }
399
400        let model: Model = client.get("/models/gpt-4").await.unwrap();
401        assert_eq!(model.id, "gpt-4");
402        assert_eq!(model.object, "model");
403        mock.assert_async().await;
404    }
405
406    #[tokio::test]
407    async fn test_post_success() {
408        let mut server = mockito::Server::new_async().await;
409        let mock = server
410            .mock("POST", "/chat/completions")
411            .match_header("authorization", "Bearer sk-test")
412            .match_header("content-type", "application/json")
413            .with_status(200)
414            .with_header("content-type", "application/json")
415            .with_body(r#"{"id":"chatcmpl-123","object":"chat.completion"}"#)
416            .create_async()
417            .await;
418
419        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
420
421        #[derive(serde::Serialize)]
422        struct Req {
423            model: String,
424        }
425        #[derive(serde::Deserialize)]
426        struct Resp {
427            id: String,
428        }
429
430        let resp: Resp = client
431            .post(
432                "/chat/completions",
433                &Req {
434                    model: "gpt-4".into(),
435                },
436            )
437            .await
438            .unwrap();
439        assert_eq!(resp.id, "chatcmpl-123");
440        mock.assert_async().await;
441    }
442
443    #[tokio::test]
444    async fn test_delete_success() {
445        let mut server = mockito::Server::new_async().await;
446        let mock = server
447            .mock("DELETE", "/models/ft-abc")
448            .with_status(200)
449            .with_header("content-type", "application/json")
450            .with_body(r#"{"id":"ft-abc","deleted":true}"#)
451            .create_async()
452            .await;
453
454        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
455
456        #[derive(serde::Deserialize)]
457        struct DeleteResp {
458            id: String,
459            deleted: bool,
460        }
461
462        let resp: DeleteResp = client.delete("/models/ft-abc").await.unwrap();
463        assert_eq!(resp.id, "ft-abc");
464        assert!(resp.deleted);
465        mock.assert_async().await;
466    }
467
468    #[tokio::test]
469    async fn test_api_error_response() {
470        let mut server = mockito::Server::new_async().await;
471        let mock = server
472            .mock("GET", "/models/nonexistent")
473            .with_status(404)
474            .with_header("content-type", "application/json")
475            .with_body(
476                r#"{"error":{"message":"The model 'nonexistent' does not exist","type":"invalid_request_error","param":null,"code":"model_not_found"}}"#,
477            )
478            .create_async()
479            .await;
480
481        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
482
483        #[derive(Debug, serde::Deserialize)]
484        struct Model {
485            _id: String,
486        }
487
488        let err = client
489            .get::<Model>("/models/nonexistent")
490            .await
491            .unwrap_err();
492        match err {
493            OpenAIError::ApiError {
494                status,
495                message,
496                type_,
497                code,
498            } => {
499                assert_eq!(status, 404);
500                assert!(message.contains("does not exist"));
501                assert_eq!(type_.as_deref(), Some("invalid_request_error"));
502                assert_eq!(code.as_deref(), Some("model_not_found"));
503            }
504            other => panic!("expected ApiError, got: {other:?}"),
505        }
506        mock.assert_async().await;
507    }
508
509    #[tokio::test]
510    async fn test_auth_headers() {
511        let mut server = mockito::Server::new_async().await;
512        let mock = server
513            .mock("GET", "/test")
514            .match_header("authorization", "Bearer sk-key")
515            .match_header("OpenAI-Organization", "org-abc")
516            .match_header("OpenAI-Project", "proj-xyz")
517            .with_status(200)
518            .with_body(r#"{"ok":true}"#)
519            .create_async()
520            .await;
521
522        let client = OpenAI::with_config(
523            ClientConfig::new("sk-key")
524                .base_url(server.url())
525                .organization("org-abc")
526                .project("proj-xyz"),
527        );
528
529        #[derive(serde::Deserialize)]
530        struct Resp {
531            ok: bool,
532        }
533
534        let resp: Resp = client.get("/test").await.unwrap();
535        assert!(resp.ok);
536        mock.assert_async().await;
537    }
538
539    #[tokio::test]
540    async fn test_retry_on_429_then_success() {
541        let mut server = mockito::Server::new_async().await;
542
543        // First request returns 429, second returns 200
544        let _mock_429 = server
545            .mock("GET", "/test")
546            .with_status(429)
547            .with_header("retry-after", "0")
548            .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
549            .create_async()
550            .await;
551
552        let mock_200 = server
553            .mock("GET", "/test")
554            .with_status(200)
555            .with_body(r#"{"ok":true}"#)
556            .create_async()
557            .await;
558
559        let client = OpenAI::with_config(
560            ClientConfig::new("sk-test")
561                .base_url(server.url())
562                .max_retries(2),
563        );
564
565        #[derive(serde::Deserialize)]
566        struct Resp {
567            ok: bool,
568        }
569
570        let resp: Resp = client.get("/test").await.unwrap();
571        assert!(resp.ok);
572        mock_200.assert_async().await;
573    }
574
575    #[tokio::test]
576    async fn test_retry_exhausted_returns_last_error() {
577        let mut server = mockito::Server::new_async().await;
578
579        // All requests return 500
580        let _mock = server
581            .mock("GET", "/test")
582            .with_status(500)
583            .with_body(r#"{"error":{"message":"Internal server error","type":"server_error","param":null,"code":null}}"#)
584            .expect_at_least(2)
585            .create_async()
586            .await;
587
588        let client = OpenAI::with_config(
589            ClientConfig::new("sk-test")
590                .base_url(server.url())
591                .max_retries(1),
592        );
593
594        #[derive(Debug, serde::Deserialize)]
595        struct Resp {
596            _ok: bool,
597        }
598
599        let err = client.get::<Resp>("/test").await.unwrap_err();
600        match err {
601            OpenAIError::ApiError { status, .. } => assert_eq!(status, 500),
602            other => panic!("expected ApiError, got: {other:?}"),
603        }
604    }
605
606    #[tokio::test]
607    async fn test_no_retry_on_400() {
608        let mut server = mockito::Server::new_async().await;
609
610        // 400 should not be retried
611        let mock = server
612            .mock("GET", "/test")
613            .with_status(400)
614            .with_body(r#"{"error":{"message":"Bad request","type":"invalid_request_error","param":null,"code":null}}"#)
615            .expect(1)
616            .create_async()
617            .await;
618
619        let client = OpenAI::with_config(
620            ClientConfig::new("sk-test")
621                .base_url(server.url())
622                .max_retries(2),
623        );
624
625        #[derive(Debug, serde::Deserialize)]
626        struct Resp {
627            _ok: bool,
628        }
629
630        let err = client.get::<Resp>("/test").await.unwrap_err();
631        match err {
632            OpenAIError::ApiError { status, .. } => assert_eq!(status, 400),
633            other => panic!("expected ApiError, got: {other:?}"),
634        }
635        mock.assert_async().await;
636    }
637
638    #[tokio::test]
639    async fn test_zero_retries_no_retry() {
640        let mut server = mockito::Server::new_async().await;
641
642        let mock = server
643            .mock("GET", "/test")
644            .with_status(429)
645            .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
646            .expect(1)
647            .create_async()
648            .await;
649
650        let client = OpenAI::with_config(
651            ClientConfig::new("sk-test")
652                .base_url(server.url())
653                .max_retries(0),
654        );
655
656        #[derive(Debug, serde::Deserialize)]
657        struct Resp {
658            _ok: bool,
659        }
660
661        let err = client.get::<Resp>("/test").await.unwrap_err();
662        match err {
663            OpenAIError::ApiError { status, .. } => assert_eq!(status, 429),
664            other => panic!("expected ApiError, got: {other:?}"),
665        }
666        mock.assert_async().await;
667    }
668}