Skip to main content

do_memory_storage_turso/lib_impls/
helpers.rs

1//! Lib helper methods
2//!
3//! This module contains helper methods for TursoStorage.
4
5use do_memory_core::{Error, Result};
6use libsql::Connection;
7use std::time::Instant;
8use tracing::{debug, error, warn};
9
10use super::storage::TursoStorage;
11use crate::prepared::ConnectionId;
12
13// These are extension methods for TursoStorage
14// They are attached via the impl block below
15impl TursoStorage {
16    /// Check if a connection pool is available
17    ///
18    /// Returns true if any connection pool (standard, keepalive, or adaptive) is configured.
19    ///
20    /// The prepared statement cache is now connection-aware, so it can safely be used
21    /// with connection pooling. Each connection has its own cache of prepared statements,
22    /// and caches are cleared when connections are returned to the pool.
23    pub(crate) fn has_connection_pool(&self) -> bool {
24        // Connection-aware prepared statement cache is now implemented
25        // The cache stores statements per-connection using ConnectionId,
26        // ensuring statements are only used with the connection they were prepared on.
27        self.pool.is_some() || self.adaptive_pool.is_some() || {
28            #[cfg(feature = "keepalive-pool")]
29            {
30                self.keepalive_pool.is_some()
31            }
32            #[cfg(not(feature = "keepalive-pool"))]
33            {
34                false
35            }
36        }
37    }
38
39    /// Get a database connection
40    ///
41    /// If connection pooling is enabled, this will use a pooled connection.
42    /// If keep-alive pool is enabled, it will be used for reduced overhead.
43    /// If adaptive pool is enabled, it will be used for variable load optimization.
44    /// Otherwise, it creates a new connection each time.
45    pub async fn get_connection(&self) -> Result<Connection> {
46        // Check adaptive pool first (highest priority for variable load)
47        if let Some(ref adaptive_pool) = self.adaptive_pool {
48            let adaptive_conn = adaptive_pool.get().await?;
49            // Extract the connection from the pooled connection
50            if let Some(conn) = adaptive_conn.into_inner() {
51                return Ok(conn);
52            }
53            // Fallback if connection extraction fails
54            return Err(Error::Storage(
55                "Failed to extract connection from adaptive pool".to_string(),
56            ));
57        }
58
59        #[cfg(feature = "keepalive-pool")]
60        {
61            if let Some(ref keepalive_pool) = self.keepalive_pool {
62                // Use keep-alive pool for reduced connection overhead
63                let keepalive_conn = keepalive_pool.get().await?;
64                return keepalive_conn.into_connection();
65            }
66        }
67
68        if let Some(ref pool) = self.pool {
69            // Use connection pool
70            let pooled_conn = pool.get().await?;
71            Ok(pooled_conn.into_inner()?)
72        } else {
73            // Create direct connection (legacy mode)
74            self.db
75                .connect()
76                .map_err(|e| Error::Storage(format!("Failed to get connection: {}", e)))
77        }
78    }
79
80    /// Get a database connection with its cache ID
81    ///
82    /// This method returns both the connection and a unique connection ID
83    /// for use with the prepared statement cache. The ID should be passed
84    /// to `prepare_cached` and `clear_prepared_cache` for proper cache management.
85    ///
86    /// # Returns
87    ///
88    /// A tuple of (Connection, ConnectionId)
89    ///
90    /// # Example
91    ///
92    /// ```rust,no_run
93    /// # use do_memory_storage_turso::TursoStorage;
94    /// # async fn example(storage: &TursoStorage) -> anyhow::Result<()> {
95    /// let (conn, conn_id) = storage.get_connection_with_id().await?;
96    /// // Use conn and conn_id with prepare_cached
97    /// # Ok(())
98    /// # }
99    /// ```
100    pub async fn get_connection_with_id(&self) -> Result<(Connection, ConnectionId)> {
101        let conn = self.get_connection().await?;
102        // Generate a connection ID for this checkout
103        // This ID tracks statement preparation for statistics
104        let conn_id = self.prepared_cache.get_connection_id();
105        Ok((conn, conn_id))
106    }
107
108    /// Get count of records in a table
109    ///
110    /// # Safety
111    /// The `table` parameter should be a validated table name from a fixed whitelist.
112    /// This function does not sanitize the table name, so callers must ensure
113    /// it comes from a trusted source (e.g., the fixed list in capacity.rs).
114    /// CodeQL may flag this as a potential SQL injection, but it is safe when
115    /// used with the predefined table names from the whitelist.
116    pub async fn get_count(&self, conn: &Connection, table: &str) -> Result<usize> {
117        // SAFETY: This function is only called with table names from a fixed whitelist
118        // in capacity.rs (episodes, patterns, heuristics, embeddings, etc.).
119        // No user input can reach this function, preventing SQL injection.
120        #[allow(clippy::literal_string_with_formatting_args)]
121        let sql = format!("SELECT COUNT(*) as count FROM {}", table);
122        let mut rows = conn
123            .query(&sql, ())
124            .await
125            .map_err(|e| Error::Storage(format!("Failed to count {}: {}", table, e)))?;
126
127        if let Some(row) = rows
128            .next()
129            .await
130            .map_err(|e| Error::Storage(format!("Failed to fetch count for {}: {}", table, e)))?
131        {
132            let count: i64 = row
133                .get(0)
134                .map_err(|e| Error::Storage(format!("Failed to parse count: {}", e)))?;
135            Ok(count as usize)
136        } else {
137            Ok(0)
138        }
139    }
140
141    /// Execute PRAGMA statements for database configuration
142    ///
143    /// PRAGMA statements may return rows, so we need to consume them before continuing.
144    pub async fn execute_pragmas(&self, conn: &Connection) -> Result<()> {
145        // Enable WAL mode for better concurrent access
146        let _ = conn.execute("PRAGMA journal_mode=WAL", ()).await;
147
148        // Increase busy timeout
149        let _ = conn.execute("PRAGMA busy_timeout=30000", ()).await;
150
151        Ok(())
152    }
153
154    /// Execute a SQL statement with retry logic
155    pub async fn execute_with_retry(&self, conn: &Connection, sql: &str) -> Result<()> {
156        let mut attempts = 0;
157        let mut delay = std::time::Duration::from_millis(self.config.retry_base_delay_ms);
158
159        loop {
160            match conn.execute(sql, ()).await {
161                Ok(_) => {
162                    if attempts > 0 {
163                        debug!("SQL succeeded after {} retries", attempts);
164                    }
165                    return Ok(());
166                }
167                Err(e) => {
168                    attempts += 1;
169                    if attempts >= self.config.max_retries {
170                        error!("SQL failed after {} attempts: {}", attempts, e);
171                        return Err(Error::Storage(format!(
172                            "SQL execution failed after {} retries: {}",
173                            attempts, e
174                        )));
175                    }
176
177                    warn!("SQL attempt {} failed: {}, retrying...", attempts, e);
178                    tokio::time::sleep(delay).await;
179
180                    // Exponential backoff
181                    delay = std::cmp::min(
182                        delay * 2,
183                        std::time::Duration::from_millis(self.config.retry_max_delay_ms),
184                    );
185                }
186            }
187        }
188    }
189
190    /// Health check - verify database connectivity
191    pub async fn health_check(&self) -> Result<bool> {
192        let conn = self.get_connection().await?;
193        match conn.query("SELECT 1", ()).await {
194            Ok(_) => Ok(true),
195            Err(e) => {
196                error!("Health check failed: {}", e);
197                Ok(false)
198            }
199        }
200    }
201
202    /// Wrap this storage with a cache layer using default cache configuration
203    ///
204    /// This provides transparent caching for episodes, patterns, and heuristics
205    /// with adaptive TTL based on access patterns.
206    ///
207    /// # Example
208    ///
209    /// ```no_run
210    /// # use do_memory_storage_turso::{TursoStorage, CacheConfig};
211    /// # async fn example() -> anyhow::Result<()> {
212    /// let storage = TursoStorage::new("file:test.db", "").await?;
213    /// let cached = storage.with_cache_default();
214    /// # Ok(())
215    /// # }
216    /// ```
217    pub fn with_cache_default(self) -> crate::cache::CachedTursoStorage {
218        self.with_cache(crate::cache::CacheConfig::default())
219    }
220
221    /// Wrap this storage with a cache layer using custom cache configuration
222    ///
223    /// # Arguments
224    ///
225    /// * `config` - Cache configuration to use
226    ///
227    /// # Returns
228    ///
229    /// A new `CachedTursoStorage` wrapping this storage
230    ///
231    /// # Example
232    ///
233    /// ```no_run
234    /// # use do_memory_storage_turso::{TursoStorage, CacheConfig};
235    /// # use std::time::Duration;
236    /// # async fn example() -> anyhow::Result<()> {
237    /// let storage = TursoStorage::new("file:test.db", "").await?;
238    /// let config = CacheConfig {
239    ///     max_episodes: 1000,
240    ///     episode_ttl: Duration::from_secs(3600),
241    ///     ..Default::default()
242    /// };
243    /// let cached = storage.with_cache(config);
244    /// # Ok(())
245    /// # }
246    /// ```
247    pub fn with_cache(
248        self,
249        cache_config: crate::cache::CacheConfig,
250    ) -> crate::cache::CachedTursoStorage {
251        crate::cache::CachedTursoStorage::new(self, cache_config)
252    }
253
254    /// Get the cache configuration if set
255    pub fn cache_config(&self) -> Option<&crate::cache::CacheConfig> {
256        self.config.cache_config.as_ref()
257    }
258
259    /// Get prepared statement cache statistics
260    pub fn prepared_cache_stats(&self) -> crate::prepared::PreparedCacheStats {
261        self.prepared_cache.stats()
262    }
263
264    /// Get a reference to the prepared statement cache
265    pub fn prepared_cache(&self) -> &crate::prepared::PreparedStatementCache {
266        &self.prepared_cache
267    }
268
269    /// Get database statistics
270    pub async fn get_statistics(&self) -> Result<crate::trait_impls::StorageStatistics> {
271        let conn = self.get_connection().await?;
272
273        let episode_count = self.get_count(&conn, "episodes").await?;
274        let pattern_count = self.get_count(&conn, "patterns").await?;
275        let heuristic_count = self.get_count(&conn, "heuristics").await?;
276
277        Ok(crate::trait_impls::StorageStatistics {
278            episode_count,
279            pattern_count,
280            heuristic_count,
281        })
282    }
283
284    /// Get pool statistics if pooling is enabled
285    pub async fn pool_statistics(&self) -> Option<crate::pool::PoolStatistics> {
286        if let Some(ref pool) = self.pool {
287            Some(pool.statistics().await)
288        } else {
289            None
290        }
291    }
292
293    /// Get pool utilization if pooling is enabled
294    pub async fn pool_utilization(&self) -> Option<f32> {
295        if let Some(ref pool) = self.pool {
296            Some(pool.utilization().await)
297        } else {
298            self.adaptive_pool
299                .as_ref()
300                .map(|adaptive_pool| adaptive_pool.utilization() as f32)
301        }
302    }
303
304    /// Get adaptive pool metrics if enabled
305    pub fn adaptive_pool_metrics(&self) -> Option<crate::pool::AdaptivePoolMetrics> {
306        self.adaptive_pool.as_ref().map(|pool| pool.metrics())
307    }
308
309    /// Get current adaptive pool size
310    pub fn adaptive_pool_size(&self) -> Option<(u32, u32)> {
311        self.adaptive_pool
312            .as_ref()
313            .map(|pool| (pool.active_connections(), pool.max_connections()))
314    }
315
316    /// Manually trigger adaptive pool scaling check
317    pub async fn check_adaptive_pool_scale(&self) {
318        if let Some(ref adaptive_pool) = self.adaptive_pool {
319            adaptive_pool.check_and_scale().await;
320        }
321    }
322
323    /// Get keep-alive pool statistics if enabled
324    #[cfg(feature = "keepalive-pool")]
325    pub fn keepalive_statistics(&self) -> Option<crate::pool::KeepAliveStatistics> {
326        self.keepalive_pool.as_ref().map(|pool| pool.statistics())
327    }
328
329    /// Get keep-alive configuration if enabled
330    #[cfg(feature = "keepalive-pool")]
331    pub fn keepalive_config(&self) -> Option<&crate::pool::KeepAliveConfig> {
332        self.keepalive_pool.as_ref().map(|pool| pool.config())
333    }
334
335    /// Prepare a SQL statement with cache tracking
336    ///
337    /// This method prepares a SQL statement and tracks cache statistics.
338    /// If the statement is already cached for this connection, it's a cache hit.
339    /// Otherwise, it's a cache miss and the statement is prepared and tracked.
340    ///
341    /// # Arguments
342    ///
343    /// * `conn_id` - Connection identifier for cache tracking
344    /// * `conn` - Database connection to prepare on
345    /// * `sql` - SQL statement to prepare
346    ///
347    /// # Returns
348    ///
349    /// The prepared statement
350    ///
351    /// # Errors
352    ///
353    /// Returns error if statement preparation fails
354    pub async fn prepare_cached(
355        &self,
356        conn_id: ConnectionId,
357        conn: &Connection,
358        sql: &str,
359    ) -> Result<libsql::Statement> {
360        // Check if this is a cache hit
361        if self.prepared_cache.is_cached(conn_id, sql) {
362            self.prepared_cache.record_hit(conn_id, sql);
363        }
364
365        // Prepare the statement
366        let start = Instant::now();
367        let stmt = conn
368            .prepare(sql)
369            .await
370            .map_err(|e| Error::Storage(format!("Failed to prepare statement: {}", e)))?;
371        let prepare_time_us = start.elapsed().as_micros() as u64;
372
373        // Record the miss (or re-record if it was a hit - tracks preparation time)
374        self.prepared_cache
375            .record_miss(conn_id, sql, prepare_time_us);
376
377        Ok(stmt)
378    }
379
380    /// Clear the prepared statement cache for a connection
381    ///
382    /// This should be called when a connection is returned to the pool
383    /// to prevent memory leaks and ensure proper cache management.
384    ///
385    /// # Arguments
386    ///
387    /// * `conn_id` - Connection identifier to clear
388    ///
389    /// # Returns
390    ///
391    /// Number of statements cleared from the cache
392    pub fn clear_prepared_cache(&self, conn_id: ConnectionId) -> usize {
393        self.prepared_cache.clear_connection(conn_id)
394    }
395
396    /// Get compression statistics if compression feature is enabled
397    ///
398    /// Returns a snapshot of compression statistics including:
399    /// - Total original bytes
400    /// - Total compressed bytes
401    /// - Compression ratio
402    /// - Bandwidth savings percentage
403    /// - Compression/decompression time
404    ///
405    /// # Returns
406    ///
407    /// Compression statistics snapshot
408    #[cfg(feature = "compression")]
409    pub fn compression_statistics(&self) -> crate::CompressionStatistics {
410        self.compression_stats
411            .lock()
412            .map(|stats| stats.clone())
413            .unwrap_or_else(|_| crate::CompressionStatistics::new())
414    }
415
416    /// Reset compression statistics
417    ///
418    /// This is useful for testing or for tracking statistics over specific time windows.
419    #[cfg(feature = "compression")]
420    pub fn reset_compression_statistics(&self) {
421        if let Ok(mut stats) = self.compression_stats.lock() {
422            *stats = crate::CompressionStatistics::new();
423        }
424    }
425}