chat_gpt_lib_rs/
api.rs

1//! The `api` module contains low-level functions for making HTTP requests to the OpenAI API.
2//! It handles authentication headers, organization headers, error parsing, and JSON (de)serialization.
3//!
4//! # Usage
5//!
6//! This module is not typically used directly. Instead, higher-level modules (e.g., for
7//! Completions, Chat, Embeddings, etc.) will call these functions to perform network requests.
8
9use crate::config::OpenAIClient;
10use crate::error::OpenAIError;
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13
14// Import for streaming support:
15use futures_util::stream::TryStreamExt;
16use tokio::io::{AsyncBufReadExt, BufReader};
17use tokio_stream::Stream; // Trait for streams.
18use tokio_stream::StreamExt as TokioStreamExt; // Needed for filter_map.
19use tokio_stream::wrappers::LinesStream;
20use tokio_util::io::StreamReader;
21
22/// Sends a POST request with a JSON body to the given `endpoint`.
23///
24/// # Parameters
25///
26/// - `client`: The [`OpenAIClient`](crate::config::OpenAIClient) holding base URL, API key, and a configured `reqwest::Client`.
27/// - `endpoint`: The relative path (e.g. `"completions"`) appended to the base URL.
28/// - `body`: A serializable request body (e.g. your request struct).
29///
30/// # Returns
31///
32/// A `Result` containing the response deserialized into type `R` on success, or an [`OpenAIError`]
33/// on failure (e.g. network, JSON parse, or API error).
34///
35/// # Errors
36///
37/// - [`OpenAIError::HTTPError`]: If the network request fails (e.g. timeout, DNS error).
38/// - [`OpenAIError::DeserializeError`]: If the response JSON can’t be parsed into `R`.
39/// - [`OpenAIError::APIError`]: If the OpenAI API indicates an error in the response body (e.g. invalid request).
40pub(crate) async fn post_json<T, R>(
41    client: &OpenAIClient,
42    endpoint: &str,
43    body: &T,
44) -> Result<R, OpenAIError>
45where
46    T: Serialize,
47    R: DeserializeOwned,
48{
49    let url = format!("{}/{}", client.base_url().trim_end_matches('/'), endpoint);
50    let mut request_builder = client.http_client.post(&url).bearer_auth(client.api_key());
51
52    // If an organization ID is configured, include that in the request headers.
53    if let Some(org_id) = client.organization() {
54        request_builder = request_builder.header("OpenAI-Organization", org_id);
55    }
56
57    let response = request_builder.json(body).send().await?;
58
59    handle_response(response).await
60}
61
62/// Sends a GET request to the given `endpoint`.
63///
64/// # Parameters
65///
66/// - `client`: The [`OpenAIClient`](crate::config::OpenAIClient) holding base URL, API key, and a configured `reqwest::Client`.
67/// - `endpoint`: The relative path (e.g. `"models"`) appended to the base URL.
68///
69/// # Returns
70///
71/// A `Result` containing the response deserialized into type `R` on success, or an [`OpenAIError`]
72/// on failure (e.g. network, JSON parse, or API error).
73///
74/// # Errors
75///
76/// - [`OpenAIError::HTTPError`]: If the network request fails (e.g. timeout, DNS error).
77/// - [`OpenAIError::DeserializeError`]: If the response JSON can’t be parsed into `R`.
78/// - [`OpenAIError::APIError`]: If the OpenAI API indicates an error in the response body (e.g. invalid request).
79pub(crate) async fn get_json<R>(client: &OpenAIClient, endpoint: &str) -> Result<R, OpenAIError>
80where
81    R: DeserializeOwned,
82{
83    let url = format!("{}/{}", client.base_url().trim_end_matches('/'), endpoint);
84    let mut request_builder = client.http_client.get(&url).bearer_auth(client.api_key());
85
86    // If an organization ID is configured, include that in the request headers.
87    if let Some(org_id) = client.organization() {
88        request_builder = request_builder.header("OpenAI-Organization", org_id);
89    }
90
91    let response = request_builder.send().await?;
92
93    handle_response(response).await
94}
95
96/// Parses the `reqwest::Response` from the OpenAI API, returning a successful `R` or an
97/// [`OpenAIError`].
98///
99/// # Parameters
100///
101/// - `response`: The raw HTTP response from `reqwest`.
102///
103/// # Returns
104///
105/// * `Ok(R)` if the response is `2xx` and can be deserialized into `R`.
106/// * `Err(OpenAIError::APIError)` if the response has a non-success status code and includes
107///   an OpenAI error message.
108/// * `Err(OpenAIError::DeserializeError)` if the JSON could not be deserialized into `R`.
109async fn handle_response<R>(response: reqwest::Response) -> Result<R, OpenAIError>
110where
111    R: DeserializeOwned,
112{
113    let status = response.status();
114    if status.is_success() {
115        // 1) Read raw text from the response
116        let text = response.text().await?;
117
118        // 2) Attempt to parse with serde_json. If it fails, map to `OpenAIError::DeserializeError`
119        let parsed: R = serde_json::from_str(&text).map_err(OpenAIError::from)?;
120
121        Ok(parsed)
122    } else {
123        parse_error_response(response).await
124    }
125}
126
127/// Attempts to parse the OpenAI error body. If successful, returns `Err(OpenAIError::APIError)`.
128/// Otherwise, returns a generic error based on the HTTP status code or raw text.
129pub async fn parse_error_response<R>(response: reqwest::Response) -> Result<R, OpenAIError> {
130    let status = response.status();
131    let text_body = response.text().await.unwrap_or_else(|_| "".to_string());
132
133    match serde_json::from_str::<crate::error::OpenAIAPIErrorBody>(&text_body) {
134        Ok(body) => Err(OpenAIError::from(body)),
135        Err(_) => {
136            let msg = format!(
137                "HTTP {} returned from OpenAI API; body: {}",
138                status, text_body
139            );
140            Err(OpenAIError::APIError {
141                message: msg,
142                err_type: None,
143                code: None,
144            })
145        }
146    }
147}
148
149/// Sends a POST request with a JSON body to the given `endpoint` and returns a stream of responses.
150/// This is designed for endpoints that support streaming responses (e.g., Chat Completions with `stream = true`).
151///
152/// # Parameters
153///
154/// - `client`: The [`OpenAIClient`](crate::config::OpenAIClient) holding base URL, API key, etc.
155/// - `endpoint`: The relative endpoint (e.g., `"chat/completions"`) appended to the base URL.
156/// - `body`: A serializable request body.
157///
158/// # Returns
159///
160/// A stream of deserialized items of type `R`. Each item represents a partial response from the server.
161///
162/// # Errors
163///
164/// Returns an [`OpenAIError`] if the initial request fails or if the HTTP response indicates an error.
165///
166/// # Dependencies
167///
168/// This function uses the latest versions of `tokio-stream` and `tokio-util`.
169pub async fn post_json_stream<T, R>(
170    client: &OpenAIClient,
171    endpoint: &str,
172    body: &T,
173) -> Result<impl Stream<Item = Result<R, OpenAIError>>, OpenAIError>
174where
175    T: Serialize,
176    R: DeserializeOwned + 'static,
177{
178    let url = format!("{}/{}", client.base_url().trim_end_matches('/'), endpoint);
179    let mut request_builder = client.http_client.post(&url).bearer_auth(client.api_key());
180
181    if let Some(org_id) = client.organization() {
182        request_builder = request_builder.header("OpenAI-Organization", org_id);
183    }
184
185    let response = request_builder.json(body).send().await?;
186
187    let status = response.status();
188    if !status.is_success() {
189        let text_body = response.text().await.unwrap_or_else(|_| "".to_string());
190        match serde_json::from_str::<crate::error::OpenAIAPIErrorBody>(&text_body) {
191            Ok(body_err) => return Err(OpenAIError::from(body_err)),
192            Err(_) => {
193                return Err(OpenAIError::APIError {
194                    message: format!(
195                        "HTTP {} returned from OpenAI API; body: {}",
196                        status, text_body
197                    ),
198                    err_type: None,
199                    code: None,
200                });
201            }
202        }
203    }
204
205    // Convert the response's byte stream into an async reader.
206    let byte_stream = response
207        .bytes_stream()
208        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
209    let stream_reader = StreamReader::new(byte_stream);
210    let buf_reader = BufReader::new(stream_reader);
211
212    // Create a stream of lines from the buffered reader.
213    let lines = LinesStream::new(buf_reader.lines());
214
215    // Process each line synchronously:
216    //   - Ignore empty lines or those that contain "[DONE]".
217    //   - Remove the "data:" prefix if present.
218    //   - Attempt to deserialize the remaining JSON into type `R`.
219    let stream = lines.filter_map(|line_result| {
220        match line_result {
221            Ok(line) => {
222                let trimmed = line.trim();
223                // Skip empty lines or termination markers.
224                if trimmed.is_empty() || trimmed.contains("[DONE]") {
225                    None
226                } else {
227                    // Remove the "data:" prefix if it exists.
228                    let data = if trimmed.starts_with("data:") {
229                        trimmed.trim_start_matches("data:").trim()
230                    } else {
231                        trimmed
232                    };
233                    // Attempt to deserialize the JSON.
234                    match serde_json::from_str::<R>(data) {
235                        Ok(parsed) => Some(Ok(parsed)),
236                        Err(e) => {
237                            eprintln!(
238                                "Warning: failed to deserialize chunk: {:?} (error: {})",
239                                data, e
240                            );
241                            None // Skip this chunk on deserialization error.
242                        }
243                    }
244                }
245            }
246            Err(e) => Some(Err(OpenAIError::from(e))),
247        }
248    });
249    Ok(stream)
250}
251
252#[cfg(test)]
253mod tests {
254    /// # Tests for the `api` module
255    ///
256    /// These tests use [`wiremock`](https://crates.io/crates/wiremock) to **mock** HTTP responses from
257    /// the OpenAI API, ensuring we can verify request-building, JSON handling, and error parsing logic
258    /// without hitting real servers.
259    use super::*;
260    use crate::config::OpenAIClient;
261    use crate::error::{OpenAIError, OpenAIError::APIError};
262    use serde::Deserialize;
263    use tokio_stream::StreamExt;
264    use wiremock::matchers::{method, path};
265    use wiremock::{Mock, MockServer, ResponseTemplate}; // <-- for `next`, `collect`, etc.
266
267    #[derive(Debug, Deserialize)]
268    struct MockResponse {
269        pub foo: String,
270        pub bar: i32,
271    }
272
273    /// Tests that `post_json` correctly sends a JSON POST request and parses a successful JSON response.
274    #[tokio::test]
275    async fn test_post_json_success() {
276        // Start a local mock server
277        let mock_server = MockServer::start().await;
278
279        // Define an expected JSON response
280        let mock_data = serde_json::json!({ "foo": "hello", "bar": 42 });
281
282        // Mock a 200 OK response from the endpoint
283        Mock::given(method("POST"))
284            .and(path("/test-endpoint"))
285            .respond_with(ResponseTemplate::new(200).set_body_json(mock_data))
286            .mount(&mock_server)
287            .await;
288
289        // Construct an OpenAIClient that points to our mock server URL
290        let client = OpenAIClient::builder()
291            .with_api_key("test-key")
292            .with_base_url(&mock_server.uri())
293            .build()
294            .unwrap();
295
296        // We’ll send some dummy request body
297        let request_body = serde_json::json!({ "dummy": true });
298
299        // Call the function under test
300        let result: Result<MockResponse, OpenAIError> =
301            post_json(&client, "test-endpoint", &request_body).await;
302
303        // Verify we got a success
304        assert!(result.is_ok(), "Expected Ok, got Err");
305        let parsed = result.unwrap();
306        assert_eq!(parsed.foo, "hello");
307        assert_eq!(parsed.bar, 42);
308    }
309
310    /// Tests that `post_json` handles non-2xx status codes and returns an `APIError`.
311    #[tokio::test]
312    async fn test_post_json_api_error() {
313        let mock_server = MockServer::start().await;
314
315        // Suppose the server returns a 400 with a JSON error body
316        let error_body = serde_json::json!({
317            "error": {
318                "message": "Invalid request",
319                "type": "invalid_request_error",
320                "param": null,
321                "code": "some_code"
322            }
323        });
324
325        Mock::given(method("POST"))
326            .and(path("/test-endpoint"))
327            .respond_with(ResponseTemplate::new(400).set_body_json(error_body))
328            .mount(&mock_server)
329            .await;
330
331        let client = OpenAIClient::builder()
332            .with_api_key("test-key")
333            .with_base_url(&mock_server.uri())
334            .build()
335            .unwrap();
336
337        let request_body = serde_json::json!({ "dummy": true });
338
339        let result: Result<MockResponse, OpenAIError> =
340            post_json(&client, "test-endpoint", &request_body).await;
341
342        // We should get an APIError with the parsed message
343        match result {
344            Err(APIError { message, .. }) => {
345                assert!(
346                    message.contains("Invalid request"),
347                    "Expected error message about invalid request, got: {}",
348                    message
349                );
350            }
351            other => panic!("Expected APIError, got {:?}", other),
352        }
353    }
354
355    /// Tests that `post_json` surfaces a deserialization error if the server returns malformed JSON.
356    #[tokio::test]
357    async fn test_post_json_deserialize_error() {
358        let mock_server = MockServer::start().await;
359
360        // Return invalid JSON that won't match `MockResponse`
361        let invalid_json = r#"{"foo": 123, "bar": "not_an_integer"}"#;
362
363        Mock::given(method("POST"))
364            .and(path("/test-endpoint"))
365            .respond_with(ResponseTemplate::new(200).set_body_raw(invalid_json, "application/json"))
366            .mount(&mock_server)
367            .await;
368
369        let client = OpenAIClient::builder()
370            .with_api_key("test-key")
371            .with_base_url(&mock_server.uri())
372            .build()
373            .unwrap();
374
375        let request_body = serde_json::json!({ "dummy": true });
376
377        let result: Result<MockResponse, OpenAIError> =
378            post_json(&client, "test-endpoint", &request_body).await;
379
380        // We expect a DeserializeError
381        assert!(matches!(result, Err(OpenAIError::DeserializeError(_))));
382    }
383
384    /// Tests that `get_json` properly sends a GET request and parses a successful JSON response.
385    #[tokio::test]
386    async fn test_get_json_success() {
387        let mock_server = MockServer::start().await;
388
389        let mock_data = serde_json::json!({ "foo": "abc", "bar": 99 });
390
391        // Mock a GET response
392        Mock::given(method("GET"))
393            .and(path("/test-get"))
394            .respond_with(ResponseTemplate::new(200).set_body_json(mock_data))
395            .mount(&mock_server)
396            .await;
397
398        let client = OpenAIClient::builder()
399            .with_api_key("test-key")
400            .with_base_url(&mock_server.uri())
401            .build()
402            .unwrap();
403
404        // Call the function under test
405        let result: Result<MockResponse, OpenAIError> = get_json(&client, "test-get").await;
406
407        // Check the result
408        assert!(result.is_ok());
409        let parsed = result.unwrap();
410        assert_eq!(parsed.foo, "abc");
411        assert_eq!(parsed.bar, 99);
412    }
413
414    /// Tests that `get_json` handles a non-successful status code with an error body.
415    #[tokio::test]
416    async fn test_get_json_api_error() {
417        let mock_server = MockServer::start().await;
418
419        let error_body = serde_json::json!({
420            "error": {
421                "message": "Resource not found",
422                "type": "not_found",
423                "code": "missing_resource"
424            }
425        });
426
427        Mock::given(method("GET"))
428            .and(path("/test-get"))
429            .respond_with(ResponseTemplate::new(404).set_body_json(error_body))
430            .mount(&mock_server)
431            .await;
432
433        let client = OpenAIClient::builder()
434            .with_api_key("test-key")
435            .with_base_url(&mock_server.uri())
436            .build()
437            .unwrap();
438
439        let result: Result<MockResponse, OpenAIError> = get_json(&client, "test-get").await;
440
441        match result {
442            Err(APIError { message, .. }) => {
443                assert!(message.contains("Resource not found"));
444            }
445            other => panic!("Expected APIError, got {:?}", other),
446        }
447    }
448
449    /* ---------------------------------------------------------------- *\
450     *  post_json_stream – successful Server-Sent Events (SSE) parsing *
451    \* ---------------------------------------------------------------- */
452    #[tokio::test]
453    async fn test_post_json_stream_success() {
454        let mock_server = MockServer::start().await;
455
456        // Simulated event-stream (empty line, "data:" prefix, [DONE] marker)
457        let body = r#"data: {"foo":"first","bar":1}
458data: {"foo":"second","bar":2}
459data: [DONE]
460"#;
461
462        Mock::given(method("POST"))
463            .and(path("/stream-endpoint"))
464            .respond_with(ResponseTemplate::new(200).set_body_raw(body, "text/event-stream"))
465            .mount(&mock_server)
466            .await;
467
468        let client = OpenAIClient::builder()
469            .with_api_key("test-key")
470            .with_base_url(&mock_server.uri())
471            .build()
472            .unwrap();
473
474        // Any serialisable body works for this test.
475        let req_body = serde_json::json!({ "stream": true });
476
477        let stream = post_json_stream::<serde_json::Value, MockResponse>(
478            &client,
479            "stream-endpoint",
480            &req_body,
481        )
482        .await
483        .expect("stream should start OK");
484
485        // Collect all chunks, unwrapping individual `Result`s.
486        let items: Vec<MockResponse> = stream
487            .map(|r| r.expect("chunk should be Ok"))
488            .collect()
489            .await;
490
491        assert_eq!(items.len(), 2);
492        assert_eq!(items[0].foo, "first");
493        assert_eq!(items[0].bar, 1);
494        assert_eq!(items[1].foo, "second");
495        assert_eq!(items[1].bar, 2);
496    }
497
498    /* -------------------------------------------------------------- *\
499     * post_json_stream – HTTP error returns an OpenAI APIError     *
500    \* -------------------------------------------------------------- */
501    #[tokio::test]
502    async fn test_post_json_stream_api_error() {
503        let mock_server = MockServer::start().await;
504
505        let error_body = serde_json::json!({
506            "error": {
507                "message": "Streaming not allowed",
508                "type": "invalid_request_error",
509                "code": "not_allowed"
510            }
511        });
512
513        Mock::given(method("POST"))
514            .and(path("/stream-endpoint"))
515            .respond_with(ResponseTemplate::new(403).set_body_json(error_body))
516            .mount(&mock_server)
517            .await;
518
519        let client = OpenAIClient::builder()
520            .with_api_key("test-key")
521            .with_base_url(&mock_server.uri())
522            .build()
523            .unwrap();
524
525        let req_body = serde_json::json!({ "stream": true });
526
527        let result = post_json_stream::<serde_json::Value, MockResponse>(
528            &client,
529            "stream-endpoint",
530            &req_body,
531        )
532        .await;
533
534        match result {
535            Err(OpenAIError::APIError { message, .. }) => {
536                assert!(message.contains("Streaming not allowed"));
537            }
538            Ok(_) => panic!("Expected APIError, got Ok(_)"),
539            Err(other) => panic!("Expected APIError, got {:?}", other.to_string()),
540        }
541    }
542
543    /* ----------------------------------------------------------------- *\
544     * parse_error_response – fallback branch with plain-text body   *
545    \* ----------------------------------------------------------------- */
546
547    #[tokio::test]
548    async fn test_post_json_plain_text_error_fallback() {
549        let mock_server = MockServer::start().await;
550
551        Mock::given(method("POST"))
552            .and(path("/plain-error"))
553            .respond_with(
554                ResponseTemplate::new(503).set_body_raw("Service unavailable", "text/plain"),
555            )
556            .mount(&mock_server)
557            .await;
558
559        let client = OpenAIClient::builder()
560            .with_api_key("test-key")
561            .with_base_url(&mock_server.uri())
562            .build()
563            .unwrap();
564
565        let req_body = serde_json::json!({});
566
567        let result: Result<MockResponse, OpenAIError> =
568            post_json(&client, "plain-error", &req_body).await;
569
570        match result {
571            Err(OpenAIError::APIError { message, .. }) => {
572                assert!(
573                    message.contains("Service unavailable"),
574                    "Unexpected message: {message}"
575                );
576            }
577            other => panic!("Expected generic APIError, got {:?}", other),
578        }
579    }
580}