use std::{sync::Arc, time::Duration};
use http::status::StatusCode;
use motore::{
layer::Layer,
service::{BoxService, Service},
};
use volo::client::MkClient;
use super::{Client, ClientBuilder, ClientInner};
use crate::{
body::{Body, BodyConversion},
context::client::ClientContext,
error::client::{ClientError, Result, other_error},
request::{Request, RequestPartsExt},
response::Response,
};
pub type ClientMockService = MockTransport;
pub enum MockTransport {
Status(StatusCode),
Service(BoxService<ClientContext, Request, Response, ClientError>),
}
impl Default for MockTransport {
fn default() -> Self {
Self::Status(StatusCode::OK)
}
}
impl MockTransport {
pub fn new() -> Self {
Self::default()
}
pub fn status_code(status: StatusCode) -> Self {
Self::Status(status)
}
pub fn service<S>(service: S) -> Self
where
S: Service<ClientContext, Request, Response = Response, Error = ClientError>
+ Send
+ Sync
+ 'static,
{
Self::Service(BoxService::new(service))
}
#[cfg(feature = "server")]
pub fn server_service<S>(service: S) -> Self
where
S: Service<crate::context::ServerContext, Request, Response = Response>
+ Send
+ Sync
+ 'static,
S::Error: Into<crate::error::BoxError>,
{
Self::Service(BoxService::new(
crate::utils::test_helpers::ConvertService::new(service),
))
}
}
impl Service<ClientContext, Request> for MockTransport {
type Response = Response;
type Error = ClientError;
async fn call(
&self,
cx: &mut ClientContext,
req: Request,
) -> Result<Self::Response, Self::Error> {
match self {
Self::Status(status) => {
let mut resp = Response::default();
status.clone_into(resp.status_mut());
Ok(resp)
}
Self::Service(srv) => srv.call(cx, req).await,
}
}
}
impl<IL, OL, C, LB> ClientBuilder<IL, OL, C, LB> {
pub fn mock<ReqBody, RespBody>(self, transport: MockTransport) -> Result<C::Target>
where
IL: Layer<ClientMockService>,
IL::Service: Send + Sync + 'static,
OL: Layer<IL::Service>,
OL::Service: Service<
ClientContext,
Request<ReqBody>,
Response = Response<RespBody>,
Error = ClientError,
> + Send
+ Sync
+ 'static,
C: MkClient<Client<ReqBody, RespBody>>,
ReqBody: Send + 'static,
RespBody: Send,
{
self.status?;
let meta_service = transport;
let service = self.outer_layer.layer(self.inner_layer.layer(meta_service));
let service = BoxService::new(service);
let client_inner = ClientInner {
service,
timeout: self.timeout,
headers: self.headers,
};
let client = Client {
inner: Arc::new(client_inner),
};
Ok(self.mk_client.mk_client(client))
}
}
#[derive(Debug, Default)]
pub enum DebugLayer {
#[default]
DumpString,
DumpBytes,
}
fn dump_request_parts(parts: &http::request::Parts) {
if let Some(url) = parts.url() {
println!(" == {url} ==");
}
println!("{:?} {:?} {:?}", parts.method, parts.uri, parts.version);
for (k, v) in parts.headers.iter() {
let Ok(v) = v.to_str() else {
continue;
};
println!("{k}: {v}");
}
}
fn dump_response_parts(parts: &http::response::Parts) {
println!("{:?} {}", parts.version, parts.status);
for (k, v) in parts.headers.iter() {
println!("{k}: {v:?}");
}
}
impl DebugLayer {
async fn dump_request(&self, req: Request) -> Result<Request> {
let (parts, body) = req.into_parts();
let bytes = body.into_bytes().await?;
println!(" ==== DebugLayer::dump_request ====");
dump_request_parts(&parts);
println!();
match self {
DebugLayer::DumpString => {
let s = std::str::from_utf8(bytes.as_ref()).map_err(other_error)?;
println!("{s}");
}
DebugLayer::DumpBytes => {
println!("{:?}", bytes.as_ref());
}
}
println!(" ==== DebugLayer::dump_request ====");
let body = Body::from(bytes);
Ok(Request::from_parts(parts, body))
}
async fn dump_response(&self, resp: Response) -> Result<Response> {
let (parts, body) = resp.into_parts();
let bytes = body.into_bytes().await?;
println!(" ==== DebugLayer::dump_response ====");
dump_response_parts(&parts);
println!();
match self {
DebugLayer::DumpString => {
let s = std::str::from_utf8(bytes.as_ref()).map_err(other_error)?;
println!("{s}");
}
DebugLayer::DumpBytes => {
println!("{:?}", bytes.as_ref());
}
}
println!(" ==== DebugLayer::dump_response ====");
let body = Body::from(bytes);
Ok(Response::from_parts(parts, body))
}
}
impl<S> Layer<S> for DebugLayer {
type Service = DebugService<S>;
fn layer(self, inner: S) -> Self::Service {
DebugService {
inner,
config: self,
}
}
}
pub struct DebugService<S> {
inner: S,
config: DebugLayer,
}
impl<S> Service<ClientContext, Request> for DebugService<S>
where
S: Service<ClientContext, Request, Response = Response, Error = ClientError> + Send + Sync,
{
type Response = Response;
type Error = ClientError;
async fn call(
&self,
cx: &mut ClientContext,
req: Request,
) -> Result<Self::Response, Self::Error> {
let req = self.config.dump_request(req).await?;
let resp = self.inner.call(cx, req).await?;
self.config.dump_response(resp).await
}
}
pub struct RetryOnStatus {
client_error: bool,
server_error: bool,
max_retry: usize,
sleep_time: Duration,
}
impl RetryOnStatus {
const DEFAULT_MAX_RETRY: usize = 5;
const DEFAULT_SLEEP_TIME: Duration = Duration::from_secs(1);
pub fn all() -> Self {
Self {
client_error: true,
server_error: true,
max_retry: Self::DEFAULT_MAX_RETRY,
sleep_time: Self::DEFAULT_SLEEP_TIME,
}
}
pub fn client_error() -> Self {
Self {
client_error: true,
server_error: true,
max_retry: Self::DEFAULT_MAX_RETRY,
sleep_time: Self::DEFAULT_SLEEP_TIME,
}
}
pub fn server_error() -> Self {
Self {
client_error: true,
server_error: true,
max_retry: Self::DEFAULT_MAX_RETRY,
sleep_time: Self::DEFAULT_SLEEP_TIME,
}
}
pub fn with_max_retry(mut self, max_retry: usize) -> Self {
self.max_retry = max_retry;
self
}
pub fn with_sleep_time(mut self, sleep_time: Duration) -> Self {
self.sleep_time = sleep_time;
self
}
}
impl<S> Layer<S> for RetryOnStatus {
type Service = RetryOnStatusService<S>;
fn layer(self, inner: S) -> Self::Service {
RetryOnStatusService {
inner,
config: self,
}
}
}
pub struct RetryOnStatusService<S> {
inner: S,
config: RetryOnStatus,
}
impl<S> Service<ClientContext, Request> for RetryOnStatusService<S>
where
S: Service<ClientContext, Request, Response = Response, Error = ClientError> + Send + Sync,
{
type Response = S::Response;
type Error = S::Error;
async fn call(
&self,
cx: &mut ClientContext,
req: Request,
) -> Result<Self::Response, Self::Error> {
let (parts, body) = req.into_parts();
let bytes_body = body.into_bytes().await?;
let mut retry = 0;
loop {
let req = Request::from_parts(parts.clone(), Body::from(bytes_body.clone()));
let resp = self.inner.call(cx, req).await?;
if (retry < self.config.max_retry)
&& ((resp.status().is_client_error() && self.config.client_error)
|| (resp.status().is_server_error() && self.config.server_error))
{
retry += 1;
tokio::time::sleep(self.config.sleep_time).await;
println!("retry on \"{}\" for {retry} time(s)", parts.uri);
} else {
return Ok(resp);
}
}
}
}
mod mock_transport_tests {
use http::status::StatusCode;
use super::MockTransport;
use crate::{ClientBuilder, body::BodyConversion};
#[tokio::test]
async fn empty_response_test() {
let client = ClientBuilder::new().mock(MockTransport::default()).unwrap();
let resp = client.get("/").send().await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(resp.headers().is_empty());
assert!(resp.into_body().into_vec().await.unwrap().is_empty());
}
}