1use crate::apis::{
2 audio, chat, completions, edits, embeddings, engines, files, fine_tunes, images, models,
3 moderations,
4};
5use crate::shared::response_wrapper::{ApiErrorResponse, OpenAIError, OpenAIResponse};
6use futures::{stream::StreamExt, Stream};
7use reqwest::{header::HeaderMap, multipart::Form, Client, Method, RequestBuilder};
8use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
9use serde::{de::DeserializeOwned, Serialize};
10use std::fs::File;
11use std::io::{self};
12use std::{fmt::Debug, pin::Pin};
13
14pub const API_BASE: &str = "https://api.openai.com/v1";
16
17pub const ORGANIZATION_HEADER: &str = "OpenAI-Organization";
19
20pub struct OpenAI {
21 pub api_key: String,
22 pub org_id: Option<String>,
23}
24
25impl OpenAI {
26 pub fn new(&self) -> Self {
27 Self {
28 api_key: self.api_key.to_owned(),
29 org_id: self.org_id.to_owned(),
30 }
31 }
32
33 fn headers(&self) -> HeaderMap {
34 let mut headers = HeaderMap::new();
35
36 if let Some(org_id) = &self.org_id {
37 headers.insert(ORGANIZATION_HEADER, org_id.parse().unwrap());
38 }
39
40 headers
41 }
42
43 fn openai_request<F>(&self, method: Method, route: &str, builder: F) -> RequestBuilder
44 where
45 F: FnOnce(RequestBuilder) -> RequestBuilder,
46 {
47 let client = Client::new();
48
49 let mut request = client
50 .request(method, API_BASE.to_string() + route)
51 .headers(self.headers())
52 .bearer_auth(&self.api_key);
53
54 request = builder(request);
55 request
56 }
57
58 async fn resolve_response<T>(request: RequestBuilder) -> OpenAIResponse<T>
59 where
60 T: DeserializeOwned + Debug,
61 {
62 let response = request.send().await?;
63 let status = response.status();
64 let bytes = response.bytes().await?;
65
66 if !status.is_success() {
67 let api_error: ApiErrorResponse =
68 serde_json::from_slice(bytes.as_ref()).map_err(OpenAIError::JSONDeserialize)?;
69
70 return Err(OpenAIError::ApiError(api_error));
71 }
72
73 let data: T =
74 serde_json::from_slice(bytes.as_ref()).map_err(OpenAIError::JSONDeserialize)?;
75
76 Ok(data)
77 }
78
79 async fn resolve_text_response(request: RequestBuilder) -> OpenAIResponse<String> {
80 let response = request.send().await?;
81 let status = response.status();
82 let text = response.text().await?;
83
84 if !status.is_success() {
85 let api_error: ApiErrorResponse =
86 serde_json::from_slice(text.as_ref()).map_err(OpenAIError::JSONDeserialize)?;
87
88 return Err(OpenAIError::ApiError(api_error));
89 }
90
91 Ok(text)
92 }
93
94 async fn resolve_file_response(request: RequestBuilder, filename: &str) -> OpenAIResponse<()> {
95 let response = request.send().await?;
96 let status = response.status();
97 let text = response.text().await?;
98
99 if !status.is_success() {
100 let api_error: ApiErrorResponse =
101 serde_json::from_slice(text.as_ref()).map_err(OpenAIError::JSONDeserialize)?;
102
103 return Err(OpenAIError::ApiError(api_error));
104 }
105
106 let mut file = File::create(filename).expect("failed to create file");
107 io::copy(&mut text.as_bytes(), &mut file).expect("failed to copy content");
108
109 Ok(())
110 }
111
112 pub(crate) async fn get<T, F>(&self, route: &str, query: &F) -> OpenAIResponse<T>
113 where
114 T: DeserializeOwned + Debug,
115 F: Serialize,
116 {
117 let request = self.openai_request(Method::GET, route, |request| request.query(query));
118 Self::resolve_response(request).await
119 }
120
121 pub(crate) async fn get_stream<T, F>(
122 &self,
123 route: &str,
124 query: &F,
125 ) -> Pin<Box<dyn Stream<Item = OpenAIResponse<T>> + Send>>
126 where
127 T: DeserializeOwned + Debug + Send + 'static,
128 F: Serialize,
129 {
130 let event_source = self
131 .openai_request(Method::GET, route, |request| request.query(query))
132 .eventsource()
133 .unwrap();
134 Self::stream_sse(event_source).await
135 }
136
137 pub(crate) async fn post<T, F>(&self, route: &str, json: &F) -> OpenAIResponse<T>
138 where
139 T: DeserializeOwned + Debug,
140 F: Serialize,
141 {
142 let request = self.openai_request(Method::POST, route, |request| request.json(json));
143 Self::resolve_response(request).await
144 }
145
146 pub(crate) async fn post_form<T>(&self, route: &str, form_data: Form) -> OpenAIResponse<T>
147 where
148 T: DeserializeOwned + Debug,
149 {
150 let request =
151 self.openai_request(Method::POST, route, |request| request.multipart(form_data));
152 Self::resolve_response(request).await
153 }
154
155 pub(crate) async fn post_form_with_text_response(
156 &self,
157 route: &str,
158 form_data: Form,
159 ) -> OpenAIResponse<String> {
160 let request =
161 self.openai_request(Method::POST, route, |request| request.multipart(form_data));
162 Self::resolve_text_response(request).await
163 }
164
165 pub(crate) async fn post_with_file_response<T>(
166 &self,
167 route: &str,
168 json: &T,
169 filename: &str,
170 ) -> OpenAIResponse<()>
171 where
172 T: Serialize,
173 {
174 let request = self.openai_request(Method::POST, route, |request| request.json(json));
175 Self::resolve_file_response(request, filename).await
176 }
177
178 pub(crate) async fn post_stream<T, F>(
179 &self,
180 route: &str,
181 json: &F,
182 ) -> Pin<Box<dyn Stream<Item = OpenAIResponse<T>> + Send>>
183 where
184 T: DeserializeOwned + Debug + Send + 'static,
185 F: Serialize,
186 {
187 let event_source = self
188 .openai_request(Method::POST, route, |request| request.json(json))
189 .eventsource()
190 .unwrap();
191 OpenAI::stream_sse(event_source).await
192 }
193
194 pub(crate) async fn delete<T, F>(&self, route: &str, json: &F) -> OpenAIResponse<T>
195 where
196 T: DeserializeOwned + Debug,
197 F: Serialize,
198 {
199 let request = self.openai_request(Method::DELETE, route, |request| request.json(json));
200 Self::resolve_response(request).await
201 }
202
203 async fn stream_sse<T>(
204 mut event_source: EventSource,
205 ) -> Pin<Box<dyn Stream<Item = OpenAIResponse<T>> + Send>>
206 where
207 T: DeserializeOwned + Debug + Send + 'static,
208 {
209 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<OpenAIResponse<T>>();
210
211 tokio::spawn(async move {
212 while let Some(evt) = event_source.next().await {
213 match evt {
214 Err(e) => {
215 if tx
216 .send(Err(OpenAIError::StreamError(e.to_string())))
217 .is_err()
218 {
219 break;
220 }
221 }
222 Ok(evt) => match evt {
223 Event::Message(message) => {
224 if message.data == "[DONE]" {
225 break;
226 }
227
228 let response = match serde_json::from_str::<T>(&message.data) {
229 Err(e) => Err(OpenAIError::JSONDeserialize(e)),
230 Ok(output) => Ok(output),
231 };
232
233 if tx.send(response).is_err() {
234 break;
235 }
236 }
237 Event::Open => continue,
238 },
239 }
240 }
241
242 event_source.close();
243 });
244
245 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
246 }
247
248 pub fn audio(&self) -> audio::Audio {
249 audio::Audio::new(self)
250 }
251
252 pub fn chat(&self) -> chat::Chat {
253 chat::Chat::new(self)
254 }
255
256 pub fn completions(&self) -> completions::Completions {
257 completions::Completions::new(self)
258 }
259
260 pub fn edits(&self) -> edits::Edits {
261 edits::Edits::new(self)
262 }
263
264 pub fn embeddings(&self) -> embeddings::Embeddings {
265 embeddings::Embeddings::new(self)
266 }
267
268 pub fn engines(&self) -> engines::Engines {
269 engines::Engines::new(self)
270 }
271
272 pub fn files(&self) -> files::Files {
273 files::Files::new(self)
274 }
275
276 pub fn fine_tunes(&self) -> fine_tunes::FineTunes {
277 fine_tunes::FineTunes::new(self)
278 }
279
280 pub fn images(&self) -> images::Images {
281 images::Images::new(self)
282 }
283
284 pub fn models(&self) -> models::Models {
285 models::Models::new(self)
286 }
287
288 pub fn moderations(&self) -> moderations::Moderations {
289 moderations::Moderations::new(self)
290 }
291}