use std::{error::Error, fmt};
use http::status::StatusCode;
use motore::{layer::Layer, service::Service};
use url::Url;
use volo::context::Context;
use crate::{
error::{ClientError, client::request_error},
request::RequestPartsExt,
response::Response,
};
#[derive(Clone, Debug, Default)]
pub struct FailOnStatus {
client_error: bool,
server_error: bool,
detailed: bool,
}
impl FailOnStatus {
pub fn all() -> Self {
Self {
client_error: true,
server_error: true,
detailed: false,
}
}
pub fn client_error() -> Self {
Self {
client_error: true,
server_error: false,
detailed: false,
}
}
pub fn server_error() -> Self {
Self {
client_error: false,
server_error: true,
detailed: false,
}
}
pub fn detailed(mut self) -> Self {
self.detailed = true;
self
}
}
impl<S> Layer<S> for FailOnStatus {
type Service = FailOnStatusService<S>;
fn layer(self, inner: S) -> Self::Service {
FailOnStatusService {
inner,
fail_on: self,
}
}
}
pub struct FailOnStatusService<S> {
inner: S,
fail_on: FailOnStatus,
}
impl<Cx, Req, S, B> Service<Cx, Req> for FailOnStatusService<S>
where
Cx: Context + Send,
Req: RequestPartsExt + Send,
S: Service<Cx, Req, Response = Response<B>, Error = ClientError> + Send + Sync,
{
type Response = S::Response;
type Error = S::Error;
async fn call(&self, cx: &mut Cx, req: Req) -> Result<Self::Response, Self::Error> {
let url = if self.fail_on.detailed {
req.url()
} else {
None
};
let resp = self.inner.call(cx, req).await?;
let status = resp.status();
if (self.fail_on.client_error && status.is_client_error())
|| (self.fail_on.server_error && status.is_server_error())
{
Err(request_error(StatusCodeError { status, url })
.with_endpoint(cx.rpc_info().callee()))
} else {
Ok(resp)
}
}
}
pub struct StatusCodeError {
status: StatusCode,
url: Option<Url>,
}
impl StatusCodeError {
pub fn status(&self) -> StatusCode {
self.status
}
pub fn url(&self) -> Option<&Url> {
self.url.as_ref()
}
}
impl fmt::Debug for StatusCodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StatusCodeError")
.field("status", &self.status)
.finish()
}
}
impl fmt::Display for StatusCodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "client received an error status `{}`", self.status)?;
if let Some(url) = &self.url {
write!(f, " for `{url}`")?;
}
Ok(())
}
}
impl Error for StatusCodeError {}
#[cfg(test)]
mod fail_on_status_tests {
use http::status::StatusCode;
use motore::service::Service;
use super::FailOnStatus;
use crate::{
ClientBuilder, body::Body, client::test_helpers::MockTransport, context::ClientContext,
error::ClientError, request::Request, response::Response,
};
struct ReturnStatus;
impl Service<ClientContext, Request> for ReturnStatus {
type Response = Response;
type Error = ClientError;
fn call(
&self,
_: &mut ClientContext,
req: Request,
) -> impl std::future::Future<Output = Result<Self::Response, Self::Error>> + Send {
let path = req.uri().path();
assert_eq!(&path[..1], "/");
let status_code = path[1..].parse::<u16>().expect("invalid uri");
let status_code = StatusCode::from_u16(status_code).expect("invalid status code");
let mut resp = Response::new(Body::empty());
*resp.status_mut() = status_code;
async { Ok(resp) }
}
}
#[tokio::test]
async fn fail_on_status_test() {
{
let client = ClientBuilder::new()
.layer_outer_front(FailOnStatus::all())
.mock(MockTransport::service(ReturnStatus))
.unwrap();
client.get("/400").send().await.unwrap_err();
client.get("/500").send().await.unwrap_err();
}
{
let client = ClientBuilder::new()
.layer_outer_front(FailOnStatus::client_error())
.mock(MockTransport::service(ReturnStatus))
.unwrap();
client.get("/400").send().await.unwrap_err();
client.get("/500").send().await.unwrap();
}
{
let client = ClientBuilder::new()
.layer_outer_front(FailOnStatus::server_error())
.mock(MockTransport::service(ReturnStatus))
.unwrap();
client.get("/400").send().await.unwrap();
client.get("/500").send().await.unwrap_err();
}
}
}