nexus_memory_hooks/
buffer.rs1use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::Arc;
12use tokio::fs;
13use tokio::io::AsyncWriteExt;
14use tokio::sync::RwLock;
15
16use crate::error::{HookError, Result};
17use crate::session::SessionContext;
18
19pub fn default_buffer_dir() -> PathBuf {
21 dirs::data_local_dir()
22 .unwrap_or_else(|| PathBuf::from("."))
23 .join("nexus")
24 .join("buffer")
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct BufferEntry {
30 pub timestamp: DateTime<Utc>,
32
33 pub context_type: String,
35
36 pub context: SessionContext,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct BufferData {
43 pub started_at: DateTime<Utc>,
45
46 pub entries: Vec<BufferEntry>,
48
49 pub last_flush: Option<DateTime<Utc>>,
51
52 pub agent_type: String,
54}
55
56impl BufferData {
57 pub fn new(agent_type: impl Into<String>) -> Self {
58 Self {
59 started_at: Utc::now(),
60 entries: Vec::new(),
61 last_flush: None,
62 agent_type: agent_type.into(),
63 }
64 }
65}
66
67pub struct PersistentBuffer {
109 buffer_dir: PathBuf,
111
112 buffers: Arc<RwLock<HashMap<String, BufferData>>>,
114
115 flush_interval_secs: u64,
117
118 max_entries: usize,
120}
121
122impl PersistentBuffer {
123 pub fn new(buffer_dir: Option<PathBuf>) -> Result<Self> {
129 let buffer_dir = buffer_dir.unwrap_or_else(default_buffer_dir);
130
131 std::fs::create_dir_all(&buffer_dir)
133 .map_err(|e| HookError::BufferError(format!("Failed to create buffer dir: {}", e)))?;
134
135 Ok(Self {
136 buffer_dir,
137 buffers: Arc::new(RwLock::new(HashMap::new())),
138 flush_interval_secs: 10,
139 max_entries: 10,
140 })
141 }
142
143 pub fn with_flush_interval(mut self, secs: u64) -> Self {
145 self.flush_interval_secs = secs;
146 self
147 }
148
149 pub fn with_max_entries(mut self, max: usize) -> Self {
151 self.max_entries = max;
152 self
153 }
154
155 pub async fn start_buffering(&self, agent_type: &str) -> Result<()> {
157 let mut buffers = self.buffers.write().await;
158
159 if !buffers.contains_key(agent_type) {
160 buffers.insert(agent_type.to_string(), BufferData::new(agent_type));
161 }
162
163 Ok(())
164 }
165
166 pub async fn buffer_context(
168 &self,
169 agent_type: &str,
170 context: SessionContext,
171 context_type: &str,
172 ) -> Result<()> {
173 {
175 let mut buffers = self.buffers.write().await;
176 if !buffers.contains_key(agent_type) {
177 buffers.insert(agent_type.to_string(), BufferData::new(agent_type));
178 }
179 }
180
181 let entry = BufferEntry {
182 timestamp: Utc::now(),
183 context_type: context_type.to_string(),
184 context,
185 };
186
187 let should_flush = {
189 let mut buffers = self.buffers.write().await;
190 if let Some(buffer) = buffers.get_mut(agent_type) {
191 buffer.entries.push(entry);
192 buffer.entries.len() >= self.max_entries
193 } else {
194 false
195 }
196 };
197
198 if should_flush {
200 self.flush_to_disk(agent_type).await?;
201 }
202
203 Ok(())
204 }
205
206 pub async fn flush_to_disk(&self, agent_type: &str) -> Result<()> {
208 let buffers = self.buffers.read().await;
209
210 if let Some(buffer) = buffers.get(agent_type) {
211 let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
212 let json = serde_json::to_string_pretty(buffer)
213 .map_err(|e| HookError::BufferError(format!("Failed to serialize: {}", e)))?;
214
215 let mut file = fs::File::create(&buffer_file)
216 .await
217 .map_err(|e| HookError::BufferError(format!("Failed to create file: {}", e)))?;
218
219 file.write_all(json.as_bytes())
220 .await
221 .map_err(|e| HookError::BufferError(format!("Failed to write: {}", e)))?;
222
223 drop(buffers);
225 let mut buffers = self.buffers.write().await;
226 if let Some(buffer) = buffers.get_mut(agent_type) {
227 buffer.last_flush = Some(Utc::now());
228 }
229 }
230
231 Ok(())
232 }
233
234 pub async fn flush_all(&self) -> Result<()> {
236 let buffers = self.buffers.read().await;
237 let agent_types: Vec<String> = buffers.keys().cloned().collect();
238 drop(buffers);
239
240 for agent_type in agent_types {
241 self.flush_to_disk(&agent_type).await?;
242 }
243
244 Ok(())
245 }
246
247 pub async fn recover_buffer(&self, agent_type: &str) -> Result<Option<BufferData>> {
249 let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
250
251 if !buffer_file.exists() {
252 return Ok(None);
253 }
254
255 let content = fs::read_to_string(&buffer_file)
256 .await
257 .map_err(|e| HookError::BufferError(format!("Failed to read buffer: {}", e)))?;
258
259 let data: BufferData = serde_json::from_str(&content)
260 .map_err(|e| HookError::BufferError(format!("Failed to parse buffer: {}", e)))?;
261
262 tracing::info!(
263 "Recovered buffer for {}: {} entries",
264 agent_type,
265 data.entries.len()
266 );
267
268 Ok(Some(data))
269 }
270
271 pub async fn clear_buffer(&self, agent_type: &str) -> Result<()> {
273 {
275 let mut buffers = self.buffers.write().await;
276 buffers.remove(agent_type);
277 }
278
279 let buffer_file = self.buffer_dir.join(format!("{}.json", agent_type));
281 if buffer_file.exists() {
282 fs::remove_file(&buffer_file)
283 .await
284 .map_err(|e| HookError::BufferError(format!("Failed to remove buffer: {}", e)))?;
285 }
286
287 Ok(())
288 }
289
290 pub async fn get_buffer_status(&self, agent_type: &str) -> Option<BufferStatus> {
292 let buffers = self.buffers.read().await;
293
294 buffers.get(agent_type).map(|buffer| BufferStatus {
295 agent_type: agent_type.to_string(),
296 started_at: buffer.started_at,
297 entries_count: buffer.entries.len(),
298 last_flush: buffer.last_flush,
299 })
300 }
301
302 pub async fn list_buffers(&self) -> Vec<BufferStatus> {
304 let buffers = self.buffers.read().await;
305
306 buffers
307 .iter()
308 .map(|(agent_type, buffer)| BufferStatus {
309 agent_type: agent_type.clone(),
310 started_at: buffer.started_at,
311 entries_count: buffer.entries.len(),
312 last_flush: buffer.last_flush,
313 })
314 .collect()
315 }
316
317 pub async fn has_buffer(&self, agent_type: &str) -> bool {
319 let buffers = self.buffers.read().await;
320 buffers.contains_key(agent_type)
321 || self
322 .buffer_dir
323 .join(format!("{}.json", agent_type))
324 .exists()
325 }
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct BufferStatus {
331 pub agent_type: String,
332 pub started_at: DateTime<Utc>,
333 pub entries_count: usize,
334 pub last_flush: Option<DateTime<Utc>>,
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use tempfile::tempdir;
341
342 #[tokio::test]
343 async fn test_buffer_context() {
344 let dir = tempdir().unwrap();
345 let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
346
347 buffer.start_buffering("test-agent").await.unwrap();
348
349 let ctx = SessionContext::new("test-agent");
350 buffer
351 .buffer_context("test-agent", ctx, "checkpoint")
352 .await
353 .unwrap();
354
355 let status = buffer.get_buffer_status("test-agent").await.unwrap();
356 assert_eq!(status.entries_count, 1);
357 }
358
359 #[tokio::test]
360 async fn test_flush_and_recover() {
361 let dir = tempdir().unwrap();
362 let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf()))
363 .unwrap()
364 .with_max_entries(1);
365
366 let ctx = SessionContext::new("test-agent");
367
368 buffer.start_buffering("test-agent").await.unwrap();
370 buffer
371 .buffer_context("test-agent", ctx.clone(), "test")
372 .await
373 .unwrap();
374
375 let recovered = buffer.recover_buffer("test-agent").await.unwrap();
377 assert!(recovered.is_some());
378
379 let data = recovered.unwrap();
380 assert_eq!(data.entries.len(), 1);
381 }
382
383 #[tokio::test]
384 async fn test_clear_buffer() {
385 let dir = tempdir().unwrap();
386 let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
387
388 buffer.start_buffering("test-agent").await.unwrap();
389
390 let ctx = SessionContext::new("test-agent");
391 buffer
392 .buffer_context("test-agent", ctx, "test")
393 .await
394 .unwrap();
395
396 buffer.flush_to_disk("test-agent").await.unwrap();
397 buffer.clear_buffer("test-agent").await.unwrap();
398
399 let status = buffer.get_buffer_status("test-agent").await;
400 assert!(status.is_none());
401
402 let recovered = buffer.recover_buffer("test-agent").await.unwrap();
403 assert!(recovered.is_none());
404 }
405
406 #[tokio::test]
407 async fn test_list_buffers() {
408 let dir = tempdir().unwrap();
409 let buffer = PersistentBuffer::new(Some(dir.path().to_path_buf())).unwrap();
410
411 buffer.start_buffering("agent1").await.unwrap();
412 buffer.start_buffering("agent2").await.unwrap();
413
414 let buffers = buffer.list_buffers().await;
415 assert_eq!(buffers.len(), 2);
416 }
417}