rama_http/layer/classify/
status_in_range_is_error.rs

1use super::{ClassifiedResponse, ClassifyResponse, NeverClassifyEos, SharedClassifier};
2use rama_http_types::StatusCode;
3use std::{fmt, ops::RangeInclusive};
4
5/// Response classifier that considers responses with a status code within some range to be
6/// failures.
7#[derive(Debug, Clone)]
8pub struct StatusInRangeAsFailures {
9    range: RangeInclusive<u16>,
10}
11
12impl StatusInRangeAsFailures {
13    /// Creates a new `StatusInRangeAsFailures`.
14    ///
15    /// # Panics
16    ///
17    /// Panics if the start or end of `range` aren't valid status codes as determined by
18    /// [`StatusCode::from_u16`].
19    ///
20    /// [`StatusCode::from_u16`]: https://docs.rs/http/latest/http/status/struct.StatusCode.html#method.from_u16
21    pub fn new(range: RangeInclusive<u16>) -> Self {
22        assert!(
23            StatusCode::from_u16(*range.start()).is_ok(),
24            "range start isn't a valid status code"
25        );
26        assert!(
27            StatusCode::from_u16(*range.end()).is_ok(),
28            "range end isn't a valid status code"
29        );
30
31        Self { range }
32    }
33
34    /// Creates a new `StatusInRangeAsFailures` that classifies client and server responses as
35    /// failures.
36    ///
37    /// This is a convenience for `StatusInRangeAsFailures::new(400..=599)`.
38    pub fn new_for_client_and_server_errors() -> Self {
39        Self::new(400..=599)
40    }
41
42    /// Convert this `StatusInRangeAsFailures` into a [`MakeClassifier`].
43    ///
44    /// [`MakeClassifier`]: super::MakeClassifier
45    pub fn into_make_classifier(self) -> SharedClassifier<Self> {
46        SharedClassifier::new(self)
47    }
48}
49
50impl ClassifyResponse for StatusInRangeAsFailures {
51    type FailureClass = StatusInRangeFailureClass;
52    type ClassifyEos = NeverClassifyEos<Self::FailureClass>;
53
54    fn classify_response<B>(
55        self,
56        res: &rama_http_types::Response<B>,
57    ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
58        if self.range.contains(&res.status().as_u16()) {
59            let class = StatusInRangeFailureClass::StatusCode(res.status());
60            ClassifiedResponse::Ready(Err(class))
61        } else {
62            ClassifiedResponse::Ready(Ok(()))
63        }
64    }
65
66    fn classify_error<E>(self, error: &E) -> Self::FailureClass
67    where
68        E: std::fmt::Display,
69    {
70        StatusInRangeFailureClass::Error(error.to_string())
71    }
72}
73
74/// The failure class for [`StatusInRangeAsFailures`].
75#[derive(Debug)]
76pub enum StatusInRangeFailureClass {
77    /// A response was classified as a failure with the corresponding status.
78    StatusCode(StatusCode),
79    /// A response was classified as an error with the corresponding error description.
80    Error(String),
81}
82
83impl fmt::Display for StatusInRangeFailureClass {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        match self {
86            Self::StatusCode(code) => write!(f, "Status code: {}", code),
87            Self::Error(error) => write!(f, "Error: {}", error),
88        }
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    #[allow(unused_imports)]
95    use super::*;
96    use rama_http_types::Response;
97
98    #[test]
99    fn basic() {
100        let classifier = StatusInRangeAsFailures::new(400..=599);
101
102        assert!(matches!(
103            classifier
104                .clone()
105                .classify_response(&response_with_status(200)),
106            ClassifiedResponse::Ready(Ok(())),
107        ));
108
109        assert!(matches!(
110            classifier
111                .clone()
112                .classify_response(&response_with_status(400)),
113            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
114                StatusCode::BAD_REQUEST
115            ))),
116        ));
117
118        assert!(matches!(
119            classifier.classify_response(&response_with_status(500)),
120            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
121                StatusCode::INTERNAL_SERVER_ERROR
122            ))),
123        ));
124    }
125
126    fn response_with_status(status: u16) -> Response<()> {
127        Response::builder().status(status).body(()).unwrap()
128    }
129}