use std::fs;
use wafrift_strategy::learning_cache::{CacheKey, LearningCache};
use wafrift_strategy::pipeline::{EvasionPipeline, EvasionStage};
use wafrift_types::Technique;
fn unique_tmp(suffix: &str) -> std::path::PathBuf {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_nanos());
std::env::temp_dir().join(format!(
"wafrift_lc_robust_{}_{}_{}.json",
std::process::id(),
nanos,
suffix
))
}
fn pipeline(name: &str) -> EvasionPipeline {
EvasionPipeline::new(
name,
vec![EvasionStage {
technique: Technique::UserAgentRotation,
context: None,
}],
1,
)
}
#[test]
fn open_does_not_crash_on_corrupt_json() {
let path = unique_tmp("corrupt");
let _ = fs::remove_file(&path);
fs::write(&path, b"{ not valid json").unwrap();
let cache = LearningCache::open(&path).expect("must not crash on corrupt JSON");
assert!(cache.keys().is_empty(), "corrupt cache must reset to empty");
let dir = path.parent().unwrap();
let stem = path.file_stem().unwrap().to_string_lossy().to_string();
let moved_aside = fs::read_dir(dir).unwrap().any(|e| {
let p = e.unwrap().path();
p.file_name()
.is_some_and(|n| n.to_string_lossy().starts_with(&stem))
&& p.extension()
.is_some_and(|e| e.to_string_lossy().starts_with("corrupt-"))
});
assert!(
moved_aside,
"corrupt file must be moved aside to <stem>.corrupt-<epoch>"
);
for e in fs::read_dir(dir).unwrap().flatten() {
let p = e.path();
if p.file_name()
.is_some_and(|n| n.to_string_lossy().starts_with(&stem))
{
let _ = fs::remove_file(p);
}
}
}
#[test]
fn open_does_not_crash_on_truncated_file() {
let path = unique_tmp("truncated");
let _ = fs::remove_file(&path);
fs::write(&path, b"{\n \"entries\": {\n \"key1\": {\n \"pip").unwrap();
let cache = LearningCache::open(&path).expect("must recover from truncated JSON");
assert!(cache.keys().is_empty());
let dir = path.parent().unwrap();
let stem = path.file_stem().unwrap().to_string_lossy().to_string();
for e in fs::read_dir(dir).unwrap().flatten() {
let p = e.path();
if p.file_name()
.is_some_and(|n| n.to_string_lossy().starts_with(&stem))
{
let _ = fs::remove_file(p);
}
}
}
#[test]
fn open_after_corrupt_file_can_save_again() {
let path = unique_tmp("recover_save");
let _ = fs::remove_file(&path);
fs::write(&path, b"\x00\x00\x00garbage\x00\x00").unwrap();
let mut cache = LearningCache::open(&path).expect("must recover from binary garbage");
cache.record_success(CacheKey::new("modsec", "xss"), pipeline("p1"));
cache.save().expect("save after corruption recovery must succeed");
let cache2 = LearningCache::open(&path).unwrap();
assert_eq!(
cache2
.get(&CacheKey::new("modsec", "xss"))
.expect("entry must persist after recovery save")
.successes,
1
);
let _ = fs::remove_file(&path);
let dir = path.parent().unwrap();
let stem = path.file_stem().unwrap().to_string_lossy().to_string();
for e in fs::read_dir(dir).unwrap().flatten() {
let p = e.path();
if p.file_name()
.is_some_and(|n| n.to_string_lossy().starts_with(&stem))
{
let _ = fs::remove_file(p);
}
}
}
#[test]
fn save_does_not_leave_partial_file_visible() {
let path = unique_tmp("atomic");
let _ = fs::remove_file(&path);
let mut cache = LearningCache::open(&path).unwrap();
cache.record_success(CacheKey::new("waf-a", "sql"), pipeline("p"));
cache.save().unwrap();
cache.record_success(CacheKey::new("waf-a", "sql"), pipeline("p"));
cache.save().unwrap();
let dir = path.parent().unwrap();
let stem = path.file_stem().unwrap().to_string_lossy().to_string();
let orphans: Vec<_> = fs::read_dir(dir)
.unwrap()
.flatten()
.filter(|e| {
let p = e.path();
let name_match = p
.file_name()
.is_some_and(|n| n.to_string_lossy().starts_with(&stem));
let is_tmp = p
.extension()
.is_some_and(|x| x.to_string_lossy().starts_with("tmp."));
name_match && is_tmp
})
.collect();
assert!(
orphans.is_empty(),
"save() must not leave orphan tmp files: {orphans:?}"
);
let cache2 = LearningCache::open(&path).expect("post-save file must reopen cleanly");
assert_eq!(
cache2.get(&CacheKey::new("waf-a", "sql")).unwrap().successes,
2
);
let _ = fs::remove_file(&path);
}
#[test]
fn save_writes_full_pretty_json_each_call() {
let path = unique_tmp("pretty");
let _ = fs::remove_file(&path);
let mut cache = LearningCache::open(&path).unwrap();
for i in 0..50 {
cache.record_success(
CacheKey::new(format!("waf-{i}"), "xss"),
pipeline(&format!("p{i}")),
);
}
cache.save().unwrap();
let bytes = fs::read(&path).unwrap();
let s = std::str::from_utf8(&bytes).expect("save must produce valid utf-8 json");
assert!(s.starts_with('{'), "pretty json must start with '{{'");
assert!(s.contains("\"entries\""), "must contain entries field");
let parsed: serde_json::Value = serde_json::from_str(s).expect("must reopen as valid json");
assert_eq!(
parsed["entries"].as_object().unwrap().len(),
50,
"must have all 50 entries"
);
let _ = fs::remove_file(&path);
}