rate_limits/
types.rs

1use std::collections::HashMap;
2use std::str::FromStr;
3
4use crate::convert;
5use crate::error::{Error, Result};
6use headers::{HeaderMap, HeaderName, HeaderValue};
7use time::format_description::well_known::{Iso8601, Rfc2822};
8use time::{Duration, OffsetDateTime, PrimitiveDateTime};
9
10const HEADER_SEPARATOR: &str = ":";
11
12/// The kind of rate limit reset time
13///
14/// There are different ways to denote rate limits reset times.
15/// Some vendors use seconds, others use a timestamp format for example.
16///
17/// This enum lists all known variants.
18#[derive(Copy, Clone, Debug, PartialEq)]
19pub enum ResetTimeKind {
20    /// Number of seconds until rate limit is lifted
21    Seconds,
22    /// Unix timestamp when rate limit will be lifted
23    Timestamp,
24    /// RFC 2822 date when rate limit will be lifted
25    ImfFixdate,
26    /// ISO 8601 date when rate limit will be lifted
27    Iso8601,
28}
29
30/// Reset time of rate limiting
31///
32/// There are different variants on how to specify reset times
33/// in rate limit headers. The most common ones are seconds and datetime.
34#[derive(Copy, Clone, Debug, PartialEq)]
35pub enum ResetTime {
36    /// Number of seconds until rate limit is lifted
37    Seconds(usize),
38    /// Date when rate limit will be lifted
39    DateTime(OffsetDateTime),
40}
41
42impl ResetTime {
43    /// Create a new reset time from a header value and a reset time kind
44    ///
45    /// # Errors
46    ///
47    /// This function returns an error if the header value cannot be parsed
48    /// or if the reset time kind is unknown.
49    pub fn new(value: &HeaderValue, kind: ResetTimeKind) -> Result<Self> {
50        let value = value.to_str()?;
51        match kind {
52            ResetTimeKind::Seconds => Ok(ResetTime::Seconds(convert::to_usize(value)?)),
53            ResetTimeKind::Timestamp => Ok(Self::DateTime(
54                OffsetDateTime::from_unix_timestamp(convert::to_i64(value)?)
55                    .map_err(Error::Time)?,
56            )),
57            ResetTimeKind::Iso8601 => {
58                // https://github.com/time-rs/time/issues/378
59                let d = PrimitiveDateTime::parse(value, &Iso8601::PARSING).map_err(Error::Parse)?;
60                Ok(ResetTime::DateTime(d.assume_utc()))
61            }
62            ResetTimeKind::ImfFixdate => {
63                let d = PrimitiveDateTime::parse(value, &Rfc2822).map_err(Error::Parse)?;
64                Ok(ResetTime::DateTime(d.assume_utc()))
65            }
66        }
67    }
68
69    /// Get the number of seconds until the rate limit gets lifted.
70    #[must_use]
71    pub fn seconds(&self) -> usize {
72        match self {
73            ResetTime::Seconds(s) => *s,
74            // OffsetDateTime is not timezone aware, so we need to convert it to UTC
75            // and then convert it to seconds.
76            // There are no negative values in the seconds field, so we can safely
77            // cast it to usize.
78            #[allow(clippy::cast_possible_truncation)]
79            ResetTime::DateTime(d) => (*d - OffsetDateTime::now_utc()).whole_seconds() as usize,
80        }
81    }
82
83    /// Convert reset time to duration
84    #[must_use]
85    pub fn duration(&self) -> Duration {
86        match self {
87            ResetTime::Seconds(s) => Duration::seconds(*s as i64),
88            ResetTime::DateTime(d) => {
89                Duration::seconds((*d - OffsetDateTime::now_utc()).whole_seconds())
90            }
91        }
92    }
93}
94
95/// Known vendors of rate limit headers
96///
97/// Vendors use different rate limit header formats,
98/// which define how to parse them.
99#[derive(Copy, Clone, Debug, PartialEq)]
100pub enum Vendor {
101    /// Rate limit headers as defined in the `polli-ratelimit-headers-00` draft
102    Standard,
103    /// Reddit rate limit headers
104    Reddit,
105    /// Github API rate limit headers
106    Github,
107    /// Twitter API rate limit headers
108    Twitter,
109    /// Vimeo rate limit headers
110    Vimeo,
111    /// Gitlab rate limit headers
112    Gitlab,
113    /// Akamai rate limit headers
114    Akamai,
115}
116
117/// A variant defines all relevant fields for parsing headers from a given vendor
118#[derive(Clone, Debug, PartialEq)]
119pub struct RateLimitVariant {
120    /// Vendor of the rate limit headers (e.g. Github, Twitter, etc.)
121    pub vendor: Vendor,
122    /// Duration of the rate limit interval
123    pub duration: Option<Duration>,
124    /// Header name for the maximum number of requests
125    pub limit_header: Option<String>,
126    /// Header name for the number of used requests
127    pub used_header: Option<String>,
128    /// Header name for the number of remaining requests
129    pub remaining_header: String,
130    /// Header name for the reset time
131    pub reset_header: String,
132    /// Kind of reset time
133    pub reset_kind: ResetTimeKind,
134}
135
136impl RateLimitVariant {
137    /// Create a new rate limit variant
138    #[must_use]
139    pub const fn new(
140        vendor: Vendor,
141        duration: Option<Duration>,
142        limit_header: Option<String>,
143        used_header: Option<String>,
144        remaining_header: String,
145        reset_header: String,
146        reset_kind: ResetTimeKind,
147    ) -> Self {
148        Self {
149            vendor,
150            duration,
151            limit_header,
152            used_header,
153            remaining_header,
154            reset_header,
155            reset_kind,
156        }
157    }
158}
159
160/// A rate limit header
161#[derive(Clone, Copy, Debug, PartialEq)]
162pub struct Limit {
163    /// Maximum number of requests for the given interval
164    pub count: usize,
165}
166
167impl Limit {
168    /// Create a new limit header
169    ///
170    /// # Errors
171    ///
172    /// This function returns an error if the header value cannot be parsed
173    pub fn new<T: AsRef<str>>(value: T) -> Result<Self> {
174        Ok(Self {
175            count: convert::to_usize(value.as_ref())?,
176        })
177    }
178}
179
180impl From<usize> for Limit {
181    fn from(count: usize) -> Self {
182        Self { count }
183    }
184}
185
186/// A rate limit header for the number of used requests
187#[derive(Clone, Copy, Debug, PartialEq)]
188pub(crate) struct Used {
189    /// Number of used requests for the given interval
190    pub(crate) count: usize,
191}
192
193impl Used {
194    pub(crate) fn new(value: &str) -> Result<Self> {
195        Ok(Self {
196            count: convert::to_usize(value)?,
197        })
198    }
199}
200
201/// A rate limit header for the number of remaining requests
202#[derive(Clone, Copy, Debug, PartialEq)]
203pub struct Remaining {
204    /// Number of remaining requests for the given interval
205    pub count: usize,
206}
207
208impl Remaining {
209    /// Create a new remaining header
210    ///
211    /// # Errors
212    ///
213    /// This function returns an error if the header value cannot be parsed
214    pub fn new(value: &str) -> Result<Self> {
215        Ok(Self {
216            count: convert::to_usize(value)?,
217        })
218    }
219}
220
221pub(crate) trait HeaderMapExt {
222    fn from_raw(raw: &str) -> Result<HeaderMap>;
223}
224
225impl HeaderMapExt for HeaderMap {
226    fn from_raw(raw: &str) -> Result<HeaderMap> {
227        let mut headers = HeaderMap::new();
228
229        for line in raw.lines() {
230            if !line.contains(HEADER_SEPARATOR) {
231                return Err(Error::HeaderWithoutColon(line.to_string()));
232            }
233            if let Some((name, value)) = line.split_once(HEADER_SEPARATOR) {
234                headers.insert(
235                    HeaderName::from_str(name)?,
236                    HeaderValue::from_str(value.trim())?,
237                );
238            }
239        }
240        Ok(headers)
241    }
242}
243
244#[derive(Clone, Debug, PartialEq, Eq)]
245pub struct CaseSensitiveHeaderMap {
246    inner: HashMap<String, HeaderValue>,
247}
248
249impl Default for CaseSensitiveHeaderMap {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255impl CaseSensitiveHeaderMap {
256    pub fn new() -> Self {
257        Self {
258            inner: HashMap::new(),
259        }
260    }
261
262    pub fn insert(&mut self, name: String, value: HeaderValue) -> Option<HeaderValue> {
263        self.inner.insert(name, value)
264    }
265
266    pub fn get(&self, k: &str) -> Option<&HeaderValue> {
267        self.inner.get(k)
268    }
269}
270
271impl FromStr for CaseSensitiveHeaderMap {
272    type Err = Error;
273
274    fn from_str(headers: &str) -> Result<Self> {
275        Ok(CaseSensitiveHeaderMap {
276            inner: headers
277                .lines()
278                .filter_map(|line| line.split_once(HEADER_SEPARATOR))
279                .map(|(header, value)| {
280                    (
281                        header.to_string(),
282                        HeaderValue::from_str(value.trim()).unwrap(),
283                    )
284                })
285                .collect(),
286        })
287    }
288}
289
290impl From<HeaderMap> for CaseSensitiveHeaderMap {
291    fn from(headers: HeaderMap) -> Self {
292        let mut cs_map = CaseSensitiveHeaderMap::new();
293        for (name, value) in headers.iter() {
294            cs_map.insert(name.as_str().to_string(), value.clone());
295        }
296        cs_map
297    }
298}
299
300impl From<&HeaderMap> for CaseSensitiveHeaderMap {
301    fn from(headers: &HeaderMap) -> Self {
302        let mut cs_map = CaseSensitiveHeaderMap::new();
303        for (name, value) in headers.iter() {
304            cs_map.insert(name.as_str().to_string(), value.clone());
305        }
306        cs_map
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_convert_from_header_map() {
316        let mut headers = HeaderMap::new();
317        headers.insert("X-RateLimit-Limit", "100".parse().unwrap());
318        headers.insert("X-RateLimit-Remaining", "99".parse().unwrap());
319        headers.insert("X-RateLimit-Reset", "1234567890".parse().unwrap());
320
321        let cs_headers = CaseSensitiveHeaderMap::from(&headers);
322        assert_eq!(
323            cs_headers,
324            CaseSensitiveHeaderMap {
325                inner: vec![
326                    (
327                        "x-ratelimit-limit".to_string(),
328                        HeaderValue::from_static("100")
329                    ),
330                    (
331                        "x-ratelimit-remaining".to_string(),
332                        HeaderValue::from_static("99")
333                    ),
334                    (
335                        "x-ratelimit-reset".to_string(),
336                        HeaderValue::from_static("1234567890")
337                    )
338                ]
339                .into_iter()
340                .collect()
341            }
342        );
343    }
344}