lib_client_openai/
client.rs

1//! OpenAI API client implementation.
2
3use crate::auth::AuthStrategy;
4use crate::error::{OpenAiError, Result};
5use crate::types::{
6    CreateChatCompletionRequest, CreateChatCompletionResponse, ErrorResponse, Model, ModelList,
7};
8use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
9use std::sync::Arc;
10
11const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
12
13/// OpenAI API client.
14pub struct Client {
15    http: reqwest::Client,
16    auth: Arc<dyn AuthStrategy>,
17    base_url: String,
18}
19
20impl Client {
21    /// Create a new client builder.
22    pub fn builder() -> ClientBuilder<()> {
23        ClientBuilder::new()
24    }
25
26    /// Create a chat completion.
27    pub async fn create_chat_completion(
28        &self,
29        request: CreateChatCompletionRequest,
30    ) -> Result<CreateChatCompletionResponse> {
31        let url = format!("{}/chat/completions", self.base_url);
32        self.post(&url, &request).await
33    }
34
35    /// List available models.
36    pub async fn list_models(&self) -> Result<ModelList> {
37        let url = format!("{}/models", self.base_url);
38        self.get(&url).await
39    }
40
41    /// Get a specific model.
42    pub async fn get_model(&self, model_id: &str) -> Result<Model> {
43        let url = format!("{}/models/{}", self.base_url, model_id);
44        self.get(&url).await
45    }
46
47    /// Send a GET request.
48    async fn get<T>(&self, url: &str) -> Result<T>
49    where
50        T: serde::de::DeserializeOwned,
51    {
52        let mut headers = HeaderMap::new();
53        self.auth.apply(&mut headers).await?;
54
55        tracing::debug!(url = %url, "GET request");
56
57        let response = self.http.get(url).headers(headers).send().await?;
58
59        self.handle_response(response).await
60    }
61
62    /// Send a POST request with JSON body.
63    async fn post<T, B>(&self, url: &str, body: &B) -> Result<T>
64    where
65        T: serde::de::DeserializeOwned,
66        B: serde::Serialize,
67    {
68        let mut headers = HeaderMap::new();
69        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
70        self.auth.apply(&mut headers).await?;
71
72        tracing::debug!(url = %url, "POST request");
73
74        let response = self
75            .http
76            .post(url)
77            .headers(headers)
78            .json(body)
79            .send()
80            .await?;
81
82        self.handle_response(response).await
83    }
84
85    /// Handle API response.
86    async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
87    where
88        T: serde::de::DeserializeOwned,
89    {
90        let status = response.status();
91        let status_code = status.as_u16();
92
93        // Extract rate limit headers before consuming response
94        let retry_after = response
95            .headers()
96            .get("retry-after")
97            .and_then(|v| v.to_str().ok())
98            .and_then(|s| s.parse().ok());
99
100        if status.is_success() {
101            let body = response.text().await?;
102            tracing::debug!(status = %status_code, "Response received");
103            serde_json::from_str(&body).map_err(OpenAiError::from)
104        } else {
105            let body = response.text().await?;
106            tracing::warn!(status = %status_code, body = %body, "API error");
107
108            // Try to parse error response
109            if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(&body) {
110                let message = error_response.error.message;
111                let code = error_response.error.code.as_deref();
112
113                return Err(match status_code {
114                    401 => OpenAiError::Unauthorized,
115                    403 => OpenAiError::Forbidden(message),
116                    404 => OpenAiError::NotFound(message),
117                    429 => OpenAiError::RateLimited {
118                        retry_after: retry_after.unwrap_or(60),
119                    },
120                    500..=599 => OpenAiError::ServerError(message),
121                    _ => match code {
122                        Some("context_length_exceeded") => {
123                            OpenAiError::ContextLengthExceeded(message)
124                        }
125                        Some("invalid_request_error") => OpenAiError::InvalidRequest(message),
126                        _ => OpenAiError::Api {
127                            status: status_code,
128                            message,
129                        },
130                    },
131                });
132            }
133
134            Err(OpenAiError::Api {
135                status: status_code,
136                message: body,
137            })
138        }
139    }
140}
141
142/// Client builder.
143pub struct ClientBuilder<A> {
144    auth: A,
145    base_url: String,
146}
147
148impl ClientBuilder<()> {
149    /// Create a new client builder.
150    pub fn new() -> Self {
151        Self {
152            auth: (),
153            base_url: DEFAULT_BASE_URL.to_string(),
154        }
155    }
156
157    /// Set the authentication strategy.
158    pub fn auth<S: AuthStrategy + 'static>(self, strategy: S) -> ClientBuilder<S> {
159        ClientBuilder {
160            auth: strategy,
161            base_url: self.base_url,
162        }
163    }
164}
165
166impl Default for ClientBuilder<()> {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172impl<A: AuthStrategy + 'static> ClientBuilder<A> {
173    /// Set a custom base URL (for Azure OpenAI or proxies).
174    pub fn base_url(mut self, url: impl Into<String>) -> Self {
175        self.base_url = url.into();
176        self
177    }
178
179    /// Build the client.
180    pub fn build(self) -> Client {
181        Client {
182            http: reqwest::Client::new(),
183            auth: Arc::new(self.auth),
184            base_url: self.base_url,
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::auth::ApiKeyAuth;
193    use crate::types::Message;
194
195    #[test]
196    fn test_builder() {
197        let client = Client::builder()
198            .auth(ApiKeyAuth::new("test-key"))
199            .base_url("https://custom.api.com")
200            .build();
201
202        assert_eq!(client.base_url, "https://custom.api.com");
203    }
204
205    #[test]
206    fn test_create_chat_completion_request() {
207        let request = CreateChatCompletionRequest::new("gpt-4o", vec![Message::user("Hello")])
208            .with_max_tokens(1024)
209            .with_temperature(0.7);
210
211        assert_eq!(request.model, "gpt-4o");
212        assert_eq!(request.max_tokens, Some(1024));
213        assert_eq!(request.temperature, Some(0.7));
214    }
215}