use std::{future::Future, pin::Pin, sync::Arc};
use reqwest::{Request, Response};
use crate::{error::OpenAIError, retry::OpenAIRetry};
#[cfg(not(target_family = "wasm"))]
type RequestFuture = Pin<Box<dyn Future<Output = Result<Request, OpenAIError>> + Send + 'static>>;
#[cfg(target_family = "wasm")]
type RequestFuture = Pin<Box<dyn Future<Output = Result<Request, OpenAIError>> + 'static>>;
#[cfg(not(target_family = "wasm"))]
pub(crate) type HttpFuture =
Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + Send + 'static>>;
#[cfg(target_family = "wasm")]
pub(crate) type HttpFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + 'static>>;
#[cfg(not(target_family = "wasm"))]
type RequestFn = dyn Fn() -> RequestFuture + Send + Sync + 'static;
#[cfg(target_family = "wasm")]
type RequestFn = dyn Fn() -> RequestFuture + 'static;
#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
pub trait MiddlewareInput: Send + Sync + 'static {}
#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
impl<T> MiddlewareInput for T where T: Send + Sync + 'static {}
#[cfg(all(feature = "middleware", target_family = "wasm"))]
pub trait MiddlewareInput: 'static {}
#[cfg(all(feature = "middleware", target_family = "wasm"))]
impl<T> MiddlewareInput for T where T: 'static {}
#[derive(Clone)]
pub struct HttpRequestFactory {
make_request: Arc<RequestFn>,
}
impl std::fmt::Debug for HttpRequestFactory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpRequestFactory").finish_non_exhaustive()
}
}
impl HttpRequestFactory {
#[cfg(not(target_family = "wasm"))]
pub fn new<F, Fut>(make_request: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Request, OpenAIError>> + Send + 'static,
{
Self {
make_request: Arc::new(move || Box::pin(make_request())),
}
}
#[cfg(target_family = "wasm")]
pub fn new<F, Fut>(make_request: F) -> Self
where
F: Fn() -> Fut + 'static,
Fut: Future<Output = Result<Request, OpenAIError>> + 'static,
{
Self {
make_request: Arc::new(move || Box::pin(make_request())),
}
}
pub async fn build(&self) -> Result<Request, OpenAIError> {
(self.make_request)().await
}
}
#[cfg(not(target_family = "wasm"))]
pub trait HttpExecutor: Send + Sync {
fn execute(&self, request: HttpRequestFactory) -> HttpFuture;
}
#[cfg(target_family = "wasm")]
pub trait HttpExecutor {
fn execute(&self, request: HttpRequestFactory) -> HttpFuture;
}
#[derive(Clone, Debug, Default)]
pub struct ReqwestService {
client: reqwest::Client,
}
impl ReqwestService {
pub fn new(client: reqwest::Client) -> Self {
Self { client }
}
}
impl tower::Service<HttpRequestFactory> for ReqwestService {
type Response = Response;
type Error = OpenAIError;
type Future = HttpFuture;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
let client = self.client.clone();
Box::pin(async move {
let request = request.build().await?;
client.execute(request).await.map_err(OpenAIError::Reqwest)
})
}
}
#[derive(Clone, Debug)]
pub(crate) struct ReqwestExecutor {
service: OpenAIRetry<ReqwestService>,
}
impl ReqwestExecutor {
pub(crate) fn new(client: reqwest::Client) -> Self {
Self {
service: tower::ServiceBuilder::new()
.layer(crate::retry::OpenAIRetryLayer::default())
.service(ReqwestService::new(client)),
}
}
}
impl HttpExecutor for ReqwestExecutor {
fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
use tower::ServiceExt;
let service = self.service.clone();
Box::pin(async move { service.oneshot(request).await })
}
}
#[cfg(feature = "middleware")]
pub(crate) struct TowerExecutor<S> {
service: S,
}
#[cfg(feature = "middleware")]
impl<S> TowerExecutor<S> {
pub(crate) fn new(service: S) -> Self {
Self { service }
}
}
#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
impl<S> HttpExecutor for TowerExecutor<S>
where
S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
S::Future: Send + 'static,
S::Error: Into<OpenAIError> + Send + Sync + 'static,
{
fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
use tower::ServiceExt;
let service = self.service.clone();
Box::pin(async move {
service.oneshot(request).await.map_err(Into::into)
})
}
}
#[cfg(all(feature = "middleware", target_family = "wasm"))]
impl<S> HttpExecutor for TowerExecutor<S>
where
S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
S::Future: 'static,
S::Error: Into<OpenAIError> + 'static,
{
fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
use tower::ServiceExt;
let service = self.service.clone();
Box::pin(async move { service.oneshot(request).await.map_err(Into::into) })
}
}
pub(crate) type SharedExecutor = Arc<dyn HttpExecutor>;