use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tower::Service;
use super::types::{CrawlRequest, CrawlResponse};
use crate::error::{CrawlError, classify_reqwest_error};
use crate::types::CrawlConfig;
#[derive(Clone)]
pub struct HttpFetchService {
client: reqwest::Client,
config: Arc<CrawlConfig>,
}
impl HttpFetchService {
pub fn new(client: reqwest::Client, config: CrawlConfig) -> Self {
Self {
client,
config: Arc::new(config),
}
}
}
fn is_retryable(e: &CrawlError) -> bool {
matches!(
e,
CrawlError::ServerError(_) | CrawlError::RateLimited(_) | CrawlError::BadGateway(_)
)
}
async fn do_fetch(
client: &reqwest::Client,
config: &CrawlConfig,
req: &CrawlRequest,
) -> Result<CrawlResponse, CrawlError> {
let mut http_req = client.get(&req.url);
if !req.headers.contains_key("user-agent") {
if let Some(ref ua) = config.user_agent {
http_req = http_req.header(reqwest::header::USER_AGENT, ua.as_str());
} else {
http_req = http_req.header(
reqwest::header::USER_AGENT,
concat!("kreuzcrawl/", env!("CARGO_PKG_VERSION")),
);
}
}
if let Some(ref auth) = config.auth {
match auth {
crate::types::AuthConfig::Basic { username, password } => {
http_req = http_req.basic_auth(username, Some(password));
}
crate::types::AuthConfig::Bearer { token } => {
http_req = http_req.bearer_auth(token);
}
crate::types::AuthConfig::Header { name, value } => {
http_req = http_req.header(name.as_str(), value.as_str());
}
}
}
for (k, v) in &config.custom_headers {
http_req = http_req.header(k.as_str(), v.as_str());
}
for (k, v) in &req.headers {
http_req = http_req.header(k.as_str(), v.as_str());
}
let resp = http_req.send().await.map_err(|e| classify_reqwest_error(&e))?;
let status = resp.status().as_u16();
let content_type = resp
.headers()
.get_all(reqwest::header::CONTENT_TYPE)
.iter()
.next_back()
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_owned();
let mut headers: HashMap<String, Vec<String>> = HashMap::new();
for (name, value) in resp.headers().iter() {
if let Ok(v) = value.to_str() {
headers.entry(name.to_string()).or_default().push(v.to_string());
}
}
match status {
401 => return Err(CrawlError::Unauthorized("unauthorized".into())),
403 => {
let server = headers
.get("server")
.and_then(|v| v.first())
.map(|s| s.to_lowercase())
.unwrap_or_default();
let body = resp.text().await.unwrap_or_default();
if crate::http::is_waf_blocked(&server, &body, &headers) {
let vendor = crate::http::detect_waf_vendor(&server, &body.to_lowercase());
return Err(CrawlError::WafBlocked(format!("waf/blocked detected: {vendor}")));
}
return Err(CrawlError::Forbidden("forbidden".into()));
}
404 => return Err(CrawlError::NotFound("not_found".into())),
408 => return Err(CrawlError::Timeout("timeout".into())),
410 => return Err(CrawlError::Gone("gone".into())),
429 => return Err(CrawlError::RateLimited("rate_limited".into())),
500 => return Err(CrawlError::ServerError("server_error".into())),
502 => return Err(CrawlError::BadGateway("bad_gateway".into())),
503 => {
return Err(CrawlError::ServerError("service unavailable".into()));
}
_ => {}
}
let body_bytes = resp.bytes().await.map_err(|e| {
let chain = crate::error::error_chain_string(&e);
let is_body_error = chain.contains("content-length")
|| chain.contains("truncat")
|| chain.contains("incomplete")
|| chain.contains("end of file")
|| chain.contains("body error")
|| chain.contains("body from connection")
|| chain.contains("decoding response body")
|| chain.contains("error decoding");
#[cfg(not(target_arch = "wasm32"))]
let is_body_error = is_body_error || e.is_body();
if is_body_error {
CrawlError::DataLoss(format!("data_loss: {e}"))
} else {
classify_reqwest_error(&e)
}
})?;
let body_vec = body_bytes.to_vec();
if let Some(expected) = headers
.get("content-length")
.and_then(|v| v.first())
.and_then(|s| s.parse::<usize>().ok())
&& body_vec.len() < expected
&& expected - body_vec.len() > 100
{
return Err(CrawlError::DataLoss(format!(
"data_loss: expected {} bytes, got {}",
expected,
body_vec.len()
)));
}
let body = String::from_utf8_lossy(&body_vec).into_owned();
#[cfg(not(target_arch = "wasm32"))]
{
let server = headers
.get("server")
.and_then(|v| v.first())
.map(|s| s.to_lowercase())
.unwrap_or_default();
if status == 200 && body.len() < 5000 && crate::http::is_waf_blocked(&server, &body, &headers) {
return Err(CrawlError::WafBlocked(format!(
"waf/blocked: {status} with challenge content"
)));
}
}
Ok(CrawlResponse {
status,
content_type,
body,
body_bytes: body_vec,
headers,
})
}
impl Service<CrawlRequest> for HttpFetchService {
type Response = CrawlResponse;
type Error = CrawlError;
type Future = Pin<Box<dyn Future<Output = Result<CrawlResponse, CrawlError>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: CrawlRequest) -> Self::Future {
let client = self.client.clone();
let config = self.config.clone();
Box::pin(async move {
let retry_count = config.retry_count;
let retry_codes = &config.retry_codes;
for attempt in 0..=retry_count {
match do_fetch(&client, &config, &req).await {
Ok(resp) => {
if retry_codes.contains(&resp.status) && attempt < retry_count {
tokio::time::sleep(Duration::from_millis(100 * (1 << attempt))).await;
continue;
}
return Ok(resp);
}
Err(e) if is_retryable(&e) && attempt < retry_count => {
tokio::time::sleep(Duration::from_millis(100 * (1 << attempt))).await;
continue;
}
Err(e) => return Err(e),
}
}
Err(CrawlError::Other("retry exhausted".into()))
})
}
}