use async_trait::async_trait;
use regex::Regex;
use crate::{error::KumoError, extract::Response};
use super::{FetchRequest, Middleware};
const DEFAULT_RETRY_CODES: &[u16] = &[429, 500, 502, 503, 504];
pub struct StatusRetry {
codes: Vec<u16>,
patterns: Vec<(Regex, Vec<u16>)>,
}
impl StatusRetry {
pub fn new() -> Self {
Self {
codes: DEFAULT_RETRY_CODES.to_vec(),
patterns: Vec::new(),
}
}
pub fn with_codes(codes: Vec<u16>) -> Self {
Self {
codes,
patterns: Vec::new(),
}
}
pub fn for_pattern(mut self, pattern: &str, codes: Vec<u16>) -> Self {
let re = Regex::new(pattern)
.unwrap_or_else(|e| panic!("invalid StatusRetry pattern '{pattern}': {e}"));
self.patterns.push((re, codes));
self
}
}
impl std::fmt::Debug for StatusRetry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StatusRetry")
.field("global_codes", &self.codes)
.field("pattern_rules", &self.patterns.len())
.finish()
}
}
impl Default for StatusRetry {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for StatusRetry {
async fn before_request(&self, _request: &mut FetchRequest) -> Result<(), KumoError> {
Ok(())
}
async fn after_response(&self, response: &mut Response) -> Result<(), KumoError> {
for (pattern, codes) in &self.patterns {
if pattern.is_match(response.url()) {
return if codes.contains(&response.status()) {
Err(KumoError::HttpStatus {
status: response.status(),
url: response.url().to_string(),
})
} else {
Ok(()) };
}
}
if self.codes.contains(&response.status()) {
return Err(KumoError::HttpStatus {
status: response.status(),
url: response.url().to_string(),
});
}
Ok(())
}
}