#![allow(non_snake_case)]
use std::sync::Arc;
use triviumdb::Database;
const DIM: usize = 4;
fn tmp_db(name: &str) -> String {
let dir = std::env::temp_dir().join("triviumdb_test");
std::fs::create_dir_all(&dir).ok();
dir.join(format!("concurrent_{}", name))
.to_string_lossy()
.to_string()
}
fn cleanup(path: &str) {
for ext in &["", ".wal", ".vec", ".lock", ".flush_ok"] {
std::fs::remove_file(format!("{}{}", path, ext)).ok();
}
}
fn seed_db(path: &str, count: usize) -> Database<f32> {
let mut db = Database::<f32>::open(path, DIM).unwrap();
for i in 0..count {
db.insert(
&[i as f32, (i as f32).sin(), (i as f32).cos(), 1.0],
serde_json::json!({"idx": i, "tag": format!("node_{}", i)}),
)
.unwrap();
}
db.flush().unwrap();
db
}
#[test]
fn 并发_8线程同时只读查询_不恐慌不矛盾() {
let path = tmp_db("multiread");
cleanup(&path);
let db = seed_db(&path, 100);
let db = Arc::new(db);
let mut handles = vec![];
for thread_id in 0..8 {
let db = Arc::clone(&db);
let handle = std::thread::spawn(move || {
for round in 0..200 {
let count = db.node_count();
assert!(
count > 0,
"线程 {} 轮 {}: node_count 返回 0",
thread_id,
round
);
let ids = db.all_node_ids();
assert_eq!(
ids.len(),
count,
"线程 {} 轮 {}: all_node_ids.len()={} != node_count={}",
thread_id,
round,
ids.len(),
count
);
for &id in &ids {
let payload = db.get_payload(id);
assert!(
payload.is_some(),
"线程 {} 轮 {}: id {} 在 all_node_ids 中但 get_payload 返回 None",
thread_id,
round,
id
);
}
let query = [(thread_id as f32) * 0.1, 0.0, 0.0, 1.0];
let result = db.search(&query, 5, 0, 0.0);
assert!(
result.is_ok(),
"线程 {} 轮 {}: search 返回 Err: {:?}",
thread_id,
round,
result.err()
);
}
});
handles.push(handle);
}
let mut panic_count = 0;
for (i, h) in handles.into_iter().enumerate() {
if let Err(e) = h.join() {
panic_count += 1;
eprintln!(" ❌ 线程 {} panic: {:?}", i, e);
}
}
assert_eq!(
panic_count, 0,
"{}/8 个只读线程发生了 panic!Database 的并发读不安全",
panic_count
);
eprintln!(" ✅ 8 线程并发只读: 8×200=1600 轮操作,零 panic,数据一致");
drop(db);
cleanup(&path);
}
#[test]
fn 并发_1写4读_写入期间读取一致性() {
use std::sync::Mutex;
let path = tmp_db("write_read_mix");
cleanup(&path);
let db = seed_db(&path, 50);
let db = Arc::new(Mutex::new(db));
let stop = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut handles = vec![];
{
let db = Arc::clone(&db);
let stop = Arc::clone(&stop);
handles.push(std::thread::spawn(move || {
for i in 100..300u32 {
if stop.load(std::sync::atomic::Ordering::Relaxed) {
break;
}
let mut db = db.lock().unwrap();
let result = db.insert(
&[i as f32, 0.0, 0.0, 0.0],
serde_json::json!({"writer": true, "seq": i}),
);
assert!(result.is_ok(), "写线程 insert 失败: {:?}", result.err());
}
}));
}
for reader_id in 0..4 {
let db = Arc::clone(&db);
let stop = Arc::clone(&stop);
handles.push(std::thread::spawn(move || {
let mut prev_count = 0usize;
let mut rounds = 0;
while !stop.load(std::sync::atomic::Ordering::Relaxed) && rounds < 500 {
let db = db.lock().unwrap();
let count = db.node_count();
assert!(
count >= prev_count,
"读线程 {} 轮 {}: 节点数从 {} 降到了 {}!数据一致性被破坏",
reader_id,
rounds,
prev_count,
count
);
prev_count = count;
let ids = db.all_node_ids();
assert_eq!(
ids.len(),
count,
"读线程 {} 轮 {}: 不一致",
reader_id,
rounds
);
drop(db);
std::thread::yield_now();
rounds += 1;
}
}));
}
let writer_handle = handles.remove(0);
writer_handle.join().unwrap();
stop.store(true, std::sync::atomic::Ordering::Relaxed);
let mut panic_count = 0;
for (i, h) in handles.into_iter().enumerate() {
if let Err(e) = h.join() {
panic_count += 1;
eprintln!(" ❌ 读线程 {} panic: {:?}", i, e);
}
}
assert_eq!(panic_count, 0, "读写混合并发测试失败");
let db = db.lock().unwrap();
assert!(
db.node_count() >= 50,
"最终节点数 {} < 初始 50",
db.node_count()
);
eprintln!(
" ✅ 1 写 4 读并发: 最终 {} 个节点,节点数单调递增,数据一致",
db.node_count()
);
drop(db);
cleanup(&path);
}
#[test]
fn 并发_编译期验证_Database是Send和Sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<Database<f32>>();
assert_sync::<Database<f32>>();
eprintln!(" ✅ Database<f32> 实现了 Send + Sync(编译期验证通过)");
}
#[test]
fn 并发_8线程TQL查询_结果互不干扰() {
let path = tmp_db("tql_concurrent");
cleanup(&path);
let mut db = Database::<f32>::open(&path, DIM).unwrap();
let ids = {
let mut tx = db.begin_tx();
tx.insert(
&[1.0, 0.0, 0.0, 0.0],
serde_json::json!({"name": "Alice", "type": "person"}),
);
tx.insert(
&[0.0, 1.0, 0.0, 0.0],
serde_json::json!({"name": "Bob", "type": "person"}),
);
tx.insert(
&[0.0, 0.0, 1.0, 0.0],
serde_json::json!({"name": "Charlie", "type": "person"}),
);
tx.insert(
&[0.0, 0.0, 0.0, 1.0],
serde_json::json!({"name": "Project X", "type": "project"}),
);
tx.commit().unwrap()
};
db.link(ids[0], ids[1], "knows", 1.0).unwrap();
db.link(ids[1], ids[2], "knows", 1.0).unwrap();
db.link(ids[0], ids[3], "works_on", 1.0).unwrap();
let db = Arc::new(db);
let queries = vec![
r#"FIND {"type": "person"} RETURN *"#,
r#"FIND {"name": "Alice"} RETURN *"#,
r#"FIND {"type": "project"} RETURN *"#,
"SEARCH [1.0, 0.0, 0.0, 0.0] TOP 2 RETURN *",
"SEARCH [0.0, 1.0, 0.0, 0.0] TOP 3 RETURN *",
r#"FIND {"type": "person"} LIMIT 2 RETURN *"#,
r#"FIND {"name": "Bob"} RETURN *"#,
"SEARCH [0.0, 0.0, 1.0, 0.0] TOP 1 RETURN *",
];
let mut handles = vec![];
for (thread_id, query) in queries.into_iter().enumerate() {
let db = Arc::clone(&db);
let query = query.to_string();
handles.push(std::thread::spawn(move || {
for round in 0..100 {
let result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| db.tql(&query)));
assert!(
result.is_ok(),
"线程 {} 轮 {} TQL panic: query={:?}",
thread_id,
round,
query
);
}
}));
}
let mut panic_count = 0;
for (i, h) in handles.into_iter().enumerate() {
if let Err(e) = h.join() {
panic_count += 1;
eprintln!(" ❌ TQL 线程 {} panic: {:?}", i, e);
}
}
assert_eq!(panic_count, 0, "{}/8 个 TQL 并发线程 panic!", panic_count);
eprintln!(" ✅ 8 线程 TQL 并发: 8×100=800 轮查询,零 panic");
drop(db);
cleanup(&path);
}