1use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::Arc;
12use tokio::fs;
13use tokio::io::AsyncWriteExt;
14use tokio::sync::RwLock;
15
16use crate::error::{HookError, Result};
17use crate::session::SessionContext;
18
19pub fn default_buffer_dir() -> PathBuf {
21 dirs::data_local_dir()
22 .unwrap_or_else(|| PathBuf::from("."))
23 .join("nexus")
24 .join("buffer")
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct BufferEntry {
30 pub timestamp: DateTime<Utc>,
32
33 pub context_type: String,
35
36 pub context: SessionContext,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct BufferData {
43 pub started_at: DateTime<Utc>,
45
46 pub entries: Vec<BufferEntry>,
48
49 pub last_flush: Option<DateTime<Utc>>,
51
52 pub agent_type: String,
54}
55
56impl BufferData {
57 pub fn new(agent_type: impl Into<String>) -> Self {
58 Self {
59 started_at: Utc::now(),
60 entries: Vec::new(),
61 last_flush: None,
62 agent_type: agent_type.into(),
63 }
64 }
65}
66
67pub struct PersistentBuffer {
109 buffer_dir: PathBuf,
111
112 buffers: Arc<RwLock<HashMap<String, BufferData>>>,
114
115 flush_interval_secs: u64,
117
118 max_entries: usize,
120}
121
122impl PersistentBuffer {
123 pub fn new(buffer_dir: Option<PathBuf>) -> Result<Self> {
129 let buffer_dir = buffer_dir.unwrap_or_else(default_buffer_dir);
130
131 std::fs::create_dir_all(&buffer_dir)
133 .map_err(|e| HookError::BufferError(format!("Failed to create buffer dir: {}", e)))?;
134
135 Ok(Self {
136 buffer_dir,
137 buffers: Arc::new(RwLock::new(HashMap::new())),
138 flush_interval_secs: 10,
139 max_entries: 10,
140 })
141 }
142
143 pub fn with_flush_interval(mut self, secs: u64) -> Self {
145 self.flush_interval_secs = secs;
146 self
147 }
148
149 pub fn with_max_entries(mut self, max: usize) -> Self {
151 self.max_entries = max;
152 self
153 }
154
155 pub async fn start_buffering(&self, agent_type: &str) -> Result<()> {
157 let mut buffers = self.buffers.write().await;
158
159 if !buffers.contains_key(agent_type) {
160 buffers.insert(agent_type.to_string(), BufferData::new(agent_type));
161 }
162
163 Ok(())
164 }
165
166 pub async fn buffer_context(
168 &self,
169 agent_type: &str,
170 context: SessionContext,
171 context_type: &str,
172 ) -> Result<()> {
173 {
175 let mut buffers = self.buffers.write().await;
176 if !buffers.contains_key(agent_type) {
177 buffers.insert(agent_type.to_string(), BufferData::new(agent_type));
178 }
179 }
180
181 let entry = BufferEntry {
182 timestamp: Utc::now(),
183 context_type: context_type.to_string(),
184 context,
185 };
186
187 let should_flush = {
189 let mut buffers = self.buffers.write().await;
190 if let Some(buffer) = buffers.get_mut(agent_type) {
191 buffer.entries.push(entry);
192 buffer.entries.len() >= self.max_entries
193 } else {
194 false
195 }
196 };
197
198 if should_flush {
200 self.flush_to_disk(agent_type).await?;
201 }
202
203 Ok(())
204 }
205
206 pub async fn flush_to_disk(&self, agent_type: &str) -> Result<()> {
208 let buffers = self.buffers.read().await;
209
210 if let Some(buffer) = buffers.get(agent_type) {
211 let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
212 let tmp_file = self.buffer_dir.join(format!("{}.json.tmp", agent_type));
213 let json = serde_json::to_string_pretty(buffer)
214 .map_err(|e| HookError::BufferError(format!("Failed to serialize: {}", e)))?;
215
216 let mut file = fs::File::create(&tmp_file)
217 .await
218 .map_err(|e| HookError::BufferError(format!("Failed to create file: {}", e)))?;
219
220 file.write_all(json.as_bytes())
221 .await
222 .map_err(|e| HookError::BufferError(format!("Failed to write: {}", e)))?;
223 file.sync_all()
224 .await
225 .map_err(|e| HookError::BufferError(format!("Failed to sync file: {}", e)))?;
226
227 #[cfg(windows)]
228 if buffer_file.exists() {
229 fs::remove_file(&buffer_file).await.map_err(|e| {
230 HookError::BufferError(format!(
231 "Failed to remove existing buffer file before replace: {}",
232 e
233 ))
234 })?;
235 }
236 if let Err(err) = fs::rename(&tmp_file, &buffer_file).await {
237 let _ = fs::remove_file(&tmp_file).await;
238 return Err(HookError::BufferError(format!(
239 "Failed to replace buffer: {}",
240 err
241 )));
242 }
243
244 #[cfg(unix)]
245 if let Some(parent) = buffer_file.parent() {
246 let dir = fs::File::open(parent).await.map_err(|e| {
247 HookError::BufferError(format!("Failed to open buffer dir for sync: {}", e))
248 })?;
249 dir.sync_all().await.map_err(|e| {
250 HookError::BufferError(format!("Failed to sync buffer dir: {}", e))
251 })?;
252 }
253
254 drop(buffers);
256 let mut buffers = self.buffers.write().await;
257 if let Some(buffer) = buffers.get_mut(agent_type) {
258 buffer.last_flush = Some(Utc::now());
259 }
260 }
261
262 Ok(())
263 }
264
265 pub async fn flush_all(&self) -> Result<()> {
267 let buffers = self.buffers.read().await;
268 let agent_types: Vec<String> = buffers.keys().cloned().collect();
269 drop(buffers);
270
271 for agent_type in agent_types {
272 self.flush_to_disk(&agent_type).await?;
273 }
274
275 Ok(())
276 }
277
278 pub async fn recover_buffer(&self, agent_type: &str) -> Result<Option<BufferData>> {
280 let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
281
282 if !buffer_file.exists() {
283 return Ok(None);
284 }
285
286 let content = fs::read_to_string(&buffer_file)
287 .await
288 .map_err(|e| HookError::BufferError(format!("Failed to read buffer: {}", e)))?;
289
290 let data: BufferData = serde_json::from_str(&content)
291 .map_err(|e| HookError::BufferError(format!("Failed to parse buffer: {}", e)))?;
292
293 tracing::info!(
294 "Recovered buffer for {}: {} entries",
295 agent_type,
296 data.entries.len()
297 );
298
299 Ok(Some(data))
300 }
301
302 pub async fn clear_buffer(&self, agent_type: &str) -> Result<()> {
304 {
306 let mut buffers = self.buffers.write().await;
307 buffers.remove(agent_type);
308 }
309
310 let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
312 if buffer_file.exists() {
313 fs::remove_file(&buffer_file)
314 .await
315 .map_err(|e| HookError::BufferError(format!("Failed to remove buffer: {}", e)))?;
316 }
317
318 Ok(())
319 }
320
321 pub async fn get_buffer_status(&self, agent_type: &str) -> Option<BufferStatus> {
323 let buffers = self.buffers.read().await;
324
325 buffers.get(agent_type).map(|buffer| BufferStatus {
326 agent_type: agent_type.to_string(),
327 started_at: buffer.started_at,
328 entries_count: buffer.entries.len(),
329 last_flush: buffer.last_flush,
330 })
331 }
332
333 pub async fn list_buffers(&self) -> Vec<BufferStatus> {
335 let buffers = self.buffers.read().await;
336
337 buffers
338 .iter()
339 .map(|(agent_type, buffer)| BufferStatus {
340 agent_type: agent_type.clone(),
341 started_at: buffer.started_at,
342 entries_count: buffer.entries.len(),
343 last_flush: buffer.last_flush,
344 })
345 .collect()
346 }
347
348 pub async fn has_buffer(&self, agent_type: &str) -> bool {
350 let buffers = self.buffers.read().await;
351 buffers.contains_key(agent_type)
352 || self
353 .buffer_dir
354 .join(format!("{}.json", agent_type))
355 .exists()
356 }
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize)]
361pub struct BufferStatus {
362 pub agent_type: String,
363 pub started_at: DateTime<Utc>,
364 pub entries_count: usize,
365 pub last_flush: Option<DateTime<Utc>>,
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use tempfile::tempdir;
372
373 #[tokio::test]
374 async fn test_buffer_context() {
375 let dir = tempdir().unwrap();
376 let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
377
378 buffer.start_buffering("test-agent").await.unwrap();
379
380 let ctx = SessionContext::new("test-agent");
381 buffer
382 .buffer_context("test-agent", ctx, "checkpoint")
383 .await
384 .unwrap();
385
386 let status = buffer.get_buffer_status("test-agent").await.unwrap();
387 assert_eq!(status.entries_count, 1);
388 }
389
390 #[tokio::test]
391 async fn test_flush_and_recover() {
392 let dir = tempdir().unwrap();
393 let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf()))
394 .unwrap()
395 .with_max_entries(1);
396
397 let ctx = SessionContext::new("test-agent");
398
399 buffer.start_buffering("test-agent").await.unwrap();
401 buffer
402 .buffer_context("test-agent", ctx.clone(), "test")
403 .await
404 .unwrap();
405
406 let recovered = buffer.recover_buffer("test-agent").await.unwrap();
408 assert!(recovered.is_some());
409
410 let data = recovered.unwrap();
411 assert_eq!(data.entries.len(), 1);
412 }
413
414 #[tokio::test]
415 async fn test_clear_buffer() {
416 let dir = tempdir().unwrap();
417 let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
418
419 buffer.start_buffering("test-agent").await.unwrap();
420
421 let ctx = SessionContext::new("test-agent");
422 buffer
423 .buffer_context("test-agent", ctx, "test")
424 .await
425 .unwrap();
426
427 buffer.flush_to_disk("test-agent").await.unwrap();
428 buffer.clear_buffer("test-agent").await.unwrap();
429
430 let status = buffer.get_buffer_status("test-agent").await;
431 assert!(status.is_none());
432
433 let recovered = buffer.recover_buffer("test-agent").await.unwrap();
434 assert!(recovered.is_none());
435 }
436
437 #[tokio::test]
438 async fn test_list_buffers() {
439 let dir = tempdir().unwrap();
440 let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
441
442 buffer.start_buffering("agent1").await.unwrap();
443 buffer.start_buffering("agent2").await.unwrap();
444
445 let buffers = buffer.list_buffers().await;
446 assert_eq!(buffers.len(), 2);
447 }
448}