Skip to main content

openai_oxide/
client.rs

1// OpenAI client
2
3use std::time::Duration;
4
5use crate::config::ClientConfig;
6use crate::error::{ErrorResponse, OpenAIError};
7use crate::resources::audio::Audio;
8use crate::resources::batches::Batches;
9use crate::resources::beta::assistants::Assistants;
10use crate::resources::beta::realtime::Realtime;
11use crate::resources::beta::runs::Runs;
12use crate::resources::beta::threads::Threads;
13use crate::resources::beta::vector_stores::VectorStores;
14use crate::resources::chat::Chat;
15use crate::resources::embeddings::Embeddings;
16use crate::resources::files::Files;
17use crate::resources::fine_tuning::FineTuning;
18use crate::resources::images::Images;
19use crate::resources::models::Models;
20use crate::resources::moderations::Moderations;
21use crate::resources::responses::Responses;
22use crate::resources::uploads::Uploads;
23
24/// Status codes that trigger a retry.
25const RETRYABLE_STATUS_CODES: [u16; 4] = [429, 500, 502, 503];
26
27/// The main OpenAI client.
28#[derive(Debug, Clone)]
29pub struct OpenAI {
30    pub(crate) http: reqwest::Client,
31    pub(crate) config: ClientConfig,
32}
33
34impl OpenAI {
35    /// Create a new client with the given API key.
36    pub fn new(api_key: impl Into<String>) -> Self {
37        Self::with_config(ClientConfig::new(api_key))
38    }
39
40    /// Create a client from a full config.
41    pub fn with_config(config: ClientConfig) -> Self {
42        let http = reqwest::Client::builder()
43            .timeout(Duration::from_secs(config.timeout_secs))
44            .build()
45            .expect("failed to build HTTP client");
46        Self { http, config }
47    }
48
49    /// Create a client using the `OPENAI_API_KEY` environment variable.
50    pub fn from_env() -> Result<Self, OpenAIError> {
51        Ok(Self::with_config(ClientConfig::from_env()?))
52    }
53
54    /// Access the Batches resource.
55    pub fn batches(&self) -> Batches<'_> {
56        Batches::new(self)
57    }
58
59    /// Access the Uploads resource.
60    pub fn uploads(&self) -> Uploads<'_> {
61        Uploads::new(self)
62    }
63
64    /// Access the Beta resources (Assistants, Threads, Runs, Vector Stores).
65    pub fn beta(&self) -> Beta<'_> {
66        Beta { client: self }
67    }
68
69    /// Access the Audio resource.
70    pub fn audio(&self) -> Audio<'_> {
71        Audio::new(self)
72    }
73
74    /// Access the Chat resource.
75    pub fn chat(&self) -> Chat<'_> {
76        Chat::new(self)
77    }
78
79    /// Access the Models resource.
80    pub fn models(&self) -> Models<'_> {
81        Models::new(self)
82    }
83
84    /// Access the Fine-tuning resource.
85    pub fn fine_tuning(&self) -> FineTuning<'_> {
86        FineTuning::new(self)
87    }
88
89    /// Access the Files resource.
90    pub fn files(&self) -> Files<'_> {
91        Files::new(self)
92    }
93
94    /// Access the Images resource.
95    pub fn images(&self) -> Images<'_> {
96        Images::new(self)
97    }
98
99    /// Access the Moderations resource.
100    pub fn moderations(&self) -> Moderations<'_> {
101        Moderations::new(self)
102    }
103
104    /// Access the Responses resource.
105    pub fn responses(&self) -> Responses<'_> {
106        Responses::new(self)
107    }
108
109    /// Access the Embeddings resource.
110    pub fn embeddings(&self) -> Embeddings<'_> {
111        Embeddings::new(self)
112    }
113
114    /// Build a request with auth headers.
115    pub(crate) fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
116        let url = format!("{}{}", self.config.base_url, path);
117        let mut req = self
118            .http
119            .request(method, &url)
120            .bearer_auth(&self.config.api_key);
121
122        if let Some(ref org) = self.config.organization {
123            req = req.header("OpenAI-Organization", org);
124        }
125        if let Some(ref project) = self.config.project {
126            req = req.header("OpenAI-Project", project);
127        }
128
129        req
130    }
131
132    /// Send a GET request and deserialize the response.
133    #[allow(dead_code)]
134    pub(crate) async fn get<T: serde::de::DeserializeOwned>(
135        &self,
136        path: &str,
137    ) -> Result<T, OpenAIError> {
138        self.send_with_retry(reqwest::Method::GET, path, None::<&()>)
139            .await
140    }
141
142    /// Send a POST request with a JSON body and deserialize the response.
143    pub(crate) async fn post<B: serde::Serialize, T: serde::de::DeserializeOwned>(
144        &self,
145        path: &str,
146        body: &B,
147    ) -> Result<T, OpenAIError> {
148        self.send_with_retry(reqwest::Method::POST, path, Some(body))
149            .await
150    }
151
152    /// Send a POST request with a multipart form body and deserialize the response.
153    pub(crate) async fn post_multipart<T: serde::de::DeserializeOwned>(
154        &self,
155        path: &str,
156        form: reqwest::multipart::Form,
157    ) -> Result<T, OpenAIError> {
158        let response = self
159            .request(reqwest::Method::POST, path)
160            .multipart(form)
161            .send()
162            .await?;
163        Self::handle_response(response).await
164    }
165
166    /// Send a GET request and return raw bytes.
167    pub(crate) async fn get_raw(&self, path: &str) -> Result<bytes::Bytes, OpenAIError> {
168        let response = self.request(reqwest::Method::GET, path).send().await?;
169
170        let status = response.status();
171        if status.is_success() {
172            Ok(response.bytes().await?)
173        } else {
174            Err(Self::extract_error(status.as_u16(), response).await)
175        }
176    }
177
178    /// Send a POST request with JSON body and return raw bytes (for binary responses like audio).
179    pub(crate) async fn post_raw<B: serde::Serialize>(
180        &self,
181        path: &str,
182        body: &B,
183    ) -> Result<bytes::Bytes, OpenAIError> {
184        let response = self
185            .request(reqwest::Method::POST, path)
186            .json(body)
187            .send()
188            .await?;
189
190        let status = response.status();
191        if status.is_success() {
192            Ok(response.bytes().await?)
193        } else {
194            Err(Self::extract_error(status.as_u16(), response).await)
195        }
196    }
197
198    /// Send a DELETE request and deserialize the response.
199    #[allow(dead_code)]
200    pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
201        &self,
202        path: &str,
203    ) -> Result<T, OpenAIError> {
204        self.send_with_retry(reqwest::Method::DELETE, path, None::<&()>)
205            .await
206    }
207
208    /// Send a request with retry logic for transient errors.
209    async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
210        &self,
211        method: reqwest::Method,
212        path: &str,
213        body: Option<&B>,
214    ) -> Result<T, OpenAIError> {
215        let max_retries = self.config.max_retries;
216        let mut last_error: Option<OpenAIError> = None;
217
218        for attempt in 0..=max_retries {
219            let mut req = self.request(method.clone(), path);
220            if let Some(b) = body {
221                req = req.json(b);
222            }
223
224            let response = match req.send().await {
225                Ok(resp) => resp,
226                Err(e) => {
227                    last_error = Some(OpenAIError::RequestError(e));
228                    if attempt < max_retries {
229                        tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
230                        continue;
231                    }
232                    break;
233                }
234            };
235
236            let status = response.status().as_u16();
237
238            if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
239                return Self::handle_response(response).await;
240            }
241
242            // Retryable status — parse Retry-After and sleep
243            let retry_after = response
244                .headers()
245                .get("retry-after")
246                .and_then(|v| v.to_str().ok())
247                .and_then(|v| v.parse::<f64>().ok());
248
249            last_error = Some(Self::extract_error(status, response).await);
250            tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
251        }
252
253        Err(last_error.unwrap_or_else(|| {
254            OpenAIError::InvalidArgument("retry loop exhausted without error".to_string())
255        }))
256    }
257
258    /// Calculate backoff delay: max(retry_after, 0.5 * 2^attempt) seconds.
259    fn backoff_delay(attempt: u32, retry_after_secs: Option<f64>) -> Duration {
260        let exponential = 0.5 * 2.0_f64.powi(attempt as i32);
261        let secs = match retry_after_secs {
262            Some(ra) => ra.max(exponential),
263            None => exponential,
264        };
265        Duration::from_secs_f64(secs.min(60.0))
266    }
267
268    /// Handle API response: check status, parse errors or deserialize body.
269    pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
270        response: reqwest::Response,
271    ) -> Result<T, OpenAIError> {
272        let status = response.status();
273        if status.is_success() {
274            let body = response.text().await?;
275            let value: T = serde_json::from_str(&body)?;
276            Ok(value)
277        } else {
278            Err(Self::extract_error(status.as_u16(), response).await)
279        }
280    }
281
282    /// Extract an OpenAIError from a failed response.
283    async fn extract_error(status: u16, response: reqwest::Response) -> OpenAIError {
284        let body = response.text().await.unwrap_or_default();
285        if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&body) {
286            OpenAIError::ApiError {
287                status,
288                message: error_resp.error.message,
289                type_: error_resp.error.type_,
290                code: error_resp.error.code,
291            }
292        } else {
293            OpenAIError::ApiError {
294                status,
295                message: body,
296                type_: None,
297                code: None,
298            }
299        }
300    }
301}
302
303/// Access beta endpoints (Assistants v2, Threads, Runs, Vector Stores).
304pub struct Beta<'a> {
305    client: &'a OpenAI,
306}
307
308impl<'a> Beta<'a> {
309    /// Access the Assistants resource.
310    pub fn assistants(&self) -> Assistants<'_> {
311        Assistants::new(self.client)
312    }
313
314    /// Access the Threads resource.
315    pub fn threads(&self) -> Threads<'_> {
316        Threads::new(self.client)
317    }
318
319    /// Access runs for a specific thread.
320    pub fn runs(&self, thread_id: &str) -> Runs<'_> {
321        Runs::new(self.client, thread_id.to_string())
322    }
323
324    /// Access the Vector Stores resource.
325    pub fn vector_stores(&self) -> VectorStores<'_> {
326        VectorStores::new(self.client)
327    }
328
329    /// Access the Realtime resource.
330    pub fn realtime(&self) -> Realtime<'_> {
331        Realtime::new(self.client)
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_new_client() {
341        let client = OpenAI::new("sk-test-key");
342        assert_eq!(client.config.api_key, "sk-test-key");
343        assert_eq!(client.config.base_url, "https://api.openai.com/v1");
344    }
345
346    #[test]
347    fn test_with_config() {
348        let config = ClientConfig::new("sk-test")
349            .base_url("https://custom.api.com")
350            .organization("org-123")
351            .timeout_secs(30);
352        let client = OpenAI::with_config(config);
353        assert_eq!(client.config.base_url, "https://custom.api.com");
354        assert_eq!(client.config.organization.as_deref(), Some("org-123"));
355        assert_eq!(client.config.timeout_secs, 30);
356    }
357
358    #[test]
359    fn test_backoff_delay() {
360        // Attempt 0: 0.5s
361        let d = OpenAI::backoff_delay(0, None);
362        assert_eq!(d, Duration::from_millis(500));
363
364        // Attempt 1: 1.0s
365        let d = OpenAI::backoff_delay(1, None);
366        assert_eq!(d, Duration::from_secs(1));
367
368        // Attempt 2: 2.0s
369        let d = OpenAI::backoff_delay(2, None);
370        assert_eq!(d, Duration::from_secs(2));
371
372        // Retry-After takes precedence when larger
373        let d = OpenAI::backoff_delay(0, Some(5.0));
374        assert_eq!(d, Duration::from_secs(5));
375
376        // Exponential wins when larger than Retry-After
377        let d = OpenAI::backoff_delay(3, Some(0.1));
378        assert_eq!(d, Duration::from_secs(4));
379
380        // Capped at 60s
381        let d = OpenAI::backoff_delay(10, None);
382        assert_eq!(d, Duration::from_secs(60));
383    }
384
385    #[tokio::test]
386    async fn test_get_success() {
387        let mut server = mockito::Server::new_async().await;
388        let mock = server
389            .mock("GET", "/models/gpt-4")
390            .with_status(200)
391            .with_header("content-type", "application/json")
392            .with_body(
393                r#"{"id":"gpt-4","object":"model","created":1687882411,"owned_by":"openai"}"#,
394            )
395            .create_async()
396            .await;
397
398        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
399
400        #[derive(serde::Deserialize)]
401        struct Model {
402            id: String,
403            object: String,
404        }
405
406        let model: Model = client.get("/models/gpt-4").await.unwrap();
407        assert_eq!(model.id, "gpt-4");
408        assert_eq!(model.object, "model");
409        mock.assert_async().await;
410    }
411
412    #[tokio::test]
413    async fn test_post_success() {
414        let mut server = mockito::Server::new_async().await;
415        let mock = server
416            .mock("POST", "/chat/completions")
417            .match_header("authorization", "Bearer sk-test")
418            .match_header("content-type", "application/json")
419            .with_status(200)
420            .with_header("content-type", "application/json")
421            .with_body(r#"{"id":"chatcmpl-123","object":"chat.completion"}"#)
422            .create_async()
423            .await;
424
425        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
426
427        #[derive(serde::Serialize)]
428        struct Req {
429            model: String,
430        }
431        #[derive(serde::Deserialize)]
432        struct Resp {
433            id: String,
434        }
435
436        let resp: Resp = client
437            .post(
438                "/chat/completions",
439                &Req {
440                    model: "gpt-4".into(),
441                },
442            )
443            .await
444            .unwrap();
445        assert_eq!(resp.id, "chatcmpl-123");
446        mock.assert_async().await;
447    }
448
449    #[tokio::test]
450    async fn test_delete_success() {
451        let mut server = mockito::Server::new_async().await;
452        let mock = server
453            .mock("DELETE", "/models/ft-abc")
454            .with_status(200)
455            .with_header("content-type", "application/json")
456            .with_body(r#"{"id":"ft-abc","deleted":true}"#)
457            .create_async()
458            .await;
459
460        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
461
462        #[derive(serde::Deserialize)]
463        struct DeleteResp {
464            id: String,
465            deleted: bool,
466        }
467
468        let resp: DeleteResp = client.delete("/models/ft-abc").await.unwrap();
469        assert_eq!(resp.id, "ft-abc");
470        assert!(resp.deleted);
471        mock.assert_async().await;
472    }
473
474    #[tokio::test]
475    async fn test_api_error_response() {
476        let mut server = mockito::Server::new_async().await;
477        let mock = server
478            .mock("GET", "/models/nonexistent")
479            .with_status(404)
480            .with_header("content-type", "application/json")
481            .with_body(
482                r#"{"error":{"message":"The model 'nonexistent' does not exist","type":"invalid_request_error","param":null,"code":"model_not_found"}}"#,
483            )
484            .create_async()
485            .await;
486
487        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
488
489        #[derive(Debug, serde::Deserialize)]
490        struct Model {
491            _id: String,
492        }
493
494        let err = client
495            .get::<Model>("/models/nonexistent")
496            .await
497            .unwrap_err();
498        match err {
499            OpenAIError::ApiError {
500                status,
501                message,
502                type_,
503                code,
504            } => {
505                assert_eq!(status, 404);
506                assert!(message.contains("does not exist"));
507                assert_eq!(type_.as_deref(), Some("invalid_request_error"));
508                assert_eq!(code.as_deref(), Some("model_not_found"));
509            }
510            other => panic!("expected ApiError, got: {other:?}"),
511        }
512        mock.assert_async().await;
513    }
514
515    #[tokio::test]
516    async fn test_auth_headers() {
517        let mut server = mockito::Server::new_async().await;
518        let mock = server
519            .mock("GET", "/test")
520            .match_header("authorization", "Bearer sk-key")
521            .match_header("OpenAI-Organization", "org-abc")
522            .match_header("OpenAI-Project", "proj-xyz")
523            .with_status(200)
524            .with_body(r#"{"ok":true}"#)
525            .create_async()
526            .await;
527
528        let client = OpenAI::with_config(
529            ClientConfig::new("sk-key")
530                .base_url(server.url())
531                .organization("org-abc")
532                .project("proj-xyz"),
533        );
534
535        #[derive(serde::Deserialize)]
536        struct Resp {
537            ok: bool,
538        }
539
540        let resp: Resp = client.get("/test").await.unwrap();
541        assert!(resp.ok);
542        mock.assert_async().await;
543    }
544
545    #[tokio::test]
546    async fn test_retry_on_429_then_success() {
547        let mut server = mockito::Server::new_async().await;
548
549        // First request returns 429, second returns 200
550        let _mock_429 = server
551            .mock("GET", "/test")
552            .with_status(429)
553            .with_header("retry-after", "0")
554            .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
555            .create_async()
556            .await;
557
558        let mock_200 = server
559            .mock("GET", "/test")
560            .with_status(200)
561            .with_body(r#"{"ok":true}"#)
562            .create_async()
563            .await;
564
565        let client = OpenAI::with_config(
566            ClientConfig::new("sk-test")
567                .base_url(server.url())
568                .max_retries(2),
569        );
570
571        #[derive(serde::Deserialize)]
572        struct Resp {
573            ok: bool,
574        }
575
576        let resp: Resp = client.get("/test").await.unwrap();
577        assert!(resp.ok);
578        mock_200.assert_async().await;
579    }
580
581    #[tokio::test]
582    async fn test_retry_exhausted_returns_last_error() {
583        let mut server = mockito::Server::new_async().await;
584
585        // All requests return 500
586        let _mock = server
587            .mock("GET", "/test")
588            .with_status(500)
589            .with_body(r#"{"error":{"message":"Internal server error","type":"server_error","param":null,"code":null}}"#)
590            .expect_at_least(2)
591            .create_async()
592            .await;
593
594        let client = OpenAI::with_config(
595            ClientConfig::new("sk-test")
596                .base_url(server.url())
597                .max_retries(1),
598        );
599
600        #[derive(Debug, serde::Deserialize)]
601        struct Resp {
602            _ok: bool,
603        }
604
605        let err = client.get::<Resp>("/test").await.unwrap_err();
606        match err {
607            OpenAIError::ApiError { status, .. } => assert_eq!(status, 500),
608            other => panic!("expected ApiError, got: {other:?}"),
609        }
610    }
611
612    #[tokio::test]
613    async fn test_no_retry_on_400() {
614        let mut server = mockito::Server::new_async().await;
615
616        // 400 should not be retried
617        let mock = server
618            .mock("GET", "/test")
619            .with_status(400)
620            .with_body(r#"{"error":{"message":"Bad request","type":"invalid_request_error","param":null,"code":null}}"#)
621            .expect(1)
622            .create_async()
623            .await;
624
625        let client = OpenAI::with_config(
626            ClientConfig::new("sk-test")
627                .base_url(server.url())
628                .max_retries(2),
629        );
630
631        #[derive(Debug, serde::Deserialize)]
632        struct Resp {
633            _ok: bool,
634        }
635
636        let err = client.get::<Resp>("/test").await.unwrap_err();
637        match err {
638            OpenAIError::ApiError { status, .. } => assert_eq!(status, 400),
639            other => panic!("expected ApiError, got: {other:?}"),
640        }
641        mock.assert_async().await;
642    }
643
644    #[tokio::test]
645    async fn test_zero_retries_no_retry() {
646        let mut server = mockito::Server::new_async().await;
647
648        let mock = server
649            .mock("GET", "/test")
650            .with_status(429)
651            .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
652            .expect(1)
653            .create_async()
654            .await;
655
656        let client = OpenAI::with_config(
657            ClientConfig::new("sk-test")
658                .base_url(server.url())
659                .max_retries(0),
660        );
661
662        #[derive(Debug, serde::Deserialize)]
663        struct Resp {
664            _ok: bool,
665        }
666
667        let err = client.get::<Resp>("/test").await.unwrap_err();
668        match err {
669            OpenAIError::ApiError { status, .. } => assert_eq!(status, 429),
670            other => panic!("expected ApiError, got: {other:?}"),
671        }
672        mock.assert_async().await;
673    }
674}