use std::collections::HashSet;
use url::Url;
use crate::wac::AclAuthorization;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Origin(String);
impl Origin {
pub fn parse(s: &str) -> Option<Self> {
let trimmed = s.trim();
if trimmed.is_empty() {
return None;
}
let url = Url::parse(trimmed).ok()?;
Self::from_url(&url)
}
pub fn from_url(url: &Url) -> Option<Self> {
let scheme = url.scheme().to_ascii_lowercase();
let host = url.host_str()?.to_ascii_lowercase();
let port = url.port(); let serialised = match port {
None => format!("{scheme}://{host}"),
Some(p) => format!("{scheme}://{host}:{p}"),
};
Some(Origin(serialised))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for Origin {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OriginPattern {
Exact(Origin),
Wildcard { scheme: String, suffix: String },
Any,
}
impl OriginPattern {
pub fn parse(s: &str) -> Option<Self> {
let trimmed = s.trim();
if trimmed.is_empty() {
return None;
}
if trimmed == "*" {
return Some(OriginPattern::Any);
}
if let Some(rest) = trimmed.strip_prefix("https://*.") {
return Self::parse_wildcard("https", rest);
}
if let Some(rest) = trimmed.strip_prefix("http://*.") {
return Self::parse_wildcard("http", rest);
}
if trimmed.ends_with('/') {
return None;
}
let origin = Origin::parse(trimmed)?;
let lc = trimmed.to_ascii_lowercase();
if origin.as_str() != lc {
return None;
}
Some(OriginPattern::Exact(origin))
}
fn parse_wildcard(scheme: &str, suffix_part: &str) -> Option<Self> {
if suffix_part.is_empty() {
return None;
}
if suffix_part.contains(char::is_whitespace) || suffix_part.contains('*') {
return None;
}
if suffix_part.contains('/') {
return None;
}
Some(OriginPattern::Wildcard {
scheme: scheme.to_string(),
suffix: suffix_part.to_ascii_lowercase(),
})
}
pub fn matches(&self, origin: &Origin) -> bool {
match self {
OriginPattern::Any => true,
OriginPattern::Exact(expected) => expected == origin,
OriginPattern::Wildcard { scheme, suffix } => {
let serialised = origin.as_str();
let (req_scheme, req_rest) = match serialised.split_once("://") {
Some(v) => v,
None => return false,
};
if req_scheme != scheme {
return false;
}
let req_host = match req_rest.split_once(':') {
Some((h, _)) => h,
None => req_rest,
};
let pattern_suffix = match suffix.split_once(':') {
Some((h, _)) => h,
None => suffix.as_str(),
};
let needle = format!(".{pattern_suffix}");
req_host.ends_with(&needle) && req_host.len() > needle.len()
}
}
}
}
pub fn extract_origin_patterns(auth: &AclAuthorization) -> Vec<OriginPattern> {
let mut out = Vec::new();
if let Some(ids) = &auth.origin {
for id in iter_ids(ids) {
if let Some(p) = OriginPattern::parse(id) {
out.push(p);
}
}
}
out
}
fn iter_ids(ids: &crate::wac::IdOrIds) -> Vec<&str> {
match ids {
crate::wac::IdOrIds::Single(r) => vec![r.id.as_str()],
crate::wac::IdOrIds::Multiple(v) => v.iter().map(|r| r.id.as_str()).collect(),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OriginDecision {
NoPolicySet,
Permitted,
RejectedMismatch,
RejectedNoOrigin,
}
pub fn check_origin(
acl: &crate::wac::AclDocument,
request_origin: Option<&Origin>,
) -> OriginDecision {
let graph = match acl.graph.as_ref() {
Some(g) => g,
None => return OriginDecision::NoPolicySet,
};
let mut any_patterns = false;
let mut matched = false;
let mut seen: HashSet<String> = HashSet::new();
for auth in graph {
for pattern in extract_origin_patterns(auth) {
let key = pattern_key(&pattern);
if !seen.insert(key) {
continue;
}
any_patterns = true;
if let Some(req) = request_origin {
if pattern.matches(req) {
matched = true;
}
}
}
}
if !any_patterns {
OriginDecision::NoPolicySet
} else if matched {
OriginDecision::Permitted
} else if request_origin.is_none() {
OriginDecision::RejectedNoOrigin
} else {
OriginDecision::RejectedMismatch
}
}
fn pattern_key(p: &OriginPattern) -> String {
match p {
OriginPattern::Any => "*".to_string(),
OriginPattern::Exact(o) => format!("exact:{}", o.as_str()),
OriginPattern::Wildcard { scheme, suffix } => {
format!("wild:{scheme}://*.{suffix}")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn origin_parse_strips_default_https_port() {
let o = Origin::parse("https://example.org:443/foo").unwrap();
assert_eq!(o.as_str(), "https://example.org");
}
#[test]
fn origin_parse_preserves_non_default_port() {
let o = Origin::parse("https://example.org:8443/foo").unwrap();
assert_eq!(o.as_str(), "https://example.org:8443");
}
#[test]
fn origin_parse_lowercases_host_and_scheme() {
let o = Origin::parse("HTTPS://Example.ORG").unwrap();
assert_eq!(o.as_str(), "https://example.org");
}
#[test]
fn origin_parse_rejects_empty_and_opaque() {
assert!(Origin::parse("").is_none());
assert!(Origin::parse("not a url").is_none());
assert!(Origin::parse("data:text/plain,hello").is_none());
}
#[test]
fn pattern_any_matches_everything() {
let any = OriginPattern::parse("*").unwrap();
assert!(any.matches(&Origin::parse("https://example.org").unwrap()));
assert!(any.matches(&Origin::parse("http://foo.test:9000").unwrap()));
}
#[test]
fn pattern_exact_requires_canonical_input() {
assert!(OriginPattern::parse("https://example.org/").is_none());
let p = OriginPattern::parse("https://example.org").unwrap();
match p {
OriginPattern::Exact(o) => assert_eq!(o.as_str(), "https://example.org"),
_ => panic!("expected Exact"),
}
}
#[test]
fn pattern_wildcard_rejects_bare_apex() {
let p = OriginPattern::parse("https://*.example.org").unwrap();
assert!(!p.matches(&Origin::parse("https://example.org").unwrap()));
assert!(p.matches(&Origin::parse("https://app.example.org").unwrap()));
assert!(p.matches(&Origin::parse("https://a.b.example.org:8443").unwrap()));
}
}