minkv 0.3.0

一个轻量级持久化键值存储,支持内存和文件后端,提供 CLI 和 TCP 服务器
Documentation
//! 服务器并发集成测试。
//!
//! 测试场景:
//! - 多客户端同时写入同一个键。
//! - 交替 set/get 操作的一致性。
//! - 慢速客户端不会导致死锁。
//! - SCAN 前缀扫描功能。

use std::io::{BufRead, BufReader, BufWriter, Write};
use std::net::{TcpListener, TcpStream, SocketAddr};
use std::sync::{Arc, Mutex, RwLock};
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::Duration;
use minkv::{KvStore, FileStorage};

/// 启动一个后台测试服务器,返回其监听地址、线程句柄和关闭标志。
///
/// 服务器使用随机端口、4 个工作线程,数据持久化到 `test_store.json`。
fn start_server() -> (SocketAddr, thread::JoinHandle<()>, Arc<AtomicBool>) {
    let shutdown = Arc::new(AtomicBool::new(false));
    let shutdown_clone = Arc::clone(&shutdown);

    let listener = TcpListener::bind("127.0.0.1:0").expect("无法绑定端口");
    let addr = listener.local_addr().unwrap();
    println!("[测试服务器] 启动,监听 {}", addr);

    let handle = thread::spawn(move || {
        let store = Arc::new(RwLock::new(KvStore::<FileStorage>::open("test_store.json")));
        let store_path = store.read().unwrap().path().to_path_buf();
        let (tx, rx) = std::sync::mpsc::channel::<TcpStream>();
        let rx = Arc::new(Mutex::new(rx));

        const NUM_WORKERS: usize = 4;
        let mut workers = vec![];
        for id in 0..NUM_WORKERS {
            let store_clone = Arc::clone(&store);
            let rx_clone = Arc::clone(&rx);
            let path_clone = store_path.clone();
            workers.push(thread::spawn(move || {
                loop {
                    let stream = {
                        let receiver = rx_clone.lock().unwrap();
                        match receiver.recv() {
                            Ok(s) => s,
                            Err(_) => break,
                        }
                    };
                    if let Err(e) = handle_test_connection(
                        id,
                        stream,
                        Arc::clone(&store_clone),
                        &path_clone,
                    ) {
                        eprintln!("工作线程 {} 处理错误: {}", id, e);
                    }
                }
            }));
        }

        listener.set_nonblocking(true).unwrap();
        loop {
            if shutdown_clone.load(Ordering::Relaxed) {
                println!("[测试服务器] 收到退出信号,关闭");
                break;
            }
            match listener.accept() {
                Ok((stream, _)) => {
                    if let Err(e) = tx.send(stream) {
                        eprintln!("发送连接失败: {}", e);
                        break;
                    }
                }
                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                    thread::sleep(Duration::from_millis(10));
                }
                Err(e) => {
                    eprintln!("接受连接错误: {}", e);
                    break;
                }
            }
        }
        drop(tx);
        for w in workers {
            let _ = w.join();
        }
        println!("[测试服务器] 停止");
    });

    (addr, handle, shutdown)
}

/// 处理测试服务器中的单个客户端连接,逻辑与正式服务器一致。
fn handle_test_connection(
    _worker_id: usize,
    mut stream: TcpStream,
    store: Arc<RwLock<KvStore<FileStorage>>>,
    store_path: &std::path::Path,
) -> std::io::Result<()> {
    let _peer_addr = stream.peer_addr()?;
    // 先创建 writer 避免借用冲突
    let mut writer = BufWriter::new(stream.try_clone()?);
    let mut reader = BufReader::new(&mut stream);
    let mut line = String::new();

    loop {
        line.clear();
        let bytes_read = reader.read_line(&mut line)?;
        if bytes_read == 0 {
            break;
        }
        let trimmed = line.trim();
        if trimmed.is_empty() {
            continue;
        }
        let parts: Vec<&str> = trimmed.split_whitespace().collect();
        if parts.is_empty() {
            continue;
        }

        let response = match parts[0].to_lowercase().as_str() {
            "get" => {
                if parts.len() != 2 {
                    "错误:用法 GET <key>".to_string()
                } else {
                    let s = store.read().unwrap();
                    match s.get(parts[1]) {
                        Some(v) => v.to_owned(),
                        None => format!("错误:键 '{}' 不存在", parts[1]),
                    }
                }
            }
            "set" => {
                if parts.len() != 3 {
                    "错误:用法 SET <key> <value>".to_string()
                } else {
                    let key = parts[1].to_string();
                    let value = parts[2].to_string();
                    let snapshot = {
                        let mut s = store.write().unwrap();
                        s.set(key, value);
                        s.clone_data()
                    };
                    match FileStorage::save_from_snapshot(&snapshot, store_path) {
                        Ok(()) => "OK".to_string(),
                        Err(e) => format!("错误:保存数据失败 - {}", e),
                    }
                }
            }
            "remove" => {
                if parts.len() != 2 {
                    "错误:用法 REMOVE <key>".to_string()
                } else {
                    let key = parts[1];
                    let (removed, snapshot) = {
                        let mut s = store.write().unwrap();
                        let removed = s.remove(key).is_some();
                        if removed {
                            (true, Some(s.clone_data()))
                        } else {
                            (false, None)
                        }
                    };
                    if removed {
                        let data = snapshot.unwrap();
                        match FileStorage::save_from_snapshot(&data, store_path) {
                            Ok(()) => "OK".to_string(),
                            Err(e) => format!("错误:保存数据失败 - {}", e),
                        }
                    } else {
                        format!("错误:键 '{}' 不存在", key)
                    }
                }
            }
            "scan" => {
                if parts.len() != 2 {
                    "错误:用法 SCAN <prefix>".to_string()
                } else {
                    let prefix = parts[1];
                    let items: Vec<(String, String)> = store
                        .read()
                        .unwrap()
                        .scan(prefix)
                        .collect();
                    for (k, v) in items {
                        writeln!(writer, "{} {}", k, v)?;
                    }
                    // 写入结束标记,一次性刷新
                    writeln!(writer, "OK")?;
                    writer.flush()?;
                    continue;
                }
            }
            _ => format!("错误:未知命令 '{}'", parts[0]),
        };

        writeln!(writer, "{}", response)?;
        writer.flush()?;
    }

    Ok(())
}

/// 向测试服务器发送单行命令,并返回响应的第一行(去掉尾部换行)。
fn send_command(addr: SocketAddr, cmd: &str) -> String {
    let mut stream = TcpStream::connect(addr).unwrap();
    writeln!(stream, "{}", cmd).unwrap();
    stream.flush().unwrap();
    let mut reader = BufReader::new(stream);
    let mut resp = String::new();
    reader.read_line(&mut resp).unwrap();
    resp.trim().to_string()
}

/// 向测试服务器发送 `scan` 命令,读取多行结果直到遇到 `OK`,
/// 返回除 `OK` 之外的所有行。
fn send_scan_command(addr: SocketAddr, prefix: &str) -> Vec<String> {
    let mut stream = TcpStream::connect(addr).unwrap();
    writeln!(stream, "scan {}", prefix).unwrap();
    stream.flush().unwrap();
    let reader = BufReader::new(stream);
    let mut lines = Vec::new();
    for line_res in reader.lines() {
        let line = line_res.unwrap();
        if line.trim() == "OK" {
            break;
        }
        lines.push(line);
    }
    lines
}

/// 测试多个客户端并发 `set` 同一个键,最终 `get` 应得到某一个值。
#[test]
fn test_concurrent_set_same_key() {
    let (addr, server_handle, shutdown) = start_server();
    thread::sleep(Duration::from_millis(100));

    const NUM_CLIENTS: usize = 20;
    let mut client_handles = vec![];

    for i in 0..NUM_CLIENTS {
        let addr = addr;
        client_handles.push(thread::spawn(move || {
            let value = format!("val_{}", i);
            send_command(addr, &format!("set x {}", value));
        }));
    }
    for h in client_handles {
        h.join().unwrap();
    }

    let final_value = send_command(addr, "get x");
    assert!(final_value.starts_with("val_") || final_value == "错误:键 'x' 不存在");
    assert_ne!(final_value, "错误:键 'x' 不存在");

    shutdown.store(true, Ordering::Relaxed);
    server_handle.join().unwrap();
}

/// 测试并发 get/set 交替执行的一致性,最终值不应为空或包含错误。
#[test]
fn test_concurrent_get_set_consistency() {
    let (addr, server_handle, shutdown) = start_server();
    thread::sleep(Duration::from_millis(100));

    const NUM_CLIENTS: usize = 10;
    const OPS_PER_CLIENT: usize = 20;
    let mut client_handles = vec![];

    for i in 0..NUM_CLIENTS {
        let addr = addr;
        client_handles.push(thread::spawn(move || {
            for j in 0..OPS_PER_CLIENT {
                let value = format!("{}_{}", i, j);
                send_command(addr, &format!("set counter {}", value));
                let _ = send_command(addr, "get counter");
            }
        }));
    }
    for h in client_handles {
        h.join().unwrap();
    }

    let final_value = send_command(addr, "get counter");
    assert!(!final_value.is_empty());
    assert!(!final_value.contains("错误"));

    shutdown.store(true, Ordering::Relaxed);
    server_handle.join().unwrap();
}

/// 测试慢速客户端不会造成死锁:一个客户端长时间持有连接,其他客户端仍能正常操作。
#[test]
fn test_slow_client_no_deadlock() {
    let (addr, server_handle, shutdown) = start_server();
    thread::sleep(Duration::from_millis(100));

    let slow_handle = thread::spawn(move || {
        let mut stream = TcpStream::connect(addr).unwrap();
        writeln!(stream, "set slow_key slow_value").unwrap();
        stream.flush().unwrap();
        let mut set_resp = String::new();
        BufReader::new(&stream).read_line(&mut set_resp).unwrap();
        thread::sleep(Duration::from_secs(2));
        writeln!(stream, "get slow_key").unwrap();
        stream.flush().unwrap();
        let mut resp = String::new();
        BufReader::new(&stream).read_line(&mut resp).unwrap();
        assert_eq!(resp.trim(), "slow_value");
    });

    let fast_handle = thread::spawn(move || {
        for i in 0..100 {
            send_command(addr, &format!("set fast_{} value", i));
            send_command(addr, &format!("get fast_{}", i));
        }
    });

    slow_handle.join().unwrap();
    fast_handle.join().unwrap();

    shutdown.store(true, Ordering::Relaxed);
    server_handle.join().unwrap();
}

/// 测试 `SCAN` 命令的前缀扫描功能。
#[test]
fn test_scan_prefix() {
    let (addr, server_handle, shutdown) = start_server();
    thread::sleep(Duration::from_millis(100));

    send_command(addr, "set fruit_apple red");
    send_command(addr, "set fruit_banana yellow");
    send_command(addr, "set veg_carrot orange");

    let lines = send_scan_command(addr, "fruit_");
    assert_eq!(lines.len(), 2);
    assert!(lines[0].contains("fruit_apple") || lines[1].contains("fruit_apple"));
    assert!(lines[0].contains("fruit_banana") || lines[1].contains("fruit_banana"));

    shutdown.store(true, Ordering::Relaxed);
    server_handle.join().unwrap();
}