Skip to main content

async_openai/
executor.rs

1use std::{future::Future, pin::Pin, sync::Arc};
2
3use reqwest::{Request, Response};
4
5use crate::{error::OpenAIError, retry::OpenAIRetry};
6
7#[cfg(not(target_family = "wasm"))]
8type RequestFuture = Pin<Box<dyn Future<Output = Result<Request, OpenAIError>> + Send + 'static>>;
9#[cfg(target_family = "wasm")]
10type RequestFuture = Pin<Box<dyn Future<Output = Result<Request, OpenAIError>> + 'static>>;
11
12#[cfg(not(target_family = "wasm"))]
13pub(crate) type HttpFuture =
14    Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + Send + 'static>>;
15#[cfg(target_family = "wasm")]
16pub(crate) type HttpFuture = Pin<Box<dyn Future<Output = Result<Response, OpenAIError>> + 'static>>;
17
18#[cfg(not(target_family = "wasm"))]
19type RequestFn = dyn Fn() -> RequestFuture + Send + Sync + 'static;
20#[cfg(target_family = "wasm")]
21type RequestFn = dyn Fn() -> RequestFuture + 'static;
22
23/// Cheaply cloneable request factory used to rebuild a request on demand.
24///
25/// This is the key boundary for middleware support:
26/// - the client captures the request inputs once
27/// - tower layers may clone the factory freely
28/// - retries rebuild the request instead of trying to clone an already-built
29///   `reqwest::Request`
30///
31/// The `Arc` is intentional. `tower::retry` needs to be able to clone the
32/// request handle without forcing the payload itself to be eagerly copied.
33/// The factory handle is cheap to clone; the request is only rebuilt when
34/// `build()` is actually called.
35#[derive(Clone)]
36pub struct HttpRequestFactory {
37    make_request: Arc<RequestFn>,
38}
39
40impl std::fmt::Debug for HttpRequestFactory {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("HttpRequestFactory").finish_non_exhaustive()
43    }
44}
45
46impl HttpRequestFactory {
47    /// Create a replayable request factory from an async request builder.
48    ///
49    /// The closure is stored behind an `Arc` so this value stays cheap to
50    /// clone when it is passed through tower layers.
51    #[cfg(not(target_family = "wasm"))]
52    pub fn new<F, Fut>(make_request: F) -> Self
53    where
54        F: Fn() -> Fut + Send + Sync + 'static,
55        Fut: Future<Output = Result<Request, OpenAIError>> + Send + 'static,
56    {
57        Self {
58            make_request: Arc::new(move || Box::pin(make_request())),
59        }
60    }
61
62    #[cfg(target_family = "wasm")]
63    pub fn new<F, Fut>(make_request: F) -> Self
64    where
65        F: Fn() -> Fut + 'static,
66        Fut: Future<Output = Result<Request, OpenAIError>> + 'static,
67    {
68        Self {
69            make_request: Arc::new(move || Box::pin(make_request())),
70        }
71    }
72
73    /// Rebuild the request for the current attempt.
74    ///
75    /// This is what makes retries possible for non-cloneable bodies. The
76    /// request is not cloned after construction; instead, the original request
77    /// inputs are replayed to produce a fresh `reqwest::Request` each time.
78    pub async fn build(&self) -> Result<Request, OpenAIError> {
79        (self.make_request)().await
80    }
81}
82
83/// Minimal request execution interface used by `Client`.
84///
85/// The executor sees the replayable factory rather than a built request so it
86/// can decide when to rebuild and send. That keeps the retry decision close to
87/// execution and avoids forcing every call site to know whether a
88/// request body is cloneable.
89#[cfg(not(target_family = "wasm"))]
90pub trait HttpExecutor: Send + Sync {
91    fn execute(&self, request: HttpRequestFactory) -> HttpFuture;
92}
93
94#[cfg(target_family = "wasm")]
95pub trait HttpExecutor {
96    fn execute(&self, request: HttpRequestFactory) -> HttpFuture;
97}
98
99/// Default tower-compatible service backed directly by `reqwest::Client`.
100///
101/// Users can layer retry, timeout, rate limiting, tracing, or any other tower
102/// middleware around this service and then install the composed service with
103/// `Client::with_http_service(...)`.
104#[derive(Clone, Debug, Default)]
105pub struct ReqwestService {
106    client: reqwest::Client,
107}
108
109impl ReqwestService {
110    pub fn new(client: reqwest::Client) -> Self {
111        Self { client }
112    }
113}
114
115impl tower::Service<HttpRequestFactory> for ReqwestService {
116    type Response = Response;
117    type Error = OpenAIError;
118    type Future = HttpFuture;
119
120    fn poll_ready(
121        &mut self,
122        _cx: &mut std::task::Context<'_>,
123    ) -> std::task::Poll<Result<(), Self::Error>> {
124        std::task::Poll::Ready(Ok(()))
125    }
126
127    fn call(&mut self, request: HttpRequestFactory) -> Self::Future {
128        let client = self.client.clone();
129        Box::pin(async move {
130            // This is the plain reqwest transport path. It intentionally does
131            // nothing beyond rebuilding the request and executing it.
132            let request = request.build().await?;
133            client.execute(request).await.map_err(OpenAIError::Reqwest)
134        })
135    }
136}
137
138#[derive(Clone, Debug)]
139pub(crate) struct ReqwestExecutor {
140    service: OpenAIRetry<ReqwestService>,
141}
142
143impl ReqwestExecutor {
144    pub(crate) fn new(client: reqwest::Client) -> Self {
145        Self {
146            service: tower::ServiceBuilder::new()
147                .layer(crate::retry::OpenAIRetryLayer::default())
148                .service(ReqwestService::new(client)),
149        }
150    }
151}
152
153impl HttpExecutor for ReqwestExecutor {
154    fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
155        use tower::ServiceExt;
156
157        let service = self.service.clone();
158        Box::pin(async move { service.oneshot(request).await })
159    }
160}
161
162#[cfg(feature = "middleware")]
163pub(crate) struct TowerExecutor<S> {
164    service: S,
165}
166
167#[cfg(feature = "middleware")]
168impl<S> TowerExecutor<S> {
169    pub(crate) fn new(service: S) -> Self {
170        // The executor is just an adapter around a user-supplied tower stack.
171        // All of the interesting policy decisions live in the stack itself;
172        // this wrapper only keeps `Client` from becoming generic over `S`.
173        Self { service }
174    }
175}
176
177#[cfg(all(feature = "middleware", not(target_family = "wasm")))]
178impl<S> HttpExecutor for TowerExecutor<S>
179where
180    S: tower::Service<HttpRequestFactory, Response = Response> + Clone + Send + Sync + 'static,
181    S::Future: Send + 'static,
182    S::Error: Into<OpenAIError> + Send + Sync + 'static,
183{
184    fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
185        use tower::ServiceExt;
186
187        let service = self.service.clone();
188        Box::pin(async move {
189            // `oneshot` keeps the client-side executor simple: the tower stack
190            // decides how to use the replayable request factory, and the client
191            // does not need to manage readiness or buffering itself.
192            service.oneshot(request).await.map_err(Into::into)
193        })
194    }
195}
196
197#[cfg(all(feature = "middleware", target_family = "wasm"))]
198impl<S> HttpExecutor for TowerExecutor<S>
199where
200    S: tower::Service<HttpRequestFactory, Response = Response> + Clone + 'static,
201    S::Future: 'static,
202    S::Error: Into<OpenAIError> + 'static,
203{
204    fn execute(&self, request: HttpRequestFactory) -> HttpFuture {
205        use tower::ServiceExt;
206
207        let service = self.service.clone();
208        Box::pin(async move { service.oneshot(request).await.map_err(Into::into) })
209    }
210}
211
212pub(crate) type SharedExecutor = Arc<dyn HttpExecutor>;