use governor::{
Quota, RateLimiter, clock::DefaultClock, middleware::NoOpMiddleware, state::InMemoryState,
state::NotKeyed,
};
use reqwest::header::{HeaderMap, HeaderValue, USER_AGENT};
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use super::config::{EdgarConfig, EdgarUrls};
use super::error::{EdgarError, Result};
const MAX_RETRIES: u32 = 5;
const INITIAL_BACKOFF_MS: u64 = 1000;
type Governor = RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
#[derive(Debug, Clone)]
pub struct Edgar {
pub(crate) client: reqwest::Client,
pub(crate) rate_limiter: Arc<Governor>,
pub(crate) edgar_archives_url: String,
pub(crate) edgar_data_url: String,
pub(crate) edgar_files_url: String,
pub(crate) edgar_search_url: String,
}
impl Edgar {
pub fn new(user_agent: &str) -> Result<Self> {
let config = EdgarConfig {
user_agent: user_agent.to_string(),
rate_limit: 10,
timeout: Duration::from_secs(30),
base_urls: EdgarUrls::default(),
};
Self::with_config(config)
}
pub fn with_config(config: EdgarConfig) -> Result<Self> {
let mut headers = HeaderMap::new();
headers.insert(
USER_AGENT,
HeaderValue::from_str(&config.user_agent)
.map_err(|e| EdgarError::ConfigError(format!("Invalid user agent: {}", e)))?,
);
let client = reqwest::Client::builder()
.default_headers(headers)
.timeout(config.timeout)
.build()
.map_err(|e| EdgarError::ConfigError(format!("Failed to build HTTP client: {}", e)))?;
let rate_limiter = Arc::new(RateLimiter::direct(Quota::per_second(
NonZeroU32::new(config.rate_limit).ok_or_else(|| {
EdgarError::ConfigError("Rate limit must be greater than zero".to_string())
})?,
)));
Ok(Edgar {
client,
rate_limiter,
edgar_archives_url: config.base_urls.archives,
edgar_data_url: config.base_urls.data,
edgar_files_url: config.base_urls.files,
edgar_search_url: config.base_urls.search,
})
}
fn calculate_backoff(retry: u32) -> Duration {
let backoff_ms = INITIAL_BACKOFF_MS * (2_u64.pow(retry));
let jitter = (backoff_ms as f64 * 0.2 * (fastrand::f64() - 0.5)) as i64;
Duration::from_millis((backoff_ms as i64 + jitter) as u64)
}
pub async fn get_bytes(&self, url: &str) -> Result<Vec<u8>> {
let mut retries = 0;
loop {
self.rate_limiter.until_ready().await;
let response = self
.client
.get(url)
.send()
.await
.map_err(EdgarError::RequestError)?;
match response.status() {
reqwest::StatusCode::OK => {
return response
.bytes()
.await
.map(|b| b.to_vec())
.map_err(EdgarError::RequestError);
}
reqwest::StatusCode::NOT_FOUND => {
return Err(EdgarError::NotFound);
}
reqwest::StatusCode::TOO_MANY_REQUESTS => {
if retries >= MAX_RETRIES {
return Err(EdgarError::RateLimitExceeded);
}
let retry_after = Self::calculate_backoff(retries);
sleep(retry_after).await;
retries += 1;
continue;
}
status => {
return Err(EdgarError::InvalidResponse(format!(
"Unexpected status code: {}",
status
)));
}
}
}
}
pub async fn get(&self, url: &str) -> Result<String> {
let mut retries = 0;
loop {
self.rate_limiter.until_ready().await;
let response_result = self.client.get(url).send().await;
match response_result {
Ok(response) => {
let status = response.status();
let headers = response.headers().clone();
if url.ends_with(".json") && status.is_success() {
if let Some(ct) = headers
.get(reqwest::header::CONTENT_TYPE)
.and_then(|val| val.to_str().ok())
{
if ct.to_lowercase().contains("text/html") {
let body_text = response
.text()
.await
.unwrap_or_else(|_| "Failed to read response body".to_string());
if body_text.trim_start().starts_with('{')
|| body_text.trim_start().starts_with('[')
{
tracing::warn!(
"Received text/html content-type for .json URL, but content appears to be JSON: {}",
url
);
return Ok(body_text);
}
let body_preview = body_text.chars().take(200).collect::<String>();
return Err(EdgarError::UnexpectedContentType {
url: url.to_string(),
expected_pattern: "application/json".to_string(),
got_content_type: ct.to_string(),
content_preview: body_preview,
});
}
}
}
match status {
reqwest::StatusCode::OK => {
return response.text().await.map_err(EdgarError::RequestError);
}
reqwest::StatusCode::NOT_FOUND => {
return Err(EdgarError::NotFound);
}
reqwest::StatusCode::TOO_MANY_REQUESTS => {
if retries >= MAX_RETRIES {
return Err(EdgarError::RateLimitExceeded);
}
let retry_after_duration = headers
.get("retry-after")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_secs)
.unwrap_or_else(|| Self::calculate_backoff(retries));
tracing::warn!(
"Rate limit hit (429) for {}. Attempt {}/{}. Waiting for {:?} before retry.",
url,
retries + 1,
MAX_RETRIES + 1, retry_after_duration
);
sleep(retry_after_duration).await;
retries += 1;
continue; }
other_status => {
let error_body = response
.text()
.await
.unwrap_or_else(|_| "Failed to read error body".to_string());
return Err(EdgarError::InvalidResponse(format!(
"Unexpected status code: {} for URL: {}. Response preview: {}",
other_status,
url,
error_body.chars().take(200).collect::<String>()
)));
}
}
}
Err(e) => {
if retries >= MAX_RETRIES {
return Err(EdgarError::RequestError(e));
}
let backoff_duration = Self::calculate_backoff(retries);
tracing::warn!(
"Request failed for {}: {:?}. Attempt {}/{}. Retrying in {:?}.",
url,
e,
retries + 1,
MAX_RETRIES + 1, backoff_duration
);
sleep(backoff_duration).await;
retries += 1;
continue; }
}
}
}
pub fn archives_url(&self) -> &str {
&self.edgar_archives_url
}
pub fn data_url(&self) -> &str {
&self.edgar_data_url
}
pub fn files_url(&self) -> &str {
&self.edgar_files_url
}
pub fn search_url(&self) -> &str {
&self.edgar_search_url
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_backoff() {
let backoff0 = Edgar::calculate_backoff(0);
let backoff1 = Edgar::calculate_backoff(1);
let backoff2 = Edgar::calculate_backoff(2);
assert!(backoff0 < backoff1);
assert!(backoff1 < backoff2);
assert!(backoff0.as_millis() >= 800 && backoff0.as_millis() <= 1200); assert!(backoff1.as_millis() >= 1600 && backoff1.as_millis() <= 2400); assert!(backoff2.as_millis() >= 3200 && backoff2.as_millis() <= 4800); }
}