Skip to main content

twapi_v2/
retry.rs

1use reqwest::{RequestBuilder, StatusCode};
2use serde::de::DeserializeOwned;
3use std::time::Duration;
4use tokio::time::{sleep, timeout};
5
6use crate::{api::execute_twitter, error::Error, headers::Headers};
7
8pub trait RetryLogger {
9    fn log(&self, builder: &RequestBuilder);
10}
11
12pub async fn execute_retry_fn<T>(
13    f: impl Fn() -> RequestBuilder,
14    retry_count: usize,
15    retryable_status_codes: &[StatusCode],
16    retry_logger: Option<&impl RetryLogger>,
17    timeout_duration: Option<Duration>,
18    retry_delay_secound_count: Option<u64>,
19) -> Result<(T, Headers), Error>
20where
21    T: DeserializeOwned,
22{
23    let mut count: usize = 0;
24
25    loop {
26        let target = f();
27        if let Some(retry_logger) = retry_logger {
28            retry_logger.log(&target);
29        }
30
31        let error = if let Some(timeout_duration) = timeout_duration {
32            match timeout(timeout_duration, execute_twitter(target)).await {
33                Ok(res) => match res {
34                    Ok(res) => return Ok(res),
35                    Err(err) => match &err {
36                        Error::Twitter(twitter_error, _, _) => {
37                            if !retryable_status_codes.contains(&twitter_error.status_code) {
38                                return Err(err);
39                            }
40                            err
41                        }
42                        _ => return Err(err),
43                    },
44                },
45                Err(_) => Error::Timeout,
46            }
47        } else {
48            match execute_twitter(target).await {
49                Ok(res) => return Ok(res),
50                Err(err) => match &err {
51                    Error::Twitter(twitter_error, _, _) => {
52                        if !retryable_status_codes.contains(&twitter_error.status_code) {
53                            return Err(err);
54                        }
55                        err
56                    }
57                    _ => return Err(err),
58                },
59            }
60        };
61        if count >= retry_count {
62            return Err(error);
63        }
64        count += 1;
65        sleep_sec(retry_delay_secound_count, count).await;
66    }
67}
68
69pub async fn execute_retry<T>(
70    builder: RequestBuilder,
71    retry_count: usize,
72    retryable_status_codes: &[StatusCode],
73    retry_logger: Option<&impl RetryLogger>,
74    timeout_duration: Option<Duration>,
75    retry_delay_secound_count: Option<u64>,
76) -> Result<(T, Headers), Error>
77where
78    T: DeserializeOwned,
79{
80    let mut count: usize = 0;
81
82    loop {
83        let target = builder
84            .try_clone()
85            .ok_or(Error::Other("builder clone fail".to_owned(), None))?;
86        if let Some(retry_logger) = retry_logger {
87            retry_logger.log(&target);
88        }
89
90        let error = if let Some(timeout_duration) = timeout_duration {
91            match timeout(timeout_duration, execute_twitter(target)).await {
92                Ok(res) => match res {
93                    Ok(res) => return Ok(res),
94                    Err(err) => match &err {
95                        Error::Twitter(twitter_error, _, _) => {
96                            if !retryable_status_codes.contains(&twitter_error.status_code) {
97                                return Err(err);
98                            }
99                            err
100                        }
101                        _ => return Err(err),
102                    },
103                },
104                Err(_) => Error::Timeout,
105            }
106        } else {
107            match execute_twitter(target).await {
108                Ok(res) => return Ok(res),
109                Err(err) => match &err {
110                    Error::Twitter(twitter_error, _, _) => {
111                        if !retryable_status_codes.contains(&twitter_error.status_code) {
112                            return Err(err);
113                        }
114                        err
115                    }
116                    _ => return Err(err),
117                },
118            }
119        };
120        if count >= retry_count {
121            return Err(error);
122        }
123        count += 1;
124        sleep_sec(retry_delay_secound_count, count).await;
125    }
126}
127
128async fn sleep_sec(retry_delay_secound_count: Option<u64>, count: usize) {
129    let seconds = retry_delay_secound_count.unwrap_or(2_i64.pow(count as u32) as u64);
130    sleep(Duration::from_secs(seconds)).await;
131}
132
133#[cfg(test)]
134mod tests {
135    use std::time::Duration;
136
137    use reqwest::{RequestBuilder, StatusCode};
138
139    use crate::{
140        api::{
141            BearerAuthentication,
142            get_2_tweets_id::{Api, Response},
143            post_2_media_upload_initialize::{self, MediaCategory},
144        },
145        retry::{execute_retry, execute_retry_fn},
146    };
147
148    use super::RetryLogger;
149
150    struct Logger;
151    impl RetryLogger for Logger {
152        fn log(&self, builder: &RequestBuilder) {
153            println!("{:?}", builder);
154        }
155    }
156
157    // BEARER_CODE=XXXXX TWEET_ID=XXXX cargo test --features retry -- --nocapture
158
159    #[tokio::test]
160    async fn it_works() {
161        let bearer_code = std::env::var("BEARER_CODE").unwrap_or_default();
162        let tweet_id = std::env::var("TWEET_ID").unwrap_or_default();
163        let logger = Logger {};
164
165        let auth = BearerAuthentication::new(bearer_code);
166        let builder: RequestBuilder = Api::open(&tweet_id).build(&auth);
167
168        let res = execute_retry::<Response>(
169            builder,
170            2,
171            &vec![StatusCode::UNAUTHORIZED],
172            Some(&logger),
173            Some(Duration::from_secs(10)),
174            None,
175        )
176        .await;
177        match res {
178            Ok(res) => {
179                println!("{:?}", res);
180            }
181            _ => {}
182        }
183    }
184
185    // BEARER_CODE=XXXXX TWEET_ID=XXXX cargo test --features retry -- --nocapture
186
187    #[tokio::test]
188    async fn it_works_fn() {
189        let bearer_code = std::env::var("BEARER_CODE").unwrap_or_default();
190        let bearer_auth = BearerAuthentication::new(bearer_code);
191
192        let logger = Logger {};
193
194        let res = execute_retry_fn::<post_2_media_upload_initialize::Response>(
195            || {
196                let body = post_2_media_upload_initialize::Body {
197                    media_category: Some(MediaCategory::TweetImage),
198                    media_type: "image/jpeg".to_string(),
199                    total_bytes: 10000,
200                    ..Default::default()
201                };
202                post_2_media_upload_initialize::Api::new(body).build(&bearer_auth)
203            },
204            2,
205            &vec![StatusCode::UNAUTHORIZED],
206            Some(&logger),
207            None,
208            None,
209        )
210        .await
211        .unwrap();
212
213        println!("{:?}", res.0);
214    }
215}