#![allow(non_snake_case)]
use std::io::Cursor;
use triviumdb::Database;
use triviumdb::storage::wal::{Wal, WalEntry};
const DIM: usize = 4;
fn tmp_db(name: &str) -> String {
let dir = std::env::temp_dir().join("triviumdb_test");
std::fs::create_dir_all(&dir).ok();
dir.join(format!("midwrite_{}", name))
.to_string_lossy()
.to_string()
}
fn cleanup(path: &str) {
for ext in &["", ".wal", ".vec", ".lock", ".flush_ok"] {
std::fs::remove_file(format!("{}{}", path, ext)).ok();
}
}
fn build_valid_wal_bytes() -> Vec<u8> {
let entries: Vec<WalEntry<f32>> = vec![
WalEntry::Insert {
id: 1,
vector: vec![1.0, 0.0, 0.0, 0.0],
payload: r#"{"name":"alice"}"#.to_string(),
},
WalEntry::Insert {
id: 2,
vector: vec![0.0, 1.0, 0.0, 0.0],
payload: r#"{"name":"bob"}"#.to_string(),
},
WalEntry::Link {
src: 1,
dst: 2,
label: "knows".to_string(),
weight: 1.0,
},
WalEntry::Insert {
id: 3,
vector: vec![0.0, 0.0, 1.0, 0.0],
payload: r#"{"name":"charlie"}"#.to_string(),
},
WalEntry::UpdatePayload {
id: 1,
payload: r#"{"name":"alice","age":30}"#.to_string(),
},
];
let mut buf = Vec::new();
for entry in &entries {
let data = bincode::serialize(entry).unwrap();
let crc = crc32fast::hash(&data);
let len = data.len() as u32;
buf.extend_from_slice(&len.to_le_bytes());
buf.extend_from_slice(&data);
buf.extend_from_slice(&crc.to_le_bytes());
}
buf
}
#[test]
fn WAL_逐字节截断_覆盖每个断电时间点() {
let full_wal = build_valid_wal_bytes();
let full_len = full_wal.len();
let (full_entries, _) = Wal::read_entries_from_reader::<f32>(Cursor::new(&full_wal)).unwrap();
assert_eq!(
full_entries.len(),
5,
"完整 WAL 应包含 5 条记录,实际 {}",
full_entries.len()
);
for cut_at in 0..full_len {
let truncated = &full_wal[..cut_at];
let result = std::panic::catch_unwind(|| {
Wal::read_entries_from_reader::<f32>(Cursor::new(truncated))
});
assert!(
result.is_ok(),
"WAL 在偏移 {}/{} 处截断后 read_entries_from_reader panic 了!",
cut_at,
full_len
);
let (entries, _) = result.unwrap().unwrap();
assert!(
entries.len() <= 5,
"截断在 {}/{} 处,条目数 {} 超过了原始 5 条",
cut_at,
full_len,
entries.len()
);
if cut_at > 0 {
let prev_truncated = &full_wal[..cut_at - 1];
let (prev_entries, _) =
Wal::read_entries_from_reader::<f32>(Cursor::new(prev_truncated)).unwrap();
assert!(
entries.len() >= prev_entries.len(),
"单调性违反:截断在 {} 处恢复 {} 条,但在 {} 处恢复了 {} 条",
cut_at,
entries.len(),
cut_at - 1,
prev_entries.len()
);
}
}
eprintln!(
" ✅ WAL 逐字节截断: 在 {}/{} 个断电点上均安全恢复,零 panic",
full_len, full_len
);
}
#[test]
fn WAL_事务边界截断_未提交事务必须被丢弃() {
let mut buf = Vec::new();
for id in 1..=2u64 {
let entry = WalEntry::Insert::<f32> {
id,
vector: vec![id as f32, 0.0, 0.0, 0.0],
payload: format!(r#"{{"id":{}}}"#, id),
};
let data = bincode::serialize(&entry).unwrap();
let crc = crc32fast::hash(&data);
buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
buf.extend_from_slice(&data);
buf.extend_from_slice(&crc.to_le_bytes());
}
let committed_boundary = buf.len();
let tx_entries: Vec<WalEntry<f32>> = vec![
WalEntry::TxBegin { tx_id: 42 },
WalEntry::Insert {
id: 3,
vector: vec![3.0, 0.0, 0.0, 0.0],
payload: r#"{"id":3}"#.to_string(),
},
WalEntry::Insert {
id: 4,
vector: vec![4.0, 0.0, 0.0, 0.0],
payload: r#"{"id":4}"#.to_string(),
},
WalEntry::Insert {
id: 5,
vector: vec![5.0, 0.0, 0.0, 0.0],
payload: r#"{"id":5}"#.to_string(),
},
WalEntry::TxCommit { tx_id: 42 },
];
for entry in &tx_entries {
let data = bincode::serialize(entry).unwrap();
let crc = crc32fast::hash(&data);
buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
buf.extend_from_slice(&data);
buf.extend_from_slice(&crc.to_le_bytes());
}
let full_len = buf.len();
let (full, _) = Wal::read_entries_from_reader::<f32>(Cursor::new(&buf)).unwrap();
assert_eq!(full.len(), 5, "完整事务应恢复 5 条记录");
for cut_at in (committed_boundary + 1)..full_len {
let truncated = &buf[..cut_at];
let (entries, _) = Wal::read_entries_from_reader::<f32>(Cursor::new(truncated)).unwrap();
assert!(
entries.len() == 2 || entries.len() == 5,
"截断在 {}/{} 处(事务区域内),恢复了 {} 条记录。\
应该只有 2(丢弃未提交事务)或 5(事务完整提交)",
cut_at,
full_len,
entries.len()
);
}
eprintln!(
" ✅ 事务边界截断: 在事务区域 {} 个断电点上,未提交事务均被正确丢弃",
full_len - committed_boundary - 1
);
}
#[test]
fn WAL_单帧字段边界截断_len_data_crc各字段() {
let entry = WalEntry::Insert::<f32> {
id: 1,
vector: vec![1.0, 2.0, 3.0, 4.0],
payload: r#"{"key":"value","nested":{"a":1}}"#.to_string(),
};
let data = bincode::serialize(&entry).unwrap();
let crc = crc32fast::hash(&data);
let len = data.len() as u32;
let mut frame = Vec::new();
frame.extend_from_slice(&len.to_le_bytes()); frame.extend_from_slice(&data); frame.extend_from_slice(&crc.to_le_bytes());
let total = frame.len();
let data_start = 4;
let crc_start = 4 + data.len();
eprintln!(
" 📊 帧结构: total={} bytes, len=[0..4), data=[4..{}), crc=[{}..{})",
total, crc_start, crc_start, total
);
for cut in 1..4 {
let truncated = &frame[..cut];
let (entries, _) = Wal::read_entries_from_reader::<f32>(Cursor::new(truncated)).unwrap();
assert_eq!(
entries.len(),
0,
"len 字段内截断(offset={})应返回 0 条记录",
cut
);
}
for cut in (data_start + 1)..crc_start {
let truncated = &frame[..cut];
let (entries, _) = Wal::read_entries_from_reader::<f32>(Cursor::new(truncated)).unwrap();
assert_eq!(
entries.len(),
0,
"data 字段内截断(offset={})应返回 0 条记录",
cut
);
}
for cut in (crc_start + 1)..total {
let truncated = &frame[..cut];
let (entries, _) = Wal::read_entries_from_reader::<f32>(Cursor::new(truncated)).unwrap();
assert_eq!(
entries.len(),
0,
"CRC 字段内截断(offset={})应返回 0 条记录",
cut
);
}
let (entries, _) = Wal::read_entries_from_reader::<f32>(Cursor::new(&frame)).unwrap();
assert_eq!(entries.len(), 1, "完整帧应恢复 1 条记录");
eprintln!(
" ✅ 单帧字段边界: len/data/crc 每个截断点均安全(共 {} 个点)",
total - 1
);
}
#[test]
fn WAL_解析器_10000轮随机字节绝不panic() {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let rounds = 10_000;
let mut panic_count = 0;
for seed in 0..rounds {
let mut hasher = DefaultHasher::new();
seed.hash(&mut hasher);
let h = hasher.finish();
let len = (h % 1024) as usize;
let garbage: Vec<u8> = (0..len)
.map(|i| {
let mut h2 = DefaultHasher::new();
(seed, i).hash(&mut h2);
h2.finish() as u8
})
.collect();
let result = std::panic::catch_unwind(|| {
Wal::read_entries_from_reader::<f32>(Cursor::new(&garbage))
});
if result.is_err() {
panic_count += 1;
eprintln!(
" ❌ Panic at seed={}, len={}, bytes={:?}",
seed,
len,
&garbage[..garbage.len().min(32)]
);
}
}
assert_eq!(
panic_count, 0,
"{}/{} 轮随机字节触发了 WAL 解析器 panic!",
panic_count, rounds
);
eprintln!(
" ✅ WAL 解析器: {}/{} 轮随机字节,零 panic",
rounds, rounds
);
}
#[test]
fn WAL_有效前缀加垃圾尾部_前缀条目必须完整恢复() {
let valid_wal = build_valid_wal_bytes();
let (baseline, _) = Wal::read_entries_from_reader::<f32>(Cursor::new(&valid_wal)).unwrap();
let baseline_count = baseline.len();
let garbage_patterns: Vec<Vec<u8>> = vec![
vec![0xFF; 64], vec![0x00; 64], vec![0xDE, 0xAD, 0xBE, 0xEF], b"INVALID WAL ENTRY GARBAGE DATA HERE!!!".to_vec(), {
let mut v = Vec::new();
v.extend_from_slice(&(0xFFFFFFFFu32).to_le_bytes());
v
},
{
let mut v = Vec::new();
v.extend_from_slice(&(100u32).to_le_bytes());
v.extend_from_slice(&[0xAA; 50]); v
},
{
let entry = WalEntry::Insert::<f32> {
id: 999,
vector: vec![9.0, 9.0, 9.0, 9.0],
payload: r#"{"poison":true}"#.to_string(),
};
let data = bincode::serialize(&entry).unwrap();
let bad_crc = 0xDEADBEEFu32;
let mut v = Vec::new();
v.extend_from_slice(&(data.len() as u32).to_le_bytes());
v.extend_from_slice(&data);
v.extend_from_slice(&bad_crc.to_le_bytes());
v
},
];
for (i, garbage) in garbage_patterns.iter().enumerate() {
let mut corrupted = valid_wal.clone();
corrupted.extend_from_slice(garbage);
let result = std::panic::catch_unwind(|| {
Wal::read_entries_from_reader::<f32>(Cursor::new(&corrupted))
});
assert!(result.is_ok(), "垃圾 pattern #{} 导致解析器 panic", i);
let (entries, _) = result.unwrap().unwrap();
assert_eq!(
entries.len(),
baseline_count,
"垃圾 pattern #{}: 有效前缀的 {} 条记录应完整恢复,实际 {}",
i,
baseline_count,
entries.len()
);
}
eprintln!(
" ✅ 有效前缀+垃圾尾部: {} 种 pattern 下 {} 条前缀记录均完整恢复",
garbage_patterns.len(),
baseline_count
);
}
#[test]
fn WAL_端到端_文件物理截断后Database重新加载() {
let path = tmp_db("e2e_truncate");
cleanup(&path);
{
let mut db = Database::<f32>::open(&path, DIM).unwrap();
for i in 0..5u32 {
db.insert(
&[i as f32, 0.0, 0.0, 0.0],
serde_json::json!({"phase": "flushed", "seq": i}),
)
.unwrap();
}
db.flush().unwrap();
for i in 5..8u32 {
db.insert(
&[i as f32, 0.0, 0.0, 0.0],
serde_json::json!({"phase": "wal_only", "seq": i}),
)
.unwrap();
}
}
let wal_path = format!("{}.wal", path);
let wal_bytes = std::fs::read(&wal_path).unwrap_or_default();
let wal_len = wal_bytes.len();
if wal_len == 0 {
eprintln!(" ⚠️ WAL 文件为空(可能 Drop 时未写入),跳过截断测试");
cleanup(&path);
return;
}
let mut cut_points: Vec<usize> = (0..wal_len).step_by(10).collect();
for offset in [
1,
2,
3,
4,
5,
wal_len / 4,
wal_len / 2,
wal_len * 3 / 4,
wal_len - 1,
] {
if offset < wal_len && !cut_points.contains(&offset) {
cut_points.push(offset);
}
}
cut_points.sort();
for &cut_at in &cut_points {
let truncated = &wal_bytes[..cut_at];
std::fs::write(&wal_path, truncated).unwrap();
std::fs::remove_file(format!("{}.flush_ok", path)).ok();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Database::<f32>::open(&path, DIM)
}));
assert!(
result.is_ok(),
"WAL 截断到 {}/{} 字节后 Database::open panic 了!",
cut_at,
wal_len
);
match result.unwrap() {
Ok(db) => {
assert!(
db.node_count() >= 5,
"WAL 截断到 {}/{} 后节点数 {} < 5(已 flush 的节点丢失了)",
cut_at,
wal_len,
db.node_count()
);
assert!(
db.node_count() <= 8,
"WAL 截断到 {}/{} 后节点数 {} > 8(凭空多出节点)",
cut_at,
wal_len,
db.node_count()
);
}
Err(e) => {
eprintln!(
" ⚠️ WAL 截断到 {}/{} 后加载失败(可接受): {}",
cut_at, wal_len, e
);
}
}
}
eprintln!(
" ✅ 端到端 WAL 截断: 在 {} 个采样截断点上均安全恢复或优雅拒绝",
cut_points.len()
);
std::fs::write(&wal_path, &wal_bytes).ok();
cleanup(&path);
}