1use crate::v1::helpers::check_status_code;
2use crate::v1::{error::APIError, resources::shared::Headers};
3use bytes::Bytes;
4#[cfg(feature = "stream")]
5use futures::{stream::StreamExt, Stream};
6use reqwest::{multipart::Form, Method, RequestBuilder};
7#[cfg(feature = "stream")]
8use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
9#[cfg(feature = "stream")]
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use std::collections::HashMap;
13#[cfg(feature = "stream")]
14use std::pin::Pin;
15
16use super::resources::shared::ResponseWrapper;
17
18const OPENAI_API_V1_ENDPOINT: &str = "https://api.openai.com/v1";
19const MIME_TYPE_APPLICATION_JSON: &str = "application/json";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23 pub http_client: reqwest::Client,
24 pub base_url: String,
25 pub api_key: String,
26 pub headers: Option<HashMap<String, String>>,
27 pub organization: Option<String>,
28 pub project: Option<String>,
29}
30
31impl Client {
32 pub fn new(api_key: String) -> Self {
34 Self {
35 api_key,
36 ..Default::default()
37 }
38 }
39
40 #[deprecated(since = "0.7.0", note = "Please use `set_base_url` instead")]
42 pub fn new_with_base(base_url: &str, api_key: String) -> Self {
43 Self {
44 base_url: base_url.to_string(),
45 api_key,
46 ..Default::default()
47 }
48 }
49
50 pub fn new_from_env() -> Self {
52 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set");
53
54 Self {
55 api_key,
56 ..Default::default()
57 }
58 }
59
60 pub fn set_base_url(&mut self, base_url: &str) -> &mut Self {
62 self.base_url = base_url.to_string();
63
64 self
65 }
66
67 pub fn set_organization(&mut self, organization: &str) -> &mut Self {
69 self.organization = Some(organization.to_string());
70
71 self
72 }
73
74 pub fn set_project(&mut self, project: &str) -> &mut Self {
76 self.project = Some(project.to_string());
77
78 self
79 }
80
81 pub fn add_header(&mut self, key: &str, value: &str) -> &mut Self {
83 self.headers
84 .get_or_insert_with(HashMap::new)
85 .insert(key.to_string(), value.to_string());
86
87 self
88 }
89
90 fn build_request(
91 &self,
92 method: reqwest::Method,
93 path: &str,
94 content_type: Option<&str>,
95 ) -> RequestBuilder {
96 let url = format!("{}{}", &self.base_url, path);
97
98 let mut request = self
99 .http_client
100 .request(method, url)
101 .bearer_auth(&self.api_key);
102
103 if let Some(content_type) = content_type {
104 request = request.header(reqwest::header::CONTENT_TYPE, content_type);
105 }
106
107 if let Some(headers) = &self.headers {
108 for (key, value) in headers {
109 request = request.header(key, value);
110 }
111 }
112
113 if let Some(organization) = &self.organization {
114 request = request.header("OpenAI-Organization", organization);
115 }
116
117 if let Some(project) = &self.project {
118 request = request.header("OpenAI-Project", project);
119 }
120
121 request
122 }
123
124 pub(crate) async fn get(&self, path: &str) -> Result<String, APIError> {
125 let result = self
126 .build_request(Method::GET, path, Some(MIME_TYPE_APPLICATION_JSON))
127 .send()
128 .await;
129
130 let response = match check_status_code(result).await {
131 Ok(response) => response,
132 Err(error) => return Err(error),
133 };
134
135 let response_text = response
136 .text()
137 .await
138 .map_err(|error| APIError::ParseError(error.to_string()))?;
139
140 #[cfg(feature = "log")]
141 log::trace!("{response_text}");
142
143 Ok(response_text)
144 }
145
146 pub(crate) async fn get_with_query<Q>(&self, path: &str, query: &Q) -> Result<String, APIError>
147 where
148 Q: Serialize,
149 {
150 let encoded_query = serde_html_form::to_string(query).unwrap_or_else(|_| "".to_string());
151
152 let path = format!("{path}?{encoded_query}");
153
154 let result = self
155 .build_request(Method::GET, &path, Some(MIME_TYPE_APPLICATION_JSON))
156 .send()
157 .await;
158
159 let response = match check_status_code(result).await {
160 Ok(response) => response,
161 Err(error) => return Err(error),
162 };
163
164 let response_text = response
165 .text()
166 .await
167 .map_err(|error| APIError::ParseError(error.to_string()))?;
168
169 #[cfg(feature = "log")]
170 log::trace!("{response_text}");
171
172 Ok(response_text)
173 }
174
175 pub(crate) async fn post<T: Serialize>(
176 &self,
177 path: &str,
178 parameters: &T,
179 query_params: impl Into<Option<&HashMap<String, String>>>,
180 ) -> Result<ResponseWrapper<String>, APIError> {
181 let result = self
182 .build_request(Method::POST, path, Some(MIME_TYPE_APPLICATION_JSON))
183 .query(&query_params.into())
184 .json(¶meters)
185 .send()
186 .await;
187
188 let response = match check_status_code(result).await {
189 Ok(response) => response,
190 Err(error) => return Err(error),
191 };
192
193 let header_map = response.headers().clone();
194
195 let response_text = response
196 .text()
197 .await
198 .map_err(|error| APIError::ParseError(error.to_string()))?;
199 let response_headers: Headers = header_map.into();
200
201 #[cfg(feature = "log")]
202 log::trace!("{response_text}");
203
204 Ok(ResponseWrapper {
205 data: response_text.to_string(),
206 headers: response_headers,
207 })
208 }
209
210 pub(crate) async fn delete(&self, path: &str) -> Result<String, APIError> {
211 let result = self
212 .build_request(Method::DELETE, path, Some(MIME_TYPE_APPLICATION_JSON))
213 .send()
214 .await;
215
216 let response = match check_status_code(result).await {
217 Ok(response) => response,
218 Err(error) => return Err(error),
219 };
220
221 response
222 .text()
223 .await
224 .map_err(|error| APIError::ParseError(error.to_string()))
225 }
226
227 pub(crate) async fn post_with_form(&self, path: &str, form: Form) -> Result<String, APIError> {
228 let result = self
229 .build_request(Method::POST, path, None)
230 .multipart(form)
231 .send()
232 .await;
233
234 let response = match check_status_code(result).await {
235 Ok(response) => response,
236 Err(error) => return Err(error),
237 };
238
239 response
240 .text()
241 .await
242 .map_err(|error| APIError::ParseError(error.to_string()))
243 }
244
245 pub(crate) async fn post_raw<T: Serialize>(
246 &self,
247 path: &str,
248 parameters: &T,
249 ) -> Result<Bytes, APIError> {
250 let result = self
251 .build_request(Method::POST, path, Some(MIME_TYPE_APPLICATION_JSON))
252 .json(¶meters)
253 .send()
254 .await;
255
256 let response = match check_status_code(result).await {
257 Ok(response) => response,
258 Err(error) => return Err(error),
259 };
260
261 response
262 .bytes()
263 .await
264 .map_err(|error| APIError::ParseError(error.to_string()))
265 }
266
267 #[cfg(feature = "stream")]
268 pub(crate) async fn post_stream<I, O>(
269 &self,
270 path: &str,
271 parameters: &I,
272 query_params: impl Into<Option<&HashMap<String, String>>>,
273 ) -> Pin<Box<dyn Stream<Item = Result<O, APIError>> + Send>>
274 where
275 I: Serialize,
276 O: DeserializeOwned + std::marker::Send + 'static,
277 {
278 let event_source = self
279 .build_request(Method::POST, path, Some(MIME_TYPE_APPLICATION_JSON))
280 .json(¶meters)
281 .query(&query_params.into())
282 .eventsource()
283 .unwrap();
284
285 Client::process_stream::<O>(event_source).await
286 }
287
288 #[cfg(feature = "stream")]
289 pub(crate) async fn post_stream_raw<I>(
290 &self,
291 path: &str,
292 parameters: &I,
293 ) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes, APIError>> + Send>>, APIError>
294 where
295 I: Serialize,
296 {
297 let stream = self
298 .build_request(Method::POST, path, Some(MIME_TYPE_APPLICATION_JSON))
299 .json(¶meters)
300 .send()
301 .await
302 .unwrap()
303 .bytes_stream()
304 .map(|item| item.map_err(|error| APIError::StreamError(error.to_string())));
305
306 Ok(Box::pin(stream)
307 as Pin<
308 Box<dyn Stream<Item = Result<Bytes, APIError>> + Send>,
309 >)
310 }
311
312 #[cfg(feature = "stream")]
313 pub(crate) async fn process_stream<O>(
314 mut event_soure: EventSource,
315 ) -> Pin<Box<dyn Stream<Item = Result<O, APIError>> + Send>>
316 where
317 O: DeserializeOwned + Send + 'static,
318 {
319 use super::error::InvalidRequestError;
320
321 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
322
323 tokio::spawn(async move {
324 while let Some(event_result) = event_soure.next().await {
325 match event_result {
326 Ok(event) => match event {
327 Event::Open => continue,
328 Event::Message(message) => {
329 if message.data == "[DONE]" {
330 break;
331 }
332
333 let response = match serde_json::from_str::<O>(&message.data) {
334 Ok(result) => Ok(result),
335 Err(error) => {
336 match serde_json::from_str::<InvalidRequestError>(&message.data)
337 {
338 Ok(invalid_request_error) => Err(APIError::StreamError(
339 invalid_request_error.to_string(),
340 )),
341 Err(_) => Err(APIError::StreamError(format!(
342 "{} {}",
343 error, message.data
344 ))),
345 }
346 }
347 };
348
349 if let Err(_error) = tx.send(response) {
350 break;
351 }
352 }
353 },
354 Err(error) => {
355 if let Err(_error) = tx.send(Err(APIError::StreamError(error.to_string())))
356 {
357 break;
358 }
359 }
360 }
361 }
362
363 event_soure.close();
364 });
365
366 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
367 }
368}
369
370impl Default for Client {
371 fn default() -> Self {
372 Client {
373 http_client: reqwest::Client::new(),
374 base_url: OPENAI_API_V1_ENDPOINT.to_string(),
375 api_key: "".to_string(),
376 headers: None,
377 organization: None,
378 project: None,
379 }
380 }
381}