rama_http/layer/classify/
status_in_range_is_error.rs

1use super::{ClassifiedResponse, ClassifyResponse, NeverClassifyEos, SharedClassifier};
2use rama_http_types::StatusCode;
3use std::{fmt, ops::RangeInclusive};
4
5/// Response classifier that considers responses with a status code within some range to be
6/// failures.
7#[derive(Debug, Clone)]
8pub struct StatusInRangeAsFailures {
9    range: RangeInclusive<u16>,
10}
11
12impl StatusInRangeAsFailures {
13    /// Creates a new `StatusInRangeAsFailures`.
14    ///
15    /// # Panics
16    ///
17    /// Panics if the start or end of `range` aren't valid status codes as determined by
18    /// [`StatusCode::from_u16`].
19    ///
20    /// [`StatusCode::from_u16`]: https://docs.rs/http/latest/http/status/struct.StatusCode.html#method.from_u16
21    pub fn new(range: RangeInclusive<u16>) -> Self {
22        assert!(
23            StatusCode::from_u16(*range.start()).is_ok(),
24            "range start isn't a valid status code"
25        );
26        assert!(
27            StatusCode::from_u16(*range.end()).is_ok(),
28            "range end isn't a valid status code"
29        );
30
31        Self { range }
32    }
33
34    /// Creates a new `StatusInRangeAsFailures` that classifies client and server responses as
35    /// failures.
36    ///
37    /// This is a convenience for `StatusInRangeAsFailures::new(400..=599)`.
38    pub fn new_for_client_and_server_errors() -> Self {
39        Self::new(400..=599)
40    }
41
42    /// Convert this `StatusInRangeAsFailures` into a [`MakeClassifier`].
43    ///
44    /// [`MakeClassifier`]: super::MakeClassifier
45    pub fn into_make_classifier(self) -> SharedClassifier<Self> {
46        SharedClassifier::new(self)
47    }
48}
49
50impl ClassifyResponse for StatusInRangeAsFailures {
51    type FailureClass = StatusInRangeFailureClass;
52    type ClassifyEos = NeverClassifyEos<Self::FailureClass>;
53
54    fn classify_response<B>(
55        self,
56        res: &rama_http_types::Response<B>,
57    ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
58        if self.range.contains(&res.status().as_u16()) {
59            let class = StatusInRangeFailureClass::StatusCode(res.status());
60            ClassifiedResponse::Ready(Err(class))
61        } else {
62            ClassifiedResponse::Ready(Ok(()))
63        }
64    }
65
66    fn classify_error<E>(self, error: &E) -> Self::FailureClass
67    where
68        E: std::fmt::Display,
69    {
70        StatusInRangeFailureClass::Error(error.to_string())
71    }
72}
73
74/// The failure class for [`StatusInRangeAsFailures`].
75#[derive(Debug)]
76pub enum StatusInRangeFailureClass {
77    /// A response was classified as a failure with the corresponding status.
78    StatusCode(StatusCode),
79    /// A response was classified as an error with the corresponding error description.
80    Error(String),
81}
82
83impl fmt::Display for StatusInRangeFailureClass {
84    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85        match self {
86            Self::StatusCode(code) => write!(f, "Status code: {}", code),
87            Self::Error(error) => write!(f, "Error: {}", error),
88        }
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    #[allow(unused_imports)]
95    use super::*;
96    use rama_http_types::Response;
97
98    #[test]
99    fn basic() {
100        let classifier = StatusInRangeAsFailures::new(400..=599);
101
102        assert!(matches!(
103            dbg!(
104                classifier
105                    .clone()
106                    .classify_response(&response_with_status(200))
107            ),
108            ClassifiedResponse::Ready(Ok(())),
109        ));
110
111        assert!(matches!(
112            dbg!(
113                classifier
114                    .clone()
115                    .classify_response(&response_with_status(400))
116            ),
117            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
118                StatusCode::BAD_REQUEST
119            ))),
120        ));
121
122        assert!(matches!(
123            dbg!(classifier.classify_response(&response_with_status(500))),
124            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
125                StatusCode::INTERNAL_SERVER_ERROR
126            ))),
127        ));
128    }
129
130    fn response_with_status(status: u16) -> Response<()> {
131        Response::builder().status(status).body(()).unwrap()
132    }
133}