openai_dive/v1/
api.rs

1use crate::v1::helpers::check_status_code;
2use crate::v1::{error::APIError, resources::shared::Headers};
3use bytes::Bytes;
4#[cfg(feature = "stream")]
5use futures::{stream::StreamExt, Stream};
6use reqwest::{multipart::Form, Method, RequestBuilder};
7#[cfg(feature = "stream")]
8use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
9#[cfg(feature = "stream")]
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use std::collections::HashMap;
13#[cfg(feature = "stream")]
14use std::pin::Pin;
15
16use super::resources::shared::ResponseWrapper;
17
18const OPENAI_API_V1_ENDPOINT: &str = "https://api.openai.com/v1";
19const MIME_TYPE_APPLICATION_JSON: &str = "application/json";
20
21#[derive(Clone, Debug)]
22pub struct Client {
23    pub http_client: reqwest::Client,
24    pub base_url: String,
25    pub api_key: String,
26    pub headers: Option<HashMap<String, String>>,
27    pub organization: Option<String>,
28    pub project: Option<String>,
29}
30
31impl Client {
32    /// Create a new instance of the OpenAI client and set the API key.
33    pub fn new(api_key: String) -> Self {
34        Self {
35            api_key,
36            ..Default::default()
37        }
38    }
39
40    /// Create a new instance of the OpenAI client with a custom base URL and set the API key.
41    #[deprecated(since = "0.7.0", note = "Please use `set_base_url` instead")]
42    pub fn new_with_base(base_url: &str, api_key: String) -> Self {
43        Self {
44            base_url: base_url.to_string(),
45            api_key,
46            ..Default::default()
47        }
48    }
49
50    /// Create a new instance of the OpenAI client and set the API key from the environment variable `OPENAI_API_KEY`.
51    pub fn new_from_env() -> Self {
52        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set");
53
54        Self {
55            api_key,
56            ..Default::default()
57        }
58    }
59
60    /// Set the base URL for the OpenAI client.
61    pub fn set_base_url(&mut self, base_url: &str) -> &mut Self {
62        self.base_url = base_url.to_string();
63
64        self
65    }
66
67    /// Set the organization header for the OpenAI client.
68    pub fn set_organization(&mut self, organization: &str) -> &mut Self {
69        self.organization = Some(organization.to_string());
70
71        self
72    }
73
74    /// Set the project header for the OpenAI client.
75    pub fn set_project(&mut self, project: &str) -> &mut Self {
76        self.project = Some(project.to_string());
77
78        self
79    }
80
81    /// Add a custom header to the OpenAI client.
82    pub fn add_header(&mut self, key: &str, value: &str) -> &mut Self {
83        self.headers
84            .get_or_insert_with(HashMap::new)
85            .insert(key.to_string(), value.to_string());
86
87        self
88    }
89
90    fn build_request(
91        &self,
92        method: reqwest::Method,
93        path: &str,
94        content_type: Option<&str>,
95    ) -> RequestBuilder {
96        let url = format!("{}{}", &self.base_url, path);
97
98        let mut request = self
99            .http_client
100            .request(method, url)
101            .bearer_auth(&self.api_key);
102
103        if let Some(content_type) = content_type {
104            request = request.header(reqwest::header::CONTENT_TYPE, content_type);
105        }
106
107        if let Some(headers) = &self.headers {
108            for (key, value) in headers {
109                request = request.header(key, value);
110            }
111        }
112
113        if let Some(organization) = &self.organization {
114            request = request.header("OpenAI-Organization", organization);
115        }
116
117        if let Some(project) = &self.project {
118            request = request.header("OpenAI-Project", project);
119        }
120
121        request
122    }
123
124    pub(crate) async fn get(&self, path: &str) -> Result<String, APIError> {
125        let result = self
126            .build_request(Method::GET, path, Some(MIME_TYPE_APPLICATION_JSON))
127            .send()
128            .await;
129
130        let response = match check_status_code(result).await {
131            Ok(response) => response,
132            Err(error) => return Err(error),
133        };
134
135        let response_text = response
136            .text()
137            .await
138            .map_err(|error| APIError::ParseError(error.to_string()))?;
139
140        #[cfg(feature = "log")]
141        log::trace!("{response_text}");
142
143        Ok(response_text)
144    }
145
146    pub(crate) async fn get_with_query<Q>(&self, path: &str, query: &Q) -> Result<String, APIError>
147    where
148        Q: Serialize,
149    {
150        let encoded_query = serde_html_form::to_string(query).unwrap_or_else(|_| "".to_string());
151
152        let path = format!("{path}?{encoded_query}");
153
154        let result = self
155            .build_request(Method::GET, &path, Some(MIME_TYPE_APPLICATION_JSON))
156            .send()
157            .await;
158
159        let response = match check_status_code(result).await {
160            Ok(response) => response,
161            Err(error) => return Err(error),
162        };
163
164        let response_text = response
165            .text()
166            .await
167            .map_err(|error| APIError::ParseError(error.to_string()))?;
168
169        #[cfg(feature = "log")]
170        log::trace!("{response_text}");
171
172        Ok(response_text)
173    }
174
175    pub(crate) async fn post<T: Serialize>(
176        &self,
177        path: &str,
178        parameters: &T,
179        query_params: impl Into<Option<&HashMap<String, String>>>,
180    ) -> Result<ResponseWrapper<String>, APIError> {
181        let result = self
182            .build_request(Method::POST, path, Some(MIME_TYPE_APPLICATION_JSON))
183            .query(&query_params.into())
184            .json(&parameters)
185            .send()
186            .await;
187
188        let response = match check_status_code(result).await {
189            Ok(response) => response,
190            Err(error) => return Err(error),
191        };
192
193        let header_map = response.headers().clone();
194
195        let response_text = response
196            .text()
197            .await
198            .map_err(|error| APIError::ParseError(error.to_string()))?;
199        let response_headers: Headers = header_map.into();
200
201        #[cfg(feature = "log")]
202        log::trace!("{response_text}");
203
204        Ok(ResponseWrapper {
205            data: response_text.to_string(),
206            headers: response_headers,
207        })
208    }
209
210    pub(crate) async fn delete(&self, path: &str) -> Result<String, APIError> {
211        let result = self
212            .build_request(Method::DELETE, path, Some(MIME_TYPE_APPLICATION_JSON))
213            .send()
214            .await;
215
216        let response = match check_status_code(result).await {
217            Ok(response) => response,
218            Err(error) => return Err(error),
219        };
220
221        response
222            .text()
223            .await
224            .map_err(|error| APIError::ParseError(error.to_string()))
225    }
226
227    pub(crate) async fn post_with_form(&self, path: &str, form: Form) -> Result<String, APIError> {
228        let result = self
229            .build_request(Method::POST, path, None)
230            .multipart(form)
231            .send()
232            .await;
233
234        let response = match check_status_code(result).await {
235            Ok(response) => response,
236            Err(error) => return Err(error),
237        };
238
239        response
240            .text()
241            .await
242            .map_err(|error| APIError::ParseError(error.to_string()))
243    }
244
245    pub(crate) async fn post_raw<T: Serialize>(
246        &self,
247        path: &str,
248        parameters: &T,
249    ) -> Result<Bytes, APIError> {
250        let result = self
251            .build_request(Method::POST, path, Some(MIME_TYPE_APPLICATION_JSON))
252            .json(&parameters)
253            .send()
254            .await;
255
256        let response = match check_status_code(result).await {
257            Ok(response) => response,
258            Err(error) => return Err(error),
259        };
260
261        response
262            .bytes()
263            .await
264            .map_err(|error| APIError::ParseError(error.to_string()))
265    }
266
267    #[cfg(feature = "stream")]
268    pub(crate) async fn post_stream<I, O>(
269        &self,
270        path: &str,
271        parameters: &I,
272        query_params: impl Into<Option<&HashMap<String, String>>>,
273    ) -> Pin<Box<dyn Stream<Item = Result<O, APIError>> + Send>>
274    where
275        I: Serialize,
276        O: DeserializeOwned + std::marker::Send + 'static,
277    {
278        let event_source = self
279            .build_request(Method::POST, path, Some(MIME_TYPE_APPLICATION_JSON))
280            .json(&parameters)
281            .query(&query_params.into())
282            .eventsource()
283            .unwrap();
284
285        Client::process_stream::<O>(event_source).await
286    }
287
288    #[cfg(feature = "stream")]
289    pub(crate) async fn post_stream_raw<I>(
290        &self,
291        path: &str,
292        parameters: &I,
293    ) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes, APIError>> + Send>>, APIError>
294    where
295        I: Serialize,
296    {
297        let stream = self
298            .build_request(Method::POST, path, Some(MIME_TYPE_APPLICATION_JSON))
299            .json(&parameters)
300            .send()
301            .await
302            .unwrap()
303            .bytes_stream()
304            .map(|item| item.map_err(|error| APIError::StreamError(error.to_string())));
305
306        Ok(Box::pin(stream)
307            as Pin<
308                Box<dyn Stream<Item = Result<Bytes, APIError>> + Send>,
309            >)
310    }
311
312    #[cfg(feature = "stream")]
313    pub(crate) async fn process_stream<O>(
314        mut event_soure: EventSource,
315    ) -> Pin<Box<dyn Stream<Item = Result<O, APIError>> + Send>>
316    where
317        O: DeserializeOwned + Send + 'static,
318    {
319        use super::error::InvalidRequestError;
320
321        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
322
323        tokio::spawn(async move {
324            while let Some(event_result) = event_soure.next().await {
325                match event_result {
326                    Ok(event) => match event {
327                        Event::Open => continue,
328                        Event::Message(message) => {
329                            if message.data == "[DONE]" {
330                                break;
331                            }
332
333                            let response = match serde_json::from_str::<O>(&message.data) {
334                                Ok(result) => Ok(result),
335                                Err(error) => {
336                                    match serde_json::from_str::<InvalidRequestError>(&message.data)
337                                    {
338                                        Ok(invalid_request_error) => Err(APIError::StreamError(
339                                            invalid_request_error.to_string(),
340                                        )),
341                                        Err(_) => Err(APIError::StreamError(format!(
342                                            "{} {}",
343                                            error, message.data
344                                        ))),
345                                    }
346                                }
347                            };
348
349                            if let Err(_error) = tx.send(response) {
350                                break;
351                            }
352                        }
353                    },
354                    Err(error) => {
355                        if let Err(_error) = tx.send(Err(APIError::StreamError(error.to_string())))
356                        {
357                            break;
358                        }
359                    }
360                }
361            }
362
363            event_soure.close();
364        });
365
366        Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
367    }
368}
369
370impl Default for Client {
371    fn default() -> Self {
372        Client {
373            http_client: reqwest::Client::new(),
374            base_url: OPENAI_API_V1_ENDPOINT.to_string(),
375            api_key: "".to_string(),
376            headers: None,
377            organization: None,
378            project: None,
379        }
380    }
381}