Skip to main content

wafrift_strategy/
learning_cache.rs

1//! Learning cache — persistent per-WAF, per-payload-type pipeline memory.
2//!
3//! After a successful bypass, the winning pipeline is cached to disk
4//! and re-used on subsequent scans of the same WAF + payload type.
5
6use crate::pipeline::EvasionPipeline;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::io::Read;
11use std::path::{Path, PathBuf};
12
13/// Cache key: WAF fingerprint + payload type.
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct CacheKey {
16    pub waf_fingerprint: String,
17    pub payload_type: String,
18}
19
20impl CacheKey {
21    #[must_use]
22    pub fn new(waf: impl Into<String>, payload: impl Into<String>) -> Self {
23        Self {
24            waf_fingerprint: waf.into(),
25            payload_type: payload.into(),
26        }
27    }
28}
29
30/// A single cached entry: the winning pipeline and its success stats.
31#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32pub struct CacheEntry {
33    pub pipeline: EvasionPipeline,
34    pub successes: u32,
35    pub attempts: u32,
36    pub last_success_epoch: u64,
37}
38
39impl CacheEntry {
40    #[must_use]
41    pub fn success_rate(&self) -> f64 {
42        if self.attempts == 0 {
43            0.0
44        } else {
45            f64::from(self.successes) / f64::from(self.attempts)
46        }
47    }
48}
49
50/// On-disk learning cache.
51///
52/// Keys are JSON-serialized [`CacheKey`] strings because JSON object keys must be strings.
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct LearningCache {
55    #[serde(skip)]
56    path: Option<PathBuf>,
57    entries: HashMap<String, CacheEntry>,
58}
59
60fn cache_key_str(k: &CacheKey) -> String {
61    serde_json::to_string(k).unwrap_or_else(|_| {
62        format!(
63            "{{\"waf_fingerprint\":{},\"payload_type\":{}}}",
64            serde_json::to_string(&k.waf_fingerprint).unwrap_or_else(|_| "\"\"".to_string()),
65            serde_json::to_string(&k.payload_type).unwrap_or_else(|_| "\"\"".to_string()),
66        )
67    })
68}
69
70impl LearningCache {
71    /// Open the default cache at `~/.wafrift/learning_cache.json`.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the home directory cannot be determined.
76    pub fn open_default() -> Result<Self, LearningCacheError> {
77        let home = dirs::home_dir().ok_or(LearningCacheError::NoHomeDir)?;
78        let path = home.join(".wafrift").join("learning_cache.json");
79        Self::open(path)
80    }
81
82    /// Open or create a cache at a specific path.
83    ///
84    /// A corrupted cache file (kill-9 mid-save, disk corruption, partial
85    /// flush) is moved aside to `<path>.corrupt-<epoch>` and a fresh
86    /// empty cache is returned. Crashing the whole strategy engine on
87    /// one bad JSON file would lose all subsequent learning — better to
88    /// surface the corruption via `tracing::warn` and keep going.
89    ///
90    /// # Errors
91    ///
92    /// Returns an error only if the file exists, looks fine, and the
93    /// underlying I/O still fails (permission denied, etc.).
94    pub fn open(path: impl AsRef<Path>) -> Result<Self, LearningCacheError> {
95        let path = path.as_ref();
96        if path.exists() {
97            // Audit (2026-05-10): pre-fix the cache was loaded with no
98            // size or depth limit on the JSON. A maliciously crafted
99            // ~/.wafrift/learning_cache.json could exhaust memory
100            // (multi-GB file) or stack (deeply nested arrays). Cap the
101            // file at MAX_CACHE_FILE_BYTES; the JSON parser then has
102            // a bounded heap and stack via that bound.
103            //
104            // Audit (2026-05-27): the previous fix used metadata().len()
105            // followed by read_to_string() — a TOCTOU window where a
106            // symlink swap or file growth between the two calls could
107            // bypass the cap. Use File::open() + take(cap+1) instead:
108            // the cap is enforced DURING the read on the same open
109            // descriptor, closing the race.
110            const MAX_CACHE_FILE_BYTES: u64 = 16 * 1024 * 1024;
111            let f = fs::File::open(path).map_err(LearningCacheError::Io)?;
112            let mut limited = f.take(MAX_CACHE_FILE_BYTES + 1);
113            let mut raw = Vec::new();
114            limited
115                .read_to_end(&mut raw)
116                .map_err(LearningCacheError::Io)?;
117            if raw.len() as u64 > MAX_CACHE_FILE_BYTES {
118                tracing::warn!(
119                    path = %path.display(),
120                    cap = MAX_CACHE_FILE_BYTES,
121                    "learning cache file exceeds size cap; moving aside and starting fresh"
122                );
123                let backup = path.with_extension(format!("oversize-{}", current_epoch()));
124                let _ = fs::rename(path, &backup);
125                return Ok(Self {
126                    path: Some(path.to_path_buf()),
127                    entries: HashMap::new(),
128                });
129            }
130            let contents = String::from_utf8(raw)
131                .map_err(|e| {
132                    std::io::Error::new(
133                        std::io::ErrorKind::InvalidData,
134                        format!("{}: learning cache is not valid UTF-8: {e}", path.display()),
135                    )
136                })
137                .map_err(LearningCacheError::Io)?;
138            match serde_json::from_str::<LearningCache>(&contents) {
139                Ok(mut cache) => {
140                    cache.path = Some(path.to_path_buf());
141                    Ok(cache)
142                }
143                Err(e) => {
144                    let backup = path.with_extension(format!("corrupt-{}", current_epoch()));
145                    let backup_msg = match fs::rename(path, &backup) {
146                        Ok(()) => format!("moved aside to {}", backup.display()),
147                        Err(rename_err) => {
148                            format!("could not rename ({rename_err}); leaving file in place")
149                        }
150                    };
151                    tracing::warn!(
152                        path = %path.display(),
153                        error = %e,
154                        backup = %backup_msg,
155                        "learning cache file corrupted; starting fresh"
156                    );
157                    Ok(Self {
158                        path: Some(path.to_path_buf()),
159                        entries: HashMap::new(),
160                    })
161                }
162            }
163        } else {
164            Ok(Self {
165                path: Some(path.to_path_buf()),
166                entries: HashMap::new(),
167            })
168        }
169    }
170
171    /// Look up a cached pipeline.
172    #[must_use]
173    pub fn get(&self, key: &CacheKey) -> Option<&CacheEntry> {
174        self.entries.get(&cache_key_str(key))
175    }
176
177    /// Record a successful bypass.
178    ///
179    /// The stored pipeline is ALWAYS overwritten with the
180    /// just-succeeded pipeline. Pre-fix this used `or_insert`
181    /// which left the existing entry's pipeline untouched — if
182    /// the first interaction for a key was a `record_failure`,
183    /// the failing pipeline got stored permanently and every
184    /// subsequent `record_success` (with a DIFFERENT, working
185    /// pipeline) silently kept the loser as the cached winner.
186    /// The planner then promoted the known-failing pipeline to
187    /// the top of every future scan.
188    pub fn record_success(&mut self, key: CacheKey, pipeline: EvasionPipeline) {
189        let now = current_epoch();
190        let entry = self
191            .entries
192            .entry(cache_key_str(&key))
193            .or_insert_with(|| CacheEntry {
194                pipeline: pipeline.clone(),
195                successes: 0,
196                attempts: 0,
197                last_success_epoch: 0,
198            });
199        // Always update to the just-succeeded pipeline — even if
200        // it's the same shape as the cached one, this is cheap.
201        entry.pipeline = pipeline;
202        entry.successes = entry.successes.saturating_add(1);
203        entry.attempts = entry.attempts.saturating_add(1);
204        entry.last_success_epoch = now;
205    }
206
207    /// Record a failed attempt.
208    ///
209    /// Failures DO NOT overwrite the stored pipeline — the cached
210    /// winner is set by `record_success`. If no success has been
211    /// recorded yet, the failing pipeline is what's stored, but
212    /// the next success will replace it.
213    pub fn record_failure(&mut self, key: CacheKey, pipeline: EvasionPipeline) {
214        let entry = self
215            .entries
216            .entry(cache_key_str(&key))
217            .or_insert(CacheEntry {
218                pipeline,
219                successes: 0,
220                attempts: 0,
221                last_success_epoch: 0,
222            });
223        entry.attempts = entry.attempts.saturating_add(1);
224    }
225
226    /// Persist the cache to disk atomically.
227    ///
228    /// Writes to a sibling `<path>.tmp.<pid>.<epoch>` file, fsyncs it,
229    /// then renames over the target path. A kill-9 between `write` and
230    /// `rename` leaves the previous good cache file untouched instead
231    /// of producing the half-written JSON that was poisoning subsequent
232    /// `open` calls.
233    ///
234    /// # Errors
235    ///
236    /// Returns an error if the file cannot be written or renamed.
237    pub fn save(&self) -> Result<(), LearningCacheError> {
238        let path = self.path.as_ref().ok_or(LearningCacheError::NoPath)?;
239        if let Some(parent) = path.parent() {
240            fs::create_dir_all(parent).map_err(LearningCacheError::Io)?;
241        }
242        let json = serde_json::to_string_pretty(self).map_err(LearningCacheError::Serde)?;
243
244        // Sibling tmp file in the same directory so `rename` is atomic
245        // (cross-FS rename on /tmp would silently fall back to copy).
246        let tmp = path.with_extension(format!("tmp.{}.{}", std::process::id(), current_epoch()));
247        // Scope the file handle so the OS releases its descriptor before
248        // we rename — Windows would otherwise refuse the rename.
249        {
250            use std::io::Write;
251            let mut f = fs::File::create(&tmp).map_err(LearningCacheError::Io)?;
252            f.write_all(json.as_bytes())
253                .map_err(LearningCacheError::Io)?;
254            f.sync_all().map_err(LearningCacheError::Io)?;
255        }
256        if let Err(e) = fs::rename(&tmp, path) {
257            // Clean up the orphaned tmp file before propagating.
258            let _ = fs::remove_file(&tmp);
259            return Err(LearningCacheError::Io(e));
260        }
261        Ok(())
262    }
263
264    /// All cached keys.
265    #[must_use]
266    pub fn keys(&self) -> Vec<CacheKey> {
267        self.entries
268            .keys()
269            .filter_map(|s| match serde_json::from_str(s) {
270                Ok(k) => Some(k),
271                Err(e) => {
272                    tracing::warn!(key = %s, error = %e, "learning cache key parse failed");
273                    None
274                }
275            })
276            .collect()
277    }
278}
279
280/// Errors from learning cache operations.
281#[derive(Debug, thiserror::Error)]
282pub enum LearningCacheError {
283    #[error("learning cache I/O error: {0}")]
284    Io(#[from] std::io::Error),
285    #[error("learning cache serialization error: {0}")]
286    Serde(#[from] serde_json::Error),
287    #[error("cannot determine home directory")]
288    NoHomeDir,
289    #[error("no path set for learning cache")]
290    NoPath,
291}
292
293fn current_epoch() -> u64 {
294    std::time::SystemTime::now()
295        .duration_since(std::time::UNIX_EPOCH)
296        .map_or(0, |d| d.as_secs())
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use crate::pipeline::EvasionStage;
303    use wafrift_types::Technique;
304
305    #[test]
306    fn cache_roundtrip() {
307        let tmp = std::env::temp_dir().join("wafrift_learning_cache_test.json");
308        let _ = fs::remove_file(&tmp);
309
310        let mut cache = LearningCache::open(&tmp).unwrap();
311        let pipeline = EvasionPipeline::new(
312            "test",
313            vec![EvasionStage {
314                technique: Technique::UserAgentRotation,
315                context: None,
316            }],
317            1,
318        );
319        cache.record_success(CacheKey::new("cloudflare", "sql"), pipeline);
320        cache.save().unwrap();
321
322        let cache2 = LearningCache::open(&tmp).unwrap();
323        let entry = cache2.get(&CacheKey::new("cloudflare", "sql")).unwrap();
324        assert_eq!(entry.successes, 1);
325        assert_eq!(entry.attempts, 1);
326
327        let _ = fs::remove_file(&tmp);
328    }
329
330    #[test]
331    fn cache_persists_across_process_restarts() {
332        let tmp = std::env::temp_dir().join("wafrift_learning_cache_restart.json");
333        let _ = fs::remove_file(&tmp);
334
335        // Process 1
336        {
337            let mut cache = LearningCache::open(&tmp).unwrap();
338            let pipeline = EvasionPipeline::new(
339                "win",
340                vec![EvasionStage {
341                    technique: Technique::GrammarMutation("sql".into()),
342                    context: None,
343                }],
344                2,
345            );
346            cache.record_success(CacheKey::new("aws_waf", "xss"), pipeline);
347            cache.save().unwrap();
348        }
349
350        // Process 2
351        {
352            let cache = LearningCache::open(&tmp).unwrap();
353            let entry = cache.get(&CacheKey::new("aws_waf", "xss")).unwrap();
354            assert_eq!(entry.successes, 1);
355            assert!(entry.last_success_epoch > 0);
356        }
357
358        let _ = fs::remove_file(&tmp);
359    }
360
361    #[test]
362    fn cache_failure_tracking() {
363        let tmp = std::env::temp_dir().join("wafrift_learning_cache_fail.json");
364        let _ = fs::remove_file(&tmp);
365
366        let mut cache = LearningCache::open(&tmp).unwrap();
367        let pipeline = EvasionPipeline::new("lose", vec![], 1);
368        let key = CacheKey::new("modsecurity", "cmdi");
369        cache.record_failure(key.clone(), pipeline);
370        cache.save().unwrap();
371
372        let cache2 = LearningCache::open(&tmp).unwrap();
373        let entry = cache2.get(&key).unwrap();
374        assert_eq!(entry.successes, 0);
375        assert_eq!(entry.attempts, 1);
376
377        let _ = fs::remove_file(&tmp);
378    }
379
380    #[test]
381    fn record_success_after_failure_overwrites_stored_pipeline() {
382        // Regression for F44: pre-fix record_success used
383        // or_insert which left the existing entry's pipeline
384        // untouched. If the first call was record_failure with a
385        // losing pipeline, the loser became permanent — the
386        // planner promoted it to every future scan.
387        let mut cache = LearningCache::default();
388        let loser = EvasionPipeline::new("LOSER", vec![], 1);
389        let winner = EvasionPipeline::new("WINNER", vec![], 1);
390        let key = CacheKey::new("cloudflare", "xss");
391        cache.record_failure(key.clone(), loser);
392        cache.record_success(key.clone(), winner);
393        let entry = cache.get(&key).expect("entry present");
394        assert_eq!(
395            entry.pipeline.name, "WINNER",
396            "post-success the winning pipeline must be cached"
397        );
398        assert_eq!(entry.successes, 1);
399        assert_eq!(entry.attempts, 2);
400    }
401
402    #[test]
403    fn second_record_success_overwrites_first_pipeline() {
404        // Newer better pipeline must replace the older one.
405        let mut cache = LearningCache::default();
406        let first = EvasionPipeline::new("FIRST", vec![], 1);
407        let second = EvasionPipeline::new("SECOND", vec![], 1);
408        let key = CacheKey::new("awswaf", "sql");
409        cache.record_success(key.clone(), first);
410        cache.record_success(key.clone(), second);
411        let entry = cache.get(&key).unwrap();
412        assert_eq!(entry.pipeline.name, "SECOND");
413        assert_eq!(entry.successes, 2);
414    }
415
416    /// Anti-regression: a cache file exceeding the size cap must be moved aside
417    /// and a fresh empty cache returned, not OOM-crash the process.
418    /// Also validates that the cap is enforced during the read (not via a
419    /// pre-check metadata() call that a file swap could race).
420    #[test]
421    fn oversized_cache_file_is_moved_aside_and_returns_empty_cache() {
422        use std::io::Write;
423
424        let tmp = std::env::temp_dir().join(format!(
425            "wafrift_learning_cache_oversize_{}.json",
426            std::process::id()
427        ));
428        let _ = fs::remove_file(&tmp);
429
430        // Write a file larger than the 16 MiB cap (17 MiB of spaces, which is
431        // valid UTF-8 but exceeds the limit before JSON parsing even starts).
432        {
433            let mut f = fs::File::create(&tmp).unwrap();
434            let chunk = vec![b' '; 64 * 1024];
435            for _ in 0..(17 * 1024 * 1024 / chunk.len()) {
436                f.write_all(&chunk).unwrap();
437            }
438            f.sync_all().unwrap();
439        }
440
441        // Must not panic; returns an empty cache.
442        let cache = LearningCache::open(&tmp).expect("open must succeed (not Err) for oversize");
443        assert!(
444            cache.keys().is_empty(),
445            "oversize cache must be treated as empty"
446        );
447
448        // Original path must have been moved aside (an .oversize-* sibling
449        // should now exist in the temp dir).
450        assert!(
451            !tmp.exists(),
452            "oversize file must be moved aside, not left at the original path"
453        );
454
455        // Cleanup any oversize-* sibling.
456        if let Ok(entries) = fs::read_dir(tmp.parent().unwrap_or(std::path::Path::new("."))) {
457            for entry in entries.flatten() {
458                let name = entry.file_name();
459                let name_str = name.to_string_lossy();
460                if name_str.starts_with(&format!(
461                    "wafrift_learning_cache_oversize_{}.json.oversize",
462                    std::process::id()
463                )) {
464                    let _ = fs::remove_file(entry.path());
465                }
466            }
467        }
468    }
469}