use std::pin::Pin;
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};
use crate::{
config::{Config, OpenAIConfig},
error::{map_deserialization_error, OpenAIError, WrappedError},
file::Files,
image::Images,
moderation::Moderations,
Assistants, Audio, Batches, Chat, Completions, Embeddings, FineTuning, Models, Threads,
VectorStores,
};
#[derive(Debug, Clone)]
pub struct Client<C: Config> {
http_client: reqwest::Client,
config: C,
backoff: backoff::ExponentialBackoff,
}
impl Client<OpenAIConfig> {
pub fn new() -> Self {
Self {
http_client: reqwest::Client::new(),
config: OpenAIConfig::default(),
backoff: Default::default(),
}
}
}
impl<C: Config> Client<C> {
pub fn build(
http_client: reqwest::Client,
config: C,
backoff: backoff::ExponentialBackoff,
) -> Self {
Self {
http_client,
config,
backoff,
}
}
pub fn with_config(config: C) -> Self {
Self {
http_client: reqwest::Client::new(),
config,
backoff: Default::default(),
}
}
pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
self.http_client = http_client;
self
}
pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
self.backoff = backoff;
self
}
pub fn models(&self) -> Models<C> {
Models::new(self)
}
pub fn completions(&self) -> Completions<C> {
Completions::new(self)
}
pub fn chat(&self) -> Chat<C> {
Chat::new(self)
}
pub fn images(&self) -> Images<C> {
Images::new(self)
}
pub fn moderations(&self) -> Moderations<C> {
Moderations::new(self)
}
pub fn files(&self) -> Files<C> {
Files::new(self)
}
pub fn fine_tuning(&self) -> FineTuning<C> {
FineTuning::new(self)
}
pub fn embeddings(&self) -> Embeddings<C> {
Embeddings::new(self)
}
pub fn audio(&self) -> Audio<C> {
Audio::new(self)
}
pub fn assistants(&self) -> Assistants<C> {
Assistants::new(self)
}
pub fn threads(&self) -> Threads<C> {
Threads::new(self)
}
pub fn vector_stores(&self) -> VectorStores<C> {
VectorStores::new(self)
}
pub fn batches(&self) -> Batches<C> {
Batches::new(self)
}
pub fn config(&self) -> &C {
&self.config
}
pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};
self.execute(request_maker).await
}
pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
Q: Serialize + ?Sized,
{
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.query(query)
.headers(self.config.headers())
.build()?)
};
self.execute(request_maker).await
}
pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.delete(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};
self.execute(request_maker).await
}
pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};
self.execute_raw(request_maker).await
}
pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
where
I: Serialize,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.build()?)
};
self.execute_raw(request_maker).await
}
pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
where
I: Serialize,
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.build()?)
};
self.execute(request_maker).await
}
pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
where
reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
F: Clone,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(async_convert::TryFrom::try_from(form.clone()).await?)
.build()?)
};
self.execute_raw(request_maker).await
}
pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
F: Clone,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(async_convert::TryFrom::try_from(form.clone()).await?)
.build()?)
};
self.execute(request_maker).await
}
async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
where
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
{
let client = self.http_client.clone();
backoff::future::retry(self.backoff.clone(), || async {
let request = request_maker().await.map_err(backoff::Error::Permanent)?;
let response = client
.execute(request)
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;
let status = response.status();
let bytes = response
.bytes()
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
if status.as_u16() == 429
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
{
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(wrapped_error.error),
retry_after: None,
});
} else {
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
wrapped_error.error,
)));
}
}
Ok(bytes)
})
.await
}
async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
{
let bytes = self.execute_raw(request_maker).await?;
let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
Ok(response)
}
pub(crate) async fn post_stream<I, O>(
&self,
path: &str,
request: I,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.eventsource()
.unwrap();
stream(event_source).await
}
pub(crate) async fn post_stream_mapped_raw_events<I, O>(
&self,
path: &str,
request: I,
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.eventsource()
.unwrap();
stream_mapped_raw_events(event_source, event_mapper).await
}
pub(crate) async fn _get_stream<Q, O>(
&self,
path: &str,
query: &Q,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
Q: Serialize + ?Sized,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.get(self.config.url(path))
.query(query)
.query(&self.config.query())
.headers(self.config.headers())
.eventsource()
.unwrap();
stream(event_source).await
}
}
pub(crate) async fn stream<O>(
mut event_source: EventSource,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
O: DeserializeOwned + std::marker::Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
break;
}
}
Ok(event) => match event {
Event::Message(message) => {
if message.data == "[DONE]" {
break;
}
let response = match serde_json::from_str::<O>(&message.data) {
Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
Ok(output) => Ok(output),
};
if let Err(_e) = tx.send(response) {
break;
}
}
Event::Open => continue,
},
}
}
event_source.close();
});
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
}
pub(crate) async fn stream_mapped_raw_events<O>(
mut event_source: EventSource,
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
O: DeserializeOwned + std::marker::Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
break;
}
}
Ok(event) => match event {
Event::Message(message) => {
let mut done = false;
if message.data == "[DONE]" {
done = true;
}
let response = event_mapper(message);
if let Err(_e) = tx.send(response) {
break;
}
if done {
break;
}
}
Event::Open => continue,
},
}
}
event_source.close();
});
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
}