Skip to main content

do_memory_storage_turso/pool/
adaptive.rs

1//! Adaptive connection pool that dynamically adjusts pool size based on load.
2
3use do_memory_core::{Error, Result};
4use libsql::Database;
5use parking_lot::RwLock;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9use tokio::sync::{OwnedSemaphorePermit, Semaphore};
10use tracing::{debug, info};
11
12/// Unique identifier for a connection
13pub type ConnectionId = u64;
14
15/// Callback type for connection lifecycle events
16///
17/// This is called when a connection is dropped, allowing external components
18/// (like the prepared statement cache) to clean up resources associated with
19/// the connection.
20pub type ConnectionCleanupCallback = Arc<dyn Fn(ConnectionId) + Send + Sync>;
21
22#[derive(Debug, Clone)]
23pub struct AdaptivePoolConfig {
24    pub min_connections: u32,
25    pub max_connections: u32,
26    pub scale_up_threshold: f64,
27    pub scale_down_threshold: f64,
28    pub scale_up_cooldown: Duration,
29    pub scale_down_cooldown: Duration,
30    pub scale_up_increment: u32,
31    pub scale_down_decrement: u32,
32    pub check_interval: Duration,
33}
34
35impl Default for AdaptivePoolConfig {
36    fn default() -> Self {
37        Self {
38            min_connections: 5,
39            max_connections: 50,
40            scale_up_threshold: 0.7,
41            scale_down_threshold: 0.3,
42            scale_up_cooldown: Duration::from_secs(10),
43            scale_down_cooldown: Duration::from_secs(30),
44            scale_up_increment: 5,
45            scale_down_decrement: 5,
46            check_interval: Duration::from_secs(5),
47        }
48    }
49}
50
51#[derive(Debug, Default)]
52pub struct AdaptivePoolMetrics {
53    pub utilization_percent: f64,
54    pub active_connections: u32,
55    pub max_connections: u32,
56    pub scale_up_count: u32,
57    pub scale_down_count: u32,
58    pub avg_wait_time_us: u64,
59    pub total_acquired: u64,
60    pub total_released: u64,
61}
62
63#[derive(Debug)]
64struct AdaptiveMetrics {
65    utilization_percent: AtomicU64,
66    active_connections: AtomicU32,
67    max_connections: AtomicU32,
68    scale_up_count: AtomicU32,
69    scale_down_count: AtomicU32,
70    avg_wait_time_us: AtomicU64,
71    total_acquired: AtomicU64,
72    total_released: AtomicU64,
73    wait_time_total_us: AtomicU64,
74    wait_count: AtomicU64,
75    last_scale_up: AtomicU64,
76    last_scale_down: AtomicU64,
77}
78
79impl Default for AdaptiveMetrics {
80    fn default() -> Self {
81        Self {
82            utilization_percent: AtomicU64::new(0),
83            active_connections: AtomicU32::new(0),
84            max_connections: AtomicU32::new(0),
85            scale_up_count: AtomicU32::new(0),
86            scale_down_count: AtomicU32::new(0),
87            avg_wait_time_us: AtomicU64::new(0),
88            total_acquired: AtomicU64::new(0),
89            total_released: AtomicU64::new(0),
90            wait_time_total_us: AtomicU64::new(0),
91            wait_count: AtomicU64::new(0),
92            last_scale_up: AtomicU64::new(0),
93            last_scale_down: AtomicU64::new(0),
94        }
95    }
96}
97
98pub struct AdaptiveConnectionPool {
99    db: Arc<Database>,
100    config: Arc<AdaptivePoolConfig>,
101    semaphore: Arc<Semaphore>,
102    current_max: Arc<AtomicU32>,
103    metrics: Arc<AdaptiveMetrics>,
104    next_conn_id: Arc<AtomicU64>,
105    cleanup_callback: RwLock<Option<ConnectionCleanupCallback>>,
106    _monitor_task: tokio::task::JoinHandle<()>,
107}
108
109impl AdaptiveConnectionPool {
110    pub async fn new(db: Arc<Database>, config: AdaptivePoolConfig) -> Result<Self> {
111        let config = Arc::new(config);
112        let initial_max = config.min_connections as usize;
113        let min_conn = config.min_connections;
114
115        info!(
116            "Creating adaptive connection pool with min={}, max={}",
117            config.min_connections, config.max_connections
118        );
119
120        let semaphore = Arc::new(Semaphore::new(initial_max));
121
122        let metrics = Arc::new(AdaptiveMetrics::default());
123        metrics.max_connections.store(min_conn, Ordering::Relaxed);
124
125        let pool = Self {
126            db,
127            config: config.clone(),
128            semaphore,
129            current_max: Arc::new(AtomicU32::new(min_conn)),
130            metrics,
131            next_conn_id: Arc::new(AtomicU64::new(1)),
132            cleanup_callback: RwLock::new(None),
133            _monitor_task: tokio::task::spawn(async {}),
134        };
135
136        let conn = pool
137            .db
138            .connect()
139            .map_err(|e| Error::Storage(format!("Failed to connect: {}", e)))?;
140        conn.query("SELECT 1", ())
141            .await
142            .map_err(|e| Error::Storage(format!("Database validation failed: {}", e)))?;
143
144        info!("Adaptive connection pool created successfully");
145
146        Ok(pool)
147    }
148
149    pub async fn new_sync(db: Arc<Database>, config: AdaptivePoolConfig) -> Result<Self> {
150        let config = Arc::new(config);
151        let initial_max = config.min_connections as usize;
152        let min_conn = config.min_connections;
153
154        info!(
155            "Creating adaptive connection pool (sync mode) with min={}, max={}",
156            config.min_connections, config.max_connections
157        );
158
159        let semaphore = Arc::new(Semaphore::new(initial_max));
160
161        let metrics = Arc::new(AdaptiveMetrics::default());
162        metrics.max_connections.store(min_conn, Ordering::Relaxed);
163
164        Ok(Self {
165            db,
166            config,
167            semaphore,
168            current_max: Arc::new(AtomicU32::new(min_conn)),
169            metrics,
170            next_conn_id: Arc::new(AtomicU64::new(1)),
171            cleanup_callback: RwLock::new(None),
172            _monitor_task: tokio::task::spawn(async {}),
173        })
174    }
175
176    async fn try_acquire(&self, timeout: Duration) -> Result<OwnedSemaphorePermit> {
177        let start = Instant::now();
178
179        match tokio::time::timeout(timeout, self.semaphore.clone().acquire_owned()).await {
180            Ok(Ok(permit)) => {
181                let wait_us = start.elapsed().as_micros() as u64;
182
183                self.metrics
184                    .wait_time_total_us
185                    .fetch_add(wait_us, Ordering::Relaxed);
186                self.metrics.wait_count.fetch_add(1, Ordering::Relaxed);
187
188                let total_time = self.metrics.wait_time_total_us.load(Ordering::Relaxed);
189                let count = self.metrics.wait_count.load(Ordering::Relaxed);
190                if let Some(avg) = total_time.checked_div(count) {
191                    self.metrics.avg_wait_time_us.store(avg, Ordering::Relaxed);
192                }
193
194                let active = self
195                    .metrics
196                    .active_connections
197                    .fetch_add(1, Ordering::Relaxed)
198                    + 1;
199
200                let max = self.current_max.load(Ordering::Relaxed);
201                let utilization = (active as f64 / max as f64) * 100.0;
202                self.metrics
203                    .utilization_percent
204                    .store(utilization as u64, Ordering::Relaxed);
205
206                self.metrics.total_acquired.fetch_add(1, Ordering::Relaxed);
207
208                Ok(permit)
209            }
210            Ok(Err(e)) => Err(Error::Storage(format!(
211                "Failed to acquire connection permit: {}",
212                e
213            ))),
214            Err(_) => Err(Error::Storage(format!(
215                "Connection acquisition timed out after {:?}",
216                timeout
217            ))),
218        }
219    }
220
221    async fn scale_up(&self) {
222        let now = Instant::now();
223        let last_up = self.metrics.last_scale_up.load(Ordering::Relaxed);
224
225        // Use duration since a fixed epoch
226        let epoch_duration = Duration::from_nanos(last_up);
227        let last_up_time = Instant::now() - epoch_duration;
228
229        if now.duration_since(last_up_time) < self.config.scale_up_cooldown {
230            return;
231        }
232
233        let current_max = self.current_max.load(Ordering::Relaxed);
234
235        if current_max >= self.config.max_connections {
236            return;
237        }
238
239        let new_max =
240            (current_max + self.config.scale_up_increment).min(self.config.max_connections);
241
242        info!("Scaling up: {} -> {} connections", current_max, new_max);
243
244        self.current_max.store(new_max, Ordering::Relaxed);
245        self.metrics
246            .max_connections
247            .store(new_max, Ordering::Relaxed);
248        self.metrics
249            .last_scale_up
250            .store(now.elapsed().as_nanos() as u64, Ordering::Relaxed);
251        self.metrics.scale_up_count.fetch_add(1, Ordering::Relaxed);
252
253        debug!("Scale up complete: {} connections", new_max);
254    }
255
256    async fn scale_down(&self) {
257        let now = Instant::now();
258        let last_down = self.metrics.last_scale_down.load(Ordering::Relaxed);
259
260        let epoch_duration = Duration::from_nanos(last_down);
261        let last_down_time = Instant::now() - epoch_duration;
262
263        if now.duration_since(last_down_time) < self.config.scale_down_cooldown {
264            return;
265        }
266
267        let current_max = self.current_max.load(Ordering::Relaxed);
268        let active = self.metrics.active_connections.load(Ordering::Relaxed);
269
270        let min_allowed = active.max(self.config.min_connections);
271        let new_max =
272            (current_max.saturating_sub(self.config.scale_down_decrement)).max(min_allowed);
273
274        if new_max >= current_max {
275            return;
276        }
277
278        info!(
279            "Scaling down: {} -> {} connections (active: {})",
280            current_max, new_max, active
281        );
282
283        self.current_max.store(new_max, Ordering::Relaxed);
284        self.metrics
285            .max_connections
286            .store(new_max, Ordering::Relaxed);
287        self.metrics
288            .last_scale_down
289            .store(now.elapsed().as_nanos() as u64, Ordering::Relaxed);
290        self.metrics
291            .scale_down_count
292            .fetch_add(1, Ordering::Relaxed);
293
294        debug!("Scale down complete: {} connections", new_max);
295    }
296
297    pub async fn check_and_scale(&self) {
298        let active = self.metrics.active_connections.load(Ordering::Relaxed);
299        let max = self.current_max.load(Ordering::Relaxed);
300        let utilization = active as f64 / max as f64;
301
302        if utilization >= self.config.scale_up_threshold {
303            self.scale_up().await;
304        } else if utilization <= self.config.scale_down_threshold {
305            self.scale_down().await;
306        }
307    }
308
309    pub async fn get(&self) -> Result<AdaptivePooledConnection> {
310        let permit = self.try_acquire(self.config.check_interval).await?;
311
312        // Generate unique connection ID
313        let conn_id = self.next_conn_id.fetch_add(1, Ordering::Relaxed);
314
315        // Create a new database connection from the database
316        let connection = self
317            .db
318            .connect()
319            .map_err(|e| Error::Storage(format!("Failed to create connection: {}", e)))?;
320
321        let metrics_ptr = Arc::as_ptr(&self.metrics) as *mut AdaptiveMetrics;
322        let current_max_ptr = Arc::as_ptr(&self.current_max) as *mut AtomicU32;
323
324        // Get cleanup callback if registered
325        let cleanup_callback = self.cleanup_callback.read().clone();
326
327        debug!("Created connection with ID: {}", conn_id);
328
329        Ok(AdaptivePooledConnection {
330            conn_id,
331            metrics_ptr,
332            current_max_ptr,
333            permit: Some(permit),
334            connection: Some(connection),
335            cleanup_callback,
336        })
337    }
338
339    pub fn available_connections(&self) -> usize {
340        self.semaphore.available_permits()
341    }
342
343    pub fn utilization(&self) -> f64 {
344        self.metrics.utilization_percent.load(Ordering::Relaxed) as f64 / 100.0
345    }
346
347    pub fn active_connections(&self) -> u32 {
348        self.metrics.active_connections.load(Ordering::Relaxed)
349    }
350
351    pub fn max_connections(&self) -> u32 {
352        self.current_max.load(Ordering::Relaxed)
353    }
354
355    pub fn metrics(&self) -> AdaptivePoolMetrics {
356        AdaptivePoolMetrics {
357            utilization_percent: self.metrics.utilization_percent.load(Ordering::Relaxed) as f64,
358            active_connections: self.metrics.active_connections.load(Ordering::Relaxed),
359            max_connections: self.metrics.max_connections.load(Ordering::Relaxed),
360            scale_up_count: self.metrics.scale_up_count.load(Ordering::Relaxed),
361            scale_down_count: self.metrics.scale_down_count.load(Ordering::Relaxed),
362            avg_wait_time_us: self.metrics.avg_wait_time_us.load(Ordering::Relaxed),
363            total_acquired: self.metrics.total_acquired.load(Ordering::Relaxed),
364            total_released: self.metrics.total_released.load(Ordering::Relaxed),
365        }
366    }
367
368    /// Register a cleanup callback to be called when connections are dropped
369    ///
370    /// This allows external components (like the prepared statement cache) to
371    /// clean up resources when a connection is returned to the pool.
372    ///
373    /// # Arguments
374    ///
375    /// * `callback` - Function to call with the connection ID when a connection is dropped
376    ///
377    /// # Example
378    ///
379    /// ```no_run
380    /// use std::sync::Arc;
381    /// use do_memory_storage_turso::pool::{AdaptiveConnectionPool, ConnectionId};
382    /// use do_memory_storage_turso::PreparedStatementCache;
383    ///
384    /// # async fn example(pool: AdaptiveConnectionPool) {
385    /// let cache = Arc::new(PreparedStatementCache::new(100));
386    /// let cache_clone = Arc::clone(&cache);
387    ///
388    /// pool.set_cleanup_callback(Arc::new(move |conn_id: ConnectionId| {
389    ///     cache_clone.clear_connection(conn_id);
390    /// }));
391    /// # }
392    /// ```
393    pub fn set_cleanup_callback(&self, callback: ConnectionCleanupCallback) {
394        *self.cleanup_callback.write() = Some(callback);
395        info!("Connection cleanup callback registered");
396    }
397
398    /// Remove the cleanup callback
399    ///
400    /// This disables automatic cleanup notifications.
401    pub fn remove_cleanup_callback(&self) {
402        *self.cleanup_callback.write() = None;
403        info!("Connection cleanup callback removed");
404    }
405
406    pub async fn shutdown(&self) {
407        info!("Shutting down adaptive connection pool");
408        tokio::time::sleep(Duration::from_millis(100)).await;
409        info!("Adaptive connection pool shutdown complete");
410    }
411}
412
413pub struct AdaptivePooledConnection {
414    conn_id: ConnectionId,
415    metrics_ptr: *mut AdaptiveMetrics,
416    current_max_ptr: *mut AtomicU32,
417    permit: Option<OwnedSemaphorePermit>,
418    connection: Option<libsql::Connection>,
419    cleanup_callback: Option<ConnectionCleanupCallback>,
420}
421
422impl std::fmt::Debug for AdaptivePooledConnection {
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        f.debug_struct("AdaptivePooledConnection")
425            .field("conn_id", &self.conn_id)
426            .field("has_cleanup_callback", &self.cleanup_callback.is_some())
427            .finish()
428    }
429}
430
431#[allow(unsafe_code)]
432unsafe impl Send for AdaptivePooledConnection {}
433#[allow(unsafe_code)]
434unsafe impl Sync for AdaptivePooledConnection {}
435
436impl AdaptivePooledConnection {
437    /// Get the unique connection identifier
438    ///
439    /// This ID is stable for the lifetime of the connection and can be used
440    /// to associate cached data (like prepared statements) with the connection.
441    pub fn connection_id(&self) -> ConnectionId {
442        self.conn_id
443    }
444
445    /// Get a reference to the underlying database connection
446    pub fn connection(&self) -> Option<&libsql::Connection> {
447        self.connection.as_ref()
448    }
449
450    /// Take ownership of the underlying connection
451    pub fn into_inner(mut self) -> Option<libsql::Connection> {
452        self.connection.take()
453    }
454}
455
456impl Drop for AdaptivePooledConnection {
457    fn drop(&mut self) {
458        if let Some(permit) = self.permit.take() {
459            drop(permit);
460
461            #[allow(unsafe_code)]
462            unsafe {
463                if let Some(metrics) = self.metrics_ptr.as_mut() {
464                    let active = metrics.active_connections.fetch_sub(1, Ordering::Relaxed);
465
466                    let max = self
467                        .current_max_ptr
468                        .as_ref()
469                        .map(|m| m.load(Ordering::Relaxed))
470                        .unwrap_or(1);
471
472                    let new_utilization = ((active - 1) as f64 / max as f64) * 100.0;
473                    metrics
474                        .utilization_percent
475                        .store(new_utilization as u64, Ordering::Relaxed);
476
477                    metrics.total_released.fetch_add(1, Ordering::Relaxed);
478                }
479            }
480
481            // Call cleanup callback if registered
482            if let Some(callback) = &self.cleanup_callback {
483                callback(self.conn_id);
484            }
485        }
486    }
487}
488
489#[cfg(test)]
490#[path = "adaptive_tests.rs"]
491mod tests;