Skip to main content

rig_core/test_utils/
http.rs

1//! HTTP client doubles for provider tests.
2
3use std::{
4    future::{self, Future},
5    sync::{Arc, Mutex, MutexGuard},
6};
7
8use bytes::Bytes;
9
10use crate::{
11    http_client::{
12        self, HttpClientExt, LazyBody, MultipartForm, Request, Response, StreamingResponse,
13    },
14    wasm_compat::WasmCompatSend,
15};
16
17/// Request data captured by [`RecordingHttpClient`].
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct CapturedHttpRequest {
20    /// Request URI.
21    pub uri: String,
22    /// Request headers.
23    pub headers: http::HeaderMap,
24    /// Request body bytes.
25    pub body: Bytes,
26}
27
28/// Response scripted for [`RecordingHttpClient`].
29#[derive(Clone, Debug)]
30pub enum MockHttpResponse {
31    /// Return this body with a successful HTTP status.
32    Success(Bytes),
33    /// Return a status-code error with the given body text.
34    Error(http::StatusCode, String),
35}
36
37impl MockHttpResponse {
38    /// Create a successful response from bytes.
39    pub fn success(body: impl Into<Bytes>) -> Self {
40        Self::Success(body.into())
41    }
42
43    /// Create an error response with a status code and message.
44    pub fn error(status: http::StatusCode, message: impl Into<String>) -> Self {
45        Self::Error(status, message.into())
46    }
47}
48
49impl Default for MockHttpResponse {
50    fn default() -> Self {
51        Self::Success(Bytes::new())
52    }
53}
54
55/// An [`HttpClientExt`] implementation that records unary requests and returns
56/// a fixed response.
57#[derive(Clone, Debug, Default)]
58pub struct RecordingHttpClient {
59    requests: Arc<Mutex<Vec<CapturedHttpRequest>>>,
60    response: Arc<Mutex<MockHttpResponse>>,
61}
62
63impl RecordingHttpClient {
64    /// Create a client that returns `response_body` for unary requests.
65    pub fn new(response_body: impl Into<Bytes>) -> Self {
66        Self {
67            requests: Arc::new(Mutex::new(Vec::new())),
68            response: Arc::new(Mutex::new(MockHttpResponse::success(response_body))),
69        }
70    }
71
72    /// Create a client that returns an HTTP status error for unary requests.
73    pub fn with_error(status: http::StatusCode, message: impl Into<String>) -> Self {
74        Self {
75            requests: Arc::new(Mutex::new(Vec::new())),
76            response: Arc::new(Mutex::new(MockHttpResponse::error(status, message))),
77        }
78    }
79
80    /// Return the requests captured so far.
81    pub fn requests(&self) -> Vec<CapturedHttpRequest> {
82        self.requests_guard().clone()
83    }
84
85    /// Replace the scripted unary response.
86    pub fn set_response(&self, response: MockHttpResponse) {
87        *self.response_guard() = response;
88    }
89
90    fn requests_guard(&self) -> MutexGuard<'_, Vec<CapturedHttpRequest>> {
91        match self.requests.lock() {
92            Ok(guard) => guard,
93            Err(poisoned) => poisoned.into_inner(),
94        }
95    }
96
97    fn response_guard(&self) -> MutexGuard<'_, MockHttpResponse> {
98        match self.response.lock() {
99            Ok(guard) => guard,
100            Err(poisoned) => poisoned.into_inner(),
101        }
102    }
103}
104
105impl HttpClientExt for RecordingHttpClient {
106    fn send<T, U>(
107        &self,
108        req: Request<T>,
109    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
110    where
111        T: Into<Bytes> + WasmCompatSend,
112        U: From<Bytes> + WasmCompatSend + 'static,
113    {
114        let requests = Arc::clone(&self.requests);
115        let response = self.response_guard().clone();
116        let (parts, body) = req.into_parts();
117        let uri = parts.uri.to_string();
118        let headers = parts.headers;
119        let body = body.into();
120
121        match requests.lock() {
122            Ok(mut guard) => guard.push(CapturedHttpRequest { uri, headers, body }),
123            Err(poisoned) => poisoned
124                .into_inner()
125                .push(CapturedHttpRequest { uri, headers, body }),
126        }
127
128        async move {
129            let response_body = match response {
130                MockHttpResponse::Success(response_body) => response_body,
131                MockHttpResponse::Error(status, message) => {
132                    return Err(http_client::Error::InvalidStatusCodeWithMessage(
133                        status, message,
134                    ));
135                }
136            };
137            let body: LazyBody<U> = Box::pin(async move { Ok(U::from(response_body)) });
138            Response::builder()
139                .status(http::StatusCode::OK)
140                .body(body)
141                .map_err(http_client::Error::Protocol)
142        }
143    }
144
145    fn send_multipart<U>(
146        &self,
147        _req: Request<MultipartForm>,
148    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
149    where
150        U: From<Bytes> + WasmCompatSend + 'static,
151    {
152        future::ready(Err(http_client::Error::InvalidStatusCode(
153            http::StatusCode::NOT_IMPLEMENTED,
154        )))
155    }
156
157    fn send_streaming<T>(
158        &self,
159        _req: Request<T>,
160    ) -> impl Future<Output = http_client::Result<StreamingResponse>> + WasmCompatSend
161    where
162        T: Into<Bytes> + WasmCompatSend,
163    {
164        future::ready(Err(http_client::Error::InvalidStatusCode(
165            http::StatusCode::NOT_IMPLEMENTED,
166        )))
167    }
168}
169
170/// A mock HTTP client that returns pre-built SSE bytes from `send_streaming`.
171///
172/// `send` and `send_multipart` always return `NOT_IMPLEMENTED`.
173#[derive(Clone, Debug, Default)]
174pub struct MockStreamingClient {
175    /// Bytes returned as a single streaming response chunk.
176    pub sse_bytes: Bytes,
177}
178
179impl HttpClientExt for MockStreamingClient {
180    fn send<T, U>(
181        &self,
182        _req: Request<T>,
183    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
184    where
185        T: Into<Bytes> + WasmCompatSend,
186        U: From<Bytes> + WasmCompatSend + 'static,
187    {
188        future::ready(Err(http_client::Error::InvalidStatusCode(
189            http::StatusCode::NOT_IMPLEMENTED,
190        )))
191    }
192
193    fn send_multipart<U>(
194        &self,
195        _req: Request<MultipartForm>,
196    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
197    where
198        U: From<Bytes> + WasmCompatSend + 'static,
199    {
200        future::ready(Err(http_client::Error::InvalidStatusCode(
201            http::StatusCode::NOT_IMPLEMENTED,
202        )))
203    }
204
205    fn send_streaming<T>(
206        &self,
207        _req: Request<T>,
208    ) -> impl Future<Output = http_client::Result<StreamingResponse>> + WasmCompatSend
209    where
210        T: Into<Bytes> + WasmCompatSend,
211    {
212        let sse_bytes = self.sse_bytes.clone();
213        async move {
214            let byte_stream =
215                futures::stream::iter(vec![Ok::<Bytes, http_client::Error>(sse_bytes)]);
216            let boxed_stream: http_client::sse::BoxedStream = Box::pin(byte_stream);
217
218            Response::builder()
219                .status(http::StatusCode::OK)
220                .header(http::header::CONTENT_TYPE, "text/event-stream")
221                .body(boxed_stream)
222                .map_err(http_client::Error::Protocol)
223        }
224    }
225}
226
227/// An [`HttpClientExt`] implementation that returns one scripted stream of byte
228/// chunks from `send_streaming`.
229#[derive(Debug, Clone, Default)]
230pub struct SequencedStreamingHttpClient {
231    chunks: Arc<Mutex<Option<Vec<http_client::Result<Bytes>>>>>,
232}
233
234impl SequencedStreamingHttpClient {
235    /// Create a streaming client from the chunks it should yield.
236    pub fn new(chunks: Vec<http_client::Result<Bytes>>) -> Self {
237        Self {
238            chunks: Arc::new(Mutex::new(Some(chunks))),
239        }
240    }
241}
242
243impl HttpClientExt for SequencedStreamingHttpClient {
244    fn send<T, U>(
245        &self,
246        _req: Request<T>,
247    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
248    where
249        T: Into<Bytes> + WasmCompatSend,
250        U: From<Bytes> + WasmCompatSend + 'static,
251    {
252        future::ready(Err(http_client::Error::InvalidStatusCode(
253            http::StatusCode::NOT_IMPLEMENTED,
254        )))
255    }
256
257    fn send_multipart<U>(
258        &self,
259        _req: Request<MultipartForm>,
260    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
261    where
262        U: From<Bytes> + WasmCompatSend + 'static,
263    {
264        future::ready(Err(http_client::Error::InvalidStatusCode(
265            http::StatusCode::NOT_IMPLEMENTED,
266        )))
267    }
268
269    fn send_streaming<T>(
270        &self,
271        _req: Request<T>,
272    ) -> impl Future<Output = http_client::Result<StreamingResponse>> + WasmCompatSend
273    where
274        T: Into<Bytes> + WasmCompatSend,
275    {
276        let chunks = match self.chunks.lock() {
277            Ok(mut guard) => guard.take(),
278            Err(poisoned) => poisoned.into_inner().take(),
279        };
280
281        async move {
282            let Some(chunks) = chunks else {
283                return Err(http_client::Error::InvalidStatusCodeWithMessage(
284                    http::StatusCode::INTERNAL_SERVER_ERROR,
285                    "streaming chunks should only be consumed once".to_string(),
286                ));
287            };
288
289            let byte_stream = futures::stream::iter(chunks);
290            let boxed_stream: http_client::sse::BoxedStream = Box::pin(byte_stream);
291
292            Response::builder()
293                .status(http::StatusCode::OK)
294                .header(http::header::CONTENT_TYPE, "text/event-stream")
295                .body(boxed_stream)
296                .map_err(http_client::Error::Protocol)
297        }
298    }
299}