use std::collections::HashSet;
use std::net::IpAddr;
use std::time::{Duration, Instant};
use dashmap::{DashMap, DashSet};
use super::{Detector, DetectorResult};
use crate::correlation::{CampaignUpdate, CorrelationReason, CorrelationType, FingerprintIndex};
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct TokenFingerprint {
pub issuer: Option<String>,
pub algorithm: String,
pub header_fields: Vec<String>,
pub payload_fields: Vec<String>,
}
impl TokenFingerprint {
pub fn from_jwt_parts(_header: &str, _payload: &str) -> Option<Self> {
Some(Self {
issuer: None,
algorithm: "RS256".to_string(),
header_fields: vec!["alg".to_string(), "typ".to_string()],
payload_fields: vec!["sub".to_string(), "iss".to_string()],
})
}
pub fn fingerprint_hash(&self) -> String {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
Hash::hash(self, &mut hasher);
format!("{:016x}", hasher.finish())
}
}
#[derive(Debug, Clone)]
pub struct AuthTokenConfig {
pub min_ips: usize,
pub window: Duration,
pub base_confidence: f64,
pub confidence_scale_divisor: f64,
}
impl Default for AuthTokenConfig {
fn default() -> Self {
Self {
min_ips: 2,
window: Duration::from_secs(600), base_confidence: 0.85,
confidence_scale_divisor: 8.0,
}
}
}
pub struct AuthTokenDetector {
config: AuthTokenConfig,
token_index: DashMap<String, Vec<(IpAddr, Instant)>>,
detected: DashSet<String>,
}
impl AuthTokenDetector {
pub fn new(config: AuthTokenConfig) -> Self {
Self {
config,
token_index: DashMap::new(),
detected: DashSet::new(),
}
}
pub fn record_token(&self, ip: IpAddr, fingerprint: TokenFingerprint) {
let hash = fingerprint.fingerprint_hash();
let now = Instant::now();
let cutoff = now - self.config.window;
self.token_index
.entry(hash)
.and_modify(|entry| {
entry.push((ip, now));
entry.retain(|(_, ts)| *ts > cutoff);
})
.or_insert_with(|| vec![(ip, now)]);
}
pub fn record_jwt(&self, ip: IpAddr, jwt: &str) {
let parts: Vec<&str> = jwt.split('.').collect();
if parts.len() >= 2 {
if let Some(fp) = TokenFingerprint::from_jwt_parts(parts[0], parts[1]) {
self.record_token(ip, fp);
}
}
}
fn get_correlated_groups(&self) -> Vec<(String, Vec<IpAddr>)> {
let cutoff = Instant::now() - self.config.window;
self.token_index
.iter()
.filter(|entry| !self.detected.contains(entry.key()))
.filter_map(|entry| {
let hash = entry.key().clone();
let entries = entry.value();
let recent_ips: HashSet<IpAddr> = entries
.iter()
.filter(|(_, ts)| *ts > cutoff)
.map(|(ip, _)| *ip)
.collect();
if recent_ips.len() >= self.config.min_ips {
Some((hash, recent_ips.into_iter().collect()))
} else {
None
}
})
.collect()
}
}
impl Detector for AuthTokenDetector {
fn name(&self) -> &'static str {
"auth_token"
}
fn analyze(&self, _index: &FingerprintIndex) -> DetectorResult<Vec<CampaignUpdate>> {
let groups = self.get_correlated_groups();
let mut updates = Vec::new();
for (token_hash, ips) in groups {
let confidence = (ips.len() as f64 / self.config.confidence_scale_divisor).min(1.0)
* self.config.base_confidence;
updates.push(CampaignUpdate {
campaign_id: Some(format!(
"auth-token-{}",
&token_hash[..8.min(token_hash.len())]
)),
status: None,
confidence: Some(confidence),
attack_types: Some(vec!["credential_stuffing".to_string()]),
add_member_ips: Some(ips.iter().map(|ip| ip.to_string()).collect()),
add_correlation_reason: Some(CorrelationReason::new(
CorrelationType::AuthToken,
confidence,
format!(
"{} IPs using tokens with identical structure/issuer",
ips.len()
),
ips.iter().map(|ip| ip.to_string()).collect(),
)),
..Default::default()
});
self.detected.insert(token_hash);
}
Ok(updates)
}
fn should_trigger(&self, ip: &IpAddr, _index: &FingerprintIndex) -> bool {
self.token_index.iter().any(|entry| {
let entries = entry.value();
entries.iter().any(|(entry_ip, _)| entry_ip == ip)
&& entries.len() >= self.config.min_ips - 1
})
}
fn scan_interval_ms(&self) -> u64 {
5000
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = AuthTokenConfig::default();
assert_eq!(config.min_ips, 2);
}
#[test]
fn test_record_token() {
let detector = AuthTokenDetector::new(AuthTokenConfig::default());
let ip: IpAddr = "192.168.1.1".parse().unwrap();
let fp = TokenFingerprint {
issuer: Some("https://auth.example.com".to_string()),
algorithm: "RS256".to_string(),
header_fields: vec!["alg".to_string()],
payload_fields: vec!["sub".to_string()],
};
detector.record_token(ip, fp);
}
#[test]
fn test_detection_with_multiple_ips() {
let detector = AuthTokenDetector::new(AuthTokenConfig::default());
let fp = TokenFingerprint {
issuer: Some("malicious-issuer".to_string()),
algorithm: "HS256".to_string(),
header_fields: vec!["alg".to_string()],
payload_fields: vec!["sub".to_string()],
};
for i in 1..=3 {
let ip: IpAddr = format!("10.0.0.{}", i).parse().unwrap();
detector.record_token(ip, fp.clone());
}
let index = FingerprintIndex::new();
let updates = detector.analyze(&index).unwrap();
assert_eq!(updates.len(), 1);
}
#[test]
fn test_name() {
let detector = AuthTokenDetector::new(AuthTokenConfig::default());
assert_eq!(detector.name(), "auth_token");
}
#[test]
fn test_scan_interval() {
let detector = AuthTokenDetector::new(AuthTokenConfig::default());
assert_eq!(detector.scan_interval_ms(), 5000);
}
}