use std::fs::{File, OpenOptions};
use std::io::{self, BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::{Error, Message, buffer::Buffer};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiskBufferConfig {
pub path: String,
pub max_size: usize,
pub delete_on_close: bool,
}
pub struct DiskBuffer {
config: DiskBufferConfig,
file_path: PathBuf,
reader: Mutex<Option<BufReader<File>>>,
writer: Mutex<Option<BufWriter<File>>>,
read_position: Mutex<u64>,
write_position: Mutex<u64>,
}
impl DiskBuffer {
pub fn new(config: &DiskBufferConfig) -> Result<Self, Error> {
let file_path = PathBuf::from(&config.path);
if let Some(parent) = file_path.parent() {
if !parent.exists() {
std::fs::create_dir_all(parent).map_err(Error::Io)?
}
}
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&file_path)
.map_err(Error::Io)?;
let file_size = file.metadata().map_err(Error::Io)?.len();
let reader = BufReader::new(file.try_clone().map_err(Error::Io)?);
let writer = BufWriter::new(file);
Ok(Self {
config: config.clone(),
file_path,
reader: Mutex::new(Some(reader)),
writer: Mutex::new(Some(writer)),
read_position: Mutex::new(0),
write_position: Mutex::new(file_size),
})
}
fn serialize_message(&self, msg: &Message) -> Result<Vec<u8>, Error> {
let content = msg.content().to_vec();
let content_len = content.len() as u32;
let mut buffer = Vec::with_capacity(4 + content_len as usize);
buffer.extend_from_slice(&content_len.to_le_bytes());
buffer.extend_from_slice(&content);
Ok(buffer)
}
fn deserialize_message(&self, mut bytes: &[u8]) -> Result<Message, Error> {
if bytes.len() < 4 {
return Err(Error::Processing("无效的消息格式".to_string()));
}
let mut len_bytes = [0u8; 4];
bytes.read_exact(&mut len_bytes).map_err(Error::Io)?;
let content_len = u32::from_le_bytes(len_bytes) as usize;
let mut content = vec![0u8; content_len];
bytes.read_exact(&mut content).map_err(Error::Io)?;
Ok(Message::new(content))
}
}
#[async_trait]
impl Buffer for DiskBuffer {
async fn push(&self, msg: &Message) -> Result<(), Error> {
let data = self.serialize_message(msg)?;
let mut write_pos = self.write_position.lock().map_err(|e| Error::Unknown(e.to_string()))?;
if (*write_pos as usize) + data.len() > self.config.max_size {
return Err(Error::Processing("磁盘缓冲区已满".to_string()));
}
let mut writer_guard = self.writer.lock().map_err(|e| Error::Unknown(e.to_string()))?;
let writer = writer_guard.as_mut().ok_or_else(|| Error::Connection("缓冲区未连接".to_string()))?;
writer.seek(SeekFrom::Start(*write_pos)).map_err(Error::Io)?;
writer.write_all(&data).map_err(Error::Io)?;
writer.flush().map_err(Error::Io)?;
*write_pos += data.len() as u64;
Ok(())
}
async fn pop(&self) -> Result<Option<Message>, Error> {
let mut read_pos = self.read_position.lock().map_err(|e| Error::Unknown(e.to_string()))?;
let write_pos = self.write_position.lock().map_err(|e| Error::Unknown(e.to_string()))?;
if *read_pos >= *write_pos {
return Ok(None);
}
let mut reader_guard = self.reader.lock().map_err(|e| Error::Unknown(e.to_string()))?;
let reader = reader_guard.as_mut().ok_or_else(|| Error::Connection("缓冲区未连接".to_string()))?;
reader.seek(SeekFrom::Start(*read_pos)).map_err(Error::Io)?;
let mut len_bytes = [0u8; 4];
reader.read_exact(&mut len_bytes).map_err(Error::Io)?;
let content_len = u32::from_le_bytes(len_bytes) as usize;
let mut content = vec![0u8; content_len];
reader.read_exact(&mut content).map_err(Error::Io)?;
*read_pos += (4 + content_len) as u64;
let msg = Message::new(content);
Ok(Some(msg))
}
async fn close(&self) -> Result<(), Error> {
let mut reader_guard = self.reader.lock().map_err(|e| Error::Unknown(e.to_string()))?;
let mut writer_guard = self.writer.lock().map_err(|e| Error::Unknown(e.to_string()))?;
*reader_guard = None;
*writer_guard = None;
if self.config.delete_on_close && Path::new(&self.file_path).exists() {
std::fs::remove_file(&self.file_path).map_err(Error::Io)?;
}
Ok(())
}
}