1use serde::{Deserialize, Serialize};
2use std::fmt;
3
4#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
5pub enum StatusCode {
6 Code(u16),
7 Range(u16),
8}
9
10impl fmt::Display for StatusCode {
11 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12 match self {
13 StatusCode::Code(n) => write!(f, "{}", n),
14 StatusCode::Range(n) => write!(f, "{}XX", n),
15 }
16 }
17}
18
19impl<'de> Deserialize<'de> for StatusCode {
20 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
21 where
22 D: serde::Deserializer<'de>,
23 {
24 use serde::de::{self, Unexpected, Visitor};
25
26 struct StatusCodeVisitor;
27
28 impl<'de> Visitor<'de> for StatusCodeVisitor {
29 type Value = StatusCode;
30
31 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
32 formatter.write_str("number between 100 and 999 (as string or integer) or a string that matches `\\dXX`")
33 }
34
35 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
36 where
37 E: de::Error,
38 {
39 if value >= 100 && value < 1000 {
40 Ok(StatusCode::Code(value as u16))
41 } else {
42 Err(E::invalid_value(Unexpected::Signed(value), &self))
43 }
44 }
45
46 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
47 where
48 E: de::Error,
49 {
50 if value >= 100 && value < 1000 {
51 Ok(StatusCode::Code(value as u16))
52 } else {
53 Err(E::invalid_value(Unexpected::Unsigned(value), &self))
54 }
55 }
56
57 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
58 where
59 E: de::Error,
60 {
61 if value.len() != 3 {
62 return Err(E::invalid_value(Unexpected::Str(value), &"length 3"));
63 }
64
65 if let Ok(number) = value.parse::<i64>() {
66 return self.visit_i64(number);
67 }
68
69 if !value.is_ascii() {
70 return Err(E::invalid_value(
71 Unexpected::Str(value),
72 &"ascii, format `\\dXX`",
73 ));
74 }
75
76 let v = value.as_bytes().to_ascii_uppercase();
77
78 match [v[0], v[1], v[2]] {
79 [n, b'X', b'X'] if n.is_ascii_digit() => {
80 Ok(StatusCode::Range(u16::from(n - b'0')))
81 }
82 _ => Err(E::invalid_value(Unexpected::Str(value), &"format `\\dXX`")),
83 }
84 }
85 }
86
87 deserializer.deserialize_any(StatusCodeVisitor)
88 }
89}
90
91impl Serialize for StatusCode {
92 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
93 where
94 S: serde::Serializer,
95 {
96 serializer.serialize_str(&self.to_string())
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::StatusCode;
103 use serde_yaml::from_str;
104
105 #[test]
106 fn deserialize_strings_and_numbers() {
107 assert_eq!(StatusCode::Code(200), from_str("200").unwrap(),);
108 assert_eq!(StatusCode::Code(200), from_str("'200'").unwrap(),);
109 }
110
111 #[test]
112 #[should_panic = "expected length 3"]
113 fn deserialize_invalid_code() {
114 let _: StatusCode = from_str("'6666'").unwrap();
115 }
116
117 #[test]
118 fn deserialize_ranges() {
119 assert_eq!(StatusCode::Range(2), from_str("2XX").unwrap(),);
120 assert_eq!(StatusCode::Range(4), from_str("'4xx'").unwrap(),);
121 }
122
123 #[test]
124 #[should_panic = "invalid value"]
125 fn deserialize_invalid_range() {
126 let _: StatusCode = from_str("2XY").unwrap();
127 }
128}