use std::net::IpAddr;
use tokio::net::lookup_host;
use tracing::debug;
use url::Url;
use super::{
FetcherError,
canonical::extract_canonical_url,
charset::{Detected, decode_to_utf8},
dns::SSRF_LEVEL,
ssrf::{self, SsrfLevel},
};
#[derive(Debug, Clone)]
pub struct FetchedPage {
pub final_url: Url,
pub canonical_url: Url,
pub status: u16,
pub content_type: Option<String>,
pub body: String,
pub charset: Detected,
pub link_header: Option<String>,
pub etag: Option<String>,
pub last_modified: Option<String>,
pub cache_control: Option<String>,
pub expires: Option<String>,
pub retry_after: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct ConditionalGet {
pub if_none_match: Option<String>,
pub if_modified_since: Option<String>,
}
pub async fn fetch_url(
client: &reqwest::Client,
url: &Url,
level: SsrfLevel,
project_root: Option<&std::path::Path>,
har_recorder: Option<&std::sync::Arc<super::har::HarRecorder>>,
) -> Result<FetchedPage, FetcherError> {
fetch_url_conditional(
client,
url,
level,
project_root,
har_recorder,
&ConditionalGet::default(),
)
.await
}
pub async fn fetch_url_conditional(
client: &reqwest::Client,
url: &Url,
level: SsrfLevel,
project_root: Option<&std::path::Path>,
har_recorder: Option<&std::sync::Arc<super::har::HarRecorder>>,
cond: &ConditionalGet,
) -> Result<FetchedPage, FetcherError> {
let start = std::time::Instant::now();
ssrf::validate_url_with_project_root(url, level, project_root)?;
let host = url
.host_str()
.ok_or(FetcherError::Ssrf(ssrf::SsrfError::NoHost))?;
let port = url.port_or_known_default().unwrap_or(0);
let addrs = resolve_host(host, port).await?;
ssrf::validate_addresses(&addrs, level)?;
let mut req = client.get(url.clone());
let mut request_headers_pairs: Vec<(String, String)> = Vec::new();
if let Some(etag) = &cond.if_none_match {
req = req.header(reqwest::header::IF_NONE_MATCH, etag);
request_headers_pairs.push(("if-none-match".into(), etag.clone()));
}
if let Some(lm) = &cond.if_modified_since {
req = req.header(reqwest::header::IF_MODIFIED_SINCE, lm);
request_headers_pairs.push(("if-modified-since".into(), lm.clone()));
}
let response = SSRF_LEVEL.scope(level, req.send()).await?;
let status = response.status().as_u16();
let final_url = Url::parse(response.url().as_str())?;
let response_headers_pairs: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let link_header = response
.headers()
.get(reqwest::header::LINK)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let etag = response
.headers()
.get(reqwest::header::ETAG)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let last_modified = response
.headers()
.get(reqwest::header::LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let cache_control = response
.headers()
.get(reqwest::header::CACHE_CONTROL)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let expires = response
.headers()
.get(reqwest::header::EXPIRES)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let retry_after = response
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let bytes = response.bytes().await?;
if let Some(recorder) = har_recorder {
let ex = super::har::RecordedExchange {
url: final_url.to_string(),
method: "GET".to_string(),
request_headers: request_headers_pairs,
response_status: status,
response_headers: response_headers_pairs,
response_body: bytes.to_vec(),
duration: start.elapsed(),
};
if let Err(e) = recorder.record(ex).await {
tracing::warn!(target: "rover::fetcher", error = ?e, "failed to record har entry");
}
}
let (body, charset) = decode_to_utf8(content_type.as_deref(), &bytes);
if let Some(ref ct) = content_type
&& ct.to_ascii_lowercase().contains("charset=")
{
debug!(
target: "rover::fetcher::charset",
http_charset = ct.as_str(),
detected = %charset.encoding.name(),
"charset detection complete"
);
}
let canonical_url = extract_canonical_url(&body, &final_url, link_header.as_deref());
Ok(FetchedPage {
final_url,
canonical_url,
status,
content_type,
body,
charset,
link_header,
etag,
last_modified,
cache_control,
expires,
retry_after,
})
}
async fn resolve_host(host: &str, port: u16) -> Result<Vec<IpAddr>, FetcherError> {
let target = format!("{host}:{port}");
let iter = lookup_host(target.as_str())
.await
.map_err(|e| FetcherError::Dns {
host: host.to_string(),
source: e,
})?;
Ok(iter.map(|sa| sa.ip()).collect())
}