1use std::sync::Mutex;
2
3use reqwest::multipart::Form;
4use reqwest::{header::AUTHORIZATION, Client, Method, RequestBuilder, Response};
5use reqwest_eventsource::{CannotCloneRequestError, EventSource, RequestBuilderExt};
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7
8pub mod chat;
9pub mod completions;
10pub mod edits;
11pub mod embeddings;
12pub mod files;
13pub mod models;
14pub mod moderations;
15
16static API_KEY: Mutex<String> = Mutex::new(String::new());
17static BASE_URL: Mutex<String> = Mutex::new(String::new());
18
19#[derive(Deserialize, Debug, Clone)]
20pub struct OpenAiError {
21 pub message: String,
22 #[serde(rename = "type")]
23 pub error_type: String,
24 pub param: Option<String>,
25 pub code: Option<String>,
26}
27
28impl OpenAiError {
29 fn new(message: String, error_type: String) -> OpenAiError {
30 OpenAiError {
31 message,
32 error_type,
33 param: None,
34 code: None,
35 }
36 }
37}
38
39impl std::fmt::Display for OpenAiError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.message)
42 }
43}
44
45impl std::error::Error for OpenAiError {}
46
47#[derive(Deserialize, Clone)]
48#[serde(untagged)]
49pub enum ApiResponse<T> {
50 Ok(T),
51 Err { error: OpenAiError },
52}
53
54#[derive(Deserialize, Clone, Copy, Debug)]
55pub struct Usage {
56 pub prompt_tokens: u32,
57 pub completion_tokens: u32,
58 pub total_tokens: u32,
59}
60
61pub type ApiResponseOrError<T> = Result<T, OpenAiError>;
62
63impl From<reqwest::Error> for OpenAiError {
64 fn from(value: reqwest::Error) -> Self {
65 OpenAiError::new(value.to_string(), "reqwest".to_string())
66 }
67}
68
69impl From<std::io::Error> for OpenAiError {
70 fn from(value: std::io::Error) -> Self {
71 OpenAiError::new(value.to_string(), "io".to_string())
72 }
73}
74
75async fn openai_request_json<F, T>(method: Method, route: &str, builder: F) -> ApiResponseOrError<T>
76where
77 F: FnOnce(RequestBuilder) -> RequestBuilder,
78 T: DeserializeOwned,
79{
80 let api_response = openai_request(method, route, builder).await?.json().await?;
81 match api_response {
82 ApiResponse::Ok(t) => Ok(t),
83 ApiResponse::Err { error } => Err(error),
84 }
85}
86
87async fn openai_request<F>(method: Method, route: &str, builder: F) -> ApiResponseOrError<Response>
88where
89 F: FnOnce(RequestBuilder) -> RequestBuilder,
90{
91 let client = Client::new();
92 let mut request = client.request(method, get_base_url().lock().unwrap().to_owned() + route);
93
94 request = builder(request);
95
96 let response = request
97 .header(AUTHORIZATION, format!("Bearer {}", API_KEY.lock().unwrap()))
98 .send()
99 .await?;
100 Ok(response)
101}
102
103async fn openai_request_stream<F>(
104 method: Method,
105 route: &str,
106 builder: F,
107) -> Result<EventSource, CannotCloneRequestError>
108where
109 F: FnOnce(RequestBuilder) -> RequestBuilder,
110{
111 let client = Client::new();
112 let mut request = client.request(method, get_base_url().lock().unwrap().to_owned() + route);
113
114 request = builder(request);
115
116 let stream = request
117 .header(AUTHORIZATION, format!("Bearer {}", API_KEY.lock().unwrap()))
118 .eventsource()?;
119
120 Ok(stream)
121}
122
123async fn openai_get<T>(route: &str) -> ApiResponseOrError<T>
124where
125 T: DeserializeOwned,
126{
127 openai_request_json(Method::GET, route, |request| request).await
128}
129
130async fn openai_delete<T>(route: &str) -> ApiResponseOrError<T>
131where
132 T: DeserializeOwned,
133{
134 openai_request_json(Method::DELETE, route, |request| request).await
135}
136
137async fn openai_post<J, T>(route: &str, json: &J) -> ApiResponseOrError<T>
138where
139 J: Serialize + ?Sized,
140 T: DeserializeOwned,
141{
142 openai_request_json(Method::POST, route, |request| request.json(json)).await
143}
144
145async fn openai_post_multipart<T>(route: &str, form: Form) -> ApiResponseOrError<T>
146where
147 T: DeserializeOwned,
148{
149 openai_request_json(Method::POST, route, |request| request.multipart(form)).await
150}
151
152pub fn set_key(value: String) {
167 *API_KEY.lock().unwrap() = value;
168}
169
170pub fn set_base_url(value: String) {
185 let base_url_mutex = get_base_url();
186 if value.is_empty() {
187 return;
188 }
189 let mut base_url = base_url_mutex.lock().unwrap();
190 *base_url = value;
191 if !base_url.ends_with('/') {
192 *base_url += "/";
193 }
194}
195
196fn get_base_url() -> &'static Mutex<String> {
199 let mut base_url = BASE_URL.lock().unwrap();
200 if base_url.is_empty() {
201 *base_url = String::from("https://api.openai.com/v1/");
202 }
203 &BASE_URL
204}
205
206#[cfg(test)]
207pub mod tests {
208 use super::*;
209
210 pub const DEFAULT_LEGACY_MODEL: &str = "gpt-3.5-turbo-instruct";
211
212 #[test]
213 fn test_get_base_url_default() {
214 assert_eq!(
215 get_base_url().lock().unwrap().to_owned(),
216 String::from("https://api.openai.com/v1/")
217 );
218
219 set_base_url(String::from(""));
221 assert_eq!(
222 get_base_url().lock().unwrap().to_owned(),
223 String::from("https://api.openai.com/v1/")
224 );
225
226 set_base_url(String::from("https://api.openai.com/v1"));
228 assert_eq!(
229 get_base_url().lock().unwrap().to_owned(),
230 String::from("https://api.openai.com/v1/")
231 );
232 }
233
234 #[test]
235 fn test_get_base_url_set() {
236 set_base_url(String::from("https://api.openai.com/v2/"));
237 assert_eq!(
238 get_base_url().lock().unwrap().to_owned(),
239 String::from("https://api.openai.com/v2/")
240 );
241 set_base_url(String::from("https://api.openai.com/v1"));
243 }
244}