use super::{ProxyAction, ProxyModule, server::SipServerRef};
use crate::call::{TransactionCookie, TrunkContext};
use crate::{
config::ProxyConfig,
proxy::routing::{TrunkConfig, source_addr_ip},
};
use anyhow::Result;
use async_trait::async_trait;
use rsipstack::sip::prelude::HeadersExt;
use rsipstack::transaction::transaction::Transaction;
use std::{
collections::{HashMap, HashSet},
net::IpAddr,
str::FromStr,
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
struct IpNetwork {
network: IpAddr,
prefix_len: u8,
}
impl IpNetwork {
fn new(network: IpAddr, prefix_len: u8) -> Self {
Self {
network,
prefix_len,
}
}
fn contains(&self, ip: &IpAddr) -> bool {
match (self.network, ip) {
(IpAddr::V4(network), IpAddr::V4(ip)) => {
let mask = if self.prefix_len == 0 {
0
} else {
u32::MAX << (32 - self.prefix_len)
};
(u32::from(network) & mask) == (u32::from(*ip) & mask)
}
(IpAddr::V6(network), IpAddr::V6(ip)) => {
let network_segments = network.segments();
let ip_segments = ip.segments();
let mut remaining_bits = self.prefix_len;
for i in 0..8 {
if remaining_bits == 0 {
return true;
}
let bits = std::cmp::min(remaining_bits, 16);
let mask = if bits == 16 {
0xFFFF
} else {
0xFFFF << (16 - bits)
};
if (network_segments[i] & mask) != (ip_segments[i] & mask) {
return false;
}
if remaining_bits >= 16 {
remaining_bits -= 16;
} else {
break;
}
}
true
}
_ => false,
}
}
}
#[derive(Debug, Clone)]
enum AclAction {
Allow,
Deny,
}
#[derive(Debug, Clone)]
struct AclRule {
action: AclAction,
network: Option<IpNetwork>,
}
impl AclRule {
fn new(rule: &str) -> Option<Self> {
let parts: Vec<&str> = rule.split_whitespace().collect();
if parts.len() < 2 {
return None;
}
let action = match parts[0].to_lowercase().as_str() {
"allow" => AclAction::Allow,
"deny" => AclAction::Deny,
_ => return None,
};
let network = if parts[1] == "all" {
None
} else {
match parse_network(parts[1]) {
Ok((network, prefix_len)) => Some(IpNetwork::new(network, prefix_len)),
Err(_) => return None,
}
};
Some(Self { action, network })
}
}
struct DosPerIpData {
recent: Vec<Instant>,
concurrent: usize,
blocked_until: Option<Instant>,
}
struct AclModuleInner {
config: Arc<ProxyConfig>,
server: Option<SipServerRef>,
ua_white_list: HashSet<String>,
ua_black_list: HashSet<String>,
fallback_rules: Vec<String>,
dos_data: Arc<RwLock<HashMap<IpAddr, DosPerIpData>>>,
}
#[derive(Clone)]
pub struct AclModule {
inner: Arc<AclModuleInner>,
}
impl AclModule {
pub fn create(server: SipServerRef, config: Arc<ProxyConfig>) -> Result<Box<dyn ProxyModule>> {
let module = AclModule::with_server(config, Some(server));
Ok(Box::new(module))
}
fn with_server(config: Arc<ProxyConfig>, server: Option<SipServerRef>) -> Self {
let fallback_rules = resolve_base_rules(&config);
let ua_white_list = config
.ua_white_list
.as_ref()
.map_or_else(HashSet::new, |list| list.iter().cloned().collect());
let ua_black_list = config
.ua_black_list
.as_ref()
.map_or_else(HashSet::new, |list| list.iter().cloned().collect());
Self {
inner: Arc::new(AclModuleInner {
config,
server,
ua_white_list,
ua_black_list,
fallback_rules,
dos_data: Arc::new(RwLock::new(HashMap::new())),
}),
}
}
pub fn new(config: Arc<ProxyConfig>) -> Self {
Self::with_server(config, None)
}
fn extract_ip(tx: &Transaction) -> Option<IpAddr> {
tx.connection
.as_ref()
.and_then(|conn| conn.get_remote_addr())
.and_then(source_addr_ip)
}
async fn dos_check_and_track(&self, ip: IpAddr) -> Result<()> {
let cfg = &self.inner.config;
let now = Instant::now();
let mut map = self.inner.dos_data.write().await;
let entry = map.entry(ip).or_insert_with(|| DosPerIpData {
recent: Vec::new(),
concurrent: 0,
blocked_until: None,
});
if let Some(until) = entry.blocked_until {
if now < until {
return Err(anyhow::anyhow!("IP blocked until {:?}", until));
}
entry.recent.clear();
entry.concurrent = 0;
entry.blocked_until = None;
}
let window = now - Duration::from_secs(1);
entry.recent.retain(|t| *t > window);
if entry.recent.len() >= cfg.dos_max_cps_per_ip as usize {
entry.blocked_until = Some(now + Duration::from_secs(cfg.dos_scan_block_duration_secs));
return Err(anyhow::anyhow!("CPS limit exceeded"));
}
if entry.concurrent >= cfg.dos_max_concurrent_per_ip as usize {
return Err(anyhow::anyhow!("Concurrent limit exceeded"));
}
if entry.recent.len() >= cfg.dos_scan_probe_threshold as usize {
entry.blocked_until = Some(now + Duration::from_secs(cfg.dos_scan_block_duration_secs));
return Err(anyhow::anyhow!("Scan detected"));
}
entry.recent.push(now);
entry.concurrent += 1;
Ok(())
}
async fn dos_release(&self, ip: IpAddr) {
if let Some(entry) = self.inner.dos_data.write().await.get_mut(&ip) {
entry.concurrent = entry.concurrent.saturating_sub(1);
}
}
fn check_uri_normalization(&self, tx: &Transaction) -> Result<()> {
let cfg = &self.inner.config;
if !cfg.uri_reject_malformed {
return Ok(());
}
let from = match tx.original.from_header() {
Ok(f) => f,
Err(e) => {
warn!("Normalization: missing/malformed From: {}", e);
return Err(anyhow::anyhow!("malformed From header"));
}
};
match from.uri() {
Ok(uri) => {
if uri.to_string().len() > cfg.uri_max_length {
warn!("Normalization: From URI too long");
return Err(anyhow::anyhow!("From URI too long"));
}
}
Err(e) => {
warn!("Normalization: malformed From URI: {}", e);
return Err(anyhow::anyhow!("malformed From URI"));
}
}
Ok(())
}
pub async fn is_from_trunk_context(&self, addr: &IpAddr) -> Option<TrunkContext> {
if let Some(server) = &self.inner.server {
let trunks = server.data_context.trunks_snapshot();
for (name, trunk) in trunks.iter() {
if trunk.matches_inbound_ip(addr).await {
return Some(TrunkContext {
id: trunk.id,
name: name.clone(),
tenant_id: None,
did_numbers: trunk.did_numbers.clone(),
});
}
}
}
let trunks: Vec<(String, TrunkConfig)> = self
.inner
.config
.trunks
.iter()
.map(|(name, trunk)| (name.clone(), trunk.clone()))
.collect();
for (name, trunk) in trunks {
if trunk.matches_inbound_ip(addr).await {
return Some(TrunkContext {
id: trunk.id,
name,
tenant_id: None,
did_numbers: trunk.did_numbers.clone(),
});
}
}
None
}
pub(crate) async fn is_ip_allowed(&self, addr: &IpAddr) -> bool {
let rules = self.load_rules().await;
for rule in rules {
match &rule.network {
Some(network) => {
if network.contains(addr) {
return matches!(rule.action, AclAction::Allow);
}
}
None => {
return matches!(rule.action, AclAction::Allow);
}
}
}
false }
pub fn is_ua_allowed(&self, ua: &str) -> bool {
if self.inner.ua_black_list.contains(ua) {
return false;
}
if self.inner.ua_white_list.is_empty() {
return true; }
self.inner.ua_white_list.contains(ua)
}
}
fn parse_network(addr: &str) -> Result<(IpAddr, u8)> {
if !addr.contains('/') {
let ip = IpAddr::from_str(addr)?;
let prefix_len = match ip {
IpAddr::V4(_) => 32,
IpAddr::V6(_) => 128,
};
return Ok((ip, prefix_len));
}
let parts: Vec<&str> = addr.split('/').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!("invalid network address"));
}
let ip = IpAddr::from_str(parts[0])?;
let prefix_len: u8 = parts[1]
.parse()
.map_err(|_| anyhow::anyhow!("invalid network address"))?;
match ip {
IpAddr::V4(ipv4) => {
if prefix_len > 32 {
return Err(anyhow::anyhow!("invalid network address prefix > 32"));
}
let mask = if prefix_len == 0 {
0
} else {
u32::MAX << (32 - prefix_len)
};
let network = u32::from(ipv4) & mask;
Ok((IpAddr::V4(network.into()), prefix_len))
}
IpAddr::V6(ipv6) => {
if prefix_len > 128 {
return Err(anyhow::anyhow!("invalid network address prefix > 128"));
}
let segments = ipv6.segments();
let mut result = [0u16; 8];
for i in 0..8usize {
if prefix_len as usize > i * 16 {
let bits = std::cmp::min(prefix_len as usize - i * 16, 16);
let mask = if bits == 16 {
0xFFFF
} else {
0xFFFF << (16 - bits)
};
result[i] = segments[i] & mask;
}
}
Ok((IpAddr::V6(result.into()), prefix_len))
}
}
}
#[async_trait]
impl ProxyModule for AclModule {
fn name(&self) -> &str {
"acl"
}
async fn on_start(&mut self) -> Result<()> {
let rules = self.load_rules().await;
debug!("ACL module started with {} rules", rules.len());
Ok(())
}
async fn on_stop(&self) -> Result<()> {
debug!("ACL module stopped");
Ok(())
}
async fn on_transaction_begin(
&self,
_token: CancellationToken,
tx: &mut Transaction,
cookie: TransactionCookie,
) -> Result<ProxyAction> {
match tx.original.user_agent_header() {
Some(ua_header) => {
let ua = ua_header.value();
if !self.is_ua_allowed(ua) {
info!(
method = tx.original.method().to_string(),
ua = ua,
"User-Agent is denied by acl module"
);
cookie.mark_as_spam(crate::call::cookie::SpamResult::UaBlacklist);
return Ok(ProxyAction::Abort);
}
}
None => {
if !self.inner.ua_white_list.is_empty() {
info!(
method = tx.original.method().to_string(),
"Missing User-Agent header, denied by acl module"
);
cookie.mark_as_spam(crate::call::cookie::SpamResult::Spam);
return Ok(ProxyAction::Abort);
}
}
}
if let Err(_e) = self.check_uri_normalization(tx) {
return Ok(ProxyAction::Abort);
}
if self.inner.config.dos_enabled {
if let Some(ip) = Self::extract_ip(tx) {
if let Err(e) = self.dos_check_and_track(ip).await {
warn!("DoS blocked {}: {}", ip, e);
return Ok(ProxyAction::Abort);
}
}
}
let from_addr = Self::extract_ip(tx)
.ok_or_else(|| anyhow::anyhow!("missing transport source address"))?;
if let Some(ctx) = self.is_from_trunk_context(&from_addr).await {
debug!(
method = tx.original.method().to_string(),
source_ip = %from_addr,
"IP is from trunk, bypassing acl check"
);
cookie.insert_extension(ctx);
return Ok(ProxyAction::Continue);
}
if self.is_ip_allowed(&from_addr).await {
return Ok(ProxyAction::Continue);
}
info!(
method = tx.original.method().to_string(),
source_ip = %from_addr,
"IP is denied by acl module"
);
cookie.mark_as_spam(crate::call::cookie::SpamResult::IpBlacklist);
Ok(ProxyAction::Abort)
}
async fn on_transaction_end(&self, tx: &mut Transaction) -> Result<()> {
if self.inner.config.dos_enabled {
if let Some(ip) = Self::extract_ip(tx) {
self.dos_release(ip).await;
}
}
Ok(())
}
}
impl AclModule {
async fn load_rules(&self) -> Vec<AclRule> {
if let Some(server) = &self.inner.server {
let snapshot = server.data_context.acl_rules_snapshot();
return parse_rules(snapshot);
}
parse_rules(self.inner.fallback_rules.clone())
}
}
fn resolve_base_rules(config: &ProxyConfig) -> Vec<String> {
config
.acl_rules
.clone()
.unwrap_or_else(|| vec!["allow all".to_string(), "deny all".to_string()])
}
fn parse_rules(rules: Vec<String>) -> Vec<AclRule> {
rules.iter().filter_map(|rule| AclRule::new(rule)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
fn create_test_config(rules: Vec<String>) -> Arc<ProxyConfig> {
Arc::new(ProxyConfig {
acl_rules: Some(rules),
..Default::default()
})
}
fn create_dos_config(
dos_enabled: bool,
max_cps: u32,
max_concurrent: u32,
scan_threshold: u32,
block_secs: u64,
) -> Arc<ProxyConfig> {
Arc::new(ProxyConfig {
dos_enabled,
dos_max_cps_per_ip: max_cps,
dos_max_concurrent_per_ip: max_concurrent,
dos_scan_probe_threshold: scan_threshold,
dos_scan_block_duration_secs: block_secs,
..Default::default()
})
}
#[test]
fn test_parse_network() {
let (net, prefix) = parse_network("192.168.1.0/24").unwrap();
assert_eq!(net, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 0)));
assert_eq!(prefix, 24);
let (net, prefix) = parse_network("192.168.1.1").unwrap();
assert_eq!(net, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
assert_eq!(prefix, 32);
let (net, prefix) = parse_network("2001:db8::/32").unwrap();
assert_eq!(net, IpAddr::V6(Ipv6Addr::from_str("2001:db8::").unwrap()));
assert_eq!(prefix, 32);
let (net, prefix) = parse_network("2001:db8::1").unwrap();
assert_eq!(net, IpAddr::V6(Ipv6Addr::from_str("2001:db8::1").unwrap()));
assert_eq!(prefix, 128);
}
#[tokio::test]
async fn test_acl_rules() {
let config = create_test_config(vec![
"deny 192.168.1.100".to_string(),
"allow 192.168.1.0/24".to_string(),
"allow 10.0.0.0/8".to_string(),
"deny all".to_string(),
]);
let acl = AclModule::new(config);
assert!(acl.is_ip_allowed(&"192.168.1.1".parse().unwrap()).await);
assert!(acl.is_ip_allowed(&"10.2.3.4".parse().unwrap()).await);
assert!(!acl.is_ip_allowed(&"192.168.1.100".parse().unwrap()).await);
assert!(!acl.is_ip_allowed(&"172.16.1.1".parse().unwrap()).await);
}
#[tokio::test]
async fn test_default_rules() {
let config = Arc::new(ProxyConfig {
acl_rules: None,
..Default::default()
});
let acl = AclModule::new(config);
assert!(acl.is_ip_allowed(&"192.168.1.1".parse().unwrap()).await);
assert!(acl.is_ip_allowed(&"10.0.0.1".parse().unwrap()).await);
}
#[tokio::test]
async fn test_dos_cps_limit() {
let config = create_dos_config(true, 3, 100, 100, 60);
let acl = AclModule::new(config);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
for _ in 0..3 {
assert!(acl.dos_check_and_track(ip).await.is_ok());
}
assert!(acl.dos_check_and_track(ip).await.is_err());
}
#[tokio::test]
async fn test_dos_concurrent_limit_and_release() {
let config = create_dos_config(true, 100, 2, 100, 60);
let acl = AclModule::new(config);
let ip: IpAddr = "10.0.0.2".parse().unwrap();
assert!(acl.dos_check_and_track(ip).await.is_ok());
assert!(acl.dos_check_and_track(ip).await.is_ok());
assert!(acl.dos_check_and_track(ip).await.is_err());
acl.dos_release(ip).await;
assert!(acl.dos_check_and_track(ip).await.is_ok());
assert!(acl.dos_check_and_track(ip).await.is_err());
}
#[tokio::test]
async fn test_dos_scan_detection() {
let config = create_dos_config(true, 100, 100, 3, 60);
let acl = AclModule::new(config);
let ip: IpAddr = "10.0.0.3".parse().unwrap();
assert!(acl.dos_check_and_track(ip).await.is_ok());
assert!(acl.dos_check_and_track(ip).await.is_ok());
assert!(acl.dos_check_and_track(ip).await.is_ok());
assert!(acl.dos_check_and_track(ip).await.is_err());
}
#[tokio::test]
async fn test_dos_block_clears_after_window() {
let config = create_dos_config(true, 1, 100, 100, 1); let acl = AclModule::new(config);
let ip: IpAddr = "10.0.0.4".parse().unwrap();
assert!(acl.dos_check_and_track(ip).await.is_ok());
assert!(acl.dos_check_and_track(ip).await.is_err());
tokio::time::sleep(Duration::from_secs(1)).await;
assert!(acl.dos_check_and_track(ip).await.is_ok());
}
#[tokio::test]
async fn test_dos_independent_per_ip() {
let config = create_dos_config(true, 1, 100, 100, 60);
let acl = AclModule::new(config);
let ip1: IpAddr = "10.0.0.5".parse().unwrap();
let ip2: IpAddr = "10.0.0.6".parse().unwrap();
assert!(acl.dos_check_and_track(ip1).await.is_ok());
assert!(acl.dos_check_and_track(ip1).await.is_err());
assert!(acl.dos_check_and_track(ip2).await.is_ok()); }
}