llm_memory_graph/storage/
pooled_backend.rs

1//! Pooled async storage backend with resource management
2//!
3//! This module provides an enterprise-grade connection pooling layer for the
4//! async storage backend. While Sled is an embedded database that doesn't have
5//! traditional "connections", this pool manages concurrent access, provides
6//! backpressure, and collects metrics for production deployments.
7//!
8//! # Features
9//!
10//! - **Concurrent Access Control**: Limits simultaneous operations to prevent resource exhaustion
11//! - **Backpressure**: Applies backpressure when pool is saturated
12//! - **Metrics**: Tracks pool utilization, wait times, and operation counts
13//! - **Timeout Handling**: Configurable timeouts for acquiring pool permits
14//! - **Graceful Degradation**: Handles overload scenarios gracefully
15//!
16//! # Architecture
17//!
18//! ```text
19//! ┌─────────────────────────────────────────┐
20//! │  PooledAsyncBackend                     │
21//! │  ┌────────────────────────────────────┐ │
22//! │  │ Semaphore (max_concurrent)         │ │
23//! │  └────────────────────────────────────┘ │
24//! │  ┌────────────────────────────────────┐ │
25//! │  │ PoolMetrics (atomic counters)      │ │
26//! │  └────────────────────────────────────┘ │
27//! │  ┌────────────────────────────────────┐ │
28//! │  │ AsyncSledBackend (underlying DB)   │ │
29//! │  └────────────────────────────────────┘ │
30//! └─────────────────────────────────────────┘
31//! ```
32
33use crate::error::{Error, Result};
34use crate::storage::{AsyncSledBackend, AsyncStorageBackend, StorageStats};
35use crate::types::{Edge, EdgeId, Node, NodeId, SessionId};
36use async_trait::async_trait;
37use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
38use std::sync::Arc;
39use std::time::Duration;
40use tokio::sync::Semaphore;
41use tokio::time::timeout;
42
43/// Configuration for the connection pool
44#[derive(Debug, Clone)]
45pub struct PoolConfig {
46    /// Maximum number of concurrent operations
47    pub max_concurrent: usize,
48    /// Timeout for acquiring a pool permit (milliseconds)
49    pub acquire_timeout_ms: u64,
50    /// Enable detailed metrics collection
51    pub enable_metrics: bool,
52}
53
54impl Default for PoolConfig {
55    fn default() -> Self {
56        Self {
57            max_concurrent: 100,      // Allow 100 concurrent operations
58            acquire_timeout_ms: 5000, // 5 second timeout
59            enable_metrics: true,
60        }
61    }
62}
63
64impl PoolConfig {
65    /// Create a new pool configuration
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    /// Set maximum concurrent operations
71    pub fn with_max_concurrent(mut self, max: usize) -> Self {
72        self.max_concurrent = max;
73        self
74    }
75
76    /// Set acquire timeout in milliseconds
77    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
78        self.acquire_timeout_ms = timeout_ms;
79        self
80    }
81
82    /// Enable or disable metrics
83    pub fn with_metrics(mut self, enable: bool) -> Self {
84        self.enable_metrics = enable;
85        self
86    }
87}
88
89/// Metrics for the connection pool
90#[derive(Debug)]
91pub struct PoolMetrics {
92    /// Total number of operations performed
93    total_operations: AtomicU64,
94    /// Total number of successful operations
95    successful_operations: AtomicU64,
96    /// Total number of failed operations
97    failed_operations: AtomicU64,
98    /// Total number of timeouts
99    timeouts: AtomicU64,
100    /// Current number of active operations
101    active_operations: AtomicUsize,
102    /// Peak number of concurrent operations
103    peak_concurrent: AtomicUsize,
104    /// Total time spent waiting for permits (microseconds)
105    total_wait_time_us: AtomicU64,
106}
107
108impl PoolMetrics {
109    /// Create new pool metrics
110    pub fn new() -> Self {
111        Self {
112            total_operations: AtomicU64::new(0),
113            successful_operations: AtomicU64::new(0),
114            failed_operations: AtomicU64::new(0),
115            timeouts: AtomicU64::new(0),
116            active_operations: AtomicUsize::new(0),
117            peak_concurrent: AtomicUsize::new(0),
118            total_wait_time_us: AtomicU64::new(0),
119        }
120    }
121
122    /// Record operation start
123    fn operation_started(&self) {
124        self.total_operations.fetch_add(1, Ordering::Relaxed);
125        let active = self.active_operations.fetch_add(1, Ordering::Relaxed) + 1;
126
127        // Update peak if needed
128        let mut peak = self.peak_concurrent.load(Ordering::Relaxed);
129        while active > peak {
130            match self.peak_concurrent.compare_exchange_weak(
131                peak,
132                active,
133                Ordering::Relaxed,
134                Ordering::Relaxed,
135            ) {
136                Ok(_) => break,
137                Err(p) => peak = p,
138            }
139        }
140    }
141
142    /// Record operation completion
143    fn operation_completed(&self, success: bool) {
144        if success {
145            self.successful_operations.fetch_add(1, Ordering::Relaxed);
146        } else {
147            self.failed_operations.fetch_add(1, Ordering::Relaxed);
148        }
149        self.active_operations.fetch_sub(1, Ordering::Relaxed);
150    }
151
152    /// Record timeout
153    fn record_timeout(&self) {
154        self.timeouts.fetch_add(1, Ordering::Relaxed);
155        self.failed_operations.fetch_add(1, Ordering::Relaxed);
156    }
157
158    /// Record wait time
159    fn record_wait_time(&self, wait_time_us: u64) {
160        self.total_wait_time_us
161            .fetch_add(wait_time_us, Ordering::Relaxed);
162    }
163
164    /// Get a snapshot of current metrics
165    pub fn snapshot(&self) -> PoolMetricsSnapshot {
166        PoolMetricsSnapshot {
167            total_operations: self.total_operations.load(Ordering::Relaxed),
168            successful_operations: self.successful_operations.load(Ordering::Relaxed),
169            failed_operations: self.failed_operations.load(Ordering::Relaxed),
170            timeouts: self.timeouts.load(Ordering::Relaxed),
171            active_operations: self.active_operations.load(Ordering::Relaxed),
172            peak_concurrent: self.peak_concurrent.load(Ordering::Relaxed),
173            total_wait_time_us: self.total_wait_time_us.load(Ordering::Relaxed),
174        }
175    }
176}
177
178impl Default for PoolMetrics {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184/// Snapshot of pool metrics at a point in time
185#[derive(Debug, Clone, Copy)]
186pub struct PoolMetricsSnapshot {
187    /// Total operations performed
188    pub total_operations: u64,
189    /// Successful operations
190    pub successful_operations: u64,
191    /// Failed operations
192    pub failed_operations: u64,
193    /// Number of timeouts
194    pub timeouts: u64,
195    /// Currently active operations
196    pub active_operations: usize,
197    /// Peak concurrent operations
198    pub peak_concurrent: usize,
199    /// Total wait time in microseconds
200    pub total_wait_time_us: u64,
201}
202
203impl PoolMetricsSnapshot {
204    /// Calculate average wait time in milliseconds
205    pub fn avg_wait_time_ms(&self) -> f64 {
206        if self.total_operations == 0 {
207            0.0
208        } else {
209            (self.total_wait_time_us as f64) / (self.total_operations as f64) / 1000.0
210        }
211    }
212
213    /// Calculate success rate (0.0 to 1.0)
214    pub fn success_rate(&self) -> f64 {
215        if self.total_operations == 0 {
216            1.0
217        } else {
218            (self.successful_operations as f64) / (self.total_operations as f64)
219        }
220    }
221
222    /// Calculate timeout rate (0.0 to 1.0)
223    pub fn timeout_rate(&self) -> f64 {
224        if self.total_operations == 0 {
225            0.0
226        } else {
227            (self.timeouts as f64) / (self.total_operations as f64)
228        }
229    }
230}
231
232/// Pooled async storage backend with resource management
233///
234/// This backend wraps AsyncSledBackend with a semaphore-based pool that:
235/// - Limits concurrent operations to prevent resource exhaustion
236/// - Provides backpressure when the pool is saturated
237/// - Collects metrics on pool utilization
238/// - Implements timeouts to prevent indefinite blocking
239///
240/// # Examples
241///
242/// ```no_run
243/// use llm_memory_graph::storage::{PooledAsyncBackend, PoolConfig};
244/// use std::path::Path;
245///
246/// #[tokio::main]
247/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
248///     let config = PoolConfig::new()
249///         .with_max_concurrent(50)
250///         .with_timeout(3000);
251///
252///     let backend = PooledAsyncBackend::open(Path::new("./data/db"), config).await?;
253///
254///     // Get pool metrics
255///     let metrics = backend.metrics();
256///     println!("Active operations: {}", metrics.active_operations);
257///
258///     Ok(())
259/// }
260/// ```
261pub struct PooledAsyncBackend {
262    /// Underlying async backend
263    backend: Arc<AsyncSledBackend>,
264    /// Semaphore for controlling concurrent access
265    semaphore: Arc<Semaphore>,
266    /// Pool configuration
267    config: PoolConfig,
268    /// Pool metrics
269    metrics: Arc<PoolMetrics>,
270}
271
272impl PooledAsyncBackend {
273    /// Open a pooled async backend
274    ///
275    /// # Examples
276    ///
277    /// ```no_run
278    /// use llm_memory_graph::storage::{PooledAsyncBackend, PoolConfig};
279    /// use std::path::Path;
280    ///
281    /// # #[tokio::main]
282    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
283    /// let backend = PooledAsyncBackend::open(
284    ///     Path::new("./data/db"),
285    ///     PoolConfig::default()
286    /// ).await?;
287    /// # Ok(())
288    /// # }
289    /// ```
290    pub async fn open(path: &std::path::Path, config: PoolConfig) -> Result<Self> {
291        let backend = AsyncSledBackend::open(path).await?;
292        let semaphore = Arc::new(Semaphore::new(config.max_concurrent));
293        // Always create metrics regardless of config for now
294        let metrics = Arc::new(PoolMetrics::new());
295
296        Ok(Self {
297            backend: Arc::new(backend),
298            semaphore,
299            config,
300            metrics,
301        })
302    }
303
304    /// Get current pool metrics
305    pub fn metrics(&self) -> PoolMetricsSnapshot {
306        self.metrics.snapshot()
307    }
308
309    /// Get pool configuration
310    pub fn config(&self) -> &PoolConfig {
311        &self.config
312    }
313
314    /// Get number of available permits
315    pub fn available_permits(&self) -> usize {
316        self.semaphore.available_permits()
317    }
318
319    /// Acquire a permit from the pool with timeout
320    async fn acquire_permit(&self) -> Result<tokio::sync::SemaphorePermit<'_>> {
321        let start = std::time::Instant::now();
322
323        let permit = timeout(
324            Duration::from_millis(self.config.acquire_timeout_ms),
325            self.semaphore.acquire(),
326        )
327        .await
328        .map_err(|_| {
329            self.metrics.record_timeout();
330            Error::Storage("Pool acquire timeout".to_string())
331        })?
332        .map_err(|_| Error::Storage("Semaphore closed".to_string()))?;
333
334        // Record wait time
335        let wait_time = start.elapsed().as_micros() as u64;
336        self.metrics.record_wait_time(wait_time);
337        self.metrics.operation_started();
338
339        Ok(permit)
340    }
341
342    /// Execute an operation with pool management
343    async fn with_permit<F, T>(&self, f: F) -> Result<T>
344    where
345        F: std::future::Future<Output = Result<T>>,
346    {
347        let _permit = self.acquire_permit().await?;
348
349        let result = f.await;
350        self.metrics.operation_completed(result.is_ok());
351
352        result
353    }
354}
355
356#[async_trait]
357impl AsyncStorageBackend for PooledAsyncBackend {
358    async fn store_node(&self, node: &Node) -> Result<()> {
359        self.with_permit(self.backend.store_node(node)).await
360    }
361
362    async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
363        self.with_permit(self.backend.get_node(id)).await
364    }
365
366    async fn delete_node(&self, id: &NodeId) -> Result<()> {
367        self.with_permit(self.backend.delete_node(id)).await
368    }
369
370    async fn store_edge(&self, edge: &Edge) -> Result<()> {
371        self.with_permit(self.backend.store_edge(edge)).await
372    }
373
374    async fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>> {
375        self.with_permit(self.backend.get_edge(id)).await
376    }
377
378    async fn delete_edge(&self, id: &EdgeId) -> Result<()> {
379        self.with_permit(self.backend.delete_edge(id)).await
380    }
381
382    async fn get_session_nodes(&self, session_id: &SessionId) -> Result<Vec<Node>> {
383        self.with_permit(self.backend.get_session_nodes(session_id))
384            .await
385    }
386
387    async fn get_outgoing_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
388        self.with_permit(self.backend.get_outgoing_edges(node_id))
389            .await
390    }
391
392    async fn get_incoming_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
393        self.with_permit(self.backend.get_incoming_edges(node_id))
394            .await
395    }
396
397    async fn flush(&self) -> Result<()> {
398        self.with_permit(self.backend.flush()).await
399    }
400
401    async fn stats(&self) -> Result<StorageStats> {
402        self.with_permit(self.backend.stats()).await
403    }
404
405    async fn store_nodes_batch(&self, nodes: &[Node]) -> Result<Vec<NodeId>> {
406        self.with_permit(self.backend.store_nodes_batch(nodes))
407            .await
408    }
409
410    async fn store_edges_batch(&self, edges: &[Edge]) -> Result<Vec<EdgeId>> {
411        self.with_permit(self.backend.store_edges_batch(edges))
412            .await
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use crate::types::ConversationSession;
420    use tempfile::tempdir;
421
422    #[tokio::test]
423    async fn test_pool_creation() {
424        let dir = tempdir().unwrap();
425        let config = PoolConfig::new().with_max_concurrent(10);
426
427        let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
428
429        assert_eq!(pool.available_permits(), 10);
430        assert_eq!(pool.config().max_concurrent, 10);
431    }
432
433    #[tokio::test]
434    async fn test_pool_operations() {
435        let dir = tempdir().unwrap();
436        let config = PoolConfig::new();
437
438        let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
439
440        // Create session
441        let session = ConversationSession::new();
442        let node = Node::Session(session.clone());
443
444        // Store node
445        pool.store_node(&node).await.unwrap();
446
447        // Retrieve node
448        let retrieved = pool.get_node(&session.node_id).await.unwrap();
449        assert!(retrieved.is_some());
450
451        // Check metrics
452        let metrics = pool.metrics();
453        assert!(metrics.total_operations >= 2); // At least store + get
454        assert!(metrics.successful_operations >= 2);
455    }
456
457    #[tokio::test]
458    async fn test_concurrent_operations() {
459        let dir = tempdir().unwrap();
460        let config = PoolConfig::new().with_max_concurrent(20);
461
462        let pool = Arc::new(PooledAsyncBackend::open(dir.path(), config).await.unwrap());
463
464        // Create 50 concurrent operations
465        let mut handles = vec![];
466        for _ in 0..50 {
467            let pool_clone = Arc::clone(&pool);
468            let handle = tokio::spawn(async move {
469                let session = ConversationSession::new();
470                let node = Node::Session(session);
471                pool_clone.store_node(&node).await
472            });
473            handles.push(handle);
474        }
475
476        // Wait for all operations
477        for handle in handles {
478            handle.await.unwrap().unwrap();
479        }
480
481        // Check metrics
482        let metrics = pool.metrics();
483        assert_eq!(metrics.total_operations, 50);
484        assert_eq!(metrics.successful_operations, 50);
485        assert!(metrics.peak_concurrent <= 20); // Shouldn't exceed pool size
486    }
487
488    #[tokio::test]
489    async fn test_pool_backpressure() {
490        let dir = tempdir().unwrap();
491        let config = PoolConfig::new()
492            .with_max_concurrent(2) // Very small pool
493            .with_timeout(1000); // 1 second timeout
494
495        let pool = Arc::new(PooledAsyncBackend::open(dir.path(), config).await.unwrap());
496
497        // Start 2 long-running operations to fill the pool
498        let pool1 = Arc::clone(&pool);
499        let handle1 = tokio::spawn(async move {
500            let _permit = pool1.acquire_permit().await.unwrap();
501            tokio::time::sleep(Duration::from_millis(500)).await;
502        });
503
504        let pool2 = Arc::clone(&pool);
505        let handle2 = tokio::spawn(async move {
506            let _permit = pool2.acquire_permit().await.unwrap();
507            tokio::time::sleep(Duration::from_millis(500)).await;
508        });
509
510        // Give time for permits to be acquired
511        tokio::time::sleep(Duration::from_millis(50)).await;
512
513        // Pool should be full
514        assert_eq!(pool.available_permits(), 0);
515
516        // Wait for operations to complete
517        handle1.await.unwrap();
518        handle2.await.unwrap();
519
520        // Permits should be returned
521        assert_eq!(pool.available_permits(), 2);
522    }
523
524    #[tokio::test]
525    async fn test_metrics_tracking() {
526        let dir = tempdir().unwrap();
527        let config = PoolConfig::new().with_metrics(true);
528
529        let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
530
531        // Perform operations
532        for _ in 0..10 {
533            let session = ConversationSession::new();
534            pool.store_node(&Node::Session(session)).await.unwrap();
535        }
536
537        let metrics = pool.metrics();
538        assert_eq!(metrics.total_operations, 10);
539        assert_eq!(metrics.successful_operations, 10);
540        assert_eq!(metrics.failed_operations, 0);
541        assert!(metrics.avg_wait_time_ms() >= 0.0);
542        assert_eq!(metrics.success_rate(), 1.0);
543    }
544
545    #[tokio::test]
546    async fn test_batch_operations_with_pool() {
547        let dir = tempdir().unwrap();
548        let config = PoolConfig::new();
549
550        let pool = PooledAsyncBackend::open(dir.path(), config).await.unwrap();
551
552        // Create batch of nodes
553        let nodes: Vec<Node> = (0..10)
554            .map(|_| Node::Session(ConversationSession::new()))
555            .collect();
556
557        let ids = pool.store_nodes_batch(&nodes).await.unwrap();
558        assert_eq!(ids.len(), 10);
559
560        // Check metrics - batch should count as 1 operation
561        let metrics = pool.metrics();
562        assert_eq!(metrics.total_operations, 1);
563    }
564}