use std::{
future::{self, Future},
sync::{Arc, Mutex, MutexGuard},
};
use bytes::Bytes;
use crate::{
http_client::{
self, HttpClientExt, LazyBody, MultipartForm, Request, Response, StreamingResponse,
},
wasm_compat::WasmCompatSend,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CapturedHttpRequest {
pub uri: String,
pub body: Bytes,
}
#[derive(Clone, Debug)]
pub enum MockHttpResponse {
Success(Bytes),
Error(http::StatusCode, String),
}
impl MockHttpResponse {
pub fn success(body: impl Into<Bytes>) -> Self {
Self::Success(body.into())
}
pub fn error(status: http::StatusCode, message: impl Into<String>) -> Self {
Self::Error(status, message.into())
}
}
impl Default for MockHttpResponse {
fn default() -> Self {
Self::Success(Bytes::new())
}
}
#[derive(Clone, Debug, Default)]
pub struct RecordingHttpClient {
requests: Arc<Mutex<Vec<CapturedHttpRequest>>>,
response: Arc<Mutex<MockHttpResponse>>,
}
impl RecordingHttpClient {
pub fn new(response_body: impl Into<Bytes>) -> Self {
Self {
requests: Arc::new(Mutex::new(Vec::new())),
response: Arc::new(Mutex::new(MockHttpResponse::success(response_body))),
}
}
pub fn with_error(status: http::StatusCode, message: impl Into<String>) -> Self {
Self {
requests: Arc::new(Mutex::new(Vec::new())),
response: Arc::new(Mutex::new(MockHttpResponse::error(status, message))),
}
}
pub fn requests(&self) -> Vec<CapturedHttpRequest> {
self.requests_guard().clone()
}
pub fn set_response(&self, response: MockHttpResponse) {
*self.response_guard() = response;
}
fn requests_guard(&self) -> MutexGuard<'_, Vec<CapturedHttpRequest>> {
match self.requests.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
fn response_guard(&self) -> MutexGuard<'_, MockHttpResponse> {
match self.response.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
}
impl HttpClientExt for RecordingHttpClient {
fn send<T, U>(
&self,
req: Request<T>,
) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
where
T: Into<Bytes> + WasmCompatSend,
U: From<Bytes> + WasmCompatSend + 'static,
{
let requests = Arc::clone(&self.requests);
let response = self.response_guard().clone();
let (parts, body) = req.into_parts();
let body = body.into();
match requests.lock() {
Ok(mut guard) => guard.push(CapturedHttpRequest {
uri: parts.uri.to_string(),
body,
}),
Err(poisoned) => poisoned.into_inner().push(CapturedHttpRequest {
uri: parts.uri.to_string(),
body,
}),
}
async move {
let response_body = match response {
MockHttpResponse::Success(response_body) => response_body,
MockHttpResponse::Error(status, message) => {
return Err(http_client::Error::InvalidStatusCodeWithMessage(
status, message,
));
}
};
let body: LazyBody<U> = Box::pin(async move { Ok(U::from(response_body)) });
Response::builder()
.status(http::StatusCode::OK)
.body(body)
.map_err(http_client::Error::Protocol)
}
}
fn send_multipart<U>(
&self,
_req: Request<MultipartForm>,
) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
where
U: From<Bytes> + WasmCompatSend + 'static,
{
future::ready(Err(http_client::Error::InvalidStatusCode(
http::StatusCode::NOT_IMPLEMENTED,
)))
}
fn send_streaming<T>(
&self,
_req: Request<T>,
) -> impl Future<Output = http_client::Result<StreamingResponse>> + WasmCompatSend
where
T: Into<Bytes> + WasmCompatSend,
{
future::ready(Err(http_client::Error::InvalidStatusCode(
http::StatusCode::NOT_IMPLEMENTED,
)))
}
}
#[derive(Clone, Debug, Default)]
pub struct MockStreamingClient {
pub sse_bytes: Bytes,
}
impl HttpClientExt for MockStreamingClient {
fn send<T, U>(
&self,
_req: Request<T>,
) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
where
T: Into<Bytes> + WasmCompatSend,
U: From<Bytes> + WasmCompatSend + 'static,
{
future::ready(Err(http_client::Error::InvalidStatusCode(
http::StatusCode::NOT_IMPLEMENTED,
)))
}
fn send_multipart<U>(
&self,
_req: Request<MultipartForm>,
) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
where
U: From<Bytes> + WasmCompatSend + 'static,
{
future::ready(Err(http_client::Error::InvalidStatusCode(
http::StatusCode::NOT_IMPLEMENTED,
)))
}
fn send_streaming<T>(
&self,
_req: Request<T>,
) -> impl Future<Output = http_client::Result<StreamingResponse>> + WasmCompatSend
where
T: Into<Bytes> + WasmCompatSend,
{
let sse_bytes = self.sse_bytes.clone();
async move {
let byte_stream =
futures::stream::iter(vec![Ok::<Bytes, http_client::Error>(sse_bytes)]);
let boxed_stream: http_client::sse::BoxedStream = Box::pin(byte_stream);
Response::builder()
.status(http::StatusCode::OK)
.header(http::header::CONTENT_TYPE, "text/event-stream")
.body(boxed_stream)
.map_err(http_client::Error::Protocol)
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SequencedStreamingHttpClient {
chunks: Arc<Mutex<Option<Vec<http_client::Result<Bytes>>>>>,
}
impl SequencedStreamingHttpClient {
pub fn new(chunks: Vec<http_client::Result<Bytes>>) -> Self {
Self {
chunks: Arc::new(Mutex::new(Some(chunks))),
}
}
}
impl HttpClientExt for SequencedStreamingHttpClient {
fn send<T, U>(
&self,
_req: Request<T>,
) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
where
T: Into<Bytes> + WasmCompatSend,
U: From<Bytes> + WasmCompatSend + 'static,
{
future::ready(Err(http_client::Error::InvalidStatusCode(
http::StatusCode::NOT_IMPLEMENTED,
)))
}
fn send_multipart<U>(
&self,
_req: Request<MultipartForm>,
) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
where
U: From<Bytes> + WasmCompatSend + 'static,
{
future::ready(Err(http_client::Error::InvalidStatusCode(
http::StatusCode::NOT_IMPLEMENTED,
)))
}
fn send_streaming<T>(
&self,
_req: Request<T>,
) -> impl Future<Output = http_client::Result<StreamingResponse>> + WasmCompatSend
where
T: Into<Bytes> + WasmCompatSend,
{
let chunks = match self.chunks.lock() {
Ok(mut guard) => guard.take(),
Err(poisoned) => poisoned.into_inner().take(),
};
async move {
let Some(chunks) = chunks else {
return Err(http_client::Error::InvalidStatusCodeWithMessage(
http::StatusCode::INTERNAL_SERVER_ERROR,
"streaming chunks should only be consumed once".to_string(),
));
};
let byte_stream = futures::stream::iter(chunks);
let boxed_stream: http_client::sse::BoxedStream = Box::pin(byte_stream);
Response::builder()
.status(http::StatusCode::OK)
.header(http::header::CONTENT_TYPE, "text/event-stream")
.body(boxed_stream)
.map_err(http_client::Error::Protocol)
}
}
}