use std::collections::HashMap;
use std::path::{Path, PathBuf};
use tracing::{info, warn};
use crate::ProxyState;
pub const MAX_GENE_BANK_BYTES: u64 = 64 * 1024 * 1024;
pub const MAX_RESTORED_HOSTS: usize = 10_000;
pub use wafrift_types::gene_bank_io::{PersistedGeneBank, PersistedHostState};
#[must_use]
pub fn default_gene_bank_path(supplied: &str) -> Option<PathBuf> {
if supplied.is_empty() {
let home = std::env::var_os("HOME").or_else(|| std::env::var_os("USERPROFILE"))?;
let p = PathBuf::from(home).join(".wafrift").join("gene-bank.json");
Some(p)
} else if supplied == "off" || supplied == "-" {
None
} else {
Some(PathBuf::from(supplied))
}
}
pub fn load(path: &Path) -> PersistedGeneBank {
match std::fs::metadata(path) {
Ok(meta) if meta.len() > MAX_GENE_BANK_BYTES => {
warn!(
path = %path.display(),
size = meta.len(),
cap = MAX_GENE_BANK_BYTES,
"gene bank file exceeds {MAX_GENE_BANK_BYTES}-byte cap; starting fresh. \
Fix: this file is far larger than any real bank — inspect for corruption \
or remove it. If a legitimate operator workflow needs more, raise \
MAX_GENE_BANK_BYTES rather than disabling the guard."
);
return PersistedGeneBank::default();
}
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
}
Err(_) => {
}
}
match std::fs::read_to_string(path) {
Ok(s) => {
if s.trim().is_empty() {
info!(path = %path.display(), "gene bank file is empty; starting fresh");
return PersistedGeneBank::default();
}
match serde_json::from_str::<serde_json::Value>(&s) {
Ok(serde_json::Value::Object(map)) => {
let has_numeric_schema =
map.get("schema").is_some_and(serde_json::Value::is_u64);
if has_numeric_schema {
let value = serde_json::Value::Object(map);
match serde_json::from_value::<PersistedGeneBank>(value) {
Ok(bank) => {
if bank.schema > 1 {
warn!(
path = %path.display(),
schema = bank.schema,
"gene bank has newer schema than expected (1); data may be incomplete"
);
}
bank
}
Err(e) => {
warn!(
path = %path.display(),
error = %e,
"gene bank malformed (schema-tagged object failed strict parse); starting fresh. Fix: inspect the file and fix the JSON syntax, or delete it to start over."
);
PersistedGeneBank::default()
}
}
} else {
let value = serde_json::Value::Object(map);
match serde_json::from_value::<HashMap<String, PersistedHostState>>(value) {
Ok(flat) => {
warn!(
path = %path.display(),
"loaded v0.1 gene-bank (flat HashMap); migrating to schema 1"
);
PersistedGeneBank {
schema: 1,
hosts: flat,
}
}
Err(e) => {
warn!(
path = %path.display(),
error = %e,
"gene bank malformed (v0.1 flat HashMap failed parse); starting fresh."
);
PersistedGeneBank::default()
}
}
}
}
Ok(_) => {
warn!(
path = %path.display(),
"gene bank malformed (top-level JSON is not an object); starting fresh."
);
PersistedGeneBank::default()
}
Err(e) => {
warn!(
path = %path.display(),
error = %e,
"gene bank malformed (invalid JSON); starting fresh. Fix: inspect the file and fix the JSON syntax, or delete it to start over."
);
PersistedGeneBank::default()
}
}
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
info!(path = %path.display(), "gene bank not found; starting fresh");
PersistedGeneBank::default()
}
Err(e) => {
warn!(
path = %path.display(),
error = %e,
"gene bank unreadable; starting fresh. Fix: check file permissions."
);
PersistedGeneBank::default()
}
}
}
pub fn save(state: &ProxyState, path: &Path) -> std::io::Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let mut bank = PersistedGeneBank {
schema: 1,
hosts: HashMap::new(),
};
for (host, hs) in &state.hosts {
if hs.proven_winners.is_empty()
&& hs.blocklisted.is_empty()
&& hs.waf_name.is_none()
&& hs.blocks == 0
{
continue; }
bank.hosts.insert(
host.clone(),
PersistedHostState {
proven_winners: hs.proven_winners.clone(),
blocklisted: hs.blocklisted.clone(),
waf_name: hs.waf_name.clone(),
},
);
}
let json = serde_json::to_string_pretty(&bank)?;
wafrift_types::loaders::write_atomic(path, json.as_bytes())?;
Ok(())
}
pub fn restore(state: &mut ProxyState, bank: PersistedGeneBank) -> usize {
let mut restored = 0usize;
let mut fifo_seen: std::collections::HashSet<String> =
state.host_fifo.iter().cloned().collect();
for (host, persisted) in bank.hosts {
if !state.hosts.contains_key(&host) && state.hosts.len() >= MAX_RESTORED_HOSTS {
continue;
}
let hs = state.hosts.entry(host.clone()).or_default();
if !persisted.proven_winners.is_empty() {
hs.proven_winners = persisted.proven_winners;
hs.discovery_complete = true;
restored += 1;
}
if !persisted.blocklisted.is_empty() {
hs.blocklisted = persisted.blocklisted;
}
if persisted.waf_name.is_some() {
hs.waf_name = persisted.waf_name;
hs.waf_confirmed = true;
}
if fifo_seen.insert(host.clone()) {
state.host_fifo.push_back(host);
}
}
while state.hosts.len() > MAX_RESTORED_HOSTS {
if let Some(oldest) = state.host_fifo.pop_front() {
state.hosts.remove(&oldest);
} else {
break;
}
}
restored
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_path_off_returns_none() {
assert_eq!(default_gene_bank_path("off"), None);
assert_eq!(default_gene_bank_path("-"), None);
}
#[test]
fn default_path_explicit_returns_pathbuf() {
let p = default_gene_bank_path("/tmp/custom.json").expect("explicit ok");
assert_eq!(p, PathBuf::from("/tmp/custom.json"));
}
#[test]
fn load_missing_file_returns_empty_bank() {
let path = std::env::temp_dir().join(format!(
"wafrift-genebank-load-missing-{}",
std::process::id()
));
let _ = std::fs::remove_file(&path);
let bank = load(&path);
assert_eq!(bank.schema, 0); assert!(bank.hosts.is_empty());
}
#[test]
fn load_empty_file_returns_empty_bank() {
let path = std::env::temp_dir().join(format!(
"wafrift-genebank-load-empty-{}",
std::process::id()
));
std::fs::write(&path, "").unwrap();
let bank = load(&path);
assert!(bank.hosts.is_empty());
let _ = std::fs::remove_file(&path);
}
#[test]
fn load_v01_flat_format_migrates_to_schema_1() {
let path =
std::env::temp_dir().join(format!("wafrift-genebank-load-v01-{}", std::process::id()));
let legacy = r#"{
"example.com": {
"proven_winners": ["encoding::Double"],
"blocklisted": [],
"waf_name": "Cloudflare"
}
}"#;
std::fs::write(&path, legacy).unwrap();
let bank = load(&path);
assert_eq!(bank.schema, 1);
let host = bank.hosts.get("example.com").expect("example.com migrated");
assert_eq!(host.proven_winners, vec!["encoding::Double".to_string()]);
assert_eq!(host.waf_name.as_deref(), Some("Cloudflare"));
let _ = std::fs::remove_file(&path);
}
#[test]
fn load_v01_host_literally_named_schema_is_not_misread_as_v1_tag() {
let path = std::env::temp_dir().join(format!(
"wafrift-genebank-load-host-named-schema-{}",
std::process::id()
));
let legacy = r#"{
"schema": {
"proven_winners": ["encoding::Hex"],
"blocklisted": [],
"waf_name": "AWS"
}
}"#;
std::fs::write(&path, legacy).unwrap();
let bank = load(&path);
assert_eq!(
bank.schema, 1,
"must migrate to schema 1, not treat 'schema' as a version tag"
);
let host = bank
.hosts
.get("schema")
.expect("host named 'schema' must survive migration");
assert_eq!(host.proven_winners, vec!["encoding::Hex".to_string()]);
assert_eq!(host.waf_name.as_deref(), Some("AWS"));
let _ = std::fs::remove_file(&path);
}
#[test]
fn load_malformed_json_returns_empty_bank_does_not_panic() {
let path = std::env::temp_dir().join(format!(
"wafrift-genebank-load-malformed-{}",
std::process::id()
));
std::fs::write(&path, "{ not valid json").unwrap();
let bank = load(&path);
assert!(bank.hosts.is_empty());
let _ = std::fs::remove_file(&path);
}
#[test]
fn load_oversized_file_returns_empty_bank_does_not_oom() {
let path = std::env::temp_dir().join(format!(
"wafrift-genebank-load-oversize-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
let f = std::fs::File::create(&path).unwrap();
f.set_len(MAX_GENE_BANK_BYTES + 1).unwrap();
drop(f);
let bank = load(&path);
assert_eq!(
bank.schema, 0,
"oversize file must return default empty bank, not partial parse"
);
assert!(bank.hosts.is_empty());
let _ = std::fs::remove_file(&path);
}
#[test]
fn restore_caps_hosts_during_loop_not_only_at_end() {
let mut bank = PersistedGeneBank {
schema: 1,
hosts: HashMap::new(),
};
for i in 0..(MAX_RESTORED_HOSTS + 50) {
bank.hosts.insert(
format!("h{i}.example"),
PersistedHostState {
proven_winners: vec!["url_encode".into()],
blocklisted: vec![],
waf_name: None,
},
);
}
let mut state = ProxyState::default();
let restored = restore(&mut state, bank);
assert!(
state.hosts.len() <= MAX_RESTORED_HOSTS,
"restore must never leave state.hosts above the cap (saw {})",
state.hosts.len()
);
assert!(
restored <= MAX_RESTORED_HOSTS,
"restore must not report more entries than the cap"
);
}
}