use std::collections::HashSet;
#[derive(Debug, Clone, Default)]
pub struct SshAllowlist {
patterns: HashSet<String>,
allowed_ports: HashSet<u16>,
allow_all: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SshMatch {
Allowed,
Blocked { reason: String },
}
impl SshAllowlist {
pub fn new() -> Self {
Self::default()
}
pub fn allow_all() -> Self {
Self {
patterns: HashSet::new(),
allowed_ports: HashSet::new(),
allow_all: true,
}
}
pub fn allow(mut self, pattern: impl Into<String>) -> Self {
self.patterns.insert(pattern.into());
self
}
pub fn allow_many(mut self, patterns: impl IntoIterator<Item = impl Into<String>>) -> Self {
for p in patterns {
self.patterns.insert(p.into());
}
self
}
pub fn allow_port(mut self, port: u16) -> Self {
self.allowed_ports.insert(port);
self
}
pub fn check(&self, host: &str, port: u16) -> SshMatch {
if self.allow_all {
return SshMatch::Allowed;
}
if !self.is_port_allowed(port) {
return SshMatch::Blocked {
reason: format!("SSH port {} is not allowed", port),
};
}
if self.patterns.is_empty() {
return SshMatch::Blocked {
reason: "no SSH hosts are allowed (empty allowlist)".to_string(),
};
}
for pattern in &self.patterns {
if Self::matches_pattern(host, pattern) {
return SshMatch::Allowed;
}
}
SshMatch::Blocked {
reason: format!("SSH host '{}' is not in allowlist", host),
}
}
pub fn is_allowed(&self, host: &str, port: u16) -> bool {
matches!(self.check(host, port), SshMatch::Allowed)
}
pub fn is_enabled(&self) -> bool {
self.allow_all || !self.patterns.is_empty()
}
fn is_port_allowed(&self, port: u16) -> bool {
if self.allow_all {
return true;
}
if self.allowed_ports.is_empty() {
return port == 22;
}
self.allowed_ports.contains(&port)
}
fn matches_pattern(host: &str, pattern: &str) -> bool {
if host == pattern {
return true;
}
if let Some(suffix) = pattern.strip_prefix("*.") {
if let Some(prefix) = host.strip_suffix(suffix) {
return prefix.ends_with('.');
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_allowlist_blocks_all() {
let allowlist = SshAllowlist::new();
assert!(matches!(
allowlist.check("example.com", 22),
SshMatch::Blocked { .. }
));
}
#[test]
fn test_allow_all() {
let allowlist = SshAllowlist::allow_all();
assert_eq!(allowlist.check("anything.com", 22), SshMatch::Allowed);
assert_eq!(allowlist.check("anything.com", 2222), SshMatch::Allowed);
}
#[test]
fn test_exact_host_match() {
let allowlist = SshAllowlist::new().allow("db.supabase.co");
assert_eq!(allowlist.check("db.supabase.co", 22), SshMatch::Allowed);
assert!(matches!(
allowlist.check("other.supabase.co", 22),
SshMatch::Blocked { .. }
));
assert!(matches!(
allowlist.check("evil.com", 22),
SshMatch::Blocked { .. }
));
}
#[test]
fn test_wildcard_pattern() {
let allowlist = SshAllowlist::new().allow("*.supabase.co");
assert_eq!(allowlist.check("db.supabase.co", 22), SshMatch::Allowed);
assert_eq!(
allowlist.check("staging.supabase.co", 22),
SshMatch::Allowed
);
assert_eq!(
allowlist.check("deep.nested.supabase.co", 22),
SshMatch::Allowed
);
assert!(matches!(
allowlist.check("supabase.co", 22),
SshMatch::Blocked { .. }
));
assert!(matches!(
allowlist.check("evil.com", 22),
SshMatch::Blocked { .. }
));
}
#[test]
fn test_port_restriction_default() {
let allowlist = SshAllowlist::new().allow("example.com");
assert_eq!(allowlist.check("example.com", 22), SshMatch::Allowed);
assert!(matches!(
allowlist.check("example.com", 2222),
SshMatch::Blocked { .. }
));
}
#[test]
fn test_port_restriction_custom() {
let allowlist = SshAllowlist::new()
.allow("example.com")
.allow_port(22)
.allow_port(2222);
assert_eq!(allowlist.check("example.com", 22), SshMatch::Allowed);
assert_eq!(allowlist.check("example.com", 2222), SshMatch::Allowed);
assert!(matches!(
allowlist.check("example.com", 3333),
SshMatch::Blocked { .. }
));
}
#[test]
fn test_ip_address() {
let allowlist = SshAllowlist::new().allow("192.168.1.100");
assert_eq!(allowlist.check("192.168.1.100", 22), SshMatch::Allowed);
assert!(matches!(
allowlist.check("192.168.1.101", 22),
SshMatch::Blocked { .. }
));
}
#[test]
fn test_multiple_patterns() {
let allowlist = SshAllowlist::new()
.allow("*.supabase.co")
.allow("bastion.example.com")
.allow("10.0.0.1");
assert_eq!(allowlist.check("db.supabase.co", 22), SshMatch::Allowed);
assert_eq!(
allowlist.check("bastion.example.com", 22),
SshMatch::Allowed
);
assert_eq!(allowlist.check("10.0.0.1", 22), SshMatch::Allowed);
assert!(matches!(
allowlist.check("evil.com", 22),
SshMatch::Blocked { .. }
));
}
#[test]
fn test_is_enabled() {
assert!(!SshAllowlist::new().is_enabled());
assert!(SshAllowlist::new().allow("x.com").is_enabled());
assert!(SshAllowlist::allow_all().is_enabled());
}
}