1use reqwest::{RequestBuilder, StatusCode};
2use serde::de::DeserializeOwned;
3use std::time::Duration;
4use tokio::time::{sleep, timeout};
5
6use crate::{api::execute_twitter, error::Error, headers::Headers};
7
8pub trait RetryLogger {
9 fn log(&self, builder: &RequestBuilder);
10}
11
12pub async fn execute_retry_fn<T>(
13 f: impl Fn() -> RequestBuilder,
14 retry_count: usize,
15 retryable_status_codes: &[StatusCode],
16 retry_logger: Option<&impl RetryLogger>,
17 timeout_duration: Option<Duration>,
18 retry_delay_secound_count: Option<u64>,
19) -> Result<(T, Headers), Error>
20where
21 T: DeserializeOwned,
22{
23 let mut count: usize = 0;
24
25 loop {
26 let target = f();
27 if let Some(retry_logger) = retry_logger {
28 retry_logger.log(&target);
29 }
30
31 let error = if let Some(timeout_duration) = timeout_duration {
32 match timeout(timeout_duration, execute_twitter(target)).await {
33 Ok(res) => match res {
34 Ok(res) => return Ok(res),
35 Err(err) => match &err {
36 Error::Twitter(twitter_error, _, _) => {
37 if !retryable_status_codes.contains(&twitter_error.status_code) {
38 return Err(err);
39 }
40 err
41 }
42 _ => return Err(err),
43 },
44 },
45 Err(_) => Error::Timeout,
46 }
47 } else {
48 match execute_twitter(target).await {
49 Ok(res) => return Ok(res),
50 Err(err) => match &err {
51 Error::Twitter(twitter_error, _, _) => {
52 if !retryable_status_codes.contains(&twitter_error.status_code) {
53 return Err(err);
54 }
55 err
56 }
57 _ => return Err(err),
58 },
59 }
60 };
61 if count >= retry_count {
62 return Err(error);
63 }
64 count += 1;
65 sleep_sec(retry_delay_secound_count, count).await;
66 }
67}
68
69pub async fn execute_retry<T>(
70 builder: RequestBuilder,
71 retry_count: usize,
72 retryable_status_codes: &[StatusCode],
73 retry_logger: Option<&impl RetryLogger>,
74 timeout_duration: Option<Duration>,
75 retry_delay_secound_count: Option<u64>,
76) -> Result<(T, Headers), Error>
77where
78 T: DeserializeOwned,
79{
80 let mut count: usize = 0;
81
82 loop {
83 let target = builder
84 .try_clone()
85 .ok_or(Error::Other("builder clone fail".to_owned(), None))?;
86 if let Some(retry_logger) = retry_logger {
87 retry_logger.log(&target);
88 }
89
90 let error = if let Some(timeout_duration) = timeout_duration {
91 match timeout(timeout_duration, execute_twitter(target)).await {
92 Ok(res) => match res {
93 Ok(res) => return Ok(res),
94 Err(err) => match &err {
95 Error::Twitter(twitter_error, _, _) => {
96 if !retryable_status_codes.contains(&twitter_error.status_code) {
97 return Err(err);
98 }
99 err
100 }
101 _ => return Err(err),
102 },
103 },
104 Err(_) => Error::Timeout,
105 }
106 } else {
107 match execute_twitter(target).await {
108 Ok(res) => return Ok(res),
109 Err(err) => match &err {
110 Error::Twitter(twitter_error, _, _) => {
111 if !retryable_status_codes.contains(&twitter_error.status_code) {
112 return Err(err);
113 }
114 err
115 }
116 _ => return Err(err),
117 },
118 }
119 };
120 if count >= retry_count {
121 return Err(error);
122 }
123 count += 1;
124 sleep_sec(retry_delay_secound_count, count).await;
125 }
126}
127
128async fn sleep_sec(retry_delay_secound_count: Option<u64>, count: usize) {
129 let seconds = retry_delay_secound_count.unwrap_or(2_i64.pow(count as u32) as u64);
130 sleep(Duration::from_secs(seconds)).await;
131}
132
133#[cfg(test)]
134mod tests {
135 use std::time::Duration;
136
137 use reqwest::{RequestBuilder, StatusCode};
138
139 use crate::{
140 api::{
141 BearerAuthentication,
142 get_2_tweets_id::{Api, Response},
143 post_2_media_upload_initialize::{self, MediaCategory},
144 },
145 retry::{execute_retry, execute_retry_fn},
146 };
147
148 use super::RetryLogger;
149
150 struct Logger;
151 impl RetryLogger for Logger {
152 fn log(&self, builder: &RequestBuilder) {
153 println!("{:?}", builder);
154 }
155 }
156
157 #[tokio::test]
160 async fn it_works() {
161 let bearer_code = std::env::var("BEARER_CODE").unwrap_or_default();
162 let tweet_id = std::env::var("TWEET_ID").unwrap_or_default();
163 let logger = Logger {};
164
165 let auth = BearerAuthentication::new(bearer_code);
166 let builder: RequestBuilder = Api::open(&tweet_id).build(&auth);
167
168 let res = execute_retry::<Response>(
169 builder,
170 2,
171 &vec![StatusCode::UNAUTHORIZED],
172 Some(&logger),
173 Some(Duration::from_secs(10)),
174 None,
175 )
176 .await;
177 match res {
178 Ok(res) => {
179 println!("{:?}", res);
180 }
181 _ => {}
182 }
183 }
184
185 #[tokio::test]
188 async fn it_works_fn() {
189 let bearer_code = std::env::var("BEARER_CODE").unwrap_or_default();
190 let bearer_auth = BearerAuthentication::new(bearer_code);
191
192 let logger = Logger {};
193
194 let res = execute_retry_fn::<post_2_media_upload_initialize::Response>(
195 || {
196 let body = post_2_media_upload_initialize::Body {
197 media_category: Some(MediaCategory::TweetImage),
198 media_type: "image/jpeg".to_string(),
199 total_bytes: 10000,
200 ..Default::default()
201 };
202 post_2_media_upload_initialize::Api::new(body).build(&bearer_auth)
203 },
204 2,
205 &vec![StatusCode::UNAUTHORIZED],
206 Some(&logger),
207 None,
208 None,
209 )
210 .await
211 .unwrap();
212
213 println!("{:?}", res.0);
214 }
215}