Skip to main content

rig_core/test_utils/
http.rs

1//! HTTP client doubles for provider tests.
2
3use std::{
4    future::{self, Future},
5    sync::{Arc, Mutex, MutexGuard},
6};
7
8use bytes::Bytes;
9
10use crate::{
11    http_client::{
12        self, HttpClientExt, LazyBody, MultipartForm, Request, Response, StreamingResponse,
13    },
14    wasm_compat::WasmCompatSend,
15};
16
17/// Request data captured by [`RecordingHttpClient`].
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct CapturedHttpRequest {
20    /// Request URI.
21    pub uri: String,
22    /// Request body bytes.
23    pub body: Bytes,
24}
25
26/// Response scripted for [`RecordingHttpClient`].
27#[derive(Clone, Debug)]
28pub enum MockHttpResponse {
29    /// Return this body with a successful HTTP status.
30    Success(Bytes),
31    /// Return a status-code error with the given body text.
32    Error(http::StatusCode, String),
33}
34
35impl MockHttpResponse {
36    /// Create a successful response from bytes.
37    pub fn success(body: impl Into<Bytes>) -> Self {
38        Self::Success(body.into())
39    }
40
41    /// Create an error response with a status code and message.
42    pub fn error(status: http::StatusCode, message: impl Into<String>) -> Self {
43        Self::Error(status, message.into())
44    }
45}
46
47impl Default for MockHttpResponse {
48    fn default() -> Self {
49        Self::Success(Bytes::new())
50    }
51}
52
53/// An [`HttpClientExt`] implementation that records unary requests and returns
54/// a fixed response.
55#[derive(Clone, Debug, Default)]
56pub struct RecordingHttpClient {
57    requests: Arc<Mutex<Vec<CapturedHttpRequest>>>,
58    response: Arc<Mutex<MockHttpResponse>>,
59}
60
61impl RecordingHttpClient {
62    /// Create a client that returns `response_body` for unary requests.
63    pub fn new(response_body: impl Into<Bytes>) -> Self {
64        Self {
65            requests: Arc::new(Mutex::new(Vec::new())),
66            response: Arc::new(Mutex::new(MockHttpResponse::success(response_body))),
67        }
68    }
69
70    /// Create a client that returns an HTTP status error for unary requests.
71    pub fn with_error(status: http::StatusCode, message: impl Into<String>) -> Self {
72        Self {
73            requests: Arc::new(Mutex::new(Vec::new())),
74            response: Arc::new(Mutex::new(MockHttpResponse::error(status, message))),
75        }
76    }
77
78    /// Return the requests captured so far.
79    pub fn requests(&self) -> Vec<CapturedHttpRequest> {
80        self.requests_guard().clone()
81    }
82
83    /// Replace the scripted unary response.
84    pub fn set_response(&self, response: MockHttpResponse) {
85        *self.response_guard() = response;
86    }
87
88    fn requests_guard(&self) -> MutexGuard<'_, Vec<CapturedHttpRequest>> {
89        match self.requests.lock() {
90            Ok(guard) => guard,
91            Err(poisoned) => poisoned.into_inner(),
92        }
93    }
94
95    fn response_guard(&self) -> MutexGuard<'_, MockHttpResponse> {
96        match self.response.lock() {
97            Ok(guard) => guard,
98            Err(poisoned) => poisoned.into_inner(),
99        }
100    }
101}
102
103impl HttpClientExt for RecordingHttpClient {
104    fn send<T, U>(
105        &self,
106        req: Request<T>,
107    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
108    where
109        T: Into<Bytes> + WasmCompatSend,
110        U: From<Bytes> + WasmCompatSend + 'static,
111    {
112        let requests = Arc::clone(&self.requests);
113        let response = self.response_guard().clone();
114        let (parts, body) = req.into_parts();
115        let body = body.into();
116
117        match requests.lock() {
118            Ok(mut guard) => guard.push(CapturedHttpRequest {
119                uri: parts.uri.to_string(),
120                body,
121            }),
122            Err(poisoned) => poisoned.into_inner().push(CapturedHttpRequest {
123                uri: parts.uri.to_string(),
124                body,
125            }),
126        }
127
128        async move {
129            let response_body = match response {
130                MockHttpResponse::Success(response_body) => response_body,
131                MockHttpResponse::Error(status, message) => {
132                    return Err(http_client::Error::InvalidStatusCodeWithMessage(
133                        status, message,
134                    ));
135                }
136            };
137            let body: LazyBody<U> = Box::pin(async move { Ok(U::from(response_body)) });
138            Response::builder()
139                .status(http::StatusCode::OK)
140                .body(body)
141                .map_err(http_client::Error::Protocol)
142        }
143    }
144
145    fn send_multipart<U>(
146        &self,
147        _req: Request<MultipartForm>,
148    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
149    where
150        U: From<Bytes> + WasmCompatSend + 'static,
151    {
152        future::ready(Err(http_client::Error::InvalidStatusCode(
153            http::StatusCode::NOT_IMPLEMENTED,
154        )))
155    }
156
157    fn send_streaming<T>(
158        &self,
159        _req: Request<T>,
160    ) -> impl Future<Output = http_client::Result<StreamingResponse>> + WasmCompatSend
161    where
162        T: Into<Bytes> + WasmCompatSend,
163    {
164        future::ready(Err(http_client::Error::InvalidStatusCode(
165            http::StatusCode::NOT_IMPLEMENTED,
166        )))
167    }
168}
169
170/// A mock HTTP client that returns pre-built SSE bytes from `send_streaming`.
171///
172/// `send` and `send_multipart` always return `NOT_IMPLEMENTED`.
173#[derive(Clone, Debug, Default)]
174pub struct MockStreamingClient {
175    /// Bytes returned as a single streaming response chunk.
176    pub sse_bytes: Bytes,
177}
178
179impl HttpClientExt for MockStreamingClient {
180    fn send<T, U>(
181        &self,
182        _req: Request<T>,
183    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
184    where
185        T: Into<Bytes> + WasmCompatSend,
186        U: From<Bytes> + WasmCompatSend + 'static,
187    {
188        future::ready(Err(http_client::Error::InvalidStatusCode(
189            http::StatusCode::NOT_IMPLEMENTED,
190        )))
191    }
192
193    fn send_multipart<U>(
194        &self,
195        _req: Request<MultipartForm>,
196    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
197    where
198        U: From<Bytes> + WasmCompatSend + 'static,
199    {
200        future::ready(Err(http_client::Error::InvalidStatusCode(
201            http::StatusCode::NOT_IMPLEMENTED,
202        )))
203    }
204
205    fn send_streaming<T>(
206        &self,
207        _req: Request<T>,
208    ) -> impl Future<Output = http_client::Result<StreamingResponse>> + WasmCompatSend
209    where
210        T: Into<Bytes> + WasmCompatSend,
211    {
212        let sse_bytes = self.sse_bytes.clone();
213        async move {
214            let byte_stream =
215                futures::stream::iter(vec![Ok::<Bytes, http_client::Error>(sse_bytes)]);
216            let boxed_stream: http_client::sse::BoxedStream = Box::pin(byte_stream);
217
218            Response::builder()
219                .status(http::StatusCode::OK)
220                .header(http::header::CONTENT_TYPE, "text/event-stream")
221                .body(boxed_stream)
222                .map_err(http_client::Error::Protocol)
223        }
224    }
225}
226
227/// An [`HttpClientExt`] implementation that returns one scripted stream of byte
228/// chunks from `send_streaming`.
229#[derive(Debug, Clone, Default)]
230pub struct SequencedStreamingHttpClient {
231    chunks: Arc<Mutex<Option<Vec<http_client::Result<Bytes>>>>>,
232}
233
234impl SequencedStreamingHttpClient {
235    /// Create a streaming client from the chunks it should yield.
236    pub fn new(chunks: Vec<http_client::Result<Bytes>>) -> Self {
237        Self {
238            chunks: Arc::new(Mutex::new(Some(chunks))),
239        }
240    }
241}
242
243impl HttpClientExt for SequencedStreamingHttpClient {
244    fn send<T, U>(
245        &self,
246        _req: Request<T>,
247    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
248    where
249        T: Into<Bytes> + WasmCompatSend,
250        U: From<Bytes> + WasmCompatSend + 'static,
251    {
252        future::ready(Err(http_client::Error::InvalidStatusCode(
253            http::StatusCode::NOT_IMPLEMENTED,
254        )))
255    }
256
257    fn send_multipart<U>(
258        &self,
259        _req: Request<MultipartForm>,
260    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
261    where
262        U: From<Bytes> + WasmCompatSend + 'static,
263    {
264        future::ready(Err(http_client::Error::InvalidStatusCode(
265            http::StatusCode::NOT_IMPLEMENTED,
266        )))
267    }
268
269    fn send_streaming<T>(
270        &self,
271        _req: Request<T>,
272    ) -> impl Future<Output = http_client::Result<StreamingResponse>> + WasmCompatSend
273    where
274        T: Into<Bytes> + WasmCompatSend,
275    {
276        let chunks = match self.chunks.lock() {
277            Ok(mut guard) => guard.take(),
278            Err(poisoned) => poisoned.into_inner().take(),
279        };
280
281        async move {
282            let Some(chunks) = chunks else {
283                return Err(http_client::Error::InvalidStatusCodeWithMessage(
284                    http::StatusCode::INTERNAL_SERVER_ERROR,
285                    "streaming chunks should only be consumed once".to_string(),
286                ));
287            };
288
289            let byte_stream = futures::stream::iter(chunks);
290            let boxed_stream: http_client::sse::BoxedStream = Box::pin(byte_stream);
291
292            Response::builder()
293                .status(http::StatusCode::OK)
294                .header(http::header::CONTENT_TYPE, "text/event-stream")
295                .body(boxed_stream)
296                .map_err(http_client::Error::Protocol)
297        }
298    }
299}