use std::io::{BufRead, BufReader, BufWriter, Write};
use std::net::{TcpListener, TcpStream, SocketAddr};
use std::sync::{Arc, Mutex, RwLock};
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::Duration;
use minkv::{KvStore, FileStorage};
fn start_server() -> (SocketAddr, thread::JoinHandle<()>, Arc<AtomicBool>) {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = Arc::clone(&shutdown);
let listener = TcpListener::bind("127.0.0.1:0").expect("无法绑定端口");
let addr = listener.local_addr().unwrap();
println!("[测试服务器] 启动,监听 {}", addr);
let handle = thread::spawn(move || {
let store = Arc::new(RwLock::new(KvStore::<FileStorage>::open("test_store.json")));
let store_path = store.read().unwrap().path().to_path_buf();
let (tx, rx) = std::sync::mpsc::channel::<TcpStream>();
let rx = Arc::new(Mutex::new(rx));
const NUM_WORKERS: usize = 4;
let mut workers = vec![];
for id in 0..NUM_WORKERS {
let store_clone = Arc::clone(&store);
let rx_clone = Arc::clone(&rx);
let path_clone = store_path.clone();
workers.push(thread::spawn(move || {
loop {
let stream = {
let receiver = rx_clone.lock().unwrap();
match receiver.recv() {
Ok(s) => s,
Err(_) => break,
}
};
if let Err(e) = handle_test_connection(
id,
stream,
Arc::clone(&store_clone),
&path_clone,
) {
eprintln!("工作线程 {} 处理错误: {}", id, e);
}
}
}));
}
listener.set_nonblocking(true).unwrap();
loop {
if shutdown_clone.load(Ordering::Relaxed) {
println!("[测试服务器] 收到退出信号,关闭");
break;
}
match listener.accept() {
Ok((stream, _)) => {
if let Err(e) = tx.send(stream) {
eprintln!("发送连接失败: {}", e);
break;
}
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(10));
}
Err(e) => {
eprintln!("接受连接错误: {}", e);
break;
}
}
}
drop(tx);
for w in workers {
let _ = w.join();
}
println!("[测试服务器] 停止");
});
(addr, handle, shutdown)
}
fn handle_test_connection(
_worker_id: usize,
mut stream: TcpStream,
store: Arc<RwLock<KvStore<FileStorage>>>,
store_path: &std::path::Path,
) -> std::io::Result<()> {
let _peer_addr = stream.peer_addr()?;
let mut writer = BufWriter::new(stream.try_clone()?);
let mut reader = BufReader::new(&mut stream);
let mut line = String::new();
loop {
line.clear();
let bytes_read = reader.read_line(&mut line)?;
if bytes_read == 0 {
break;
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let parts: Vec<&str> = trimmed.split_whitespace().collect();
if parts.is_empty() {
continue;
}
let response = match parts[0].to_lowercase().as_str() {
"get" => {
if parts.len() != 2 {
"错误:用法 GET <key>".to_string()
} else {
let s = store.read().unwrap();
match s.get(parts[1]) {
Some(v) => v.to_owned(),
None => format!("错误:键 '{}' 不存在", parts[1]),
}
}
}
"set" => {
if parts.len() != 3 {
"错误:用法 SET <key> <value>".to_string()
} else {
let key = parts[1].to_string();
let value = parts[2].to_string();
let snapshot = {
let mut s = store.write().unwrap();
s.set(key, value);
s.clone_data()
};
match FileStorage::save_from_snapshot(&snapshot, store_path) {
Ok(()) => "OK".to_string(),
Err(e) => format!("错误:保存数据失败 - {}", e),
}
}
}
"remove" => {
if parts.len() != 2 {
"错误:用法 REMOVE <key>".to_string()
} else {
let key = parts[1];
let (removed, snapshot) = {
let mut s = store.write().unwrap();
let removed = s.remove(key).is_some();
if removed {
(true, Some(s.clone_data()))
} else {
(false, None)
}
};
if removed {
let data = snapshot.unwrap();
match FileStorage::save_from_snapshot(&data, store_path) {
Ok(()) => "OK".to_string(),
Err(e) => format!("错误:保存数据失败 - {}", e),
}
} else {
format!("错误:键 '{}' 不存在", key)
}
}
}
"scan" => {
if parts.len() != 2 {
"错误:用法 SCAN <prefix>".to_string()
} else {
let prefix = parts[1];
let items: Vec<(String, String)> = store
.read()
.unwrap()
.scan(prefix)
.collect();
for (k, v) in items {
writeln!(writer, "{} {}", k, v)?;
}
writeln!(writer, "OK")?;
writer.flush()?;
continue;
}
}
_ => format!("错误:未知命令 '{}'", parts[0]),
};
writeln!(writer, "{}", response)?;
writer.flush()?;
}
Ok(())
}
fn send_command(addr: SocketAddr, cmd: &str) -> String {
let mut stream = TcpStream::connect(addr).unwrap();
writeln!(stream, "{}", cmd).unwrap();
stream.flush().unwrap();
let mut reader = BufReader::new(stream);
let mut resp = String::new();
reader.read_line(&mut resp).unwrap();
resp.trim().to_string()
}
fn send_scan_command(addr: SocketAddr, prefix: &str) -> Vec<String> {
let mut stream = TcpStream::connect(addr).unwrap();
writeln!(stream, "scan {}", prefix).unwrap();
stream.flush().unwrap();
let reader = BufReader::new(stream);
let mut lines = Vec::new();
for line_res in reader.lines() {
let line = line_res.unwrap();
if line.trim() == "OK" {
break;
}
lines.push(line);
}
lines
}
#[test]
fn test_concurrent_set_same_key() {
let (addr, server_handle, shutdown) = start_server();
thread::sleep(Duration::from_millis(100));
const NUM_CLIENTS: usize = 20;
let mut client_handles = vec![];
for i in 0..NUM_CLIENTS {
let addr = addr;
client_handles.push(thread::spawn(move || {
let value = format!("val_{}", i);
send_command(addr, &format!("set x {}", value));
}));
}
for h in client_handles {
h.join().unwrap();
}
let final_value = send_command(addr, "get x");
assert!(final_value.starts_with("val_") || final_value == "错误:键 'x' 不存在");
assert_ne!(final_value, "错误:键 'x' 不存在");
shutdown.store(true, Ordering::Relaxed);
server_handle.join().unwrap();
}
#[test]
fn test_concurrent_get_set_consistency() {
let (addr, server_handle, shutdown) = start_server();
thread::sleep(Duration::from_millis(100));
const NUM_CLIENTS: usize = 10;
const OPS_PER_CLIENT: usize = 20;
let mut client_handles = vec![];
for i in 0..NUM_CLIENTS {
let addr = addr;
client_handles.push(thread::spawn(move || {
for j in 0..OPS_PER_CLIENT {
let value = format!("{}_{}", i, j);
send_command(addr, &format!("set counter {}", value));
let _ = send_command(addr, "get counter");
}
}));
}
for h in client_handles {
h.join().unwrap();
}
let final_value = send_command(addr, "get counter");
assert!(!final_value.is_empty());
assert!(!final_value.contains("错误"));
shutdown.store(true, Ordering::Relaxed);
server_handle.join().unwrap();
}
#[test]
fn test_slow_client_no_deadlock() {
let (addr, server_handle, shutdown) = start_server();
thread::sleep(Duration::from_millis(100));
let slow_handle = thread::spawn(move || {
let mut stream = TcpStream::connect(addr).unwrap();
writeln!(stream, "set slow_key slow_value").unwrap();
stream.flush().unwrap();
let mut set_resp = String::new();
BufReader::new(&stream).read_line(&mut set_resp).unwrap();
thread::sleep(Duration::from_secs(2));
writeln!(stream, "get slow_key").unwrap();
stream.flush().unwrap();
let mut resp = String::new();
BufReader::new(&stream).read_line(&mut resp).unwrap();
assert_eq!(resp.trim(), "slow_value");
});
let fast_handle = thread::spawn(move || {
for i in 0..100 {
send_command(addr, &format!("set fast_{} value", i));
send_command(addr, &format!("get fast_{}", i));
}
});
slow_handle.join().unwrap();
fast_handle.join().unwrap();
shutdown.store(true, Ordering::Relaxed);
server_handle.join().unwrap();
}
#[test]
fn test_scan_prefix() {
let (addr, server_handle, shutdown) = start_server();
thread::sleep(Duration::from_millis(100));
send_command(addr, "set fruit_apple red");
send_command(addr, "set fruit_banana yellow");
send_command(addr, "set veg_carrot orange");
let lines = send_scan_command(addr, "fruit_");
assert_eq!(lines.len(), 2);
assert!(lines[0].contains("fruit_apple") || lines[1].contains("fruit_apple"));
assert!(lines[0].contains("fruit_banana") || lines[1].contains("fruit_banana"));
shutdown.store(true, Ordering::Relaxed);
server_handle.join().unwrap();
}