Skip to main content

async_openai/
executor.rs

1use std::{future::Future, pin::Pin, sync::Arc};
2
3use reqwest::{Request, Response};
4
5use crate::{error::OpenAIError, retry::OpenAIRetry};
6
7#[cfg(not(target_family = "wasm"))]
8type RequestFuture = Pin<Box<dyn Future<Output = Result<Request, OpenAIError>> + Send + 'static>>;
9#[cfg(target_family = "wasm")]
10type RequestFuture = Pin<Box<dyn Future<Output = Result<Request, OpenAIError>> + 'static>>;
11
12#[cfg(not(target_family = "wasm"))]
13pub(crate) type HttpFuture =
14    Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + Send + 'static>>;
15#[cfg(target_family = "wasm")]
16pub(crate) type HttpFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + 'static>>;
17
18#[cfg(not(target_family = "wasm"))]
19type RequestFn = dyn Fn() -> RequestFuture + Send + Sync + 'static;
20#[cfg(target_family = "wasm")]
21type RequestFn = dyn Fn() -> RequestFuture + 'static;
22
23/// Trait for types allowed as input in `*_byot` macro methods that works with `middleware` feature.
24#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
25pub trait MiddlewareInput: Send + Sync + 'static {}
26#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
27impl<T> MiddlewareInput for T where T: Send + Sync + 'static {}
28
29#[cfg(all(feature = "middleware", target_family = "wasm"))]
30pub trait MiddlewareInput: 'static {}
31#[cfg(all(feature = "middleware", target_family = "wasm"))]
32impl<T> MiddlewareInput for T where T: 'static {}
33
34/// Cheaply cloneable request factory used to rebuild a request on demand.
35///
36/// This is the key boundary for middleware support:
37/// - the client captures the request inputs once
38/// - tower layers may clone the factory freely
39/// - retries rebuild the request instead of trying to clone an already-built
40///   `reqwest::Request`
41///
42/// The `Arc` is intentional. `tower::retry` needs to be able to clone the
43/// request handle without forcing the payload itself to be eagerly copied.
44/// The factory handle is cheap to clone; the request is only rebuilt when
45/// `build()` is actually called.
46#[derive(Clone)]
47pub struct HttpRequestFactory {
48    make_request: Arc<RequestFn>,
49}
50
51impl std::fmt::Debug for HttpRequestFactory {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("HttpRequestFactory").finish_non_exhaustive()
54    }
55}
56
57impl HttpRequestFactory {
58    /// Create a replayable request factory from an async request builder.
59    ///
60    /// The closure is stored behind an `Arc` so this value stays cheap to
61    /// clone when it is passed through tower layers.
62    #[cfg(not(target_family = "wasm"))]
63    pub fn new<F, Fut>(make_request: F) -> Self
64    where
65        F: Fn() -> Fut + Send + Sync + 'static,
66        Fut: Future<Output = Result<Request, OpenAIError>> + Send + 'static,
67    {
68        Self {
69            make_request: Arc::new(move || Box::pin(make_request())),
70        }
71    }
72
73    #[cfg(target_family = "wasm")]
74    pub fn new<F, Fut>(make_request: F) -> Self
75    where
76        F: Fn() -> Fut + 'static,
77        Fut: Future<Output = Result<Request, OpenAIError>> + 'static,
78    {
79        Self {
80            make_request: Arc::new(move || Box::pin(make_request())),
81        }
82    }
83
84    /// Rebuild the request for the current attempt.
85    ///
86    /// This is what makes retries possible for non-cloneable bodies. The
87    /// request is not cloned after construction; instead, the original request
88    /// inputs are replayed to produce a fresh `reqwest::Request` each time.
89    pub async fn build(&self) -> Result<Request, OpenAIError> {
90        (self.make_request)().await
91    }
92}
93
94/// Minimal request execution interface used by `Client`.
95///
96/// The executor sees the replayable factory rather than a built request so it
97/// can decide when to rebuild and send. That keeps the retry decision close to
98/// execution and avoids forcing every call site to know whether a
99/// request body is cloneable.
100#[cfg(not(target_family = "wasm"))]
101pub trait HttpExecutor: Send + Sync {
102    fn execute(&self, request: HttpRequestFactory) -> HttpFuture;
103}
104
105#[cfg(target_family = "wasm")]
106pub trait HttpExecutor {
107    fn execute(&self, request: HttpRequestFactory) -> HttpFuture;
108}
109
110/// Default tower-compatible service backed directly by `reqwest::Client`.
111///
112/// Users can layer retry, timeout, rate limiting, tracing, or any other tower
113/// middleware around this service and then install the composed service with
114/// `Client::with_http_service(...)`.
115#[derive(Clone, Debug, Default)]
116pub struct ReqwestService {
117    client: reqwest::Client,
118}
119
120impl ReqwestService {
121    pub fn new(client: reqwest::Client) -> Self {
122        Self { client }
123    }
124}
125
126impl tower::Service<HttpRequestFactory> for ReqwestService {
127    type Response = Response;
128    type Error = OpenAIError;
129    type Future = HttpFuture;
130
131    fn poll_ready(
132        &mut self,
133        _cx: &mut std::task::Context<'_>,
134    ) -> std::task::Poll<Result<(), Self::Error>> {
135        std::task::Poll::Ready(Ok(()))
136    }
137
138    fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
139        let client = self.client.clone();
140        Box::pin(async move {
141            // This is the plain reqwest transport path. It intentionally does
142            // nothing beyond rebuilding the request and executing it.
143            let request = request.build().await?;
144            client.execute(request).await.map_err(OpenAIError::Reqwest)
145        })
146    }
147}
148
149#[derive(Clone, Debug)]
150pub(crate) struct ReqwestExecutor {
151    service: OpenAIRetry<ReqwestService>,
152}
153
154impl ReqwestExecutor {
155    pub(crate) fn new(client: reqwest::Client) -> Self {
156        Self {
157            service: tower::ServiceBuilder::new()
158                .layer(crate::retry::OpenAIRetryLayer::default())
159                .service(ReqwestService::new(client)),
160        }
161    }
162}
163
164impl HttpExecutor for ReqwestExecutor {
165    fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
166        use tower::ServiceExt;
167
168        let service = self.service.clone();
169        Box::pin(async move { service.oneshot(request).await })
170    }
171}
172
173#[cfg(feature = "middleware")]
174pub(crate) struct TowerExecutor<S> {
175    service: S,
176}
177
178#[cfg(feature = "middleware")]
179impl<S> TowerExecutor<S> {
180    pub(crate) fn new(service: S) -> Self {
181        // The executor is just an adapter around a user-supplied tower stack.
182        // All of the interesting policy decisions live in the stack itself;
183        // this wrapper only keeps `Client` from becoming generic over `S`.
184        Self { service }
185    }
186}
187
188#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
189impl<S> HttpExecutor for TowerExecutor<S>
190where
191    S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
192    S::Future: Send + 'static,
193    S::Error: Into<OpenAIError> + Send + Sync + 'static,
194{
195    fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
196        use tower::ServiceExt;
197
198        let service = self.service.clone();
199        Box::pin(async move {
200            // `oneshot` keeps the client-side executor simple: the tower stack
201            // decides how to use the replayable request factory, and the client
202            // does not need to manage readiness or buffering itself.
203            service.oneshot(request).await.map_err(Into::into)
204        })
205    }
206}
207
208#[cfg(all(feature = "middleware", target_family = "wasm"))]
209impl<S> HttpExecutor for TowerExecutor<S>
210where
211    S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
212    S::Future: 'static,
213    S::Error: Into<OpenAIError> + 'static,
214{
215    fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
216        use tower::ServiceExt;
217
218        let service = self.service.clone();
219        Box::pin(async move { service.oneshot(request).await.map_err(Into::into) })
220    }
221}
222
223pub(crate) type SharedExecutor = Arc<dyn HttpExecutor>;