egg_mode_extras/limits/
mod.rs

1mod method_limit;
2mod stream;
3
4use stream::ResponseFuture;
5pub use stream::{Pageable, TimelineScrollback};
6
7use super::method::Method;
8use egg_mode::{
9    error::{Error, TwitterErrors},
10    service::rate_limit_status,
11    Token,
12};
13use futures::{stream::LocalBoxStream, Stream, StreamExt, TryStreamExt};
14use std::time::Duration;
15
16const OVER_CAPACITY_DELAY_SECS: u64 = 60;
17const OVER_CAPACITY_ERROR_CODE: i32 = 130;
18
19pub struct RateLimitTracker {
20    limits: method_limit::MethodLimitStore,
21    over_capacity_delay: Duration,
22}
23
24impl RateLimitTracker {
25    pub async fn new(token: Token) -> Result<Self, Error> {
26        let status = rate_limit_status(&token).await?.response;
27        let limits = method_limit::MethodLimitStore::from(status);
28        let over_capacity_delay = Duration::from_secs(OVER_CAPACITY_DELAY_SECS);
29
30        Ok(Self {
31            limits,
32            over_capacity_delay,
33        })
34    }
35
36    pub fn make_stream<'a, L: Pageable<'a> + 'a>(
37        &self,
38        loader: L,
39        method: &Method,
40    ) -> LocalBoxStream<'a, Result<L::Item, Error>> {
41        let limit = self.limits.get(method);
42        let over_capacity_delay = self.over_capacity_delay;
43
44        futures::stream::try_unfold(
45            (loader, false, false),
46            move |(mut this, is_done, is_over_capacity)| {
47                let limit = limit.clone();
48                async move {
49                    if is_done {
50                        let res: Result<Option<_>, Error> = Ok(None);
51                        res
52                    } else {
53                        if is_over_capacity {
54                            log::warn!(
55                                "Waiting for {:?} after over capacity error",
56                                over_capacity_delay
57                            );
58                            tokio::time::sleep(over_capacity_delay).await;
59                        }
60
61                        if let Some(delay) = limit.wait_duration() {
62                            log::warn!(
63                                "Waiting for {:?} for rate limit reset at {:?}",
64                                delay,
65                                limit.reset_time()
66                            );
67                            tokio::time::sleep(delay).await;
68                        }
69
70                        limit.decrement();
71                        let mut response = match this.load().await {
72                            Ok(response) => Ok(response),
73                            Err(Error::TwitterError(headers, TwitterErrors { errors })) => {
74                                if errors.len() == 1 && errors[0].code == OVER_CAPACITY_ERROR_CODE {
75                                    return Ok(Some((None, (this, false, true))));
76                                } else {
77                                    Err(Error::TwitterError(headers, TwitterErrors { errors }))
78                                }
79                            }
80                            Err(other) => Err(other),
81                        }?;
82
83                        let is_done = this.update(&mut response);
84
85                        limit.update(
86                            response.rate_limit_status.remaining,
87                            response.rate_limit_status.reset,
88                        );
89
90                        Ok(Some((
91                            Some(L::extract(response.response)),
92                            (this, is_done, false),
93                        )))
94                    }
95                }
96            },
97        )
98        .try_filter_map(futures::future::ok)
99        .map_ok(|items| futures::stream::iter(items).map(Ok))
100        .try_flatten()
101        .boxed_local()
102    }
103
104    pub fn wrap_stream<'a, T: 'a, S: Stream<Item = ResponseFuture<'a, T>> + 'a>(
105        &self,
106        stream: S,
107        method: &Method,
108        ignore_over_capacity_errors: bool,
109    ) -> LocalBoxStream<'a, Result<T, Error>> {
110        let limit = self.limits.get(method);
111        let over_capacity_delay = self.over_capacity_delay;
112
113        stream
114            .filter_map(move |future| {
115                let limit = limit.clone();
116                async move {
117                    if let Some(delay) = limit.wait_duration() {
118                        log::warn!(
119                            "Waiting for {:?} for rate limit reset at {:?}",
120                            delay,
121                            limit.reset_time()
122                        );
123                        tokio::time::sleep(delay).await;
124                    }
125
126                    limit.decrement();
127
128                    match future.await {
129                        Ok(response) => Some(Ok(response.response)),
130                        Err(Error::TwitterError(headers, TwitterErrors { errors }))
131                            if ignore_over_capacity_errors =>
132                        {
133                            if errors.len() == 1 && errors[0].code == OVER_CAPACITY_ERROR_CODE {
134                                log::warn!(
135                                    "Waiting for {:?} after over capacity error",
136                                    over_capacity_delay
137                                );
138                                tokio::time::sleep(over_capacity_delay).await;
139                                None
140                            } else {
141                                Some(Err(Error::TwitterError(headers, TwitterErrors { errors })))
142                            }
143                        }
144                        Err(error) => Some(Err(error)),
145                    }
146                }
147            })
148            .boxed_local()
149    }
150}