1use crate::pipeline::EvasionPipeline;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::io::Read;
11use std::path::{Path, PathBuf};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct CacheKey {
16 pub waf_fingerprint: String,
17 pub payload_type: String,
18}
19
20impl CacheKey {
21 #[must_use]
22 pub fn new(waf: impl Into<String>, payload: impl Into<String>) -> Self {
23 Self {
24 waf_fingerprint: waf.into(),
25 payload_type: payload.into(),
26 }
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32pub struct CacheEntry {
33 pub pipeline: EvasionPipeline,
34 pub successes: u32,
35 pub attempts: u32,
36 pub last_success_epoch: u64,
37}
38
39impl CacheEntry {
40 #[must_use]
41 pub fn success_rate(&self) -> f64 {
42 if self.attempts == 0 {
43 0.0
44 } else {
45 f64::from(self.successes) / f64::from(self.attempts)
46 }
47 }
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct LearningCache {
55 #[serde(skip)]
56 path: Option<PathBuf>,
57 entries: HashMap<String, CacheEntry>,
58}
59
60fn cache_key_str(k: &CacheKey) -> String {
61 serde_json::to_string(k).unwrap_or_else(|_| {
62 format!(
63 "{{\"waf_fingerprint\":{},\"payload_type\":{}}}",
64 serde_json::to_string(&k.waf_fingerprint).unwrap_or_else(|_| "\"\"".to_string()),
65 serde_json::to_string(&k.payload_type).unwrap_or_else(|_| "\"\"".to_string()),
66 )
67 })
68}
69
70impl LearningCache {
71 pub fn open_default() -> Result<Self, LearningCacheError> {
77 let home = dirs::home_dir().ok_or(LearningCacheError::NoHomeDir)?;
78 let path = home.join(".wafrift").join("learning_cache.json");
79 Self::open(path)
80 }
81
82 pub fn open(path: impl AsRef<Path>) -> Result<Self, LearningCacheError> {
95 let path = path.as_ref();
96 if path.exists() {
97 const MAX_CACHE_FILE_BYTES: u64 = 16 * 1024 * 1024;
111 let f = fs::File::open(path).map_err(LearningCacheError::Io)?;
112 let mut limited = f.take(MAX_CACHE_FILE_BYTES + 1);
113 let mut raw = Vec::new();
114 limited
115 .read_to_end(&mut raw)
116 .map_err(LearningCacheError::Io)?;
117 if raw.len() as u64 > MAX_CACHE_FILE_BYTES {
118 tracing::warn!(
119 path = %path.display(),
120 cap = MAX_CACHE_FILE_BYTES,
121 "learning cache file exceeds size cap; moving aside and starting fresh"
122 );
123 let backup = path.with_extension(format!("oversize-{}", current_epoch()));
124 let _ = fs::rename(path, &backup);
125 return Ok(Self {
126 path: Some(path.to_path_buf()),
127 entries: HashMap::new(),
128 });
129 }
130 let contents = String::from_utf8(raw)
131 .map_err(|e| {
132 std::io::Error::new(
133 std::io::ErrorKind::InvalidData,
134 format!("{}: learning cache is not valid UTF-8: {e}", path.display()),
135 )
136 })
137 .map_err(LearningCacheError::Io)?;
138 match serde_json::from_str::<LearningCache>(&contents) {
139 Ok(mut cache) => {
140 cache.path = Some(path.to_path_buf());
141 Ok(cache)
142 }
143 Err(e) => {
144 let backup = path.with_extension(format!("corrupt-{}", current_epoch()));
145 let backup_msg = match fs::rename(path, &backup) {
146 Ok(()) => format!("moved aside to {}", backup.display()),
147 Err(rename_err) => {
148 format!("could not rename ({rename_err}); leaving file in place")
149 }
150 };
151 tracing::warn!(
152 path = %path.display(),
153 error = %e,
154 backup = %backup_msg,
155 "learning cache file corrupted; starting fresh"
156 );
157 Ok(Self {
158 path: Some(path.to_path_buf()),
159 entries: HashMap::new(),
160 })
161 }
162 }
163 } else {
164 Ok(Self {
165 path: Some(path.to_path_buf()),
166 entries: HashMap::new(),
167 })
168 }
169 }
170
171 #[must_use]
173 pub fn get(&self, key: &CacheKey) -> Option<&CacheEntry> {
174 self.entries.get(&cache_key_str(key))
175 }
176
177 pub fn record_success(&mut self, key: CacheKey, pipeline: EvasionPipeline) {
189 let now = current_epoch();
190 let entry = self
191 .entries
192 .entry(cache_key_str(&key))
193 .or_insert_with(|| CacheEntry {
194 pipeline: pipeline.clone(),
195 successes: 0,
196 attempts: 0,
197 last_success_epoch: 0,
198 });
199 entry.pipeline = pipeline;
202 entry.successes = entry.successes.saturating_add(1);
203 entry.attempts = entry.attempts.saturating_add(1);
204 entry.last_success_epoch = now;
205 }
206
207 pub fn record_failure(&mut self, key: CacheKey, pipeline: EvasionPipeline) {
214 let entry = self
215 .entries
216 .entry(cache_key_str(&key))
217 .or_insert(CacheEntry {
218 pipeline,
219 successes: 0,
220 attempts: 0,
221 last_success_epoch: 0,
222 });
223 entry.attempts = entry.attempts.saturating_add(1);
224 }
225
226 pub fn save(&self) -> Result<(), LearningCacheError> {
238 let path = self.path.as_ref().ok_or(LearningCacheError::NoPath)?;
239 if let Some(parent) = path.parent() {
240 fs::create_dir_all(parent).map_err(LearningCacheError::Io)?;
241 }
242 let json = serde_json::to_string_pretty(self).map_err(LearningCacheError::Serde)?;
243
244 let tmp = path.with_extension(format!("tmp.{}.{}", std::process::id(), current_epoch()));
247 {
250 use std::io::Write;
251 let mut f = fs::File::create(&tmp).map_err(LearningCacheError::Io)?;
252 f.write_all(json.as_bytes())
253 .map_err(LearningCacheError::Io)?;
254 f.sync_all().map_err(LearningCacheError::Io)?;
255 }
256 if let Err(e) = fs::rename(&tmp, path) {
257 let _ = fs::remove_file(&tmp);
259 return Err(LearningCacheError::Io(e));
260 }
261 Ok(())
262 }
263
264 #[must_use]
266 pub fn keys(&self) -> Vec<CacheKey> {
267 self.entries
268 .keys()
269 .filter_map(|s| match serde_json::from_str(s) {
270 Ok(k) => Some(k),
271 Err(e) => {
272 tracing::warn!(key = %s, error = %e, "learning cache key parse failed");
273 None
274 }
275 })
276 .collect()
277 }
278}
279
280#[derive(Debug, thiserror::Error)]
282pub enum LearningCacheError {
283 #[error("learning cache I/O error: {0}")]
284 Io(#[from] std::io::Error),
285 #[error("learning cache serialization error: {0}")]
286 Serde(#[from] serde_json::Error),
287 #[error("cannot determine home directory")]
288 NoHomeDir,
289 #[error("no path set for learning cache")]
290 NoPath,
291}
292
293fn current_epoch() -> u64 {
294 std::time::SystemTime::now()
295 .duration_since(std::time::UNIX_EPOCH)
296 .map_or(0, |d| d.as_secs())
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::pipeline::EvasionStage;
303 use wafrift_types::Technique;
304
305 #[test]
306 fn cache_roundtrip() {
307 let tmp = std::env::temp_dir().join("wafrift_learning_cache_test.json");
308 let _ = fs::remove_file(&tmp);
309
310 let mut cache = LearningCache::open(&tmp).unwrap();
311 let pipeline = EvasionPipeline::new(
312 "test",
313 vec![EvasionStage {
314 technique: Technique::UserAgentRotation,
315 context: None,
316 }],
317 1,
318 );
319 cache.record_success(CacheKey::new("cloudflare", "sql"), pipeline);
320 cache.save().unwrap();
321
322 let cache2 = LearningCache::open(&tmp).unwrap();
323 let entry = cache2.get(&CacheKey::new("cloudflare", "sql")).unwrap();
324 assert_eq!(entry.successes, 1);
325 assert_eq!(entry.attempts, 1);
326
327 let _ = fs::remove_file(&tmp);
328 }
329
330 #[test]
331 fn cache_persists_across_process_restarts() {
332 let tmp = std::env::temp_dir().join("wafrift_learning_cache_restart.json");
333 let _ = fs::remove_file(&tmp);
334
335 {
337 let mut cache = LearningCache::open(&tmp).unwrap();
338 let pipeline = EvasionPipeline::new(
339 "win",
340 vec![EvasionStage {
341 technique: Technique::GrammarMutation("sql".into()),
342 context: None,
343 }],
344 2,
345 );
346 cache.record_success(CacheKey::new("aws_waf", "xss"), pipeline);
347 cache.save().unwrap();
348 }
349
350 {
352 let cache = LearningCache::open(&tmp).unwrap();
353 let entry = cache.get(&CacheKey::new("aws_waf", "xss")).unwrap();
354 assert_eq!(entry.successes, 1);
355 assert!(entry.last_success_epoch > 0);
356 }
357
358 let _ = fs::remove_file(&tmp);
359 }
360
361 #[test]
362 fn cache_failure_tracking() {
363 let tmp = std::env::temp_dir().join("wafrift_learning_cache_fail.json");
364 let _ = fs::remove_file(&tmp);
365
366 let mut cache = LearningCache::open(&tmp).unwrap();
367 let pipeline = EvasionPipeline::new("lose", vec![], 1);
368 let key = CacheKey::new("modsecurity", "cmdi");
369 cache.record_failure(key.clone(), pipeline);
370 cache.save().unwrap();
371
372 let cache2 = LearningCache::open(&tmp).unwrap();
373 let entry = cache2.get(&key).unwrap();
374 assert_eq!(entry.successes, 0);
375 assert_eq!(entry.attempts, 1);
376
377 let _ = fs::remove_file(&tmp);
378 }
379
380 #[test]
381 fn record_success_after_failure_overwrites_stored_pipeline() {
382 let mut cache = LearningCache::default();
388 let loser = EvasionPipeline::new("LOSER", vec![], 1);
389 let winner = EvasionPipeline::new("WINNER", vec![], 1);
390 let key = CacheKey::new("cloudflare", "xss");
391 cache.record_failure(key.clone(), loser);
392 cache.record_success(key.clone(), winner);
393 let entry = cache.get(&key).expect("entry present");
394 assert_eq!(
395 entry.pipeline.name, "WINNER",
396 "post-success the winning pipeline must be cached"
397 );
398 assert_eq!(entry.successes, 1);
399 assert_eq!(entry.attempts, 2);
400 }
401
402 #[test]
403 fn second_record_success_overwrites_first_pipeline() {
404 let mut cache = LearningCache::default();
406 let first = EvasionPipeline::new("FIRST", vec![], 1);
407 let second = EvasionPipeline::new("SECOND", vec![], 1);
408 let key = CacheKey::new("awswaf", "sql");
409 cache.record_success(key.clone(), first);
410 cache.record_success(key.clone(), second);
411 let entry = cache.get(&key).unwrap();
412 assert_eq!(entry.pipeline.name, "SECOND");
413 assert_eq!(entry.successes, 2);
414 }
415
416 #[test]
421 fn oversized_cache_file_is_moved_aside_and_returns_empty_cache() {
422 use std::io::Write;
423
424 let tmp = std::env::temp_dir().join(format!(
425 "wafrift_learning_cache_oversize_{}.json",
426 std::process::id()
427 ));
428 let _ = fs::remove_file(&tmp);
429
430 {
433 let mut f = fs::File::create(&tmp).unwrap();
434 let chunk = vec![b' '; 64 * 1024];
435 for _ in 0..(17 * 1024 * 1024 / chunk.len()) {
436 f.write_all(&chunk).unwrap();
437 }
438 f.sync_all().unwrap();
439 }
440
441 let cache = LearningCache::open(&tmp).expect("open must succeed (not Err) for oversize");
443 assert!(
444 cache.keys().is_empty(),
445 "oversize cache must be treated as empty"
446 );
447
448 assert!(
451 !tmp.exists(),
452 "oversize file must be moved aside, not left at the original path"
453 );
454
455 if let Ok(entries) = fs::read_dir(tmp.parent().unwrap_or(std::path::Path::new("."))) {
457 for entry in entries.flatten() {
458 let name = entry.file_name();
459 let name_str = name.to_string_lossy();
460 if name_str.starts_with(&format!(
461 "wafrift_learning_cache_oversize_{}.json.oversize",
462 std::process::id()
463 )) {
464 let _ = fs::remove_file(entry.path());
465 }
466 }
467 }
468 }
469}