#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WsUpgradeRefusal {
NotTls,
OriginMissing,
OriginRejected,
}
pub fn evaluate_ws_upgrade(
is_tls_edge: bool,
origin: Option<&str>,
allowlist: &[String],
) -> Result<(), WsUpgradeRefusal> {
if !is_tls_edge {
return Err(WsUpgradeRefusal::NotTls);
}
match origin {
None => Err(WsUpgradeRefusal::OriginMissing),
Some(o) if o.bytes().any(|b| b == b'\r' || b == b'\n') => {
Err(WsUpgradeRefusal::OriginRejected)
}
Some(o) if allowlist.iter().any(|allowed| allowed == o) => Ok(()),
Some(_) => Err(WsUpgradeRefusal::OriginRejected),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn allowlist() -> Vec<String> {
vec![
"https://app.example.com".to_string(),
"https://admin.example.com".to_string(),
]
}
#[test]
fn allowed_origin_over_tls_is_accepted() {
assert_eq!(
evaluate_ws_upgrade(true, Some("https://app.example.com"), &allowlist()),
Ok(())
);
}
#[test]
fn non_tls_edge_is_refused_before_origin_is_consulted() {
assert_eq!(
evaluate_ws_upgrade(false, Some("https://app.example.com"), &allowlist()),
Err(WsUpgradeRefusal::NotTls)
);
}
#[test]
fn missing_origin_is_refused() {
assert_eq!(
evaluate_ws_upgrade(true, None, &allowlist()),
Err(WsUpgradeRefusal::OriginMissing)
);
}
#[test]
fn empty_allowlist_denies_every_origin() {
assert_eq!(
evaluate_ws_upgrade(true, Some("https://app.example.com"), &[]),
Err(WsUpgradeRefusal::OriginRejected)
);
}
#[test]
fn origin_match_is_exact_not_prefix_or_suffix() {
assert_eq!(
evaluate_ws_upgrade(true, Some("https://app.example.com.evil.com"), &allowlist()),
Err(WsUpgradeRefusal::OriginRejected)
);
assert_eq!(
evaluate_ws_upgrade(true, Some("https://app.example.co"), &allowlist()),
Err(WsUpgradeRefusal::OriginRejected)
);
}
#[test]
fn crlf_smuggled_origin_is_rejected() {
for smuggled in [
"https://app.example.com\r\nX-Injected: 1",
"https://app.example.com\n",
"https://app.example.com\r",
] {
assert_eq!(
evaluate_ws_upgrade(true, Some(smuggled), &allowlist()),
Err(WsUpgradeRefusal::OriginRejected)
);
}
}
}