llm_memory_graph/storage/
pooled_backend.rs1use crate::error::{Error, Result};
34use crate::storage::{AsyncSledBackend, AsyncStorageBackend, StorageStats};
35use crate::types::{Edge, EdgeId, Node, NodeId, SessionId};
36use async_trait::async_trait;
37use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
38use std::sync::Arc;
39use std::time::Duration;
40use tokio::sync::Semaphore;
41use tokio::time::timeout;
42
43#[derive(Debug, Clone)]
45pub struct PoolConfig {
46 pub max_concurrent: usize,
48 pub acquire_timeout_ms: u64,
50 pub enable_metrics: bool,
52}
53
54impl Default for PoolConfig {
55 fn default() -> Self {
56 Self {
57 max_concurrent: 100, acquire_timeout_ms: 5000, enable_metrics: true,
60 }
61 }
62}
63
64impl PoolConfig {
65 pub fn new() -> Self {
67 Self::default()
68 }
69
70 pub fn with_max_concurrent(mut self, max: usize) -> Self {
72 self.max_concurrent = max;
73 self
74 }
75
76 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
78 self.acquire_timeout_ms = timeout_ms;
79 self
80 }
81
82 pub fn with_metrics(mut self, enable: bool) -> Self {
84 self.enable_metrics = enable;
85 self
86 }
87}
88
89#[derive(Debug)]
91pub struct PoolMetrics {
92 total_operations: AtomicU64,
94 successful_operations: AtomicU64,
96 failed_operations: AtomicU64,
98 timeouts: AtomicU64,
100 active_operations: AtomicUsize,
102 peak_concurrent: AtomicUsize,
104 total_wait_time_us: AtomicU64,
106}
107
108impl PoolMetrics {
109 pub fn new() -> Self {
111 Self {
112 total_operations: AtomicU64::new(0),
113 successful_operations: AtomicU64::new(0),
114 failed_operations: AtomicU64::new(0),
115 timeouts: AtomicU64::new(0),
116 active_operations: AtomicUsize::new(0),
117 peak_concurrent: AtomicUsize::new(0),
118 total_wait_time_us: AtomicU64::new(0),
119 }
120 }
121
122 fn operation_started(&self) {
124 self.total_operations.fetch_add(1, Ordering::Relaxed);
125 let active = self.active_operations.fetch_add(1, Ordering::Relaxed) + 1;
126
127 let mut peak = self.peak_concurrent.load(Ordering::Relaxed);
129 while active > peak {
130 match self.peak_concurrent.compare_exchange_weak(
131 peak,
132 active,
133 Ordering::Relaxed,
134 Ordering::Relaxed,
135 ) {
136 Ok(_) => break,
137 Err(p) => peak = p,
138 }
139 }
140 }
141
142 fn operation_completed(&self, success: bool) {
144 if success {
145 self.successful_operations.fetch_add(1, Ordering::Relaxed);
146 } else {
147 self.failed_operations.fetch_add(1, Ordering::Relaxed);
148 }
149 self.active_operations.fetch_sub(1, Ordering::Relaxed);
150 }
151
152 fn record_timeout(&self) {
154 self.timeouts.fetch_add(1, Ordering::Relaxed);
155 self.failed_operations.fetch_add(1, Ordering::Relaxed);
156 }
157
158 fn record_wait_time(&self, wait_time_us: u64) {
160 self.total_wait_time_us
161 .fetch_add(wait_time_us, Ordering::Relaxed);
162 }
163
164 pub fn snapshot(&self) -> PoolMetricsSnapshot {
166 PoolMetricsSnapshot {
167 total_operations: self.total_operations.load(Ordering::Relaxed),
168 successful_operations: self.successful_operations.load(Ordering::Relaxed),
169 failed_operations: self.failed_operations.load(Ordering::Relaxed),
170 timeouts: self.timeouts.load(Ordering::Relaxed),
171 active_operations: self.active_operations.load(Ordering::Relaxed),
172 peak_concurrent: self.peak_concurrent.load(Ordering::Relaxed),
173 total_wait_time_us: self.total_wait_time_us.load(Ordering::Relaxed),
174 }
175 }
176}
177
178impl Default for PoolMetrics {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184#[derive(Debug, Clone, Copy)]
186pub struct PoolMetricsSnapshot {
187 pub total_operations: u64,
189 pub successful_operations: u64,
191 pub failed_operations: u64,
193 pub timeouts: u64,
195 pub active_operations: usize,
197 pub peak_concurrent: usize,
199 pub total_wait_time_us: u64,
201}
202
203impl PoolMetricsSnapshot {
204 pub fn avg_wait_time_ms(&self) -> f64 {
206 if self.total_operations == 0 {
207 0.0
208 } else {
209 (self.total_wait_time_us as f64) / (self.total_operations as f64) / 1000.0
210 }
211 }
212
213 pub fn success_rate(&self) -> f64 {
215 if self.total_operations == 0 {
216 1.0
217 } else {
218 (self.successful_operations as f64) / (self.total_operations as f64)
219 }
220 }
221
222 pub fn timeout_rate(&self) -> f64 {
224 if self.total_operations == 0 {
225 0.0
226 } else {
227 (self.timeouts as f64) / (self.total_operations as f64)
228 }
229 }
230}
231
232pub struct PooledAsyncBackend {
262 backend: Arc<AsyncSledBackend>,
264 semaphore: Arc<Semaphore>,
266 config: PoolConfig,
268 metrics: Arc<PoolMetrics>,
270}
271
272impl PooledAsyncBackend {
273 pub async fn open(path: &std::path::Path, config: PoolConfig) -> Result<Self> {
291 let backend = AsyncSledBackend::open(path).await?;
292 let semaphore = Arc::new(Semaphore::new(config.max_concurrent));
293 let metrics = Arc::new(PoolMetrics::new());
295
296 Ok(Self {
297 backend: Arc::new(backend),
298 semaphore,
299 config,
300 metrics,
301 })
302 }
303
304 pub fn metrics(&self) -> PoolMetricsSnapshot {
306 self.metrics.snapshot()
307 }
308
309 pub fn config(&self) -> &PoolConfig {
311 &self.config
312 }
313
314 pub fn available_permits(&self) -> usize {
316 self.semaphore.available_permits()
317 }
318
319 async fn acquire_permit(&self) -> Result<tokio::sync::SemaphorePermit<'_>> {
321 let start = std::time::Instant::now();
322
323 let permit = timeout(
324 Duration::from_millis(self.config.acquire_timeout_ms),
325 self.semaphore.acquire(),
326 )
327 .await
328 .map_err(|_| {
329 self.metrics.record_timeout();
330 Error::Storage("Pool acquire timeout".to_string())
331 })?
332 .map_err(|_| Error::Storage("Semaphore closed".to_string()))?;
333
334 let wait_time = start.elapsed().as_micros() as u64;
336 self.metrics.record_wait_time(wait_time);
337 self.metrics.operation_started();
338
339 Ok(permit)
340 }
341
342 async fn with_permit<F, T>(&self, f: F) -> Result<T>
344 where
345 F: std::future::Future<Output = Result<T>>,
346 {
347 let _permit = self.acquire_permit().await?;
348
349 let result = f.await;
350 self.metrics.operation_completed(result.is_ok());
351
352 result
353 }
354}
355
356#[async_trait]
357impl AsyncStorageBackend for PooledAsyncBackend {
358 async fn store_node(&self, node: &Node) -> Result<()> {
359 self.with_permit(self.backend.store_node(node)).await
360 }
361
362 async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
363 self.with_permit(self.backend.get_node(id)).await
364 }
365
366 async fn delete_node(&self, id: &NodeId) -> Result<()> {
367 self.with_permit(self.backend.delete_node(id)).await
368 }
369
370 async fn store_edge(&self, edge: &Edge) -> Result<()> {
371 self.with_permit(self.backend.store_edge(edge)).await
372 }
373
374 async fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>> {
375 self.with_permit(self.backend.get_edge(id)).await
376 }
377
378 async fn delete_edge(&self, id: &EdgeId) -> Result<()> {
379 self.with_permit(self.backend.delete_edge(id)).await
380 }
381
382 async fn get_session_nodes(&self, session_id: &SessionId) -> Result<Vec<Node>> {
383 self.with_permit(self.backend.get_session_nodes(session_id))
384 .await
385 }
386
387 async fn get_outgoing_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
388 self.with_permit(self.backend.get_outgoing_edges(node_id))
389 .await
390 }
391
392 async fn get_incoming_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
393 self.with_permit(self.backend.get_incoming_edges(node_id))
394 .await
395 }
396
397 async fn flush(&self) -> Result<()> {
398 self.with_permit(self.backend.flush()).await
399 }
400
401 async fn stats(&self) -> Result<StorageStats> {
402 self.with_permit(self.backend.stats()).await
403 }
404
405 async fn store_nodes_batch(&self, nodes: &[Node]) -> Result<Vec<NodeId>> {
406 self.with_permit(self.backend.store_nodes_batch(nodes))
407 .await
408 }
409
410 async fn store_edges_batch(&self, edges: &[Edge]) -> Result<Vec<EdgeId>> {
411 self.with_permit(self.backend.store_edges_batch(edges))
412 .await
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use crate::types::ConversationSession;
420 use tempfile::tempdir;
421
422 #[tokio::test]
423 async fn test_pool_creation() {
424 let dir = tempdir().unwrap();
425 let config = PoolConfig::new().with_max_concurrent(10);
426
427 let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
428
429 assert_eq!(pool.available_permits(), 10);
430 assert_eq!(pool.config().max_concurrent, 10);
431 }
432
433 #[tokio::test]
434 async fn test_pool_operations() {
435 let dir = tempdir().unwrap();
436 let config = PoolConfig::new();
437
438 let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
439
440 let session = ConversationSession::new();
442 let node = Node::Session(session.clone());
443
444 pool.store_node(&node).await.unwrap();
446
447 let retrieved = pool.get_node(&session.node_id).await.unwrap();
449 assert!(retrieved.is_some());
450
451 let metrics = pool.metrics();
453 assert!(metrics.total_operations >= 2); assert!(metrics.successful_operations >= 2);
455 }
456
457 #[tokio::test]
458 async fn test_concurrent_operations() {
459 let dir = tempdir().unwrap();
460 let config = PoolConfig::new().with_max_concurrent(20);
461
462 let pool = Arc::new(PooledAsyncBackend::open(dir.path(), config).await.unwrap());
463
464 let mut handles = vec![];
466 for _ in 0..50 {
467 let pool_clone = Arc::clone(&pool);
468 let handle = tokio::spawn(async move {
469 let session = ConversationSession::new();
470 let node = Node::Session(session);
471 pool_clone.store_node(&node).await
472 });
473 handles.push(handle);
474 }
475
476 for handle in handles {
478 handle.await.unwrap().unwrap();
479 }
480
481 let metrics = pool.metrics();
483 assert_eq!(metrics.total_operations, 50);
484 assert_eq!(metrics.successful_operations, 50);
485 assert!(metrics.peak_concurrent <= 20); }
487
488 #[tokio::test]
489 async fn test_pool_backpressure() {
490 let dir = tempdir().unwrap();
491 let config = PoolConfig::new()
492 .with_max_concurrent(2) .with_timeout(1000); let pool = Arc::new(PooledAsyncBackend::open(dir.path(), config).await.unwrap());
496
497 let pool1 = Arc::clone(&pool);
499 let handle1 = tokio::spawn(async move {
500 let _permit = pool1.acquire_permit().await.unwrap();
501 tokio::time::sleep(Duration::from_millis(500)).await;
502 });
503
504 let pool2 = Arc::clone(&pool);
505 let handle2 = tokio::spawn(async move {
506 let _permit = pool2.acquire_permit().await.unwrap();
507 tokio::time::sleep(Duration::from_millis(500)).await;
508 });
509
510 tokio::time::sleep(Duration::from_millis(50)).await;
512
513 assert_eq!(pool.available_permits(), 0);
515
516 handle1.await.unwrap();
518 handle2.await.unwrap();
519
520 assert_eq!(pool.available_permits(), 2);
522 }
523
524 #[tokio::test]
525 async fn test_metrics_tracking() {
526 let dir = tempdir().unwrap();
527 let config = PoolConfig::new().with_metrics(true);
528
529 let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
530
531 for _ in 0..10 {
533 let session = ConversationSession::new();
534 pool.store_node(&Node::Session(session)).await.unwrap();
535 }
536
537 let metrics = pool.metrics();
538 assert_eq!(metrics.total_operations, 10);
539 assert_eq!(metrics.successful_operations, 10);
540 assert_eq!(metrics.failed_operations, 0);
541 assert!(metrics.avg_wait_time_ms() >= 0.0);
542 assert_eq!(metrics.success_rate(), 1.0);
543 }
544
545 #[tokio::test]
546 async fn test_batch_operations_with_pool() {
547 let dir = tempdir().unwrap();
548 let config = PoolConfig::new();
549
550 let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
551
552 let nodes: Vec<Node> = (0..10)
554 .map(|_| Node::Session(ConversationSession::new()))
555 .collect();
556
557 let ids = pool.store_nodes_batch(&nodes).await.unwrap();
558 assert_eq!(ids.len(), 10);
559
560 let metrics = pool.metrics();
562 assert_eq!(metrics.total_operations, 1);
563 }
564}