Skip to main content

lychee_lib/types/accept/
range.rs

1use std::{
2    collections::HashSet, fmt::Display, hash::BuildHasher, num::ParseIntError, ops::RangeInclusive,
3    str::FromStr, sync::LazyLock,
4};
5
6use http::StatusCode;
7use regex::Regex;
8use thiserror::Error;
9
10/// Smallest accepted value
11const MIN: u16 = 100;
12
13/// Biggest accepted value
14const MAX: u16 = 999;
15
16/// Regex to compute range values
17static RANGE_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
18    Regex::new(
19        r"^(?<start>[0-9]+)?\.\.(?<until>(?<inclusive>=)?(?<end>[0-9]+))?$|^(?<single>[0-9]+)$",
20    )
21    .unwrap()
22});
23
24/// Indicates that the parsing process of an [`StatusRange`]  from a string
25/// failed due to various underlying reasons.
26#[derive(Debug, Error, PartialEq)]
27pub enum StatusRangeError {
28    /// The string input didn't contain any range pattern.
29    #[error("no range pattern found")]
30    NoRangePattern,
31
32    /// The start or end index could not be parsed as an integer.
33    #[error("failed to parse str as integer")]
34    ParseIntError(#[from] ParseIntError),
35
36    /// The start index is larger than the end index.
37    #[error("invalid range indices, only start < end supported")]
38    InvalidRangeIndices,
39
40    /// The u16 values must be representable as status code
41    #[error("values must represent valid status codes between {MIN} and {MAX} (inclusive)")]
42    InvalidStatusCodeValue,
43}
44
45/// [`StatusRange`] specifies which HTTP status codes are accepted and
46/// considered successful when checking a remote URL.
47/// Only represents valid status codes,
48/// invalid status codes (<100 or >999) are rejected.
49#[derive(Clone, Debug, PartialEq)]
50pub struct StatusRange(RangeInclusive<u16>);
51
52impl FromStr for StatusRange {
53    type Err = StatusRangeError;
54
55    fn from_str(s: &str) -> Result<Self, Self::Err> {
56        let captures = RANGE_PATTERN
57            .captures(s)
58            .ok_or(StatusRangeError::NoRangePattern)?;
59
60        if let Some(value) = captures.name("single") {
61            let value: u16 = value.as_str().parse()?;
62            Self::new(value, value)
63        } else {
64            let start: u16 = match captures.name("start") {
65                Some(start) => start.as_str().parse().unwrap_or_default(),
66                None => MIN,
67            };
68            if captures.name("until").is_none() {
69                return Self::new(start, MAX);
70            }
71
72            let inclusive = captures.name("inclusive").is_some();
73            let end: u16 = captures["end"].parse()?;
74
75            if inclusive {
76                Self::new(start, end)
77            } else {
78                Self::new(start, end - 1)
79            }
80        }
81    }
82}
83
84impl Display for StatusRange {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        write!(f, "{}..={}", self.start(), self.end())
87    }
88}
89
90impl StatusRange {
91    /// Creates a new [`StatusRange`] which matches values between `start` and
92    /// `end` (both inclusive).
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if `start` > `end`, `start` < 100 or `end` > 999.
97    pub const fn new(start: u16, end: u16) -> Result<Self, StatusRangeError> {
98        if start < MIN || end > MAX {
99            return Err(StatusRangeError::InvalidStatusCodeValue);
100        }
101
102        if start > end {
103            return Err(StatusRangeError::InvalidRangeIndices);
104        }
105
106        Ok(Self(RangeInclusive::new(start, end)))
107    }
108
109    /// Returns the `start` value of this [`StatusRange`].
110    #[must_use]
111    pub const fn start(&self) -> &u16 {
112        self.0.start()
113    }
114
115    /// Returns the `end` value of this [`StatusRange`].
116    #[must_use]
117    pub const fn end(&self) -> &u16 {
118        self.0.end()
119    }
120
121    /// Returns whether this [`StatusRange`] contains `value`.
122    #[must_use]
123    pub fn contains(&self, value: u16) -> bool {
124        self.0.contains(&value)
125    }
126
127    pub(crate) const fn update_start(&mut self, new_start: u16) -> Result<(), StatusRangeError> {
128        // Can't use `?` in const function as of 1.91.0
129        match Self::new(new_start, *self.end()) {
130            Ok(r) => {
131                self.0 = r.0;
132                Ok(())
133            }
134            Err(e) => Err(e),
135        }
136    }
137
138    pub(crate) const fn update_end(&mut self, new_end: u16) -> Result<(), StatusRangeError> {
139        // Can't use `?` in const function as of 1.91.0
140        match Self::new(*self.start(), new_end) {
141            Ok(r) => {
142                self.0 = r.0;
143                Ok(())
144            }
145            Err(e) => Err(e),
146        }
147    }
148
149    pub(crate) fn merge(&mut self, other: &Self) -> bool {
150        // Merge when the end value of self overlaps with other's start
151        if self.end() >= other.start() && other.end() >= self.end() {
152            // We can ignore the result here, as it is guaranteed that
153            // start < new_end
154            let _ = self.update_end(*other.end());
155            return true;
156        }
157
158        // Merge when the start value of self overlaps with other's end
159        if self.start() <= other.end() && other.start() <= self.start() {
160            // We can ignore the result here, as it is guaranteed that
161            // start < new_end
162            let _ = self.update_start(*other.start());
163            return true;
164        }
165
166        false
167    }
168}
169
170impl<S: BuildHasher + Default> From<StatusRange> for HashSet<StatusCode, S> {
171    fn from(value: StatusRange) -> Self {
172        value
173            .0
174            .map(StatusCode::from_u16)
175            // SAFETY: StatusRange represents only valid status codes
176            .map(|r| r.expect("StatusRange's invariant was violated, this is a bug"))
177            .collect()
178    }
179}
180
181#[cfg(test)]
182mod test {
183    use super::*;
184    use rstest::rstest;
185
186    #[rstest]
187    #[case("..", vec![MIN, 150, 200, MAX], vec![MIN - 1, MAX + 1])]
188    #[case("100..", vec![100, 101, 150, 200, MAX], vec![0, 50, 99])]
189    #[case("100..=200", vec![100, 150, 200], vec![0, 50, 99, 201, 250, MAX+1])]
190    #[case("..=100", vec![100], vec![99, 101])]
191    #[case("100..200", vec![100, 150, 199], vec![99, 200, 250])]
192    #[case("..101", vec![100], vec![99, 101])]
193    #[case("404", vec![404], vec![200, 304, 403, 405, 500])]
194    fn test_from_str(
195        #[case] input: &str,
196        #[case] included_values: Vec<u16>,
197        #[case] excluded_values: Vec<u16>,
198    ) {
199        let range = StatusRange::from_str(input).unwrap();
200
201        for included in included_values {
202            assert!(range.contains(included));
203        }
204
205        for excluded in excluded_values {
206            assert!(!range.contains(excluded));
207        }
208    }
209
210    #[rstest]
211    #[case("..100", StatusRangeError::InvalidRangeIndices)]
212    #[case("200..=100", StatusRangeError::InvalidRangeIndices)]
213    #[case("100..100", StatusRangeError::InvalidRangeIndices)]
214    #[case("..=", StatusRangeError::NoRangePattern)]
215    #[case("100..=", StatusRangeError::NoRangePattern)]
216    #[case("-100..=100", StatusRangeError::NoRangePattern)]
217    #[case("-100..100", StatusRangeError::NoRangePattern)]
218    #[case("100..=-100", StatusRangeError::NoRangePattern)]
219    #[case("100..-100", StatusRangeError::NoRangePattern)]
220    #[case("abcd", StatusRangeError::NoRangePattern)]
221    #[case("-1", StatusRangeError::NoRangePattern)]
222    #[case("0", StatusRangeError::InvalidStatusCodeValue)]
223    #[case("1..5", StatusRangeError::InvalidStatusCodeValue)]
224    #[case("99..102", StatusRangeError::InvalidStatusCodeValue)]
225    #[case("999..=1000", StatusRangeError::InvalidStatusCodeValue)]
226    fn test_from_str_invalid(#[case] input: &str, #[case] error: StatusRangeError) {
227        let range = StatusRange::from_str(input);
228        assert_eq!(range, Err(error));
229    }
230
231    #[rstest]
232    #[case("100..=200", "210..=300", "100..=200")]
233    #[case("100..=200", "190..=300", "100..=300")]
234    #[case("100..200", "200..300", "100..200")]
235    #[case("100..200", "190..300", "100..300")]
236    fn test_merge(#[case] range: &str, #[case] other: &str, #[case] result: &str) {
237        let mut range = StatusRange::from_str(range).unwrap();
238        let other = StatusRange::from_str(other).unwrap();
239
240        let result = StatusRange::from_str(result).unwrap();
241        range.merge(&other);
242
243        assert_eq!(result, range);
244    }
245}