use reqwest::multipart::Form;
use reqwest::{header::AUTHORIZATION, Client, Method, RequestBuilder, Response};
use reqwest_eventsource::{CannotCloneRequestError, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::env;
use std::env::VarError;
use std::sync::{LazyLock, RwLock};
pub mod chat;
pub mod completions;
pub mod edits;
pub mod embeddings;
pub mod files;
pub mod models;
pub mod moderations;
pub static DEFAULT_BASE_URL: LazyLock<String> =
LazyLock::new(|| String::from("https://api.openai.com/v1/"));
static DEFAULT_CREDENTIALS: LazyLock<RwLock<Credentials>> =
LazyLock::new(|| RwLock::new(Credentials::from_env()));
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Credentials {
api_key: String,
base_url: String,
}
impl Credentials {
pub fn new(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
let base_url = base_url.into();
let base_url = if base_url.is_empty() {
DEFAULT_BASE_URL.clone()
} else {
parse_base_url(base_url)
};
Self {
api_key: api_key.into(),
base_url,
}
}
pub fn from_env() -> Credentials {
let api_key = env::var("OPENAI_KEY").unwrap();
let base_url_unparsed = env::var("OPENAI_BASE_URL").unwrap_or_else(|e| match e {
VarError::NotPresent => DEFAULT_BASE_URL.clone(),
VarError::NotUnicode(v) => panic!("OPENAI_BASE_URL is not unicode: {v:#?}"),
});
let base_url = parse_base_url(base_url_unparsed);
Credentials { api_key, base_url }
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn base_url(&self) -> &str {
&self.base_url
}
}
#[derive(Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct OpenAiError {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
pub param: Option<String>,
pub code: Option<String>,
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq, Default)]
pub struct RequestPagination {
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub after: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub order: Option<RequestOrder>,
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
pub enum RequestOrder {
#[serde(rename = "asc")]
Ascending,
#[serde(rename = "desc")]
Descending,
}
impl OpenAiError {
fn new(message: String, error_type: String) -> OpenAiError {
OpenAiError {
message,
error_type,
param: None,
code: None,
}
}
}
impl std::fmt::Display for OpenAiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for OpenAiError {}
#[derive(Deserialize, Clone)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Err { error: OpenAiError },
Ok(T),
}
#[derive(Deserialize, Clone, Copy, Debug, Eq, PartialEq)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
pub type ApiResponseOrError<T> = Result<T, OpenAiError>;
impl From<reqwest::Error> for OpenAiError {
fn from(value: reqwest::Error) -> Self {
OpenAiError::new(value.to_string(), "reqwest".to_string())
}
}
impl From<std::io::Error> for OpenAiError {
fn from(value: std::io::Error) -> Self {
OpenAiError::new(value.to_string(), "io".to_string())
}
}
async fn openai_request_json<F, T>(
method: Method,
route: &str,
builder: F,
credentials_opt: Option<Credentials>,
) -> ApiResponseOrError<T>
where
F: FnOnce(RequestBuilder) -> RequestBuilder,
T: DeserializeOwned,
{
let api_response = openai_request(method, route, builder, credentials_opt)
.await?
.json()
.await?;
match api_response {
ApiResponse::Ok(t) => Ok(t),
ApiResponse::Err { error } => Err(error),
}
}
async fn openai_request<F>(
method: Method,
route: &str,
builder: F,
credentials_opt: Option<Credentials>,
) -> ApiResponseOrError<Response>
where
F: FnOnce(RequestBuilder) -> RequestBuilder,
{
let client = Client::new();
let credentials =
credentials_opt.unwrap_or_else(|| DEFAULT_CREDENTIALS.read().unwrap().clone());
let mut request = client.request(method, format!("{}{route}", credentials.base_url));
request = builder(request);
let response = request
.header(AUTHORIZATION, format!("Bearer {}", credentials.api_key))
.send()
.await?;
Ok(response)
}
async fn openai_request_stream<F>(
method: Method,
route: &str,
builder: F,
credentials_opt: Option<Credentials>,
) -> Result<EventSource, CannotCloneRequestError>
where
F: FnOnce(RequestBuilder) -> RequestBuilder,
{
let client = Client::new();
let credentials =
credentials_opt.unwrap_or_else(|| DEFAULT_CREDENTIALS.read().unwrap().clone());
let mut request = client.request(method, format!("{}{route}", credentials.base_url));
request = builder(request);
let stream = request
.header(AUTHORIZATION, format!("Bearer {}", credentials.api_key))
.eventsource()?;
Ok(stream)
}
async fn openai_get<T>(route: &str, credentials_opt: Option<Credentials>) -> ApiResponseOrError<T>
where
T: DeserializeOwned,
{
openai_request_json(Method::GET, route, |request| request, credentials_opt).await
}
async fn openai_get_with_query<T, Query>(
route: &str,
query: &Query,
credentials_opt: Option<Credentials>,
) -> ApiResponseOrError<T>
where
T: DeserializeOwned,
Query: Serialize + ?Sized,
{
openai_request_json(
Method::GET,
route,
|request| request.query(query),
credentials_opt,
)
.await
}
async fn openai_delete<T>(
route: &str,
credentials_opt: Option<Credentials>,
) -> ApiResponseOrError<T>
where
T: DeserializeOwned,
{
openai_request_json(Method::DELETE, route, |request| request, credentials_opt).await
}
async fn openai_post<J, T>(
route: &str,
json: &J,
credentials_opt: Option<Credentials>,
) -> ApiResponseOrError<T>
where
J: Serialize + ?Sized,
T: DeserializeOwned,
{
openai_request_json(
Method::POST,
route,
|request| request.json(json),
credentials_opt,
)
.await
}
async fn openai_post_multipart<T>(
route: &str,
form: Form,
credentials_opt: Option<Credentials>,
) -> ApiResponseOrError<T>
where
T: DeserializeOwned,
{
openai_request_json(
Method::POST,
route,
|request| request.multipart(form),
credentials_opt,
)
.await
}
#[deprecated(
since = "1.0.0-alpha.16",
note = "use the `Credentials` struct instead"
)]
pub fn set_key(value: String) {
let mut credentials = DEFAULT_CREDENTIALS.write().unwrap();
credentials.api_key = value;
}
#[deprecated(
since = "1.0.0-alpha.16",
note = "use the `Credentials` struct instead"
)]
pub fn set_base_url(mut value: String) {
if value.is_empty() {
return;
}
value = parse_base_url(value);
let mut credentials = DEFAULT_CREDENTIALS.write().unwrap();
credentials.base_url = value;
}
fn parse_base_url(mut value: String) -> String {
if !value.ends_with('/') {
value += "/";
}
value
}
#[cfg(test)]
pub mod tests {
pub const DEFAULT_LEGACY_MODEL: &str = "gpt-3.5-turbo-instruct";
}