use crate::pipeline::EvasionPipeline;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct CacheKey {
pub waf_fingerprint: String,
pub payload_type: String,
}
impl CacheKey {
#[must_use]
pub fn new(waf: impl Into<String>, payload: impl Into<String>) -> Self {
Self {
waf_fingerprint: waf.into(),
payload_type: payload.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CacheEntry {
pub pipeline: EvasionPipeline,
pub successes: u32,
pub attempts: u32,
pub last_success_epoch: u64,
}
impl CacheEntry {
#[must_use]
pub fn success_rate(&self) -> f64 {
if self.attempts == 0 {
0.0
} else {
f64::from(self.successes) / f64::from(self.attempts)
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LearningCache {
#[serde(skip)]
path: Option<PathBuf>,
entries: HashMap<String, CacheEntry>,
}
fn cache_key_str(k: &CacheKey) -> String {
serde_json::to_string(k).unwrap_or_else(|_| {
format!(
"{{\"waf_fingerprint\":{},\"payload_type\":{}}}",
serde_json::to_string(&k.waf_fingerprint).unwrap_or_else(|_| "\"\"".to_string()),
serde_json::to_string(&k.payload_type).unwrap_or_else(|_| "\"\"".to_string()),
)
})
}
impl LearningCache {
pub fn open_default() -> Result<Self, LearningCacheError> {
let home = dirs::home_dir().ok_or(LearningCacheError::NoHomeDir)?;
let path = home.join(".wafrift").join("learning_cache.json");
Self::open(path)
}
pub fn open(path: impl AsRef<Path>) -> Result<Self, LearningCacheError> {
let path = path.as_ref();
if path.exists() {
let contents = fs::read_to_string(path).map_err(LearningCacheError::Io)?;
let mut cache: LearningCache =
serde_json::from_str(&contents).map_err(LearningCacheError::Serde)?;
cache.path = Some(path.to_path_buf());
Ok(cache)
} else {
Ok(Self {
path: Some(path.to_path_buf()),
entries: HashMap::new(),
})
}
}
#[must_use]
pub fn get(&self, key: &CacheKey) -> Option<&CacheEntry> {
self.entries.get(&cache_key_str(key))
}
pub fn record_success(&mut self, key: CacheKey, pipeline: EvasionPipeline) {
let now = current_epoch();
let entry = self
.entries
.entry(cache_key_str(&key))
.or_insert(CacheEntry {
pipeline,
successes: 0,
attempts: 0,
last_success_epoch: 0,
});
entry.successes += 1;
entry.attempts += 1;
entry.last_success_epoch = now;
}
pub fn record_failure(&mut self, key: CacheKey, pipeline: EvasionPipeline) {
let entry = self
.entries
.entry(cache_key_str(&key))
.or_insert(CacheEntry {
pipeline,
successes: 0,
attempts: 0,
last_success_epoch: 0,
});
entry.attempts += 1;
}
pub fn save(&self) -> Result<(), LearningCacheError> {
let path = self.path.as_ref().ok_or(LearningCacheError::NoPath)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(LearningCacheError::Io)?;
}
let json = serde_json::to_string_pretty(self).map_err(LearningCacheError::Serde)?;
fs::write(path, json).map_err(LearningCacheError::Io)
}
#[must_use]
pub fn keys(&self) -> Vec<CacheKey> {
self.entries
.keys()
.filter_map(|s| serde_json::from_str(s).ok())
.collect()
}
}
#[derive(Debug)]
pub enum LearningCacheError {
Io(std::io::Error),
Serde(serde_json::Error),
NoHomeDir,
NoPath,
}
impl std::fmt::Display for LearningCacheError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "learning cache I/O error: {e}"),
Self::Serde(e) => write!(f, "learning cache serialization error: {e}"),
Self::NoHomeDir => f.write_str("cannot determine home directory"),
Self::NoPath => f.write_str("no path set for learning cache"),
}
}
}
impl std::error::Error for LearningCacheError {}
fn current_epoch() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_secs())
}
#[cfg(test)]
mod tests {
use super::*;
use wafrift_types::Technique;
#[test]
fn cache_roundtrip() {
let tmp = std::env::temp_dir().join("wafrift_learning_cache_test.json");
let _ = fs::remove_file(&tmp);
let mut cache = LearningCache::open(&tmp).unwrap();
let pipeline = EvasionPipeline::new("test", vec![Technique::UserAgentRotation], 1);
cache.record_success(CacheKey::new("cloudflare", "sql"), pipeline);
cache.save().unwrap();
let cache2 = LearningCache::open(&tmp).unwrap();
let entry = cache2.get(&CacheKey::new("cloudflare", "sql")).unwrap();
assert_eq!(entry.successes, 1);
assert_eq!(entry.attempts, 1);
let _ = fs::remove_file(&tmp);
}
#[test]
fn cache_persists_across_process_restarts() {
let tmp = std::env::temp_dir().join("wafrift_learning_cache_restart.json");
let _ = fs::remove_file(&tmp);
{
let mut cache = LearningCache::open(&tmp).unwrap();
let pipeline =
EvasionPipeline::new("win", vec![Technique::GrammarMutation("sql".into())], 2);
cache.record_success(CacheKey::new("aws_waf", "xss"), pipeline);
cache.save().unwrap();
}
{
let cache = LearningCache::open(&tmp).unwrap();
let entry = cache.get(&CacheKey::new("aws_waf", "xss")).unwrap();
assert_eq!(entry.successes, 1);
assert!(entry.last_success_epoch > 0);
}
let _ = fs::remove_file(&tmp);
}
#[test]
fn cache_failure_tracking() {
let tmp = std::env::temp_dir().join("wafrift_learning_cache_fail.json");
let _ = fs::remove_file(&tmp);
let mut cache = LearningCache::open(&tmp).unwrap();
let pipeline = EvasionPipeline::new("lose", vec![], 1);
let key = CacheKey::new("modsecurity", "cmdi");
cache.record_failure(key.clone(), pipeline);
cache.save().unwrap();
let cache2 = LearningCache::open(&tmp).unwrap();
let entry = cache2.get(&key).unwrap();
assert_eq!(entry.successes, 0);
assert_eq!(entry.attempts, 1);
let _ = fs::remove_file(&tmp);
}
}