lychee_lib/types/accept/
range.rs

1use std::{fmt::Display, num::ParseIntError, ops::RangeInclusive, str::FromStr, sync::LazyLock};
2
3use regex::Regex;
4use thiserror::Error;
5
6static RANGE_PATTERN: LazyLock<Regex> =
7    LazyLock::new(|| Regex::new(r"^([0-9]{3})?\.\.((=?)([0-9]{3}))?$|^([0-9]{3})$").unwrap());
8
9/// Indicates that the parsing process of an [`AcceptRange`]  from a string
10/// failed due to various underlying reasons.
11#[derive(Debug, Error, PartialEq)]
12pub enum AcceptRangeError {
13    /// The string input didn't contain any range pattern.
14    #[error("no range pattern found")]
15    NoRangePattern,
16
17    /// The start or end index could not be parsed as an integer.
18    #[error("failed to parse str as integer")]
19    ParseIntError(#[from] ParseIntError),
20
21    /// The start index is larger than the end index.
22    #[error("invalid range indices, only start < end supported")]
23    InvalidRangeIndices,
24}
25
26/// [`AcceptRange`] specifies which HTTP status codes are accepted and
27/// considered successful when checking a remote URL.
28#[derive(Clone, Debug, PartialEq)]
29pub struct AcceptRange(RangeInclusive<u16>);
30
31impl FromStr for AcceptRange {
32    type Err = AcceptRangeError;
33
34    fn from_str(s: &str) -> Result<Self, Self::Err> {
35        let captures = RANGE_PATTERN
36            .captures(s)
37            .ok_or(AcceptRangeError::NoRangePattern)?;
38
39        if let Some(value) = captures.get(5) {
40            let value: u16 = value.as_str().parse()?;
41            Self::new_from(value, value)
42        } else {
43            let start: u16 = match captures.get(1) {
44                Some(start) => start.as_str().parse().unwrap_or_default(),
45                None => 0,
46            };
47            if captures.get(2).is_none() {
48                return Self::new_from(start, u16::MAX);
49            }
50
51            let inclusive = !captures[3].is_empty();
52            let end: u16 = captures[4].parse()?;
53
54            if inclusive {
55                Self::new_from(start, end)
56            } else {
57                Self::new_from(start, end - 1)
58            }
59        }
60    }
61}
62
63impl Display for AcceptRange {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        write!(f, "{}..={}", self.start(), self.end())
66    }
67}
68
69impl AcceptRange {
70    /// Creates a new [`AcceptRange`] which matches values between `start` and
71    /// `end` (both inclusive).
72    #[must_use]
73    pub const fn new(start: u16, end: u16) -> Self {
74        Self(RangeInclusive::new(start, end))
75    }
76
77    /// Creates a new [`AcceptRange`] which matches values between `start` and
78    /// `end` (both inclusive). It additionally validates that `start` > `end`.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if `start` > `end`.
83    pub const fn new_from(start: u16, end: u16) -> Result<Self, AcceptRangeError> {
84        if start > end {
85            return Err(AcceptRangeError::InvalidRangeIndices);
86        }
87
88        Ok(Self::new(start, end))
89    }
90
91    /// Returns the `start` value of this [`AcceptRange`].
92    #[must_use]
93    pub const fn start(&self) -> &u16 {
94        self.0.start()
95    }
96
97    /// Returns the `end` value of this [`AcceptRange`].
98    #[must_use]
99    pub const fn end(&self) -> &u16 {
100        self.0.end()
101    }
102
103    /// Returns whether this [`AcceptRange`] contains `value`.
104    #[must_use]
105    pub fn contains(&self, value: u16) -> bool {
106        self.0.contains(&value)
107    }
108
109    /// Consumes self and returns the inner range.
110    #[must_use]
111    pub const fn inner(self) -> RangeInclusive<u16> {
112        self.0
113    }
114
115    pub(crate) const fn update_start(&mut self, new_start: u16) -> Result<(), AcceptRangeError> {
116        let end = *self.end();
117
118        if new_start > end {
119            return Err(AcceptRangeError::InvalidRangeIndices);
120        }
121
122        self.0 = RangeInclusive::new(new_start, end);
123        Ok(())
124    }
125
126    pub(crate) const fn update_end(&mut self, new_end: u16) -> Result<(), AcceptRangeError> {
127        let start = *self.start();
128
129        if start > new_end {
130            return Err(AcceptRangeError::InvalidRangeIndices);
131        }
132
133        self.0 = RangeInclusive::new(*self.start(), new_end);
134        Ok(())
135    }
136
137    pub(crate) fn merge(&mut self, other: &Self) -> bool {
138        // Merge when the end value of self overlaps with other's start
139        if self.end() >= other.start() && other.end() >= self.end() {
140            // We can ignore the result here, as it is guaranteed that
141            // start < new_end
142            let _ = self.update_end(*other.end());
143            return true;
144        }
145
146        // Merge when the start value of self overlaps with other's end
147        if self.start() <= other.end() && other.start() <= self.start() {
148            // We can ignore the result here, as it is guaranteed that
149            // start < new_end
150            let _ = self.update_start(*other.start());
151            return true;
152        }
153
154        false
155    }
156}
157
158#[cfg(test)]
159mod test {
160    use super::*;
161    use rstest::rstest;
162
163    #[rstest]
164    #[case("..", vec![0, 100, 150, 200, u16::MAX], vec![])]
165    #[case("100..", vec![100, 101, 150, 200, u16::MAX], vec![0, 50, 99])]
166    #[case("100..=200", vec![100, 150, 200], vec![0, 50, 99, 201, 250])]
167    #[case("..=100", vec![0, 50, 100], vec![101, 150, 200])]
168    #[case("100..200", vec![100, 150, 199], vec![99, 200, 250])]
169    #[case("..100", vec![0, 50, 99], vec![100, 150])]
170    #[case("404", vec![404], vec![200, 304, 403, 405, 500])]
171    fn test_from_str(
172        #[case] input: &str,
173        #[case] valid_values: Vec<u16>,
174        #[case] invalid_values: Vec<u16>,
175    ) {
176        let range = AcceptRange::from_str(input).unwrap();
177
178        for valid in valid_values {
179            assert!(range.contains(valid));
180        }
181
182        for invalid in invalid_values {
183            assert!(!range.contains(invalid));
184        }
185    }
186
187    #[rstest]
188    #[case("200..=100", AcceptRangeError::InvalidRangeIndices)]
189    #[case("..=", AcceptRangeError::NoRangePattern)]
190    #[case("100..=", AcceptRangeError::NoRangePattern)]
191    #[case("-100..=100", AcceptRangeError::NoRangePattern)]
192    #[case("-100..100", AcceptRangeError::NoRangePattern)]
193    #[case("100..=-100", AcceptRangeError::NoRangePattern)]
194    #[case("100..-100", AcceptRangeError::NoRangePattern)]
195    #[case("0..0", AcceptRangeError::NoRangePattern)]
196    #[case("abcd", AcceptRangeError::NoRangePattern)]
197    #[case("-1", AcceptRangeError::NoRangePattern)]
198    #[case("0", AcceptRangeError::NoRangePattern)]
199    fn test_from_str_invalid(#[case] input: &str, #[case] error: AcceptRangeError) {
200        let range = AcceptRange::from_str(input);
201        assert_eq!(range, Err(error));
202    }
203
204    #[rstest]
205    #[case("100..=200", "210..=300", "100..=200")]
206    #[case("100..=200", "190..=300", "100..=300")]
207    #[case("100..200", "200..300", "100..200")]
208    #[case("100..200", "190..300", "100..300")]
209    fn test_merge(#[case] range: &str, #[case] other: &str, #[case] result: &str) {
210        let mut range = AcceptRange::from_str(range).unwrap();
211        let other = AcceptRange::from_str(other).unwrap();
212
213        let result = AcceptRange::from_str(result).unwrap();
214        range.merge(&other);
215
216        assert_eq!(result, range);
217    }
218}