use actix_cors::Cors;
use actix_web::http::header;
use actix_web::middleware::DefaultHeaders;
use std::collections::HashSet;
use tracing::info;
use tracing::warn;
const DEFAULT_CSP: &str = concat!(
"default-src 'self'; ",
"base-uri 'self'; ",
"object-src 'none'; ",
"frame-ancestors 'none'; ",
"script-src 'self'; ",
"style-src 'self'; ",
"img-src 'self' data: https:; ",
"font-src 'self' data:; ",
"connect-src 'self' ws: wss:; ",
"form-action 'self';"
);
fn resolve_csp_header_value(override_value: Option<&str>) -> header::HeaderValue {
let csp = override_value.unwrap_or(DEFAULT_CSP);
match header::HeaderValue::from_str(csp) {
Ok(v) => v,
Err(e) => {
warn!(
"Invalid BAMBOO_CSP value ({}); falling back to DEFAULT_CSP",
e
);
header::HeaderValue::from_static(DEFAULT_CSP)
}
}
}
#[derive(Debug, Clone, Default)]
struct CorsAllowlist {
exact_origins: HashSet<String>,
hosts: Vec<HostPattern>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum HostPattern {
Exact(String),
Suffix(String), }
fn normalize_origin(origin: &str) -> Option<String> {
let url = url::Url::parse(origin).ok()?;
let scheme = url.scheme().to_ascii_lowercase();
let host = url.host()?;
let host_str = match host {
url::Host::Domain(d) => d.to_ascii_lowercase(),
url::Host::Ipv4(v4) => v4.to_string(),
url::Host::Ipv6(v6) => format!("[{v6}]"),
};
let port = url.port();
let default_port = match scheme.as_str() {
"http" => Some(80),
"https" => Some(443),
_ => None,
};
let port = match (port, default_port) {
(Some(p), Some(d)) if p == d => None,
(p, _) => p,
};
Some(match port {
Some(p) => format!("{scheme}://{host_str}:{p}"),
None => format!("{scheme}://{host_str}"),
})
}
fn parse_cors_allowlist(raw: &str) -> CorsAllowlist {
let mut allow = CorsAllowlist::default();
for item in raw.split(',') {
let token = item.trim();
if token.is_empty() {
continue;
}
if token.contains("://") {
match normalize_origin(token) {
Some(origin) => {
allow.exact_origins.insert(origin);
}
None => {
warn!(
"Invalid CORS origin entry '{}'; expected an origin like https://app.example.com",
token
);
}
}
continue;
}
let host = token.to_ascii_lowercase();
if let Some(rest) = host.strip_prefix("*.") {
if !rest.is_empty() {
allow.hosts.push(HostPattern::Suffix(format!(".{rest}")));
}
} else {
allow.hosts.push(HostPattern::Exact(host));
}
}
allow
}
fn parse_cors_allowlist_env() -> CorsAllowlist {
const ENV_KEY: &str = "BAMBOO_CORS_ALLOW_ORIGINS";
let raw = match std::env::var(ENV_KEY) {
Ok(v) => v,
Err(_) => return CorsAllowlist::default(),
};
let allow = parse_cors_allowlist(&raw);
if !allow.exact_origins.is_empty() || !allow.hosts.is_empty() {
info!(
"CORS allowlist enabled via BAMBOO_CORS_ALLOW_ORIGINS ({} exact origin(s), {} host pattern(s))",
allow.exact_origins.len(),
allow.hosts.len()
);
}
allow
}
fn is_allowed_by_allowlist(origin: &str, allow: &CorsAllowlist) -> bool {
if let Some(normalized) = normalize_origin(origin) {
if allow.exact_origins.contains(&normalized) {
return true;
}
}
if allow.exact_origins.contains(origin) {
return true;
}
let url = match url::Url::parse(origin) {
Ok(u) => u,
Err(_) => return false,
};
let host = match url.host_str() {
Some(h) => h.to_ascii_lowercase(),
None => return false,
};
for pat in &allow.hosts {
match pat {
HostPattern::Exact(h) => {
if &host == h {
return true;
}
}
HostPattern::Suffix(suffix) => {
if host.ends_with(suffix) {
return true;
}
}
}
}
false
}
fn is_local_dev_origin(o: &str) -> bool {
o.starts_with("http://localhost:")
|| o.starts_with("http://127.0.0.1:")
|| o.starts_with("https://localhost:")
|| o.starts_with("https://127.0.0.1:")
|| o.starts_with("http://[::1]:")
|| o.starts_with("https://[::1]:")
}
pub fn build_security_headers() -> DefaultHeaders {
let csp_override = std::env::var("BAMBOO_CSP").ok();
let csp_value = resolve_csp_header_value(csp_override.as_deref());
DefaultHeaders::new()
.add(("X-Frame-Options", "DENY"))
.add(("X-Content-Type-Options", "nosniff"))
.add(("X-XSS-Protection", "1; mode=block"))
.add(("Referrer-Policy", "strict-origin-when-cross-origin"))
.add((header::CONTENT_SECURITY_POLICY, csp_value))
}
pub fn build_cors(bind_addr: &str, port: u16) -> Cors {
let allowlist = parse_cors_allowlist_env();
let cors = if bind_addr == "127.0.0.1" || bind_addr == "localhost" || bind_addr == "::1" {
info!("CORS configured for development mode: allowing all origins and headers (localhost only)");
Cors::default()
.allow_any_origin()
.allow_any_method()
.allow_any_header()
.max_age(3600)
} else if bind_addr == "0.0.0.0" {
info!("CORS configured for 0.0.0.0 bind: allowing localhost/loopback origins (+ optional allowlist)");
Cors::default()
.allowed_origin_fn(move |origin, _req_head| {
let o = match origin.to_str() {
Ok(v) => v,
Err(_) => return false,
};
if is_allowed_by_allowlist(o, &allowlist) {
return true;
}
if is_local_dev_origin(o) {
return true;
}
if o == "tauri://localhost"
|| o == "https://tauri.localhost"
|| o == "http://tauri.localhost"
{
return true;
}
if o == format!("http://localhost:{port}")
|| o == format!("http://127.0.0.1:{port}")
{
return true;
}
false
})
.allow_any_method()
.allow_any_header()
.max_age(3600)
} else {
info!(
"CORS configured for custom bind address: {} (+ optional allowlist)",
bind_addr
);
let bind_host = bind_addr.to_ascii_lowercase();
let allowlist = allowlist.clone();
Cors::default()
.allowed_origin_fn(move |origin, _req_head| {
let o = match origin.to_str() {
Ok(v) => v,
Err(_) => return false,
};
if is_allowed_by_allowlist(o, &allowlist) {
return true;
}
let url = match url::Url::parse(o) {
Ok(u) => u,
Err(_) => return false,
};
let Some(host) = url.host_str() else {
return false;
};
host.eq_ignore_ascii_case(&bind_host)
})
.allow_any_method()
.allow_any_header()
.max_age(3600)
};
cors
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_csp_has_no_unsafe_keywords() {
assert!(!DEFAULT_CSP.contains("unsafe-inline"));
assert!(!DEFAULT_CSP.contains("unsafe-eval"));
}
#[test]
fn invalid_override_falls_back_to_default() {
let v = resolve_csp_header_value(Some("default-src 'self'\nscript-src 'self'"));
assert_eq!(v, header::HeaderValue::from_static(DEFAULT_CSP));
}
#[test]
fn cors_allowlist_parses_hosts_and_origins() {
let allow = parse_cors_allowlist(
"https://app.example.com/, app.example2.com, *.example.net , http://localhost:5173",
);
assert!(allow.exact_origins.contains("https://app.example.com"));
assert!(allow.exact_origins.contains("http://localhost:5173"));
assert!(allow
.hosts
.contains(&HostPattern::Exact("app.example2.com".to_string())));
assert!(allow
.hosts
.contains(&HostPattern::Suffix(".example.net".to_string())));
}
#[test]
fn cors_allowlist_matches_exact_and_wildcard_hosts() {
let mut allow = CorsAllowlist::default();
allow
.exact_origins
.insert("https://app.example.com".to_string());
allow
.hosts
.push(HostPattern::Exact("app2.example.com".to_string()));
allow
.hosts
.push(HostPattern::Suffix(".example.net".to_string()));
assert!(is_allowed_by_allowlist("https://app.example.com", &allow));
assert!(is_allowed_by_allowlist(
"https://app.example.com:443",
&allow
));
assert!(is_allowed_by_allowlist(
"http://app2.example.com:5173",
&allow
));
assert!(is_allowed_by_allowlist("https://x.example.net", &allow));
assert!(!is_allowed_by_allowlist("https://example.net", &allow));
assert!(!is_allowed_by_allowlist("https://evil.com", &allow));
}
}