rs_openai/
client.rs

1use crate::apis::{
2    audio, chat, completions, edits, embeddings, engines, files, fine_tunes, images, models,
3    moderations,
4};
5use crate::shared::response_wrapper::{ApiErrorResponse, OpenAIError, OpenAIResponse};
6use futures::{stream::StreamExt, Stream};
7use reqwest::{header::HeaderMap, multipart::Form, Client, Method, RequestBuilder};
8use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
9use serde::{de::DeserializeOwned, Serialize};
10use std::fs::File;
11use std::io::{self};
12use std::{fmt::Debug, pin::Pin};
13
14// Default v1 API base url
15pub const API_BASE: &str = "https://api.openai.com/v1";
16
17/// Name for organization header
18pub const ORGANIZATION_HEADER: &str = "OpenAI-Organization";
19
20pub struct OpenAI {
21    pub api_key: String,
22    pub org_id: Option<String>,
23}
24
25impl OpenAI {
26    pub fn new(&self) -> Self {
27        Self {
28            api_key: self.api_key.to_owned(),
29            org_id: self.org_id.to_owned(),
30        }
31    }
32
33    fn headers(&self) -> HeaderMap {
34        let mut headers = HeaderMap::new();
35
36        if let Some(org_id) = &self.org_id {
37            headers.insert(ORGANIZATION_HEADER, org_id.parse().unwrap());
38        }
39
40        headers
41    }
42
43    fn openai_request<F>(&self, method: Method, route: &str, builder: F) -> RequestBuilder
44    where
45        F: FnOnce(RequestBuilder) -> RequestBuilder,
46    {
47        let client = Client::new();
48
49        let mut request = client
50            .request(method, API_BASE.to_string() + route)
51            .headers(self.headers())
52            .bearer_auth(&self.api_key);
53
54        request = builder(request);
55        request
56    }
57
58    async fn resolve_response<T>(request: RequestBuilder) -> OpenAIResponse<T>
59    where
60        T: DeserializeOwned + Debug,
61    {
62        let response = request.send().await?;
63        let status = response.status();
64        let bytes = response.bytes().await?;
65
66        if !status.is_success() {
67            let api_error: ApiErrorResponse =
68                serde_json::from_slice(bytes.as_ref()).map_err(OpenAIError::JSONDeserialize)?;
69
70            return Err(OpenAIError::ApiError(api_error));
71        }
72
73        let data: T =
74            serde_json::from_slice(bytes.as_ref()).map_err(OpenAIError::JSONDeserialize)?;
75
76        Ok(data)
77    }
78
79    async fn resolve_text_response(request: RequestBuilder) -> OpenAIResponse<String> {
80        let response = request.send().await?;
81        let status = response.status();
82        let text = response.text().await?;
83
84        if !status.is_success() {
85            let api_error: ApiErrorResponse =
86                serde_json::from_slice(text.as_ref()).map_err(OpenAIError::JSONDeserialize)?;
87
88            return Err(OpenAIError::ApiError(api_error));
89        }
90
91        Ok(text)
92    }
93
94    async fn resolve_file_response(request: RequestBuilder, filename: &str) -> OpenAIResponse<()> {
95        let response = request.send().await?;
96        let status = response.status();
97        let text = response.text().await?;
98
99        if !status.is_success() {
100            let api_error: ApiErrorResponse =
101                serde_json::from_slice(text.as_ref()).map_err(OpenAIError::JSONDeserialize)?;
102
103            return Err(OpenAIError::ApiError(api_error));
104        }
105
106        let mut file = File::create(filename).expect("failed to create file");
107        io::copy(&mut text.as_bytes(), &mut file).expect("failed to copy content");
108
109        Ok(())
110    }
111
112    pub(crate) async fn get<T, F>(&self, route: &str, query: &F) -> OpenAIResponse<T>
113    where
114        T: DeserializeOwned + Debug,
115        F: Serialize,
116    {
117        let request = self.openai_request(Method::GET, route, |request| request.query(query));
118        Self::resolve_response(request).await
119    }
120
121    pub(crate) async fn get_stream<T, F>(
122        &self,
123        route: &str,
124        query: &F,
125    ) -> Pin<Box<dyn Stream<Item = OpenAIResponse<T>> + Send>>
126    where
127        T: DeserializeOwned + Debug + Send + 'static,
128        F: Serialize,
129    {
130        let event_source = self
131            .openai_request(Method::GET, route, |request| request.query(query))
132            .eventsource()
133            .unwrap();
134        Self::stream_sse(event_source).await
135    }
136
137    pub(crate) async fn post<T, F>(&self, route: &str, json: &F) -> OpenAIResponse<T>
138    where
139        T: DeserializeOwned + Debug,
140        F: Serialize,
141    {
142        let request = self.openai_request(Method::POST, route, |request| request.json(json));
143        Self::resolve_response(request).await
144    }
145
146    pub(crate) async fn post_form<T>(&self, route: &str, form_data: Form) -> OpenAIResponse<T>
147    where
148        T: DeserializeOwned + Debug,
149    {
150        let request =
151            self.openai_request(Method::POST, route, |request| request.multipart(form_data));
152        Self::resolve_response(request).await
153    }
154
155    pub(crate) async fn post_form_with_text_response(
156        &self,
157        route: &str,
158        form_data: Form,
159    ) -> OpenAIResponse<String> {
160        let request =
161            self.openai_request(Method::POST, route, |request| request.multipart(form_data));
162        Self::resolve_text_response(request).await
163    }
164
165    pub(crate) async fn post_with_file_response<T>(
166        &self,
167        route: &str,
168        json: &T,
169        filename: &str,
170    ) -> OpenAIResponse<()>
171    where
172        T: Serialize,
173    {
174        let request = self.openai_request(Method::POST, route, |request| request.json(json));
175        Self::resolve_file_response(request, filename).await
176    }
177
178    pub(crate) async fn post_stream<T, F>(
179        &self,
180        route: &str,
181        json: &F,
182    ) -> Pin<Box<dyn Stream<Item = OpenAIResponse<T>> + Send>>
183    where
184        T: DeserializeOwned + Debug + Send + 'static,
185        F: Serialize,
186    {
187        let event_source = self
188            .openai_request(Method::POST, route, |request| request.json(json))
189            .eventsource()
190            .unwrap();
191        OpenAI::stream_sse(event_source).await
192    }
193
194    pub(crate) async fn delete<T, F>(&self, route: &str, json: &F) -> OpenAIResponse<T>
195    where
196        T: DeserializeOwned + Debug,
197        F: Serialize,
198    {
199        let request = self.openai_request(Method::DELETE, route, |request| request.json(json));
200        Self::resolve_response(request).await
201    }
202
203    async fn stream_sse<T>(
204        mut event_source: EventSource,
205    ) -> Pin<Box<dyn Stream<Item = OpenAIResponse<T>> + Send>>
206    where
207        T: DeserializeOwned + Debug + Send + 'static,
208    {
209        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<OpenAIResponse<T>>();
210
211        tokio::spawn(async move {
212            while let Some(evt) = event_source.next().await {
213                match evt {
214                    Err(e) => {
215                        if tx
216                            .send(Err(OpenAIError::StreamError(e.to_string())))
217                            .is_err()
218                        {
219                            break;
220                        }
221                    }
222                    Ok(evt) => match evt {
223                        Event::Message(message) => {
224                            if message.data == "[DONE]" {
225                                break;
226                            }
227
228                            let response = match serde_json::from_str::<T>(&message.data) {
229                                Err(e) => Err(OpenAIError::JSONDeserialize(e)),
230                                Ok(output) => Ok(output),
231                            };
232
233                            if tx.send(response).is_err() {
234                                break;
235                            }
236                        }
237                        Event::Open => continue,
238                    },
239                }
240            }
241
242            event_source.close();
243        });
244
245        Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
246    }
247
248    pub fn audio(&self) -> audio::Audio {
249        audio::Audio::new(self)
250    }
251
252    pub fn chat(&self) -> chat::Chat {
253        chat::Chat::new(self)
254    }
255
256    pub fn completions(&self) -> completions::Completions {
257        completions::Completions::new(self)
258    }
259
260    pub fn edits(&self) -> edits::Edits {
261        edits::Edits::new(self)
262    }
263
264    pub fn embeddings(&self) -> embeddings::Embeddings {
265        embeddings::Embeddings::new(self)
266    }
267
268    pub fn engines(&self) -> engines::Engines {
269        engines::Engines::new(self)
270    }
271
272    pub fn files(&self) -> files::Files {
273        files::Files::new(self)
274    }
275
276    pub fn fine_tunes(&self) -> fine_tunes::FineTunes {
277        fine_tunes::FineTunes::new(self)
278    }
279
280    pub fn images(&self) -> images::Images {
281        images::Images::new(self)
282    }
283
284    pub fn models(&self) -> models::Models {
285        models::Models::new(self)
286    }
287
288    pub fn moderations(&self) -> moderations::Moderations {
289        moderations::Moderations::new(self)
290    }
291}