openai_fork/
lib.rs

1use std::sync::Mutex;
2
3use reqwest::multipart::Form;
4use reqwest::{header::AUTHORIZATION, Client, Method, RequestBuilder, Response};
5use reqwest_eventsource::{CannotCloneRequestError, EventSource, RequestBuilderExt};
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7
8pub mod chat;
9pub mod completions;
10pub mod edits;
11pub mod embeddings;
12pub mod files;
13pub mod models;
14pub mod moderations;
15
16static API_KEY: Mutex<String> = Mutex::new(String::new());
17static BASE_URL: Mutex<String> = Mutex::new(String::new());
18
19#[derive(Deserialize, Debug, Clone)]
20pub struct OpenAiError {
21    pub message: String,
22    #[serde(rename = "type")]
23    pub error_type: String,
24    pub param: Option<String>,
25    pub code: Option<String>,
26}
27
28impl OpenAiError {
29    fn new(message: String, error_type: String) -> OpenAiError {
30        OpenAiError {
31            message,
32            error_type,
33            param: None,
34            code: None,
35        }
36    }
37}
38
39impl std::fmt::Display for OpenAiError {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.message)
42    }
43}
44
45impl std::error::Error for OpenAiError {}
46
47#[derive(Deserialize, Clone)]
48#[serde(untagged)]
49pub enum ApiResponse<T> {
50    Ok(T),
51    Err { error: OpenAiError },
52}
53
54#[derive(Deserialize, Clone, Copy, Debug)]
55pub struct Usage {
56    pub prompt_tokens: u32,
57    pub completion_tokens: u32,
58    pub total_tokens: u32,
59}
60
61pub type ApiResponseOrError<T> = Result<T, OpenAiError>;
62
63impl From<reqwest::Error> for OpenAiError {
64    fn from(value: reqwest::Error) -> Self {
65        OpenAiError::new(value.to_string(), "reqwest".to_string())
66    }
67}
68
69impl From<std::io::Error> for OpenAiError {
70    fn from(value: std::io::Error) -> Self {
71        OpenAiError::new(value.to_string(), "io".to_string())
72    }
73}
74
75async fn openai_request_json<F, T>(method: Method, route: &str, builder: F) -> ApiResponseOrError<T>
76where
77    F: FnOnce(RequestBuilder) -> RequestBuilder,
78    T: DeserializeOwned,
79{
80    let api_response = openai_request(method, route, builder).await?.json().await?;
81    match api_response {
82        ApiResponse::Ok(t) => Ok(t),
83        ApiResponse::Err { error } => Err(error),
84    }
85}
86
87async fn openai_request<F>(method: Method, route: &str, builder: F) -> ApiResponseOrError<Response>
88where
89    F: FnOnce(RequestBuilder) -> RequestBuilder,
90{
91    let client = Client::new();
92    let mut request = client.request(method, get_base_url().lock().unwrap().to_owned() + route);
93
94    request = builder(request);
95
96    let response = request
97        .header(AUTHORIZATION, format!("Bearer {}", API_KEY.lock().unwrap()))
98        .send()
99        .await?;
100    Ok(response)
101}
102
103async fn openai_request_stream<F>(
104    method: Method,
105    route: &str,
106    builder: F,
107) -> Result<EventSource, CannotCloneRequestError>
108where
109    F: FnOnce(RequestBuilder) -> RequestBuilder,
110{
111    let client = Client::new();
112    let mut request = client.request(method, get_base_url().lock().unwrap().to_owned() + route);
113
114    request = builder(request);
115
116    let stream = request
117        .header(AUTHORIZATION, format!("Bearer {}", API_KEY.lock().unwrap()))
118        .eventsource()?;
119
120    Ok(stream)
121}
122
123async fn openai_get<T>(route: &str) -> ApiResponseOrError<T>
124where
125    T: DeserializeOwned,
126{
127    openai_request_json(Method::GET, route, |request| request).await
128}
129
130async fn openai_delete<T>(route: &str) -> ApiResponseOrError<T>
131where
132    T: DeserializeOwned,
133{
134    openai_request_json(Method::DELETE, route, |request| request).await
135}
136
137async fn openai_post<J, T>(route: &str, json: &J) -> ApiResponseOrError<T>
138where
139    J: Serialize + ?Sized,
140    T: DeserializeOwned,
141{
142    openai_request_json(Method::POST, route, |request| request.json(json)).await
143}
144
145async fn openai_post_multipart<T>(route: &str, form: Form) -> ApiResponseOrError<T>
146where
147    T: DeserializeOwned,
148{
149    openai_request_json(Method::POST, route, |request| request.multipart(form)).await
150}
151
152/// Sets the key for all OpenAI API functions.
153///
154/// ## Examples
155///
156/// Use environment variable `OPENAI_KEY` defined from `.env` file:
157///
158/// ```rust
159/// use openai::set_key;
160/// use dotenvy::dotenv;
161/// use std::env;
162///
163/// dotenv().ok();
164/// set_key(env::var("OPENAI_KEY").unwrap());
165/// ```
166pub fn set_key(value: String) {
167    *API_KEY.lock().unwrap() = value;
168}
169
170/// Sets the base url for all OpenAI API functions.
171///
172/// ## Examples
173///
174/// Use environment variable `OPENAI_BASE_URL` defined from `.env` file:
175///
176/// ```rust
177/// use openai::set_base_url;
178/// use dotenvy::dotenv;
179/// use std::env;
180///
181/// dotenv().ok();
182/// set_base_url(env::var("OPENAI_BASE_URL").unwrap_or_default());
183/// ```
184pub fn set_base_url(value: String) {
185    let base_url_mutex = get_base_url();
186    if value.is_empty() {
187        return;
188    }
189    let mut base_url = base_url_mutex.lock().unwrap();
190    *base_url = value;
191    if !base_url.ends_with('/') {
192        *base_url += "/";
193    }
194}
195
196/// Returns the base url for all OpenAI API functions.
197/// Defaults to `https://api.openai.com/v1/`.
198fn get_base_url() -> &'static Mutex<String> {
199    let mut base_url = BASE_URL.lock().unwrap();
200    if base_url.is_empty() {
201        *base_url = String::from("https://api.openai.com/v1/");
202    }
203    &BASE_URL
204}
205
206#[cfg(test)]
207pub mod tests {
208    use super::*;
209
210    pub const DEFAULT_LEGACY_MODEL: &str = "gpt-3.5-turbo-instruct";
211
212    #[test]
213    fn test_get_base_url_default() {
214        assert_eq!(
215            get_base_url().lock().unwrap().to_owned(),
216            String::from("https://api.openai.com/v1/")
217        );
218
219        // empty env var
220        set_base_url(String::from(""));
221        assert_eq!(
222            get_base_url().lock().unwrap().to_owned(),
223            String::from("https://api.openai.com/v1/")
224        );
225
226        // appends slash
227        set_base_url(String::from("https://api.openai.com/v1"));
228        assert_eq!(
229            get_base_url().lock().unwrap().to_owned(),
230            String::from("https://api.openai.com/v1/")
231        );
232    }
233
234    #[test]
235    fn test_get_base_url_set() {
236        set_base_url(String::from("https://api.openai.com/v2/"));
237        assert_eq!(
238            get_base_url().lock().unwrap().to_owned(),
239            String::from("https://api.openai.com/v2/")
240        );
241        // need this here to reset the base url for other tests
242        set_base_url(String::from("https://api.openai.com/v1"));
243    }
244}