Skip to main content

rivven_rdbc/
connection.rs

1//! Connection traits for rivven-rdbc
2//!
3//! Core abstractions for database connectivity:
4//! - Connection: Basic connection with query execution
5//! - ConnectionLifecycle: Optional lifecycle tracking for pool management
6//! - PreparedStatement: Parameterized query support
7//! - Transaction: ACID transaction support
8//! - RowStream: Streaming row iteration
9
10use async_trait::async_trait;
11use std::future::Future;
12use std::pin::Pin;
13use std::time::{Duration, Instant};
14
15use crate::error::Result;
16use crate::types::{Row, Value};
17
18/// Lifecycle tracking for connections (used by connection pools)
19///
20/// This trait provides optional lifecycle information for connections,
21/// enabling accurate pool management and observability.
22///
23/// # Example
24///
25/// ```rust,ignore
26/// use rivven_rdbc::connection::ConnectionLifecycle;
27/// use std::time::Duration;
28///
29/// let conn = pool.get().await?;
30/// println!("Connection age: {:?}", conn.age());
31/// if conn.is_expired(Duration::from_secs(1800)) {
32///     println!("Connection should be recycled");
33/// }
34/// ```
35#[async_trait]
36pub trait ConnectionLifecycle: Send + Sync {
37    /// Get the instant when this connection was created
38    fn created_at(&self) -> Instant;
39
40    /// Get the age of this connection (time since creation)
41    fn age(&self) -> Duration {
42        self.created_at().elapsed()
43    }
44
45    /// Check if connection has exceeded the given maximum lifetime
46    fn is_expired(&self, max_lifetime: Duration) -> bool {
47        self.age() > max_lifetime
48    }
49
50    /// Get duration since this connection was last used
51    async fn idle_time(&self) -> Duration;
52
53    /// Check if connection has exceeded idle timeout
54    async fn is_idle_expired(&self, idle_timeout: Duration) -> bool {
55        self.idle_time().await > idle_timeout
56    }
57
58    /// Update the last-used timestamp (called when connection is actively used)
59    async fn touch(&self);
60}
61
62/// A connection to a database
63#[async_trait]
64pub trait Connection: Send + Sync {
65    /// Execute a query that returns rows
66    async fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>>;
67
68    /// Execute a query that modifies data, returns affected row count
69    async fn execute(&self, sql: &str, params: &[Value]) -> Result<u64>;
70
71    /// Execute a batch of statements, returns affected counts per statement
72    async fn execute_batch(&self, statements: &[(&str, &[Value])]) -> Result<Vec<u64>> {
73        let mut results = Vec::with_capacity(statements.len());
74        for (sql, params) in statements {
75            results.push(self.execute(sql, params).await?);
76        }
77        Ok(results)
78    }
79
80    /// Prepare a statement for repeated execution
81    async fn prepare(&self, sql: &str) -> Result<Box<dyn PreparedStatement>>;
82
83    /// Begin a transaction
84    async fn begin(&self) -> Result<Box<dyn Transaction>>;
85
86    /// Begin a transaction with specified isolation level
87    async fn begin_with_isolation(
88        &self,
89        isolation: IsolationLevel,
90    ) -> Result<Box<dyn Transaction>> {
91        // Default implementation just begins and sets isolation
92        let tx = self.begin().await?;
93        tx.set_isolation_level(isolation).await?;
94        Ok(tx)
95    }
96
97    /// Execute a query and stream results
98    async fn query_stream(&self, sql: &str, params: &[Value]) -> Result<Pin<Box<dyn RowStream>>>;
99
100    /// Execute a query and return the first row (convenience method)
101    async fn query_one(&self, sql: &str, params: &[Value]) -> Result<Option<Row>> {
102        let rows = self.query(sql, params).await?;
103        Ok(rows.into_iter().next())
104    }
105
106    /// Check if connection is valid/alive
107    async fn is_valid(&self) -> bool;
108
109    /// Close the connection
110    async fn close(&self) -> Result<()>;
111}
112
113/// A prepared statement
114#[async_trait]
115pub trait PreparedStatement: Send + Sync {
116    /// Execute the prepared statement with given parameters
117    async fn execute(&self, params: &[Value]) -> Result<u64>;
118
119    /// Query with the prepared statement
120    async fn query(&self, params: &[Value]) -> Result<Vec<Row>>;
121
122    /// Get the SQL string
123    fn sql(&self) -> &str;
124}
125
126/// A database transaction
127#[async_trait]
128pub trait Transaction: Send + Sync {
129    /// Execute a query that returns rows
130    async fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>>;
131
132    /// Execute a query that modifies data
133    async fn execute(&self, sql: &str, params: &[Value]) -> Result<u64>;
134
135    /// Execute a batch of statements within the transaction, returns affected counts per statement
136    async fn execute_batch(&self, statements: &[(&str, &[Value])]) -> Result<Vec<u64>> {
137        let mut results = Vec::with_capacity(statements.len());
138        for (sql, params) in statements {
139            results.push(self.execute(sql, params).await?);
140        }
141        Ok(results)
142    }
143
144    /// Commit the transaction
145    async fn commit(self: Box<Self>) -> Result<()>;
146
147    /// Rollback the transaction
148    async fn rollback(self: Box<Self>) -> Result<()>;
149
150    /// Set transaction isolation level
151    async fn set_isolation_level(&self, level: IsolationLevel) -> Result<()>;
152
153    /// Create a savepoint
154    async fn savepoint(&self, name: &str) -> Result<()>;
155
156    /// Rollback to a savepoint
157    async fn rollback_to_savepoint(&self, name: &str) -> Result<()>;
158
159    /// Release a savepoint
160    async fn release_savepoint(&self, name: &str) -> Result<()>;
161}
162
163/// Streaming row iterator
164pub trait RowStream: Send {
165    /// Get the next row
166    fn next(&mut self) -> Pin<Box<dyn Future<Output = Result<Option<Row>>> + Send + '_>>;
167}
168
169/// Transaction isolation levels
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
171pub enum IsolationLevel {
172    /// Read uncommitted - dirty reads possible
173    ReadUncommitted,
174    /// Read committed - no dirty reads (PostgreSQL default)
175    ReadCommitted,
176    /// Repeatable read - no non-repeatable reads (MySQL default)
177    RepeatableRead,
178    /// Serializable - full isolation
179    Serializable,
180    /// Snapshot isolation (SQL Server specific)
181    Snapshot,
182}
183
184impl IsolationLevel {
185    /// Convert to SQL string for SET TRANSACTION statement
186    pub fn to_sql(&self) -> &'static str {
187        match self {
188            Self::ReadUncommitted => "READ UNCOMMITTED",
189            Self::ReadCommitted => "READ COMMITTED",
190            Self::RepeatableRead => "REPEATABLE READ",
191            Self::Serializable => "SERIALIZABLE",
192            Self::Snapshot => "SNAPSHOT",
193        }
194    }
195}
196
197impl std::fmt::Display for IsolationLevel {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        write!(f, "{}", self.to_sql())
200    }
201}
202
203/// Configuration for creating connections
204#[derive(Clone)]
205pub struct ConnectionConfig {
206    /// Connection URL (e.g., postgres://user:pass@host:5432/db)
207    pub url: String,
208    /// Connection timeout in milliseconds
209    pub connect_timeout_ms: u64,
210    /// Query timeout in milliseconds (0 = no timeout)
211    pub query_timeout_ms: u64,
212    /// Statement cache size
213    pub statement_cache_size: usize,
214    /// Application name (shown in pg_stat_activity, etc)
215    pub application_name: Option<String>,
216    /// Additional connection properties
217    pub properties: std::collections::HashMap<String, String>,
218}
219
220impl std::fmt::Debug for ConnectionConfig {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        // Redact credentials from the URL to prevent leaking passwords to logs.
223        let redacted_url = match url::Url::parse(&self.url) {
224            Ok(mut parsed) => {
225                if parsed.password().is_some() {
226                    let _ = parsed.set_password(Some("***"));
227                }
228                parsed.to_string()
229            }
230            Err(_) => "***".to_string(),
231        };
232
233        f.debug_struct("ConnectionConfig")
234            .field("url", &redacted_url)
235            .field("connect_timeout_ms", &self.connect_timeout_ms)
236            .field("query_timeout_ms", &self.query_timeout_ms)
237            .field("statement_cache_size", &self.statement_cache_size)
238            .field("application_name", &self.application_name)
239            .field("properties", &self.properties)
240            .finish()
241    }
242}
243
244impl Default for ConnectionConfig {
245    fn default() -> Self {
246        Self {
247            url: String::new(),
248            connect_timeout_ms: 10_000,
249            query_timeout_ms: 30_000,
250            statement_cache_size: 100,
251            application_name: Some("rivven-rdbc".into()),
252            properties: std::collections::HashMap::new(),
253        }
254    }
255}
256
257impl ConnectionConfig {
258    /// Create configuration with just a URL
259    pub fn new(url: impl Into<String>) -> Self {
260        Self {
261            url: url.into(),
262            ..Default::default()
263        }
264    }
265
266    /// Set connection timeout
267    pub fn with_connect_timeout(mut self, ms: u64) -> Self {
268        self.connect_timeout_ms = ms;
269        self
270    }
271
272    /// Set query timeout
273    pub fn with_query_timeout(mut self, ms: u64) -> Self {
274        self.query_timeout_ms = ms;
275        self
276    }
277
278    /// Set statement cache size
279    pub fn with_statement_cache_size(mut self, size: usize) -> Self {
280        self.statement_cache_size = size;
281        self
282    }
283
284    /// Set application name
285    pub fn with_application_name(mut self, name: impl Into<String>) -> Self {
286        self.application_name = Some(name.into());
287        self
288    }
289
290    /// Add a connection property
291    pub fn with_property(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
292        self.properties.insert(key.into(), value.into());
293        self
294    }
295}
296
297/// Factory for creating connections
298#[async_trait]
299pub trait ConnectionFactory: Send + Sync {
300    /// Create a new connection
301    async fn connect(&self, config: &ConnectionConfig) -> Result<Box<dyn Connection>>;
302
303    /// Get the database type
304    fn database_type(&self) -> DatabaseType;
305}
306
307/// Database type identifier
308#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
309pub enum DatabaseType {
310    /// PostgreSQL
311    PostgreSQL,
312    /// MySQL/MariaDB
313    MySQL,
314    /// SQL Server
315    SqlServer,
316    /// SQLite
317    SQLite,
318    /// Oracle
319    Oracle,
320    /// Unknown/custom
321    Unknown,
322}
323
324impl std::fmt::Display for DatabaseType {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        match self {
327            Self::PostgreSQL => write!(f, "PostgreSQL"),
328            Self::MySQL => write!(f, "MySQL"),
329            Self::SqlServer => write!(f, "SQL Server"),
330            Self::SQLite => write!(f, "SQLite"),
331            Self::Oracle => write!(f, "Oracle"),
332            Self::Unknown => write!(f, "Unknown"),
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_isolation_level_to_sql() {
343        assert_eq!(IsolationLevel::ReadCommitted.to_sql(), "READ COMMITTED");
344        assert_eq!(IsolationLevel::Serializable.to_sql(), "SERIALIZABLE");
345    }
346
347    #[test]
348    fn test_connection_config_builder() {
349        let config = ConnectionConfig::new("postgres://localhost/test")
350            .with_connect_timeout(5000)
351            .with_query_timeout(15000)
352            .with_application_name("myapp")
353            .with_property("sslmode", "require");
354
355        assert_eq!(config.url, "postgres://localhost/test");
356        assert_eq!(config.connect_timeout_ms, 5000);
357        assert_eq!(config.query_timeout_ms, 15000);
358        assert_eq!(config.application_name, Some("myapp".into()));
359        assert_eq!(config.properties.get("sslmode"), Some(&"require".into()));
360    }
361
362    #[test]
363    fn test_database_type_display() {
364        assert_eq!(format!("{}", DatabaseType::PostgreSQL), "PostgreSQL");
365        assert_eq!(format!("{}", DatabaseType::MySQL), "MySQL");
366        assert_eq!(format!("{}", DatabaseType::SqlServer), "SQL Server");
367        assert_eq!(format!("{}", DatabaseType::SQLite), "SQLite");
368        assert_eq!(format!("{}", DatabaseType::Oracle), "Oracle");
369        assert_eq!(format!("{}", DatabaseType::Unknown), "Unknown");
370    }
371
372    #[test]
373    fn test_isolation_level_display() {
374        assert_eq!(
375            format!("{}", IsolationLevel::ReadUncommitted),
376            "READ UNCOMMITTED"
377        );
378        assert_eq!(
379            format!("{}", IsolationLevel::ReadCommitted),
380            "READ COMMITTED"
381        );
382        assert_eq!(
383            format!("{}", IsolationLevel::RepeatableRead),
384            "REPEATABLE READ"
385        );
386        assert_eq!(format!("{}", IsolationLevel::Serializable), "SERIALIZABLE");
387        assert_eq!(format!("{}", IsolationLevel::Snapshot), "SNAPSHOT");
388    }
389
390    /// Test ConnectionLifecycle default implementations
391    #[test]
392    fn test_connection_lifecycle_defaults() {
393        // Test that default implementations are logically correct
394        // The actual trait requires async methods, so we test the logic
395        let now = Instant::now();
396
397        // age = now.elapsed(), should be very small
398        let age = now.elapsed();
399        assert!(age < Duration::from_secs(1));
400
401        // is_expired with 30 minutes should be false for a just-created connection
402        let max_lifetime = Duration::from_secs(1800);
403        assert!(age <= max_lifetime);
404
405        // is_expired with very short lifetime should become true quickly
406        std::thread::sleep(Duration::from_millis(5));
407        let short_lifetime = Duration::from_millis(1);
408        assert!(now.elapsed() > short_lifetime);
409    }
410}