use super::config::FetcherConfig;
use super::error::UrlFetcherError;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use texting_robots::Robot;
use tokio::sync::Mutex;
use url::Url;
#[derive(Debug, Clone)]
pub struct FetchResult {
pub bytes: Vec<u8>,
pub content_type: String,
pub url: String,
}
const ROBOTS_CACHE_TTL: Duration = Duration::from_secs(3600);
const ROBOTS_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
struct RobotsCacheEntry {
robot: Robot,
crawl_delay: Option<Duration>,
fetched_at: Instant,
}
pub struct UrlFetcher {
client: Arc<Client>,
config: FetcherConfig,
robots_cache: Arc<Mutex<HashMap<String, RobotsCacheEntry>>>,
last_fetch: Arc<Mutex<HashMap<String, Instant>>>,
}
impl UrlFetcher {
pub fn new() -> Result<Self, UrlFetcherError> {
Self::with_config(FetcherConfig::default())
}
pub fn with_config(config: FetcherConfig) -> Result<Self, UrlFetcherError> {
let client = Client::builder()
.timeout(config.timeout)
.user_agent(&config.user_agent)
.redirect(if config.follow_redirects {
reqwest::redirect::Policy::limited(config.max_redirects)
} else {
reqwest::redirect::Policy::none()
})
.build()
.map_err(|e| UrlFetcherError::HttpError(e.to_string()))?;
Ok(Self {
client: Arc::new(client),
config,
robots_cache: Arc::new(Mutex::new(HashMap::new())),
last_fetch: Arc::new(Mutex::new(HashMap::new())),
})
}
pub async fn fetch_with_metadata(&self, url: &str) -> Result<FetchResult, UrlFetcherError> {
let parsed_url = Url::parse(url)?;
if self.config.respect_robots_txt {
self.check_robots_txt(&parsed_url).await?;
}
let retry_config = cognee_utils::RetryConfig {
max_retries: 2,
initial_delay_ms: 500,
max_delay_ms: 10_000,
backoff_multiplier: 2.0,
jitter_factor: None,
};
let client = Arc::clone(&self.client);
let url_owned = url.to_string();
let parsed_for_rate = parsed_url.clone();
let fetcher = self;
cognee_utils::retry_with_backoff(
retry_config,
|| {
let client = Arc::clone(&client);
let url = url_owned.clone();
let parsed = parsed_for_rate.clone();
async move {
fetcher.respect_rate_limit(&parsed).await;
let response = client
.get(&url)
.send()
.await
.map_err(UrlFetcherError::from)?;
let status = response.status();
if !status.is_success() {
return Err(UrlFetcherError::HttpStatus(
status.as_u16(),
format!("Failed to fetch URL: {url}"),
));
}
let final_url = response.url().to_string();
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let bytes = response
.bytes()
.await
.map_err(|e| UrlFetcherError::HttpError(e.to_string()))?
.to_vec();
Ok(FetchResult {
bytes,
content_type,
url: final_url,
})
}
},
should_retry,
)
.await
}
pub async fn fetch(&self, url: &str) -> Result<String, UrlFetcherError> {
let result = self.fetch_with_metadata(url).await?;
String::from_utf8(result.bytes)
.map_err(|e| UrlFetcherError::ParseError(format!("Invalid UTF-8 response: {e}")))
}
pub async fn fetch_streaming<F, Fut, E>(
&self,
url: &str,
mut callback: F,
) -> Result<(), UrlFetcherError>
where
F: FnMut(&[u8]) -> Fut,
Fut: std::future::Future<Output = Result<(), E>>,
E: From<UrlFetcherError> + From<std::io::Error>,
{
use futures_util::StreamExt;
let parsed_url = Url::parse(url)?;
if self.config.respect_robots_txt {
self.check_robots_txt(&parsed_url).await?;
}
self.respect_rate_limit(&parsed_url).await;
let response = self.client.get(url).send().await?;
let status = response.status();
if !status.is_success() {
return Err(UrlFetcherError::HttpStatus(
status.as_u16(),
format!("Failed to fetch URL: {url}"),
));
}
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result
.map_err(|e: reqwest::Error| UrlFetcherError::HttpError(e.to_string()))?;
callback(&chunk)
.await
.map_err(|_e| UrlFetcherError::from(std::io::Error::other("Callback error")))?;
}
Ok(())
}
async fn check_robots_txt(&self, url: &Url) -> Result<(), UrlFetcherError> {
let origin = url.origin().unicode_serialization();
let robot_allowed = {
let mut cache = self.robots_cache.lock().await;
if let Some(entry) = cache.get(&origin)
&& entry.fetched_at.elapsed() >= ROBOTS_CACHE_TTL
{
cache.remove(&origin);
}
if let Some(entry) = cache.get(&origin) {
entry.robot.allowed(url.as_str())
} else {
drop(cache);
let (robot, crawl_delay) = self.fetch_robots_txt(&origin).await;
let allowed = robot.allowed(url.as_str());
let mut cache = self.robots_cache.lock().await;
cache.entry(origin).or_insert(RobotsCacheEntry {
robot,
crawl_delay,
fetched_at: Instant::now(),
});
allowed
}
};
if robot_allowed {
Ok(())
} else {
Err(UrlFetcherError::RobotsDisallowed(url.to_string()))
}
}
async fn fetch_robots_txt(&self, origin: &str) -> (Robot, Option<Duration>) {
let robots_url = format!("{origin}/robots.txt");
let body =
match tokio::time::timeout(ROBOTS_FETCH_TIMEOUT, self.client.get(&robots_url).send())
.await
{
Ok(Ok(resp)) if resp.status().is_success() => {
resp.bytes().await.map(|b| b.to_vec()).unwrap_or_default()
}
_ => {
Vec::new()
}
};
let robot = Robot::new(&self.config.user_agent, &body).unwrap_or_else(|_| {
#[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
Robot::new(&self.config.user_agent, b"").expect("empty robots.txt should always parse")
});
let crawl_delay = robot.delay.map(|secs| {
let d = Duration::from_secs_f32(secs);
d.min(self.config.max_crawl_delay)
});
(robot, crawl_delay)
}
async fn respect_rate_limit(&self, url: &Url) {
let origin = url.origin().unicode_serialization();
let robots_delay = {
let cache = self.robots_cache.lock().await;
cache.get(&origin).and_then(|entry| entry.crawl_delay)
};
let effective_delay = robots_delay.unwrap_or(self.config.crawl_delay);
let mut last = self.last_fetch.lock().await;
if let Some(prev) = last.get(&origin) {
let elapsed = prev.elapsed();
if elapsed < effective_delay {
let wait = effective_delay - elapsed;
drop(last);
tokio::time::sleep(wait).await;
last = self.last_fetch.lock().await;
}
}
last.insert(origin, Instant::now());
}
pub async fn get_content_type(&self, url: &str) -> Result<String, UrlFetcherError> {
let response = self.client.head(url).send().await?;
Ok(response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("text/html")
.to_string())
}
}
impl Default for UrlFetcher {
fn default() -> Self {
#[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
Self::new().expect("Failed to create default UrlFetcher")
}
}
fn should_retry(err: &UrlFetcherError) -> cognee_utils::RetryDecision {
match err {
UrlFetcherError::HttpStatus(status, _) => {
if *status == 429 || *status >= 500 {
cognee_utils::RetryDecision::Retry
} else {
cognee_utils::RetryDecision::Abort
}
}
UrlFetcherError::Timeout(_) | UrlFetcherError::HttpError(_) => {
cognee_utils::RetryDecision::Retry
}
UrlFetcherError::RobotsDisallowed(_)
| UrlFetcherError::InvalidUrl(_)
| UrlFetcherError::ParseError(_)
| UrlFetcherError::IoError(_) => cognee_utils::RetryDecision::Abort,
}
}