#![allow(clippy::module_name_repetitions)]
use std::time::Duration;
use reqwest::{IntoUrl, Method};
use url::Url;
use crate::error::{Error, Result};
pub const DEFAULT_ALLOWLIST_RAW: &str = "http://127.0.0.1:*,http://localhost:*";
pub const ALLOWLIST_ENV_VAR: &str = "REPOSIX_ALLOWED_ORIGINS";
#[derive(Debug, Clone)]
pub struct ClientOpts {
pub total_timeout: Duration,
pub user_agent: Option<String>,
}
impl Default for ClientOpts {
fn default() -> Self {
Self {
total_timeout: Duration::from_secs(5),
user_agent: Some(concat!("reposix/", env!("CARGO_PKG_VERSION")).to_owned()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct OriginGlob {
scheme: String,
host: String,
port: Option<u16>,
}
impl OriginGlob {
pub(crate) fn matches(&self, url: &Url) -> bool {
if url.scheme() != self.scheme {
return false;
}
let Some(url_host) = url.host_str() else {
return false;
};
if url_host != self.host {
return false;
}
match self.port {
None => true,
Some(expected) => url.port_or_known_default() == Some(expected),
}
}
}
pub(crate) fn parse_allowlist(raw: &str) -> Result<Vec<OriginGlob>> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return parse_allowlist_inner(DEFAULT_ALLOWLIST_RAW);
}
parse_allowlist_inner(trimmed)
}
fn parse_allowlist_inner(raw: &str) -> Result<Vec<OriginGlob>> {
let mut out = Vec::new();
for (idx, entry) in raw.split(',').enumerate() {
let entry = entry.trim();
if entry.is_empty() {
return Err(Error::Other(format!(
"REPOSIX_ALLOWED_ORIGINS: entry {idx} is empty"
)));
}
out.push(parse_one(entry)?);
}
Ok(out)
}
fn parse_one(entry: &str) -> Result<OriginGlob> {
let (url_src, wildcard_port) = if let Some(stripped) = entry.strip_suffix(":*") {
(stripped.to_owned(), true)
} else {
(entry.to_owned(), false)
};
let mut to_parse = url_src;
if !to_parse.ends_with('/') {
to_parse.push('/');
}
let parsed = Url::parse(&to_parse).map_err(|e| {
Error::Other(format!(
"REPOSIX_ALLOWED_ORIGINS: entry {entry:?} failed to parse: {e}"
))
})?;
let scheme = parsed.scheme().to_owned();
if scheme != "http" && scheme != "https" {
return Err(Error::Other(format!(
"REPOSIX_ALLOWED_ORIGINS: entry {entry:?} scheme must be http or https"
)));
}
let Some(host) = parsed.host_str() else {
return Err(Error::Other(format!(
"REPOSIX_ALLOWED_ORIGINS: entry {entry:?} has empty host"
)));
};
if host.is_empty() {
return Err(Error::Other(format!(
"REPOSIX_ALLOWED_ORIGINS: entry {entry:?} has empty host"
)));
}
let port = if wildcard_port {
None
} else {
parsed.port_or_known_default()
};
Ok(OriginGlob {
scheme,
host: host.to_owned(),
port,
})
}
pub(crate) fn load_allowlist_from_env() -> Result<Vec<OriginGlob>> {
match std::env::var(ALLOWLIST_ENV_VAR) {
Ok(v) => parse_allowlist(&v),
Err(_) => parse_allowlist(""),
}
}
#[derive(Debug, Clone)]
pub struct HttpClient {
inner: reqwest::Client,
}
impl HttpClient {
pub async fn request<U: IntoUrl>(&self, method: Method, url: U) -> Result<reqwest::Response> {
self.request_with_headers(method, url, &[]).await
}
pub async fn request_with_headers<U: IntoUrl>(
&self,
method: Method,
url: U,
headers: &[(&str, &str)],
) -> Result<reqwest::Response> {
self.request_with_headers_and_body(method, url, headers, None::<&[u8]>)
.await
}
pub async fn request_with_headers_and_body<U, B>(
&self,
method: Method,
url: U,
headers: &[(&str, &str)],
body: Option<B>,
) -> Result<reqwest::Response>
where
U: IntoUrl,
B: Into<reqwest::Body>,
{
let parsed = url
.into_url()
.map_err(|e| Error::InvalidOrigin(format!("{e}")))?;
let allowlist = load_allowlist_from_env()?;
if !allowlist.iter().any(|g| g.matches(&parsed)) {
return Err(Error::InvalidOrigin(parsed.to_string()));
}
let mut builder = self.inner.request(method, parsed);
for (k, v) in headers {
builder = builder.header(*k, *v);
}
if let Some(body) = body {
builder = builder.body(body);
}
let resp = builder.send().await?;
Ok(resp)
}
pub async fn get<U: IntoUrl>(&self, url: U) -> Result<reqwest::Response> {
self.request(Method::GET, url).await
}
pub async fn post<U: IntoUrl>(&self, url: U) -> Result<reqwest::Response> {
self.request(Method::POST, url).await
}
pub async fn patch<U: IntoUrl>(&self, url: U) -> Result<reqwest::Response> {
self.request(Method::PATCH, url).await
}
pub async fn delete<U: IntoUrl>(&self, url: U) -> Result<reqwest::Response> {
self.request(Method::DELETE, url).await
}
}
pub fn client(opts: ClientOpts) -> Result<HttpClient> {
let _ = load_allowlist_from_env()?;
#[allow(clippy::disallowed_methods)]
let mut builder = reqwest::ClientBuilder::new();
builder = builder
.redirect(reqwest::redirect::Policy::none())
.timeout(opts.total_timeout);
if let Some(ua) = opts.user_agent {
builder = builder.user_agent(ua);
}
let inner = builder.build()?;
Ok(HttpClient { inner })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn client_opts_default_is_5s_timeout() {
let opts = ClientOpts::default();
assert_eq!(opts.total_timeout, Duration::from_secs(5));
assert!(opts.user_agent.as_deref().unwrap().starts_with("reposix/"));
}
#[test]
fn parse_allowlist_default_has_two_entries() {
let entries = parse_allowlist("http://127.0.0.1:*,http://localhost:*").unwrap();
assert_eq!(entries.len(), 2);
}
#[test]
fn parse_allowlist_empty_input_returns_default() {
let entries = parse_allowlist("").unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.iter().any(|g| g.host == "127.0.0.1"));
assert!(entries.iter().any(|g| g.host == "localhost"));
}
#[test]
fn parse_allowlist_whitespace_only_returns_default() {
let entries = parse_allowlist(" \t ").unwrap();
assert_eq!(entries.len(), 2);
}
#[test]
fn parse_allowlist_bad_input_errors() {
let err = parse_allowlist("not a url").unwrap_err();
assert!(matches!(err, Error::Other(_)), "got {err:?}");
}
#[test]
fn parse_allowlist_bad_scheme_errors() {
assert!(matches!(
parse_allowlist("ftp://127.0.0.1:*"),
Err(Error::Other(_))
));
}
#[test]
fn parse_allowlist_empty_host_errors() {
assert!(matches!(
parse_allowlist("http://:80"),
Err(Error::Other(_))
));
}
#[test]
fn parse_allowlist_bad_port_errors() {
assert!(matches!(
parse_allowlist("http://127.0.0.1:notaport"),
Err(Error::Other(_))
));
}
#[test]
fn origin_glob_matches_loopback_any_port() {
let glob = &parse_allowlist("http://127.0.0.1:*").unwrap()[0];
let url = Url::parse("http://127.0.0.1:7878").unwrap();
assert!(glob.matches(&url));
}
#[test]
fn origin_glob_rejects_https_when_http_configured() {
let glob = &parse_allowlist("http://127.0.0.1:*").unwrap()[0];
let url = Url::parse("https://127.0.0.1:7878").unwrap();
assert!(!glob.matches(&url));
}
#[test]
fn origin_glob_rejects_non_loopback_host() {
let glob = &parse_allowlist("http://127.0.0.1:*").unwrap()[0];
let url = Url::parse("http://evil.example:80").unwrap();
assert!(!glob.matches(&url));
}
#[test]
fn origin_glob_matches_exact_port() {
let glob = &parse_allowlist("http://127.0.0.1:80").unwrap()[0];
let url = Url::parse("http://127.0.0.1:80").unwrap();
assert!(glob.matches(&url));
}
#[test]
fn origin_glob_rejects_wrong_exact_port() {
let glob = &parse_allowlist("http://127.0.0.1:80").unwrap()[0];
let url = Url::parse("http://127.0.0.1:81").unwrap();
assert!(!glob.matches(&url));
}
#[test]
fn parse_allowlist_accepts_ipv6_with_explicit_port() {
let entries = parse_allowlist("http://[::1]:7777").unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].host, "[::1]");
assert_eq!(entries[0].port, Some(7777));
}
#[test]
fn parse_allowlist_accepts_ipv6_with_wildcard_port() {
let entries = parse_allowlist("http://[::1]:*").unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].host, "[::1]");
assert_eq!(entries[0].port, None);
}
#[test]
fn origin_glob_matches_ipv6_loopback_any_port() {
let glob = &parse_allowlist("http://[::1]:*").unwrap()[0];
let url = Url::parse("http://[::1]:7777/").unwrap();
assert!(glob.matches(&url));
}
#[test]
fn origin_glob_matches_ipv6_loopback_exact_port() {
let glob = &parse_allowlist("http://[::1]:7777").unwrap()[0];
let url = Url::parse("http://[::1]:7777/").unwrap();
assert!(glob.matches(&url));
}
#[test]
fn origin_glob_ipv6_rejects_wrong_port() {
let glob = &parse_allowlist("http://[::1]:7777").unwrap()[0];
let url = Url::parse("http://[::1]:7778/").unwrap();
assert!(!glob.matches(&url));
}
#[test]
fn parse_allowlist_localhost_wildcard_still_parses() {
let entries = parse_allowlist("https://localhost:*").unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].scheme, "https");
assert_eq!(entries[0].host, "localhost");
assert_eq!(entries[0].port, None);
}
}