use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tokio::sync::RwLock;
use crate::error::{HookError, Result};
use crate::session::SessionContext;
pub fn default_buffer_dir() -> PathBuf {
dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("nexus")
.join("buffer")
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BufferEntry {
pub timestamp: DateTime<Utc>,
pub context_type: String,
pub context: SessionContext,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BufferData {
pub started_at: DateTime<Utc>,
pub entries: Vec<BufferEntry>,
pub last_flush: Option<DateTime<Utc>>,
pub agent_type: String,
}
impl BufferData {
pub fn new(agent_type: impl Into<String>) -> Self {
Self {
started_at: Utc::now(),
entries: Vec::new(),
last_flush: None,
agent_type: agent_type.into(),
}
}
}
pub struct PersistentBuffer {
buffer_dir: PathBuf,
buffers: Arc<RwLock<HashMap<String, BufferData>>>,
flush_interval_secs: u64,
max_entries: usize,
}
impl PersistentBuffer {
pub fn new(buffer_dir: Option<PathBuf>) -> Result<Self> {
let buffer_dir = buffer_dir.unwrap_or_else(default_buffer_dir);
std::fs::create_dir_all(&buffer_dir)
.map_err(|e| HookError::BufferError(format!("Failed to create buffer dir: {}", e)))?;
Ok(Self {
buffer_dir,
buffers: Arc::new(RwLock::new(HashMap::new())),
flush_interval_secs: 10,
max_entries: 10,
})
}
pub fn with_flush_interval(mut self, secs: u64) -> Self {
self.flush_interval_secs = secs;
self
}
pub fn with_max_entries(mut self, max: usize) -> Self {
self.max_entries = max;
self
}
pub async fn start_buffering(&self, agent_type: &str) -> Result<()> {
let mut buffers = self.buffers.write().await;
if !buffers.contains_key(agent_type) {
buffers.insert(agent_type.to_string(), BufferData::new(agent_type));
}
Ok(())
}
pub async fn buffer_context(
&self,
agent_type: &str,
context: SessionContext,
context_type: &str,
) -> Result<()> {
{
let mut buffers = self.buffers.write().await;
if !buffers.contains_key(agent_type) {
buffers.insert(agent_type.to_string(), BufferData::new(agent_type));
}
}
let entry = BufferEntry {
timestamp: Utc::now(),
context_type: context_type.to_string(),
context,
};
let should_flush = {
let mut buffers = self.buffers.write().await;
if let Some(buffer) = buffers.get_mut(agent_type) {
buffer.entries.push(entry);
buffer.entries.len() >= self.max_entries
} else {
false
}
};
if should_flush {
self.flush_to_disk(agent_type).await?;
}
Ok(())
}
pub async fn flush_to_disk(&self, agent_type: &str) -> Result<()> {
let buffers = self.buffers.read().await;
if let Some(buffer) = buffers.get(agent_type) {
let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
let tmp_file = self.buffer_dir.join(format!("{}.json.tmp", agent_type));
let json = serde_json::to_string_pretty(buffer)
.map_err(|e| HookError::BufferError(format!("Failed to serialize: {}", e)))?;
let mut file = fs::File::create(&tmp_file)
.await
.map_err(|e| HookError::BufferError(format!("Failed to create file: {}", e)))?;
file.write_all(json.as_bytes())
.await
.map_err(|e| HookError::BufferError(format!("Failed to write: {}", e)))?;
file.sync_all()
.await
.map_err(|e| HookError::BufferError(format!("Failed to sync file: {}", e)))?;
#[cfg(windows)]
if buffer_file.exists() {
fs::remove_file(&buffer_file).await.map_err(|e| {
HookError::BufferError(format!(
"Failed to remove existing buffer file before replace: {}",
e
))
})?;
}
if let Err(err) = fs::rename(&tmp_file, &buffer_file).await {
let _ = fs::remove_file(&tmp_file).await;
return Err(HookError::BufferError(format!(
"Failed to replace buffer: {}",
err
)));
}
#[cfg(unix)]
if let Some(parent) = buffer_file.parent() {
let dir = fs::File::open(parent).await.map_err(|e| {
HookError::BufferError(format!("Failed to open buffer dir for sync: {}", e))
})?;
dir.sync_all().await.map_err(|e| {
HookError::BufferError(format!("Failed to sync buffer dir: {}", e))
})?;
}
drop(buffers);
let mut buffers = self.buffers.write().await;
if let Some(buffer) = buffers.get_mut(agent_type) {
buffer.last_flush = Some(Utc::now());
}
}
Ok(())
}
pub async fn flush_all(&self) -> Result<()> {
let buffers = self.buffers.read().await;
let agent_types: Vec<String> = buffers.keys().cloned().collect();
drop(buffers);
for agent_type in agent_types {
self.flush_to_disk(&agent_type).await?;
}
Ok(())
}
pub async fn recover_buffer(&self, agent_type: &str) -> Result<Option<BufferData>> {
let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
if !buffer_file.exists() {
return Ok(None);
}
let content = fs::read_to_string(&buffer_file)
.await
.map_err(|e| HookError::BufferError(format!("Failed to read buffer: {}", e)))?;
let data: BufferData = serde_json::from_str(&content)
.map_err(|e| HookError::BufferError(format!("Failed to parse buffer: {}", e)))?;
tracing::info!(
"Recovered buffer for {}: {} entries",
agent_type,
data.entries.len()
);
Ok(Some(data))
}
pub async fn clear_buffer(&self, agent_type: &str) -> Result<()> {
{
let mut buffers = self.buffers.write().await;
buffers.remove(agent_type);
}
let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
if buffer_file.exists() {
fs::remove_file(&buffer_file)
.await
.map_err(|e| HookError::BufferError(format!("Failed to remove buffer: {}", e)))?;
}
Ok(())
}
pub async fn get_buffer_status(&self, agent_type: &str) -> Option<BufferStatus> {
let buffers = self.buffers.read().await;
buffers.get(agent_type).map(|buffer| BufferStatus {
agent_type: agent_type.to_string(),
started_at: buffer.started_at,
entries_count: buffer.entries.len(),
last_flush: buffer.last_flush,
})
}
pub async fn list_buffers(&self) -> Vec<BufferStatus> {
let buffers = self.buffers.read().await;
buffers
.iter()
.map(|(agent_type, buffer)| BufferStatus {
agent_type: agent_type.clone(),
started_at: buffer.started_at,
entries_count: buffer.entries.len(),
last_flush: buffer.last_flush,
})
.collect()
}
pub async fn has_buffer(&self, agent_type: &str) -> bool {
let buffers = self.buffers.read().await;
buffers.contains_key(agent_type)
|| self
.buffer_dir
.join(format!("{}.json", agent_type))
.exists()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BufferStatus {
pub agent_type: String,
pub started_at: DateTime<Utc>,
pub entries_count: usize,
pub last_flush: Option<DateTime<Utc>>,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[tokio::test]
async fn test_buffer_context() {
let dir = tempdir().unwrap();
let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
buffer.start_buffering("test-agent").await.unwrap();
let ctx = SessionContext::new("test-agent");
buffer
.buffer_context("test-agent", ctx, "checkpoint")
.await
.unwrap();
let status = buffer.get_buffer_status("test-agent").await.unwrap();
assert_eq!(status.entries_count, 1);
}
#[tokio::test]
async fn test_flush_and_recover() {
let dir = tempdir().unwrap();
let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf()))
.unwrap()
.with_max_entries(1);
let ctx = SessionContext::new("test-agent");
buffer.start_buffering("test-agent").await.unwrap();
buffer
.buffer_context("test-agent", ctx.clone(), "test")
.await
.unwrap();
let recovered = buffer.recover_buffer("test-agent").await.unwrap();
assert!(recovered.is_some());
let data = recovered.unwrap();
assert_eq!(data.entries.len(), 1);
}
#[tokio::test]
async fn test_clear_buffer() {
let dir = tempdir().unwrap();
let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
buffer.start_buffering("test-agent").await.unwrap();
let ctx = SessionContext::new("test-agent");
buffer
.buffer_context("test-agent", ctx, "test")
.await
.unwrap();
buffer.flush_to_disk("test-agent").await.unwrap();
buffer.clear_buffer("test-agent").await.unwrap();
let status = buffer.get_buffer_status("test-agent").await;
assert!(status.is_none());
let recovered = buffer.recover_buffer("test-agent").await.unwrap();
assert!(recovered.is_none());
}
#[tokio::test]
async fn test_list_buffers() {
let dir = tempdir().unwrap();
let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
buffer.start_buffering("agent1").await.unwrap();
buffer.start_buffering("agent2").await.unwrap();
let buffers = buffer.list_buffers().await;
assert_eq!(buffers.len(), 2);
}
}