use std::collections::HashMap;
use std::time::Duration;
use keyhog_core::{HttpMethod, VerificationResult};
use reqwest::Client;
use crate::ssrf::{is_private_ip_addr, is_private_url};
pub(crate) const PRIVATE_URL_ERROR: &str = "blocked: private URL";
pub(crate) const HTTPS_ONLY_ERROR: &str = "blocked: HTTPS only";
pub(crate) struct ResolvedTarget {
pub client: Client,
pub url: reqwest::Url,
}
pub(crate) enum RequestBuildResult {
Ready(reqwest::RequestBuilder),
Final {
result: VerificationResult,
metadata: HashMap<String, String>,
transient: bool,
},
}
pub(crate) struct RequestError {
pub result: VerificationResult,
pub transient: bool,
}
pub(crate) async fn resolved_client_for_url(
base_client: &Client,
raw_url: &str,
timeout: Duration,
allow_private_ips: bool,
allow_http: bool,
) -> std::result::Result<ResolvedTarget, VerificationResult> {
let url = match reqwest::Url::parse(raw_url) {
Ok(url) => url,
Err(e) => return Err(VerificationResult::Error(format!("invalid URL: {}", e))),
};
if !allow_private_ips && is_private_url(url.as_str()) {
return Err(VerificationResult::Error(PRIVATE_URL_ERROR.into()));
}
let mut pinned_addrs: Vec<std::net::SocketAddr> = Vec::new();
let host = url.host_str().unwrap_or_default().to_string();
let port = url.port_or_known_default().unwrap_or(443);
if !host.is_empty() {
let target = format!("{host}:{port}");
let addrs: std::result::Result<Vec<std::net::SocketAddr>, std::io::Error> =
tokio::net::lookup_host(target.as_str())
.await
.map(|iter| iter.collect());
match addrs {
Ok(addrs) if addrs.is_empty() => {
return Err(VerificationResult::Error(
"blocked: DNS returned no addresses".into(),
));
}
Ok(addrs) => {
if !allow_private_ips && addrs.iter().any(|addr| is_private_ip_addr(&addr.ip())) {
return Err(VerificationResult::Error(PRIVATE_URL_ERROR.into()));
}
pinned_addrs = addrs;
}
Err(_) => {
return Err(VerificationResult::Error(
"blocked: DNS resolution failed".into(),
));
}
}
}
if !allow_http && url.scheme() != "https" {
return Err(VerificationResult::Error(HTTPS_ONLY_ERROR.into()));
}
let client = if !pinned_addrs.is_empty() {
match Client::builder()
.timeout(timeout)
.resolve_to_addrs(&host, &pinned_addrs)
.build()
{
Ok(c) => c,
Err(_) => {
let _ = base_client;
base_client.clone()
}
}
} else {
base_client.clone()
};
Ok(ResolvedTarget { client, url })
}
pub(crate) async fn build_request_for_step(
client: &Client,
method: &HttpMethod,
auth: &keyhog_core::AuthSpec,
url: reqwest::Url,
credential: &str,
companions: &HashMap<String, String>,
timeout: Duration,
) -> RequestBuildResult {
let request = request_for_method(client, method, url).timeout(timeout);
crate::verify::auth::build_request_for_auth(
request, auth, credential, companions, timeout, client,
)
.await
}
fn request_for_method(
client: &Client,
method: &HttpMethod,
url: reqwest::Url,
) -> reqwest::RequestBuilder {
match method {
HttpMethod::Get => client.get(url),
HttpMethod::Post => client.post(url),
HttpMethod::Put => client.put(url),
HttpMethod::Delete => client.delete(url),
HttpMethod::Patch => client.patch(url),
HttpMethod::Head => client.head(url),
}
}
pub(crate) async fn execute_request(
request: reqwest::RequestBuilder,
) -> std::result::Result<reqwest::Response, RequestError> {
request.send().await.map_err(|e| RequestError {
result: if e.is_timeout() {
VerificationResult::Error("timeout".into())
} else if e.is_redirect() {
VerificationResult::Error("too many redirects".into())
} else if e.is_connect() {
VerificationResult::Error("connection failed".into())
} else {
VerificationResult::Error("request failed".into())
},
transient: e.is_timeout() || e.is_connect(),
})
}