use std::{collections::HashMap, sync::Arc};
use super::Fetcher;
use crate::{
error::KumoError,
extract::{Response, response::ResponseBody},
middleware::FetchRequest,
};
use reqwest::Client;
use tokio::sync::RwLock;
#[derive(Debug)]
pub struct HttpFetcher {
client: Client,
default_user_agent: String,
proxy_clients: Arc<RwLock<HashMap<String, Client>>>,
}
impl HttpFetcher {
pub fn new(client: Client, default_user_agent: impl Into<String>) -> Self {
Self {
client,
default_user_agent: default_user_agent.into(),
proxy_clients: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn client_for(&self, request: &FetchRequest) -> Result<Client, KumoError> {
let Some(ref proxy_url) = request.proxy else {
return Ok(self.client.clone());
};
{
let cache = self.proxy_clients.read().await;
if let Some(client) = cache.get(proxy_url) {
return Ok(client.clone());
}
}
let proxy = reqwest::Proxy::all(proxy_url.as_str()).map_err(KumoError::Fetch)?;
let new_client = Client::builder()
.cookie_store(true)
.user_agent(&self.default_user_agent)
.proxy(proxy)
.build()
.map_err(KumoError::Fetch)?;
let mut cache = self.proxy_clients.write().await;
Ok(cache.entry(proxy_url.clone()).or_insert(new_client).clone())
}
}
#[async_trait::async_trait]
impl Fetcher for HttpFetcher {
async fn fetch(&self, request: &FetchRequest) -> Result<Response, KumoError> {
let client = self.client_for(request).await?;
let mut builder = client.request(request.method.clone(), request.url());
for (name, value) in &request.headers {
builder = builder.header(name, value);
}
if let Some(body) = &request.body {
builder = builder.body(body.clone());
}
let start = std::time::Instant::now();
let res = builder.send().await.map_err(KumoError::Fetch)?;
let status = res.status().as_u16();
let headers = res.headers().clone();
let is_text = super::is_text_content_type(
headers
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok()),
);
let body = if is_text {
ResponseBody::Text(res.text().await.map_err(KumoError::Fetch)?)
} else {
ResponseBody::Bytes(res.bytes().await.map_err(KumoError::Fetch)?)
};
let elapsed = start.elapsed();
let byte_count = match &body {
ResponseBody::Text(s) => s.len() as u64,
ResponseBody::Bytes(b) => b.len() as u64,
};
tracing::debug!(
url = %request.url(),
status,
bytes = byte_count,
elapsed_ms = elapsed.as_millis(),
"response"
);
Ok(Response::new(
request.url().to_string(),
status,
headers,
elapsed,
body,
))
}
}