Skip to main content

lychee_lib/types/accept/
range.rs

1use std::{
2    collections::HashSet, fmt::Display, hash::BuildHasher, num::ParseIntError, ops::RangeInclusive,
3    str::FromStr, sync::LazyLock,
4};
5
6use http::StatusCode;
7use regex::Regex;
8use thiserror::Error;
9
10use crate::hint;
11
12/// Smallest accepted value
13const MIN: u16 = 100;
14
15/// Biggest accepted value
16const MAX: u16 = 999;
17
18/// Regex to compute range values
19static RANGE_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
20    Regex::new(
21        r"^(?<start>[0-9]+)?\.\.(?<until>(?<inclusive>=)?(?<end>[0-9]+))?$|^(?<single>[0-9]+)$",
22    )
23    .unwrap()
24});
25
26/// Indicates that the parsing process of an [`StatusRange`]  from a string
27/// failed due to various underlying reasons.
28#[derive(Debug, Error, PartialEq)]
29pub enum StatusRangeError {
30    /// The string input didn't contain any range pattern.
31    #[error("no range pattern found")]
32    NoRangePattern,
33
34    /// The start or end index could not be parsed as an integer.
35    #[error("failed to parse str as integer")]
36    ParseIntError(#[from] ParseIntError),
37
38    /// The start index is larger than the end index.
39    #[error("invalid range indices, only start < end supported")]
40    InvalidRangeIndices,
41
42    /// The u16 values must be representable as status code
43    #[error("values must represent valid status codes between {MIN} and {MAX} (inclusive)")]
44    InvalidStatusCodeValue,
45}
46
47/// [`StatusRange`] specifies which HTTP status codes are accepted and
48/// considered successful when checking a remote URL.
49/// Only represents valid status codes,
50/// invalid status codes (<100 or >999) are rejected.
51#[derive(Clone, Debug, PartialEq)]
52pub struct StatusRange(RangeInclusive<u16>);
53
54impl FromStr for StatusRange {
55    type Err = StatusRangeError;
56
57    fn from_str(s: &str) -> Result<Self, Self::Err> {
58        let captures = RANGE_PATTERN
59            .captures(s)
60            .ok_or(StatusRangeError::NoRangePattern)?;
61
62        if let Some(value) = captures.name("single") {
63            let value: u16 = value.as_str().parse()?;
64            Self::new(value, value)
65        } else {
66            let start: u16 = match captures.name("start") {
67                Some(start) => start.as_str().parse().unwrap_or_default(),
68                None => MIN,
69            };
70            if captures.name("until").is_none() {
71                return Self::new(start, MAX);
72            }
73
74            let inclusive = captures.name("inclusive").is_some();
75            let end: u16 = captures["end"].parse()?;
76
77            if inclusive {
78                Self::new(start, end)
79            } else {
80                if end == start + 1 {
81                    hint!(
82                        "Accept range `{s}` only matches the single status code {start}. \
83                         Did you mean `{start}..={end}` or `{start},{end}`?"
84                    );
85                }
86                Self::new(start, end - 1)
87            }
88        }
89    }
90}
91
92impl Display for StatusRange {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        write!(f, "{}..={}", self.start(), self.end())
95    }
96}
97
98impl StatusRange {
99    /// Creates a new [`StatusRange`] which matches values between `start` and
100    /// `end` (both inclusive).
101    ///
102    /// # Errors
103    ///
104    /// Returns an error if `start` > `end`, `start` < 100 or `end` > 999.
105    pub const fn new(start: u16, end: u16) -> Result<Self, StatusRangeError> {
106        if start < MIN || end > MAX {
107            return Err(StatusRangeError::InvalidStatusCodeValue);
108        }
109
110        if start > end {
111            return Err(StatusRangeError::InvalidRangeIndices);
112        }
113
114        Ok(Self(RangeInclusive::new(start, end)))
115    }
116
117    /// Returns the `start` value of this [`StatusRange`].
118    #[must_use]
119    pub const fn start(&self) -> &u16 {
120        self.0.start()
121    }
122
123    /// Returns the `end` value of this [`StatusRange`].
124    #[must_use]
125    pub const fn end(&self) -> &u16 {
126        self.0.end()
127    }
128
129    /// Returns whether this [`StatusRange`] contains `value`.
130    #[must_use]
131    pub fn contains(&self, value: u16) -> bool {
132        self.0.contains(&value)
133    }
134
135    pub(crate) const fn update_start(&mut self, new_start: u16) -> Result<(), StatusRangeError> {
136        // Can't use `?` in const function as of 1.91.0
137        match Self::new(new_start, *self.end()) {
138            Ok(r) => {
139                self.0 = r.0;
140                Ok(())
141            }
142            Err(e) => Err(e),
143        }
144    }
145
146    pub(crate) const fn update_end(&mut self, new_end: u16) -> Result<(), StatusRangeError> {
147        // Can't use `?` in const function as of 1.91.0
148        match Self::new(*self.start(), new_end) {
149            Ok(r) => {
150                self.0 = r.0;
151                Ok(())
152            }
153            Err(e) => Err(e),
154        }
155    }
156
157    pub(crate) fn merge(&mut self, other: &Self) -> bool {
158        // Merge when the end value of self overlaps with other's start
159        if self.end() >= other.start() && other.end() >= self.end() {
160            // We can ignore the result here, as it is guaranteed that
161            // start < new_end
162            let _ = self.update_end(*other.end());
163            return true;
164        }
165
166        // Merge when the start value of self overlaps with other's end
167        if self.start() <= other.end() && other.start() <= self.start() {
168            // We can ignore the result here, as it is guaranteed that
169            // start < new_end
170            let _ = self.update_start(*other.start());
171            return true;
172        }
173
174        false
175    }
176}
177
178impl<S: BuildHasher + Default> From<StatusRange> for HashSet<StatusCode, S> {
179    fn from(value: StatusRange) -> Self {
180        value
181            .0
182            .map(StatusCode::from_u16)
183            // SAFETY: StatusRange represents only valid status codes
184            .map(|r| r.expect("StatusRange's invariant was violated, this is a bug"))
185            .collect()
186    }
187}
188
189#[cfg(test)]
190mod test {
191    use super::*;
192    use rstest::rstest;
193
194    #[rstest]
195    #[case("..", vec![MIN, 150, 200, MAX], vec![MIN - 1, MAX + 1])]
196    #[case("100..", vec![100, 101, 150, 200, MAX], vec![0, 50, 99])]
197    #[case("100..=200", vec![100, 150, 200], vec![0, 50, 99, 201, 250, MAX+1])]
198    #[case("..=100", vec![100], vec![99, 101])]
199    #[case("100..200", vec![100, 150, 199], vec![99, 200, 250])]
200    #[case("..101", vec![100], vec![99, 101])]
201    #[case("404", vec![404], vec![200, 304, 403, 405, 500])]
202    fn test_from_str(
203        #[case] input: &str,
204        #[case] included_values: Vec<u16>,
205        #[case] excluded_values: Vec<u16>,
206    ) {
207        let range = StatusRange::from_str(input).unwrap();
208
209        for included in included_values {
210            assert!(range.contains(included));
211        }
212
213        for excluded in excluded_values {
214            assert!(!range.contains(excluded));
215        }
216    }
217
218    #[rstest]
219    #[case("..100", StatusRangeError::InvalidRangeIndices)]
220    #[case("200..=100", StatusRangeError::InvalidRangeIndices)]
221    #[case("100..100", StatusRangeError::InvalidRangeIndices)]
222    #[case("..=", StatusRangeError::NoRangePattern)]
223    #[case("100..=", StatusRangeError::NoRangePattern)]
224    #[case("-100..=100", StatusRangeError::NoRangePattern)]
225    #[case("-100..100", StatusRangeError::NoRangePattern)]
226    #[case("100..=-100", StatusRangeError::NoRangePattern)]
227    #[case("100..-100", StatusRangeError::NoRangePattern)]
228    #[case("abcd", StatusRangeError::NoRangePattern)]
229    #[case("-1", StatusRangeError::NoRangePattern)]
230    #[case("0", StatusRangeError::InvalidStatusCodeValue)]
231    #[case("1..5", StatusRangeError::InvalidStatusCodeValue)]
232    #[case("99..102", StatusRangeError::InvalidStatusCodeValue)]
233    #[case("999..=1000", StatusRangeError::InvalidStatusCodeValue)]
234    fn test_from_str_invalid(#[case] input: &str, #[case] error: StatusRangeError) {
235        let range = StatusRange::from_str(input);
236        assert_eq!(range, Err(error));
237    }
238
239    #[rstest]
240    #[case("100..=200", "210..=300", "100..=200")]
241    #[case("100..=200", "190..=300", "100..=300")]
242    #[case("100..200", "200..300", "100..200")]
243    #[case("100..200", "190..300", "100..300")]
244    fn test_merge(#[case] range: &str, #[case] other: &str, #[case] result: &str) {
245        let mut range = StatusRange::from_str(range).unwrap();
246        let other = StatusRange::from_str(other).unwrap();
247
248        let result = StatusRange::from_str(result).unwrap();
249        range.merge(&other);
250
251        assert_eq!(result, range);
252    }
253}