lib_client_openai/
client.rs1use crate::auth::AuthStrategy;
4use crate::error::{OpenAiError, Result};
5use crate::types::{
6 CreateChatCompletionRequest, CreateChatCompletionResponse, ErrorResponse, Model, ModelList,
7};
8use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
9use std::sync::Arc;
10
11const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
12
13pub struct Client {
15 http: reqwest::Client,
16 auth: Arc<dyn AuthStrategy>,
17 base_url: String,
18}
19
20impl Client {
21 pub fn builder() -> ClientBuilder<()> {
23 ClientBuilder::new()
24 }
25
26 pub async fn create_chat_completion(
28 &self,
29 request: CreateChatCompletionRequest,
30 ) -> Result<CreateChatCompletionResponse> {
31 let url = format!("{}/chat/completions", self.base_url);
32 self.post(&url, &request).await
33 }
34
35 pub async fn list_models(&self) -> Result<ModelList> {
37 let url = format!("{}/models", self.base_url);
38 self.get(&url).await
39 }
40
41 pub async fn get_model(&self, model_id: &str) -> Result<Model> {
43 let url = format!("{}/models/{}", self.base_url, model_id);
44 self.get(&url).await
45 }
46
47 async fn get<T>(&self, url: &str) -> Result<T>
49 where
50 T: serde::de::DeserializeOwned,
51 {
52 let mut headers = HeaderMap::new();
53 self.auth.apply(&mut headers).await?;
54
55 tracing::debug!(url = %url, "GET request");
56
57 let response = self.http.get(url).headers(headers).send().await?;
58
59 self.handle_response(response).await
60 }
61
62 async fn post<T, B>(&self, url: &str, body: &B) -> Result<T>
64 where
65 T: serde::de::DeserializeOwned,
66 B: serde::Serialize,
67 {
68 let mut headers = HeaderMap::new();
69 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
70 self.auth.apply(&mut headers).await?;
71
72 tracing::debug!(url = %url, "POST request");
73
74 let response = self
75 .http
76 .post(url)
77 .headers(headers)
78 .json(body)
79 .send()
80 .await?;
81
82 self.handle_response(response).await
83 }
84
85 async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
87 where
88 T: serde::de::DeserializeOwned,
89 {
90 let status = response.status();
91 let status_code = status.as_u16();
92
93 let retry_after = response
95 .headers()
96 .get("retry-after")
97 .and_then(|v| v.to_str().ok())
98 .and_then(|s| s.parse().ok());
99
100 if status.is_success() {
101 let body = response.text().await?;
102 tracing::debug!(status = %status_code, "Response received");
103 serde_json::from_str(&body).map_err(OpenAiError::from)
104 } else {
105 let body = response.text().await?;
106 tracing::warn!(status = %status_code, body = %body, "API error");
107
108 if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(&body) {
110 let message = error_response.error.message;
111 let code = error_response.error.code.as_deref();
112
113 return Err(match status_code {
114 401 => OpenAiError::Unauthorized,
115 403 => OpenAiError::Forbidden(message),
116 404 => OpenAiError::NotFound(message),
117 429 => OpenAiError::RateLimited {
118 retry_after: retry_after.unwrap_or(60),
119 },
120 500..=599 => OpenAiError::ServerError(message),
121 _ => match code {
122 Some("context_length_exceeded") => {
123 OpenAiError::ContextLengthExceeded(message)
124 }
125 Some("invalid_request_error") => OpenAiError::InvalidRequest(message),
126 _ => OpenAiError::Api {
127 status: status_code,
128 message,
129 },
130 },
131 });
132 }
133
134 Err(OpenAiError::Api {
135 status: status_code,
136 message: body,
137 })
138 }
139 }
140}
141
142pub struct ClientBuilder<A> {
144 auth: A,
145 base_url: String,
146}
147
148impl ClientBuilder<()> {
149 pub fn new() -> Self {
151 Self {
152 auth: (),
153 base_url: DEFAULT_BASE_URL.to_string(),
154 }
155 }
156
157 pub fn auth<S: AuthStrategy + 'static>(self, strategy: S) -> ClientBuilder<S> {
159 ClientBuilder {
160 auth: strategy,
161 base_url: self.base_url,
162 }
163 }
164}
165
166impl Default for ClientBuilder<()> {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172impl<A: AuthStrategy + 'static> ClientBuilder<A> {
173 pub fn base_url(mut self, url: impl Into<String>) -> Self {
175 self.base_url = url.into();
176 self
177 }
178
179 pub fn build(self) -> Client {
181 Client {
182 http: reqwest::Client::new(),
183 auth: Arc::new(self.auth),
184 base_url: self.base_url,
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::auth::ApiKeyAuth;
193 use crate::types::Message;
194
195 #[test]
196 fn test_builder() {
197 let client = Client::builder()
198 .auth(ApiKeyAuth::new("test-key"))
199 .base_url("https://custom.api.com")
200 .build();
201
202 assert_eq!(client.base_url, "https://custom.api.com");
203 }
204
205 #[test]
206 fn test_create_chat_completion_request() {
207 let request = CreateChatCompletionRequest::new("gpt-4o", vec![Message::user("Hello")])
208 .with_max_tokens(1024)
209 .with_temperature(0.7);
210
211 assert_eq!(request.model, "gpt-4o");
212 assert_eq!(request.max_tokens, Some(1024));
213 assert_eq!(request.temperature, Some(0.7));
214 }
215}