1use std::{
2 collections::HashSet, fmt::Display, hash::BuildHasher, num::ParseIntError, ops::RangeInclusive,
3 str::FromStr, sync::LazyLock,
4};
5
6use http::StatusCode;
7use regex::Regex;
8use thiserror::Error;
9
10use crate::hint;
11
12const MIN: u16 = 100;
14
15const MAX: u16 = 999;
17
18static RANGE_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
20 Regex::new(
21 r"^(?<start>[0-9]+)?\.\.(?<until>(?<inclusive>=)?(?<end>[0-9]+))?$|^(?<single>[0-9]+)$",
22 )
23 .unwrap()
24});
25
26#[derive(Debug, Error, PartialEq)]
29pub enum StatusRangeError {
30 #[error("no range pattern found")]
32 NoRangePattern,
33
34 #[error("failed to parse str as integer")]
36 ParseIntError(#[from] ParseIntError),
37
38 #[error("invalid range indices, only start < end supported")]
40 InvalidRangeIndices,
41
42 #[error("values must represent valid status codes between {MIN} and {MAX} (inclusive)")]
44 InvalidStatusCodeValue,
45}
46
47#[derive(Clone, Debug, PartialEq)]
52pub struct StatusRange(RangeInclusive<u16>);
53
54impl FromStr for StatusRange {
55 type Err = StatusRangeError;
56
57 fn from_str(s: &str) -> Result<Self, Self::Err> {
58 let captures = RANGE_PATTERN
59 .captures(s)
60 .ok_or(StatusRangeError::NoRangePattern)?;
61
62 if let Some(value) = captures.name("single") {
63 let value: u16 = value.as_str().parse()?;
64 Self::new(value, value)
65 } else {
66 let start: u16 = match captures.name("start") {
67 Some(start) => start.as_str().parse().unwrap_or_default(),
68 None => MIN,
69 };
70 if captures.name("until").is_none() {
71 return Self::new(start, MAX);
72 }
73
74 let inclusive = captures.name("inclusive").is_some();
75 let end: u16 = captures["end"].parse()?;
76
77 if inclusive {
78 Self::new(start, end)
79 } else {
80 if end == start + 1 {
81 hint!(
82 "Accept range `{s}` only matches the single status code {start}. \
83 Did you mean `{start}..={end}` or `{start},{end}`?"
84 );
85 }
86 Self::new(start, end - 1)
87 }
88 }
89 }
90}
91
92impl Display for StatusRange {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 write!(f, "{}..={}", self.start(), self.end())
95 }
96}
97
98impl StatusRange {
99 pub const fn new(start: u16, end: u16) -> Result<Self, StatusRangeError> {
106 if start < MIN || end > MAX {
107 return Err(StatusRangeError::InvalidStatusCodeValue);
108 }
109
110 if start > end {
111 return Err(StatusRangeError::InvalidRangeIndices);
112 }
113
114 Ok(Self(RangeInclusive::new(start, end)))
115 }
116
117 #[must_use]
119 pub const fn start(&self) -> &u16 {
120 self.0.start()
121 }
122
123 #[must_use]
125 pub const fn end(&self) -> &u16 {
126 self.0.end()
127 }
128
129 #[must_use]
131 pub fn contains(&self, value: u16) -> bool {
132 self.0.contains(&value)
133 }
134
135 pub(crate) const fn update_start(&mut self, new_start: u16) -> Result<(), StatusRangeError> {
136 match Self::new(new_start, *self.end()) {
138 Ok(r) => {
139 self.0 = r.0;
140 Ok(())
141 }
142 Err(e) => Err(e),
143 }
144 }
145
146 pub(crate) const fn update_end(&mut self, new_end: u16) -> Result<(), StatusRangeError> {
147 match Self::new(*self.start(), new_end) {
149 Ok(r) => {
150 self.0 = r.0;
151 Ok(())
152 }
153 Err(e) => Err(e),
154 }
155 }
156
157 pub(crate) fn merge(&mut self, other: &Self) -> bool {
158 if self.end() >= other.start() && other.end() >= self.end() {
160 let _ = self.update_end(*other.end());
163 return true;
164 }
165
166 if self.start() <= other.end() && other.start() <= self.start() {
168 let _ = self.update_start(*other.start());
171 return true;
172 }
173
174 false
175 }
176}
177
178impl<S: BuildHasher + Default> From<StatusRange> for HashSet<StatusCode, S> {
179 fn from(value: StatusRange) -> Self {
180 value
181 .0
182 .map(StatusCode::from_u16)
183 .map(|r| r.expect("StatusRange's invariant was violated, this is a bug"))
185 .collect()
186 }
187}
188
189#[cfg(test)]
190mod test {
191 use super::*;
192 use rstest::rstest;
193
194 #[rstest]
195 #[case("..", vec![MIN, 150, 200, MAX], vec![MIN - 1, MAX + 1])]
196 #[case("100..", vec![100, 101, 150, 200, MAX], vec![0, 50, 99])]
197 #[case("100..=200", vec![100, 150, 200], vec![0, 50, 99, 201, 250, MAX+1])]
198 #[case("..=100", vec![100], vec![99, 101])]
199 #[case("100..200", vec![100, 150, 199], vec![99, 200, 250])]
200 #[case("..101", vec![100], vec![99, 101])]
201 #[case("404", vec![404], vec![200, 304, 403, 405, 500])]
202 fn test_from_str(
203 #[case] input: &str,
204 #[case] included_values: Vec<u16>,
205 #[case] excluded_values: Vec<u16>,
206 ) {
207 let range = StatusRange::from_str(input).unwrap();
208
209 for included in included_values {
210 assert!(range.contains(included));
211 }
212
213 for excluded in excluded_values {
214 assert!(!range.contains(excluded));
215 }
216 }
217
218 #[rstest]
219 #[case("..100", StatusRangeError::InvalidRangeIndices)]
220 #[case("200..=100", StatusRangeError::InvalidRangeIndices)]
221 #[case("100..100", StatusRangeError::InvalidRangeIndices)]
222 #[case("..=", StatusRangeError::NoRangePattern)]
223 #[case("100..=", StatusRangeError::NoRangePattern)]
224 #[case("-100..=100", StatusRangeError::NoRangePattern)]
225 #[case("-100..100", StatusRangeError::NoRangePattern)]
226 #[case("100..=-100", StatusRangeError::NoRangePattern)]
227 #[case("100..-100", StatusRangeError::NoRangePattern)]
228 #[case("abcd", StatusRangeError::NoRangePattern)]
229 #[case("-1", StatusRangeError::NoRangePattern)]
230 #[case("0", StatusRangeError::InvalidStatusCodeValue)]
231 #[case("1..5", StatusRangeError::InvalidStatusCodeValue)]
232 #[case("99..102", StatusRangeError::InvalidStatusCodeValue)]
233 #[case("999..=1000", StatusRangeError::InvalidStatusCodeValue)]
234 fn test_from_str_invalid(#[case] input: &str, #[case] error: StatusRangeError) {
235 let range = StatusRange::from_str(input);
236 assert_eq!(range, Err(error));
237 }
238
239 #[rstest]
240 #[case("100..=200", "210..=300", "100..=200")]
241 #[case("100..=200", "190..=300", "100..=300")]
242 #[case("100..200", "200..300", "100..200")]
243 #[case("100..200", "190..300", "100..300")]
244 fn test_merge(#[case] range: &str, #[case] other: &str, #[case] result: &str) {
245 let mut range = StatusRange::from_str(range).unwrap();
246 let other = StatusRange::from_str(other).unwrap();
247
248 let result = StatusRange::from_str(result).unwrap();
249 range.merge(&other);
250
251 assert_eq!(result, range);
252 }
253}