1use std::{
2 collections::HashSet, fmt::Display, hash::BuildHasher, num::ParseIntError, ops::RangeInclusive,
3 str::FromStr, sync::LazyLock,
4};
5
6use http::StatusCode;
7use regex::Regex;
8use thiserror::Error;
9
10const MIN: u16 = 100;
12
13const MAX: u16 = 999;
15
16static RANGE_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
18 Regex::new(
19 r"^(?<start>[0-9]+)?\.\.(?<until>(?<inclusive>=)?(?<end>[0-9]+))?$|^(?<single>[0-9]+)$",
20 )
21 .unwrap()
22});
23
24#[derive(Debug, Error, PartialEq)]
27pub enum StatusRangeError {
28 #[error("no range pattern found")]
30 NoRangePattern,
31
32 #[error("failed to parse str as integer")]
34 ParseIntError(#[from] ParseIntError),
35
36 #[error("invalid range indices, only start < end supported")]
38 InvalidRangeIndices,
39
40 #[error("values must represent valid status codes between {MIN} and {MAX} (inclusive)")]
42 InvalidStatusCodeValue,
43}
44
45#[derive(Clone, Debug, PartialEq)]
50pub struct StatusRange(RangeInclusive<u16>);
51
52impl FromStr for StatusRange {
53 type Err = StatusRangeError;
54
55 fn from_str(s: &str) -> Result<Self, Self::Err> {
56 let captures = RANGE_PATTERN
57 .captures(s)
58 .ok_or(StatusRangeError::NoRangePattern)?;
59
60 if let Some(value) = captures.name("single") {
61 let value: u16 = value.as_str().parse()?;
62 Self::new(value, value)
63 } else {
64 let start: u16 = match captures.name("start") {
65 Some(start) => start.as_str().parse().unwrap_or_default(),
66 None => MIN,
67 };
68 if captures.name("until").is_none() {
69 return Self::new(start, MAX);
70 }
71
72 let inclusive = captures.name("inclusive").is_some();
73 let end: u16 = captures["end"].parse()?;
74
75 if inclusive {
76 Self::new(start, end)
77 } else {
78 Self::new(start, end - 1)
79 }
80 }
81 }
82}
83
84impl Display for StatusRange {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}..={}", self.start(), self.end())
87 }
88}
89
90impl StatusRange {
91 pub const fn new(start: u16, end: u16) -> Result<Self, StatusRangeError> {
98 if start < MIN || end > MAX {
99 return Err(StatusRangeError::InvalidStatusCodeValue);
100 }
101
102 if start > end {
103 return Err(StatusRangeError::InvalidRangeIndices);
104 }
105
106 Ok(Self(RangeInclusive::new(start, end)))
107 }
108
109 #[must_use]
111 pub const fn start(&self) -> &u16 {
112 self.0.start()
113 }
114
115 #[must_use]
117 pub const fn end(&self) -> &u16 {
118 self.0.end()
119 }
120
121 #[must_use]
123 pub fn contains(&self, value: u16) -> bool {
124 self.0.contains(&value)
125 }
126
127 pub(crate) const fn update_start(&mut self, new_start: u16) -> Result<(), StatusRangeError> {
128 match Self::new(new_start, *self.end()) {
130 Ok(r) => {
131 self.0 = r.0;
132 Ok(())
133 }
134 Err(e) => Err(e),
135 }
136 }
137
138 pub(crate) const fn update_end(&mut self, new_end: u16) -> Result<(), StatusRangeError> {
139 match Self::new(*self.start(), new_end) {
141 Ok(r) => {
142 self.0 = r.0;
143 Ok(())
144 }
145 Err(e) => Err(e),
146 }
147 }
148
149 pub(crate) fn merge(&mut self, other: &Self) -> bool {
150 if self.end() >= other.start() && other.end() >= self.end() {
152 let _ = self.update_end(*other.end());
155 return true;
156 }
157
158 if self.start() <= other.end() && other.start() <= self.start() {
160 let _ = self.update_start(*other.start());
163 return true;
164 }
165
166 false
167 }
168}
169
170impl<S: BuildHasher + Default> From<StatusRange> for HashSet<StatusCode, S> {
171 fn from(value: StatusRange) -> Self {
172 value
173 .0
174 .map(StatusCode::from_u16)
175 .map(|r| r.expect("StatusRange's invariant was violated, this is a bug"))
177 .collect()
178 }
179}
180
181#[cfg(test)]
182mod test {
183 use super::*;
184 use rstest::rstest;
185
186 #[rstest]
187 #[case("..", vec![MIN, 150, 200, MAX], vec![MIN - 1, MAX + 1])]
188 #[case("100..", vec![100, 101, 150, 200, MAX], vec![0, 50, 99])]
189 #[case("100..=200", vec![100, 150, 200], vec![0, 50, 99, 201, 250, MAX+1])]
190 #[case("..=100", vec![100], vec![99, 101])]
191 #[case("100..200", vec![100, 150, 199], vec![99, 200, 250])]
192 #[case("..101", vec![100], vec![99, 101])]
193 #[case("404", vec![404], vec![200, 304, 403, 405, 500])]
194 fn test_from_str(
195 #[case] input: &str,
196 #[case] included_values: Vec<u16>,
197 #[case] excluded_values: Vec<u16>,
198 ) {
199 let range = StatusRange::from_str(input).unwrap();
200
201 for included in included_values {
202 assert!(range.contains(included));
203 }
204
205 for excluded in excluded_values {
206 assert!(!range.contains(excluded));
207 }
208 }
209
210 #[rstest]
211 #[case("..100", StatusRangeError::InvalidRangeIndices)]
212 #[case("200..=100", StatusRangeError::InvalidRangeIndices)]
213 #[case("100..100", StatusRangeError::InvalidRangeIndices)]
214 #[case("..=", StatusRangeError::NoRangePattern)]
215 #[case("100..=", StatusRangeError::NoRangePattern)]
216 #[case("-100..=100", StatusRangeError::NoRangePattern)]
217 #[case("-100..100", StatusRangeError::NoRangePattern)]
218 #[case("100..=-100", StatusRangeError::NoRangePattern)]
219 #[case("100..-100", StatusRangeError::NoRangePattern)]
220 #[case("abcd", StatusRangeError::NoRangePattern)]
221 #[case("-1", StatusRangeError::NoRangePattern)]
222 #[case("0", StatusRangeError::InvalidStatusCodeValue)]
223 #[case("1..5", StatusRangeError::InvalidStatusCodeValue)]
224 #[case("99..102", StatusRangeError::InvalidStatusCodeValue)]
225 #[case("999..=1000", StatusRangeError::InvalidStatusCodeValue)]
226 fn test_from_str_invalid(#[case] input: &str, #[case] error: StatusRangeError) {
227 let range = StatusRange::from_str(input);
228 assert_eq!(range, Err(error));
229 }
230
231 #[rstest]
232 #[case("100..=200", "210..=300", "100..=200")]
233 #[case("100..=200", "190..=300", "100..=300")]
234 #[case("100..200", "200..300", "100..200")]
235 #[case("100..200", "190..300", "100..300")]
236 fn test_merge(#[case] range: &str, #[case] other: &str, #[case] result: &str) {
237 let mut range = StatusRange::from_str(range).unwrap();
238 let other = StatusRange::from_str(other).unwrap();
239
240 let result = StatusRange::from_str(result).unwrap();
241 range.merge(&other);
242
243 assert_eq!(result, range);
244 }
245}