openapiv3/
status_code.rs

1use serde::{Deserialize, Serialize};
2use std::fmt;
3
4#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
5pub enum StatusCode {
6    Code(u16),
7    Range(u16),
8}
9
10impl fmt::Display for StatusCode {
11    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12        match self {
13            StatusCode::Code(n) => write!(f, "{}", n),
14            StatusCode::Range(n) => write!(f, "{}XX", n),
15        }
16    }
17}
18
19impl<'de> Deserialize<'de> for StatusCode {
20    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
21    where
22        D: serde::Deserializer<'de>,
23    {
24        use serde::de::{self, Unexpected, Visitor};
25
26        struct StatusCodeVisitor;
27
28        impl<'de> Visitor<'de> for StatusCodeVisitor {
29            type Value = StatusCode;
30
31            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
32                formatter.write_str("number between 100 and 999 (as string or integer) or a string that matches `\\dXX`")
33            }
34
35            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
36            where
37                E: de::Error,
38            {
39                if value >= 100 && value < 1000 {
40                    Ok(StatusCode::Code(value as u16))
41                } else {
42                    Err(E::invalid_value(Unexpected::Signed(value), &self))
43                }
44            }
45
46            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
47            where
48                E: de::Error,
49            {
50                if value >= 100 && value < 1000 {
51                    Ok(StatusCode::Code(value as u16))
52                } else {
53                    Err(E::invalid_value(Unexpected::Unsigned(value), &self))
54                }
55            }
56
57            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
58            where
59                E: de::Error,
60            {
61                if value.len() != 3 {
62                    return Err(E::invalid_value(Unexpected::Str(value), &"length 3"));
63                }
64
65                if let Ok(number) = value.parse::<i64>() {
66                    return self.visit_i64(number);
67                }
68
69                if !value.is_ascii() {
70                    return Err(E::invalid_value(
71                        Unexpected::Str(value),
72                        &"ascii, format `\\dXX`",
73                    ));
74                }
75
76                let v = value.as_bytes().to_ascii_uppercase();
77
78                match [v[0], v[1], v[2]] {
79                    [n, b'X', b'X'] if n.is_ascii_digit() => {
80                        Ok(StatusCode::Range(u16::from(n - b'0')))
81                    }
82                    _ => Err(E::invalid_value(Unexpected::Str(value), &"format `\\dXX`")),
83                }
84            }
85        }
86
87        deserializer.deserialize_any(StatusCodeVisitor)
88    }
89}
90
91impl Serialize for StatusCode {
92    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
93    where
94        S: serde::Serializer,
95    {
96        serializer.serialize_str(&self.to_string())
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::StatusCode;
103    use serde_yaml::from_str;
104
105    #[test]
106    fn deserialize_strings_and_numbers() {
107        assert_eq!(StatusCode::Code(200), from_str("200").unwrap(),);
108        assert_eq!(StatusCode::Code(200), from_str("'200'").unwrap(),);
109    }
110
111    #[test]
112    #[should_panic = "expected length 3"]
113    fn deserialize_invalid_code() {
114        let _: StatusCode = from_str("'6666'").unwrap();
115    }
116
117    #[test]
118    fn deserialize_ranges() {
119        assert_eq!(StatusCode::Range(2), from_str("2XX").unwrap(),);
120        assert_eq!(StatusCode::Range(4), from_str("'4xx'").unwrap(),);
121    }
122
123    #[test]
124    #[should_panic = "invalid value"]
125    fn deserialize_invalid_range() {
126        let _: StatusCode = from_str("2XY").unwrap();
127    }
128}