1use crate::{LastFmError, Result};
2use std::future::Future;
3
4#[derive(Debug, Clone)]
6pub struct RetryConfig {
7 pub max_retries: u32,
9 pub base_delay: u64,
11 pub max_delay: u64,
13}
14
15impl Default for RetryConfig {
16 fn default() -> Self {
17 Self {
18 max_retries: 3,
19 base_delay: 5,
20 max_delay: 300, }
22 }
23}
24
25#[derive(Debug)]
27pub struct RetryResult<T> {
28 pub result: T,
30 pub attempts_made: u32,
32 pub total_retry_time: u64,
34}
35
36pub async fn retry_with_backoff<T, F, Fut, OnRateLimit>(
50 config: RetryConfig,
51 operation_name: &str,
52 mut operation: F,
53 mut on_rate_limit: OnRateLimit,
54) -> Result<RetryResult<T>>
55where
56 F: FnMut() -> Fut,
57 Fut: Future<Output = Result<T>>,
58 OnRateLimit: FnMut(u64, &str),
59{
60 let mut retries = 0;
61 let mut total_retry_time = 0;
62
63 loop {
64 match operation().await {
65 Ok(result) => {
66 return Ok(RetryResult {
67 result,
68 attempts_made: retries,
69 total_retry_time,
70 });
71 }
72 Err(LastFmError::RateLimit { retry_after }) => {
73 if retries >= config.max_retries {
74 log::warn!(
75 "Max retries ({}) exceeded for {} operation",
76 config.max_retries,
77 operation_name
78 );
79 return Err(LastFmError::RateLimit { retry_after });
80 }
81
82 let base_backoff = config.base_delay * 2_u64.pow(retries);
84 let delay = std::cmp::min(
85 std::cmp::min(retry_after + base_backoff, config.max_delay),
86 retry_after + (retries as u64 * 30), );
88
89 log::info!(
90 "{} rate limited. Waiting {} seconds before retry {} of {}",
91 operation_name,
92 delay,
93 retries + 1,
94 config.max_retries
95 );
96
97 on_rate_limit(delay, operation_name);
99
100 tokio::time::sleep(std::time::Duration::from_secs(delay)).await;
101 retries += 1;
102 total_retry_time += delay;
103 }
104 Err(other_error) => {
105 return Err(other_error);
106 }
107 }
108 }
109}
110
111pub async fn retry_operation<T, F, Fut>(
113 config: RetryConfig,
114 operation_name: &str,
115 operation: F,
116) -> Result<RetryResult<T>>
117where
118 F: FnMut() -> Fut,
119 Fut: Future<Output = Result<T>>,
120{
121 retry_with_backoff(config, operation_name, operation, |delay, op_name| {
122 log::debug!("Rate limited during {op_name}: waiting {delay} seconds");
123 })
124 .await
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use std::sync::atomic::{AtomicU32, Ordering};
131 use std::sync::Arc;
132
133 #[tokio::test]
134 async fn test_successful_operation() {
135 let config = RetryConfig {
136 max_retries: 3,
137 base_delay: 1,
138 max_delay: 60,
139 };
140
141 let result = retry_operation(config, "test", || async { Ok::<i32, LastFmError>(42) }).await;
142
143 assert!(result.is_ok());
144 let retry_result = result.unwrap();
145 assert_eq!(retry_result.result, 42);
146 assert_eq!(retry_result.attempts_made, 0);
147 assert_eq!(retry_result.total_retry_time, 0);
148 }
149
150 #[tokio::test]
151 async fn test_retry_on_rate_limit() {
152 let config = RetryConfig {
153 max_retries: 2,
154 base_delay: 1,
155 max_delay: 60,
156 };
157
158 let call_count = Arc::new(AtomicU32::new(0));
159 let call_count_clone = call_count.clone();
160
161 let result = retry_operation(config, "test", move || {
162 let count = call_count_clone.fetch_add(1, Ordering::SeqCst);
163 async move {
164 if count < 2 {
165 Err(LastFmError::RateLimit { retry_after: 1 })
166 } else {
167 Ok::<i32, LastFmError>(42)
168 }
169 }
170 })
171 .await;
172
173 assert!(result.is_ok());
174 let retry_result = result.unwrap();
175 assert_eq!(retry_result.result, 42);
176 assert_eq!(retry_result.attempts_made, 2);
177 assert!(retry_result.total_retry_time >= 2); }
179
180 #[tokio::test]
181 async fn test_max_retries_exceeded() {
182 let config = RetryConfig {
183 max_retries: 1,
184 base_delay: 1,
185 max_delay: 60,
186 };
187
188 let result = retry_operation(config, "test", || async {
189 Err::<i32, LastFmError>(LastFmError::RateLimit { retry_after: 1 })
190 })
191 .await;
192
193 assert!(result.is_err());
194 match result.unwrap_err() {
195 LastFmError::RateLimit { .. } => {} other => panic!("Expected rate limit error, got: {other:?}"),
197 }
198 }
199}