use std::io::{BufRead, BufReader, BufWriter, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::{Arc, Mutex, RwLock};
use std::sync::mpsc::{self, Receiver};
use std::thread;
use minkv::{KvStore, FileStorage};
const SERVER_ADDR: &str = "127.0.0.1:8080";
const STORE_PATH: &str = "store.json";
const NUM_WORKERS: usize = 4;
fn main() {
let listener = TcpListener::bind(SERVER_ADDR).expect("无法绑定端口");
println!(
"服务器已启动 (线程池大小: {}),监听地址: {}",
NUM_WORKERS, SERVER_ADDR
);
println!(
"数据文件保存路径: {}",
std::env::current_dir().unwrap().join(STORE_PATH).display()
);
let store = Arc::new(RwLock::new(KvStore::<FileStorage>::open(STORE_PATH)));
let store_path = {
let guard = store.read().unwrap();
guard.path().to_path_buf()
};
let (tx, rx) = mpsc::channel::<TcpStream>();
let rx = Arc::new(Mutex::new(rx));
let mut workers = vec![];
for id in 0..NUM_WORKERS {
let store_clone = Arc::clone(&store);
let rx_clone = Arc::clone(&rx);
let path_clone = store_path.clone();
workers.push(thread::spawn(move || {
worker_loop(id, store_clone, rx_clone, path_clone)
}));
}
for stream in listener.incoming() {
match stream {
Ok(stream) => {
if let Err(e) = tx.send(stream) {
eprintln!("发送连接到工作线程失败: {}", e);
break;
}
}
Err(e) => eprintln!("接受连接失败: {}", e),
}
}
for handle in workers {
handle.join().expect("工作线程崩溃");
}
}
fn worker_loop(
worker_id: usize,
store: Arc<RwLock<KvStore<FileStorage>>>,
rx: Arc<Mutex<Receiver<TcpStream>>>,
store_path: std::path::PathBuf,
) {
println!("[工作线程 {}] 启动", worker_id);
loop {
let stream = {
let receiver = rx.lock().unwrap();
match receiver.recv() {
Ok(stream) => stream,
Err(_) => {
println!("[工作线程 {}] 通道关闭,退出", worker_id);
break;
}
}
};
if let Err(e) = handle_client(worker_id, stream, store.clone(), &store_path) {
eprintln!("[工作线程 {}] 处理客户端出错: {}", worker_id, e);
}
}
println!("[工作线程 {}] 停止", worker_id);
}
fn handle_client(
worker_id: usize,
mut stream: TcpStream,
store: Arc<RwLock<KvStore<FileStorage>>>,
store_path: &std::path::Path,
) -> std::io::Result<()> {
let peer_addr = stream.peer_addr()?;
println!("[工作线程 {}] 新客户端连接: {}", worker_id, peer_addr);
let mut writer = BufWriter::new(stream.try_clone()?);
let reader = BufReader::new(&mut stream);
let result = reader.lines().try_for_each(|line_result| -> std::io::Result<()> {
let line = line_result?;
let trimmed = line.trim();
if trimmed.is_empty() {
return Ok(());
}
println!("[工作线程 {}] 收到命令: {}", worker_id, trimmed);
let parts: Vec<&str> = trimmed.split_whitespace().collect();
if parts.is_empty() {
return Ok(());
}
let response = match parts[0].to_lowercase().as_str() {
"get" => {
if parts.len() != 2 {
"错误:用法 GET <key>".to_string()
} else {
let key = parts[1];
let s = store.read().unwrap();
match s.get(key) {
Some(v) => v.to_owned(),
None => format!("错误:键 '{}' 不存在", key),
}
}
}
"set" => {
if parts.len() != 3 {
"错误:用法 SET <key> <value>".to_string()
} else {
let key = parts[1].to_string();
let value = parts[2].to_string();
let snapshot = {
let mut s = store.write().unwrap();
s.set(key, value);
s.clone_data()
};
match FileStorage::save_from_snapshot(&snapshot, store_path) {
Ok(()) => {
println!("[工作线程 {}] SET 持久化成功", worker_id);
"OK".to_string()
}
Err(e) => format!("错误:保存数据失败 - {}", e),
}
}
}
"remove" => {
if parts.len() != 2 {
"错误:用法 REMOVE <key>".to_string()
} else {
let key = parts[1];
let (removed, snapshot) = {
let mut s = store.write().unwrap();
let removed = s.remove(key).is_some();
if removed {
(true, Some(s.clone_data()))
} else {
(false, None)
}
};
if removed {
let data = snapshot.unwrap();
match FileStorage::save_from_snapshot(&data, store_path) {
Ok(()) => "OK".to_string(),
Err(e) => format!("错误:保存数据失败 - {}", e),
}
} else {
format!("错误:键 '{}' 不存在", key)
}
}
}
"scan" => {
if parts.len() != 2 {
"错误:用法 SCAN <prefix>".to_string()
} else {
let prefix = parts[1];
let items: Vec<(String, String)> = store
.read()
.unwrap()
.scan(prefix)
.collect();
for (k, v) in items {
writeln!(writer, "{} {}", k, v)?;
}
writer.flush()?;
"OK".to_string()
}
}
_ => format!("错误:未知命令 '{}'", parts[0]),
};
writeln!(writer, "{}", response)?;
writer.flush()?;
Ok(())
});
if let Err(e) = result {
eprintln!("[工作线程 {}] 读取命令流错误: {}", worker_id, e);
return Err(e);
}
println!(
"[工作线程 {}] 客户端 {} 断开连接",
worker_id, peer_addr
);
Ok(())
}