Skip to main content

nexus_memory_hooks/
buffer.rs

1//! Persistent buffer for crash recovery
2//!
3//! The persistent buffer acts as a safety net to prevent memory loss
4//! even if all hooks fail. It continuously buffers session context
5//! for recovery after crashes.
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::Arc;
12use tokio::fs;
13use tokio::io::AsyncWriteExt;
14use tokio::sync::RwLock;
15
16use crate::error::{HookError, Result};
17use crate::session::SessionContext;
18
19/// Default buffer directory
20pub fn default_buffer_dir() -> PathBuf {
21    dirs::data_local_dir()
22        .unwrap_or_else(|| PathBuf::from("."))
23        .join("nexus")
24        .join("buffer")
25}
26
27/// Buffer entry stored in memory and on disk
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct BufferEntry {
30    /// Timestamp when entry was created
31    pub timestamp: DateTime<Utc>,
32
33    /// Type of context
34    pub context_type: String,
35
36    /// The context data
37    pub context: SessionContext,
38}
39
40/// Buffer data structure stored on disk
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct BufferData {
43    /// When buffering started
44    pub started_at: DateTime<Utc>,
45
46    /// Buffer entries
47    pub entries: Vec<BufferEntry>,
48
49    /// When last flushed to disk
50    pub last_flush: Option<DateTime<Utc>>,
51
52    /// Agent type
53    pub agent_type: String,
54}
55
56impl BufferData {
57    pub fn new(agent_type: impl Into<String>) -> Self {
58        Self {
59            started_at: Utc::now(),
60            entries: Vec::new(),
61            last_flush: None,
62            agent_type: agent_type.into(),
63        }
64    }
65}
66
67/// Persistent buffer for crash recovery
68///
69/// This buffer continuously stores session context entries. If all other
70/// detection methods fail, we can recover from this buffer.
71///
72/// # Buffer Lifecycle
73///
74/// 1. Start buffering when session starts
75/// 2. Continuously append context entries
76/// 3. Periodically flush to disk
77/// 4. Recover after crash
78/// 5. Clear buffer after successful storage
79///
80/// # Example
81///
82/// ```rust,no_run
83/// use nexus_memory_hooks::buffer::PersistentBuffer;
84/// use nexus_memory_hooks::session::SessionContext;
85///
86/// #[tokio::main]
87/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
88///     let mut buffer = PersistentBuffer::new(None)?;
89///
90///     // Start buffering for an agent
91///     buffer.start_buffering("claude-code").await?;
92///
93///     // Buffer context periodically
94///     let ctx = SessionContext::new("claude-code");
95///     buffer.buffer_context("claude-code", ctx, "checkpoint").await?;
96///
97///     // Recover after crash
98///     if let Some(data) = buffer.recover_buffer("claude-code").await? {
99///         println!("Recovered {} entries", data.entries.len());
100///     }
101///
102///     // Clear after successful storage
103///     buffer.clear_buffer("claude-code").await?;
104///
105///     Ok(())
106/// }
107/// ```
108pub struct PersistentBuffer {
109    /// Buffer directory
110    buffer_dir: PathBuf,
111
112    /// In-memory buffers by agent type
113    buffers: Arc<RwLock<HashMap<String, BufferData>>>,
114
115    /// Flush interval in seconds
116    flush_interval_secs: u64,
117
118    /// Maximum entries before auto-flush
119    max_entries: usize,
120}
121
122impl PersistentBuffer {
123    /// Create a new persistent buffer
124    ///
125    /// # Arguments
126    ///
127    /// * `buffer_dir` - Directory for buffer files (default: ~/.local/share/nexus/buffer)
128    pub fn new(buffer_dir: Option<PathBuf>) -> Result<Self> {
129        let buffer_dir = buffer_dir.unwrap_or_else(default_buffer_dir);
130
131        // Create directory if it doesn't exist
132        std::fs::create_dir_all(&buffer_dir)
133            .map_err(|e| HookError::BufferError(format!("Failed to create buffer dir: {}", e)))?;
134
135        Ok(Self {
136            buffer_dir,
137            buffers: Arc::new(RwLock::new(HashMap::new())),
138            flush_interval_secs: 10,
139            max_entries: 10,
140        })
141    }
142
143    /// Set flush interval
144    pub fn with_flush_interval(mut self, secs: u64) -> Self {
145        self.flush_interval_secs = secs;
146        self
147    }
148
149    /// Set max entries before auto-flush
150    pub fn with_max_entries(mut self, max: usize) -> Self {
151        self.max_entries = max;
152        self
153    }
154
155    /// Start buffering for an agent
156    pub async fn start_buffering(&self, agent_type: &str) -> Result<()> {
157        let mut buffers = self.buffers.write().await;
158
159        if !buffers.contains_key(agent_type) {
160            buffers.insert(agent_type.to_string(), BufferData::new(agent_type));
161        }
162
163        Ok(())
164    }
165
166    /// Buffer a context entry
167    pub async fn buffer_context(
168        &self,
169        agent_type: &str,
170        context: SessionContext,
171        context_type: &str,
172    ) -> Result<()> {
173        // Ensure buffering is started
174        {
175            let mut buffers = self.buffers.write().await;
176            if !buffers.contains_key(agent_type) {
177                buffers.insert(agent_type.to_string(), BufferData::new(agent_type));
178            }
179        }
180
181        let entry = BufferEntry {
182            timestamp: Utc::now(),
183            context_type: context_type.to_string(),
184            context,
185        };
186
187        // Add to memory buffer
188        let should_flush = {
189            let mut buffers = self.buffers.write().await;
190            if let Some(buffer) = buffers.get_mut(agent_type) {
191                buffer.entries.push(entry);
192                buffer.entries.len() >= self.max_entries
193            } else {
194                false
195            }
196        };
197
198        // Auto-flush if buffer is large
199        if should_flush {
200            self.flush_to_disk(agent_type).await?;
201        }
202
203        Ok(())
204    }
205
206    /// Flush buffer to disk
207    pub async fn flush_to_disk(&self, agent_type: &str) -> Result<()> {
208        let buffers = self.buffers.read().await;
209
210        if let Some(buffer) = buffers.get(agent_type) {
211            let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
212            let tmp_file = self.buffer_dir.join(format!("{}.json.tmp", agent_type));
213            let json = serde_json::to_string_pretty(buffer)
214                .map_err(|e| HookError::BufferError(format!("Failed to serialize: {}", e)))?;
215
216            let mut file = fs::File::create(&tmp_file)
217                .await
218                .map_err(|e| HookError::BufferError(format!("Failed to create file: {}", e)))?;
219
220            file.write_all(json.as_bytes())
221                .await
222                .map_err(|e| HookError::BufferError(format!("Failed to write: {}", e)))?;
223            file.sync_all()
224                .await
225                .map_err(|e| HookError::BufferError(format!("Failed to sync file: {}", e)))?;
226
227            #[cfg(windows)]
228            if buffer_file.exists() {
229                fs::remove_file(&buffer_file).await.map_err(|e| {
230                    HookError::BufferError(format!(
231                        "Failed to remove existing buffer file before replace: {}",
232                        e
233                    ))
234                })?;
235            }
236            if let Err(err) = fs::rename(&tmp_file, &buffer_file).await {
237                let _ = fs::remove_file(&tmp_file).await;
238                return Err(HookError::BufferError(format!(
239                    "Failed to replace buffer: {}",
240                    err
241                )));
242            }
243
244            #[cfg(unix)]
245            if let Some(parent) = buffer_file.parent() {
246                let dir = fs::File::open(parent).await.map_err(|e| {
247                    HookError::BufferError(format!("Failed to open buffer dir for sync: {}", e))
248                })?;
249                dir.sync_all().await.map_err(|e| {
250                    HookError::BufferError(format!("Failed to sync buffer dir: {}", e))
251                })?;
252            }
253
254            // Update last_flush time
255            drop(buffers);
256            let mut buffers = self.buffers.write().await;
257            if let Some(buffer) = buffers.get_mut(agent_type) {
258                buffer.last_flush = Some(Utc::now());
259            }
260        }
261
262        Ok(())
263    }
264
265    /// Flush all buffers to disk
266    pub async fn flush_all(&self) -> Result<()> {
267        let buffers = self.buffers.read().await;
268        let agent_types: Vec<String> = buffers.keys().cloned().collect();
269        drop(buffers);
270
271        for agent_type in agent_types {
272            self.flush_to_disk(&agent_type).await?;
273        }
274
275        Ok(())
276    }
277
278    /// Recover buffered context after crash
279    pub async fn recover_buffer(&self, agent_type: &str) -> Result<Option<BufferData>> {
280        let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
281
282        if !buffer_file.exists() {
283            return Ok(None);
284        }
285
286        let content = fs::read_to_string(&buffer_file)
287            .await
288            .map_err(|e| HookError::BufferError(format!("Failed to read buffer: {}", e)))?;
289
290        let data: BufferData = serde_json::from_str(&content)
291            .map_err(|e| HookError::BufferError(format!("Failed to parse buffer: {}", e)))?;
292
293        tracing::info!(
294            "Recovered buffer for {}: {} entries",
295            agent_type,
296            data.entries.len()
297        );
298
299        Ok(Some(data))
300    }
301
302    /// Clear buffer after successful storage
303    pub async fn clear_buffer(&self, agent_type: &str) -> Result<()> {
304        // Clear from memory
305        {
306            let mut buffers = self.buffers.write().await;
307            buffers.remove(agent_type);
308        }
309
310        // Clear from disk
311        let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
312        if buffer_file.exists() {
313            fs::remove_file(&buffer_file)
314                .await
315                .map_err(|e| HookError::BufferError(format!("Failed to remove buffer: {}", e)))?;
316        }
317
318        Ok(())
319    }
320
321    /// Get buffer status
322    pub async fn get_buffer_status(&self, agent_type: &str) -> Option<BufferStatus> {
323        let buffers = self.buffers.read().await;
324
325        buffers.get(agent_type).map(|buffer| BufferStatus {
326            agent_type: agent_type.to_string(),
327            started_at: buffer.started_at,
328            entries_count: buffer.entries.len(),
329            last_flush: buffer.last_flush,
330        })
331    }
332
333    /// List all active buffers
334    pub async fn list_buffers(&self) -> Vec<BufferStatus> {
335        let buffers = self.buffers.read().await;
336
337        buffers
338            .iter()
339            .map(|(agent_type, buffer)| BufferStatus {
340                agent_type: agent_type.clone(),
341                started_at: buffer.started_at,
342                entries_count: buffer.entries.len(),
343                last_flush: buffer.last_flush,
344            })
345            .collect()
346    }
347
348    /// Check if buffer exists for agent
349    pub async fn has_buffer(&self, agent_type: &str) -> bool {
350        let buffers = self.buffers.read().await;
351        buffers.contains_key(agent_type)
352            || self
353                .buffer_dir
354                .join(format!("{}.json", agent_type))
355                .exists()
356    }
357}
358
359/// Buffer status information
360#[derive(Debug, Clone, Serialize, Deserialize)]
361pub struct BufferStatus {
362    pub agent_type: String,
363    pub started_at: DateTime<Utc>,
364    pub entries_count: usize,
365    pub last_flush: Option<DateTime<Utc>>,
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use tempfile::tempdir;
372
373    #[tokio::test]
374    async fn test_buffer_context() {
375        let dir = tempdir().unwrap();
376        let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
377
378        buffer.start_buffering("test-agent").await.unwrap();
379
380        let ctx = SessionContext::new("test-agent");
381        buffer
382            .buffer_context("test-agent", ctx, "checkpoint")
383            .await
384            .unwrap();
385
386        let status = buffer.get_buffer_status("test-agent").await.unwrap();
387        assert_eq!(status.entries_count, 1);
388    }
389
390    #[tokio::test]
391    async fn test_flush_and_recover() {
392        let dir = tempdir().unwrap();
393        let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf()))
394            .unwrap()
395            .with_max_entries(1);
396
397        let ctx = SessionContext::new("test-agent");
398
399        // This should auto-flush due to max_entries = 1
400        buffer.start_buffering("test-agent").await.unwrap();
401        buffer
402            .buffer_context("test-agent", ctx.clone(), "test")
403            .await
404            .unwrap();
405
406        // Recover
407        let recovered = buffer.recover_buffer("test-agent").await.unwrap();
408        assert!(recovered.is_some());
409
410        let data = recovered.unwrap();
411        assert_eq!(data.entries.len(), 1);
412    }
413
414    #[tokio::test]
415    async fn test_clear_buffer() {
416        let dir = tempdir().unwrap();
417        let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
418
419        buffer.start_buffering("test-agent").await.unwrap();
420
421        let ctx = SessionContext::new("test-agent");
422        buffer
423            .buffer_context("test-agent", ctx, "test")
424            .await
425            .unwrap();
426
427        buffer.flush_to_disk("test-agent").await.unwrap();
428        buffer.clear_buffer("test-agent").await.unwrap();
429
430        let status = buffer.get_buffer_status("test-agent").await;
431        assert!(status.is_none());
432
433        let recovered = buffer.recover_buffer("test-agent").await.unwrap();
434        assert!(recovered.is_none());
435    }
436
437    #[tokio::test]
438    async fn test_list_buffers() {
439        let dir = tempdir().unwrap();
440        let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
441
442        buffer.start_buffering("agent1").await.unwrap();
443        buffer.start_buffering("agent2").await.unwrap();
444
445        let buffers = buffer.list_buffers().await;
446        assert_eq!(buffers.len(), 2);
447    }
448}