use crate::error::{GatewayError, Result};
use crate::middleware::ip_matcher::IpMatcher;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Debug)]
pub struct TcpPermit {
counter: Arc<AtomicUsize>,
}
impl Drop for TcpPermit {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::SeqCst);
}
}
pub struct TcpFilter {
ip_matcher: IpMatcher,
max_connections: Option<u32>,
current_connections: Arc<AtomicUsize>,
}
impl TcpFilter {
pub fn new(max_connections: Option<u32>, allowed_ips: &[String]) -> Result<Self> {
let ip_matcher = IpMatcher::new(allowed_ips)?;
Ok(Self {
ip_matcher,
max_connections,
current_connections: Arc::new(AtomicUsize::new(0)),
})
}
pub fn check_connection(&self, addr: &str) -> Result<TcpPermit> {
if !self.ip_matcher.is_empty() && !self.ip_matcher.is_allowed(addr) {
return Err(GatewayError::MiddlewareRejected(format!(
"TCP connection from {} denied by IP filter",
addr
)));
}
if let Some(max) = self.max_connections {
let current = self.current_connections.fetch_add(1, Ordering::SeqCst);
if current >= max as usize {
self.current_connections.fetch_sub(1, Ordering::SeqCst);
return Err(GatewayError::MiddlewareRejected(format!(
"TCP connection limit reached ({}/{})",
current, max
)));
}
} else {
self.current_connections.fetch_add(1, Ordering::SeqCst);
}
Ok(TcpPermit {
counter: self.current_connections.clone(),
})
}
#[allow(dead_code)]
pub fn active_connections(&self) -> usize {
self.current_connections.load(Ordering::SeqCst)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_filter() {
let filter = TcpFilter::new(None, &[]).unwrap();
let permit = filter.check_connection("10.0.0.1").unwrap();
assert_eq!(filter.active_connections(), 1);
drop(permit);
assert_eq!(filter.active_connections(), 0);
}
#[test]
fn test_ip_allowed() {
let ips = vec!["10.0.0.0/8".to_string()];
let filter = TcpFilter::new(None, &ips).unwrap();
assert!(filter.check_connection("10.0.0.1").is_ok());
}
#[test]
fn test_ip_denied() {
let ips = vec!["10.0.0.0/8".to_string()];
let filter = TcpFilter::new(None, &ips).unwrap();
let result = filter.check_connection("192.168.1.1");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("denied"));
}
#[test]
fn test_connection_limit() {
let filter = TcpFilter::new(Some(2), &[]).unwrap();
let p1 = filter.check_connection("10.0.0.1").unwrap();
let p2 = filter.check_connection("10.0.0.2").unwrap();
assert_eq!(filter.active_connections(), 2);
let result = filter.check_connection("10.0.0.3");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("limit reached"));
drop(p1);
assert_eq!(filter.active_connections(), 1);
let _p3 = filter.check_connection("10.0.0.3").unwrap();
assert_eq!(filter.active_connections(), 2);
drop(p2);
drop(_p3);
}
#[test]
fn test_permit_decrements_on_drop() {
let filter = TcpFilter::new(Some(10), &[]).unwrap();
{
let _p = filter.check_connection("10.0.0.1").unwrap();
assert_eq!(filter.active_connections(), 1);
}
assert_eq!(filter.active_connections(), 0);
}
#[test]
fn test_combined_ip_and_limit() {
let ips = vec!["10.0.0.0/8".to_string()];
let filter = TcpFilter::new(Some(1), &ips).unwrap();
let p = filter.check_connection("10.0.0.1").unwrap();
assert_eq!(filter.active_connections(), 1);
let result = filter.check_connection("10.0.0.2");
assert!(result.is_err());
drop(p);
let result = filter.check_connection("192.168.1.1");
assert!(result.is_err());
assert_eq!(filter.active_connections(), 0);
}
#[test]
fn test_no_ip_filter_accepts_all() {
let filter = TcpFilter::new(Some(100), &[]).unwrap();
assert!(filter.check_connection("1.2.3.4").is_ok());
assert!(filter.check_connection("::1").is_ok());
}
#[test]
fn test_invalid_ip_entries_rejected() {
let ips = vec!["not-valid".to_string()];
assert!(TcpFilter::new(None, &ips).is_err());
}
#[test]
fn test_multiple_permits_concurrent() {
let filter = TcpFilter::new(Some(5), &[]).unwrap();
let permits: Vec<_> = (0..5)
.map(|i| filter.check_connection(&format!("10.0.0.{}", i)).unwrap())
.collect();
assert_eq!(filter.active_connections(), 5);
assert!(filter.check_connection("10.0.0.5").is_err());
drop(permits);
assert_eq!(filter.active_connections(), 0);
}
#[test]
fn test_limit_one() {
let filter = TcpFilter::new(Some(1), &[]).unwrap();
let p = filter.check_connection("10.0.0.1").unwrap();
assert!(filter.check_connection("10.0.0.2").is_err());
drop(p);
assert!(filter.check_connection("10.0.0.2").is_ok());
}
}