use std::io::Read;
use async_trait::async_trait;
use http::{Request, Response, Uri};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::{
http::{AsyncHttpResolver, HttpResolverError, SyncHttpResolver},
Result,
};
#[derive(Debug)]
pub struct RestrictedResolver<T> {
inner: T,
allowed_hosts: Option<Vec<HostPattern>>,
}
impl<T> RestrictedResolver<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
allowed_hosts: None,
}
}
#[allow(dead_code)] pub fn with_allowed_hosts(inner: T, allowed_hosts: Vec<HostPattern>) -> Self {
Self {
inner,
allowed_hosts: Some(allowed_hosts),
}
}
pub fn set_allowed_hosts(&mut self, allowed_hosts: Option<Vec<HostPattern>>) {
self.allowed_hosts = allowed_hosts;
}
#[allow(dead_code)] pub fn allowed_hosts(&self) -> Option<&[HostPattern]> {
self.allowed_hosts.as_deref()
}
fn is_uri_allowed(&self, uri: &Uri) -> bool {
self.allowed_hosts
.as_ref()
.map(|hosts| is_uri_allowed(hosts, uri))
.unwrap_or(true) }
}
impl<T: SyncHttpResolver> SyncHttpResolver for RestrictedResolver<T> {
fn http_resolve(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Box<dyn Read>>, HttpResolverError> {
if !self.is_uri_allowed(request.uri()) {
return Err(HttpResolverError::UriDisallowed {
uri: request.uri().to_string(),
});
}
self.inner.http_resolve(request)
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl<T: AsyncHttpResolver + Sync> AsyncHttpResolver for RestrictedResolver<T> {
async fn http_resolve_async(
&self,
request: Request<Vec<u8>>,
) -> Result<Response<Box<dyn Read>>, HttpResolverError> {
if !self.is_uri_allowed(request.uri()) {
return Err(HttpResolverError::UriDisallowed {
uri: request.uri().to_string(),
});
}
self.inner.http_resolve_async(request).await
}
}
#[cfg_attr(
feature = "json_schema",
derive(schemars::JsonSchema),
schemars(with = "String")
)]
#[derive(Debug, Clone, PartialEq)]
pub struct HostPattern {
pattern: String,
scheme: Option<String>,
host: Option<String>,
port: Option<String>,
}
impl HostPattern {
pub fn new(pattern: &str) -> Self {
let pattern = pattern.to_ascii_lowercase();
let (scheme, rest): (Option<String>, &str) =
if let Some(host) = pattern.strip_prefix("https://") {
(Some("https".to_owned()), host)
} else if let Some(host) = pattern.strip_prefix("http://") {
(Some("http".to_owned()), host)
} else {
(None, &pattern)
};
let (host, port) = if let Some((host, port)) = rest.rsplit_once(':') {
(host, Some(port.to_owned()))
} else {
(rest, None)
};
Self {
host: if host.is_empty() {
None
} else {
Some(host.to_owned())
},
pattern,
scheme,
port,
}
}
pub fn matches(&self, uri: &Uri) -> bool {
if let Some(allowed_host_pattern) = &self.host {
if let Some(host) = uri.host() {
let is_host_allowed = if let Some(suffix) = allowed_host_pattern.strip_prefix("*.")
{
let host = host.to_ascii_lowercase();
if host.len() <= suffix.len() || !host.ends_with(&suffix) {
false
} else {
host.as_bytes()[host.len() - suffix.len() - 1] == b'.'
}
} else {
allowed_host_pattern.eq_ignore_ascii_case(host)
};
let is_port_allowed =
self.port.as_deref() == uri.port().as_ref().map(|port| port.as_str());
if is_host_allowed && is_port_allowed {
if let Some(allowed_scheme) = &self.scheme {
if let Some(scheme) = uri.scheme() {
return scheme.as_str() == allowed_scheme;
}
} else {
return true;
}
}
}
} else if let Some(allowed_scheme) = &self.scheme {
if let Some(scheme) = uri.scheme() {
return scheme.as_str() == allowed_scheme;
}
}
false
}
}
impl From<&str> for HostPattern {
fn from(pattern: &str) -> Self {
Self::new(pattern)
}
}
impl Serialize for HostPattern {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.pattern.to_string())
}
}
impl<'de> Deserialize<'de> for HostPattern {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(HostPattern::new(&String::deserialize(deserializer)?))
}
}
pub(crate) fn is_uri_allowed(patterns: &[HostPattern], uri: &Uri) -> bool {
for pattern in patterns {
if pattern.matches(uri) {
return true;
}
}
false
}
#[cfg(test)]
mod test {
#![allow(clippy::panic, clippy::unwrap_used)]
use super::*;
struct NoopHttpResolver;
impl SyncHttpResolver for NoopHttpResolver {
fn http_resolve(
&self,
_request: Request<Vec<u8>>,
) -> Result<Response<Box<dyn Read>>, HttpResolverError> {
Ok(Response::new(Box::new(std::io::empty()) as Box<dyn Read>))
}
}
fn assert_allowed_uri(resolver: &impl SyncHttpResolver, uri: &'static str) {
let result = resolver.http_resolve(
Request::get(Uri::from_static(uri))
.body(Vec::new())
.unwrap(),
);
assert!(matches!(result, Ok(..)));
}
fn assert_disallowed_uri(resolver: &impl SyncHttpResolver, uri: &'static str) {
let result = resolver.http_resolve(
Request::get(Uri::from_static(uri))
.body(Vec::new())
.unwrap(),
);
assert!(matches!(
result,
Err(HttpResolverError::UriDisallowed { .. })
));
}
#[test]
fn allowed_http_request() {
let allowed_list = vec![
"*.prefix.contentauthenticity.org".into(),
"test.contentauthenticity.org".into(),
"fakecontentauthenticity.org".into(),
"https://*.contentauthenticity.org".into(),
"https://test.contentauthenticity.org".into(),
];
let restricted_resolver =
RestrictedResolver::with_allowed_hosts(NoopHttpResolver, allowed_list);
assert_allowed_uri(&restricted_resolver, "fakecontentauthenticity.org");
assert_allowed_uri(&restricted_resolver, "test.prefix.contentauthenticity.org");
assert_allowed_uri(&restricted_resolver, "https://test.contentauthenticity.org");
assert_allowed_uri(
&restricted_resolver,
"https://test2.contentauthenticity.org",
);
assert_disallowed_uri(&restricted_resolver, "test.test.contentauthenticity.org");
assert_disallowed_uri(
&restricted_resolver,
"https://test.prefix.fakecontentauthenticity.org",
);
assert_disallowed_uri(
&restricted_resolver,
"https://test.fakecontentauthenticity.org",
);
assert_disallowed_uri(&restricted_resolver, "https://contentauthenticity.org");
}
#[test]
fn allowed_none_http_request() {
let allowed_list = vec![];
let restricted_resolver =
RestrictedResolver::with_allowed_hosts(NoopHttpResolver, allowed_list);
assert_disallowed_uri(
&restricted_resolver,
"test.test.fakecontentauthenticity.org",
);
assert_disallowed_uri(
&restricted_resolver,
"https://test.prefix.fakecontentauthenticity.org",
);
assert_disallowed_uri(
&restricted_resolver,
"https://test.fakecontentauthenticity.org",
);
assert_disallowed_uri(&restricted_resolver, "https://contentauthenticity.org");
}
#[test]
fn wildcard_pattern() {
let pattern = HostPattern::new("*.contentauthenticity.org");
let uri = Uri::from_static("test.contentauthenticity.org");
assert!(pattern.matches(&uri));
let uri = Uri::from_static("contentauthenticity.org");
assert!(!pattern.matches(&uri));
let uri = Uri::from_static("fakecontentauthenticity.org");
assert!(!pattern.matches(&uri));
}
#[test]
fn wildcard_pattern_with_scheme() {
let pattern = HostPattern::new("https://*.contentauthenticity.org");
let uri = Uri::from_static("test.contentauthenticity.org");
assert!(!pattern.matches(&uri));
let uri = Uri::from_static("contentauthenticity.org");
assert!(!pattern.matches(&uri));
let uri = Uri::from_static("fakecontentauthenticity.org");
assert!(!pattern.matches(&uri));
let uri = Uri::from_static("https://test.contentauthenticity.org");
assert!(pattern.matches(&uri));
let uri = Uri::from_static("https://contentauthenticity.org");
assert!(!pattern.matches(&uri));
let uri = Uri::from_static("https://fakecontentauthenticity.org");
assert!(!pattern.matches(&uri));
let uri = Uri::from_static("http://test.contentauthenticity.org");
assert!(!pattern.matches(&uri));
}
#[test]
fn case_insensitive_pattern() {
let pattern = HostPattern::new("*.contentAuthenticity.org");
let uri = Uri::from_static("tEst.conTentauthenticity.orG");
assert!(pattern.matches(&uri));
}
#[test]
fn exact_pattern() {
let pattern = HostPattern::new("contentauthenticity.org");
let uri = Uri::from_static("contentauthenticity.org");
assert!(pattern.matches(&uri));
let uri = Uri::from_static("https://contentauthenticity.org");
assert!(pattern.matches(&uri));
let uri = Uri::from_static("http://contentauthenticity.org");
assert!(pattern.matches(&uri));
}
#[test]
fn exact_pattern_with_schema() {
let pattern = HostPattern::new("https://contentauthenticity.org");
let uri = Uri::from_static("https://contentauthenticity.org");
assert!(pattern.matches(&uri));
let uri = Uri::from_static("http://contentauthenticity.org");
assert!(!pattern.matches(&uri));
let uri = Uri::from_static("contentauthenticity.org");
assert!(!pattern.matches(&uri));
}
#[test]
fn exact_pattern_ip_address() {
let pattern = HostPattern::new("192.0.2.1");
let uri = Uri::from_static("192.0.2.1");
assert!(pattern.matches(&uri));
let uri = Uri::from_static("192.0.2.1.1");
assert!(!pattern.matches(&uri));
}
#[test]
fn exact_pattern_ip_address_with_port() {
let pattern = HostPattern::new("192.0.2.1:443");
let uri = Uri::from_static("192.0.2.1:443");
assert!(pattern.matches(&uri));
let uri = Uri::from_static("192.0.2.1");
assert!(!pattern.matches(&uri));
}
#[test]
fn exact_pattern_hostname_with_port() {
let pattern = HostPattern::new("contentauthenticity.org:8080");
let uri = Uri::from_static("contentauthenticity.org:8080");
assert!(pattern.matches(&uri));
let uri = Uri::from_static("contentauthenticity.org");
assert!(!pattern.matches(&uri));
}
#[test]
fn scheme_only_pattern() {
let pattern = HostPattern::new("https://");
let uri = Uri::from_static("https://contentauthenticity.org");
assert!(pattern.matches(&uri));
let uri = Uri::from_static("http://contentauthenticity.org");
assert!(!pattern.matches(&uri));
let uri = Uri::from_static("contentauthenticity.org");
assert!(!pattern.matches(&uri));
}
#[test]
fn invalid_pattern() {
let pattern = HostPattern::new("https:// ");
let uri = Uri::from_static("https://contentauthenticity.org");
assert!(!pattern.matches(&uri));
}
#[test]
fn test_restricted_generic_resolver() {
use crate::http::{HttpResolverError, SyncGenericResolver, SyncHttpResolver};
let inner = SyncGenericResolver::new();
let mut resolver = RestrictedResolver::new(inner);
resolver.set_allowed_hosts(Some(vec!["127.0.0.1".into()]));
let request = http::Request::get("http://127.0.0.1/test")
.body(vec![])
.unwrap();
let result = resolver.http_resolve(request);
assert!(!matches!(
result,
Err(HttpResolverError::UriDisallowed { .. })
));
let request = http::Request::get("http://example.com/test")
.body(vec![])
.unwrap();
let result = resolver.http_resolve(request);
assert!(matches!(
result,
Err(HttpResolverError::UriDisallowed { .. })
));
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn test_restricted_async_generic_resolver() {
use crate::http::{AsyncGenericResolver, AsyncHttpResolver, HttpResolverError};
let inner = AsyncGenericResolver::new();
let mut resolver = RestrictedResolver::new(inner);
resolver.set_allowed_hosts(Some(vec!["127.0.0.1".into()]));
let request = http::Request::get("http://127.0.0.1/test")
.body(vec![])
.unwrap();
let result = resolver.http_resolve_async(request).await;
assert!(!matches!(
result,
Err(HttpResolverError::UriDisallowed { .. })
));
let request = http::Request::get("http://example.com/test")
.body(vec![])
.unwrap();
let result = resolver.http_resolve_async(request).await;
assert!(matches!(
result,
Err(HttpResolverError::UriDisallowed { .. })
));
}
}