1use async_trait::async_trait;
11use std::future::Future;
12use std::pin::Pin;
13use std::time::{Duration, Instant};
14
15use crate::error::Result;
16use crate::types::{Row, Value};
17
18#[async_trait]
36pub trait ConnectionLifecycle: Send + Sync {
37 fn created_at(&self) -> Instant;
39
40 fn age(&self) -> Duration {
42 self.created_at().elapsed()
43 }
44
45 fn is_expired(&self, max_lifetime: Duration) -> bool {
47 self.age() > max_lifetime
48 }
49
50 async fn idle_time(&self) -> Duration;
52
53 async fn is_idle_expired(&self, idle_timeout: Duration) -> bool {
55 self.idle_time().await > idle_timeout
56 }
57
58 async fn touch(&self);
60}
61
62#[async_trait]
64pub trait Connection: Send + Sync {
65 async fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>>;
67
68 async fn execute(&self, sql: &str, params: &[Value]) -> Result<u64>;
70
71 async fn execute_batch(&self, statements: &[(&str, &[Value])]) -> Result<Vec<u64>> {
73 let mut results = Vec::with_capacity(statements.len());
74 for (sql, params) in statements {
75 results.push(self.execute(sql, params).await?);
76 }
77 Ok(results)
78 }
79
80 async fn prepare(&self, sql: &str) -> Result<Box<dyn PreparedStatement>>;
82
83 async fn begin(&self) -> Result<Box<dyn Transaction>>;
85
86 async fn begin_with_isolation(
88 &self,
89 isolation: IsolationLevel,
90 ) -> Result<Box<dyn Transaction>> {
91 let tx = self.begin().await?;
93 tx.set_isolation_level(isolation).await?;
94 Ok(tx)
95 }
96
97 async fn query_stream(&self, sql: &str, params: &[Value]) -> Result<Pin<Box<dyn RowStream>>>;
99
100 async fn query_one(&self, sql: &str, params: &[Value]) -> Result<Option<Row>> {
102 let rows = self.query(sql, params).await?;
103 Ok(rows.into_iter().next())
104 }
105
106 async fn is_valid(&self) -> bool;
108
109 async fn close(&self) -> Result<()>;
111}
112
113#[async_trait]
115pub trait PreparedStatement: Send + Sync {
116 async fn execute(&self, params: &[Value]) -> Result<u64>;
118
119 async fn query(&self, params: &[Value]) -> Result<Vec<Row>>;
121
122 fn sql(&self) -> &str;
124}
125
126#[async_trait]
128pub trait Transaction: Send + Sync {
129 async fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>>;
131
132 async fn execute(&self, sql: &str, params: &[Value]) -> Result<u64>;
134
135 async fn execute_batch(&self, statements: &[(&str, &[Value])]) -> Result<Vec<u64>> {
137 let mut results = Vec::with_capacity(statements.len());
138 for (sql, params) in statements {
139 results.push(self.execute(sql, params).await?);
140 }
141 Ok(results)
142 }
143
144 async fn commit(self: Box<Self>) -> Result<()>;
146
147 async fn rollback(self: Box<Self>) -> Result<()>;
149
150 async fn set_isolation_level(&self, level: IsolationLevel) -> Result<()>;
152
153 async fn savepoint(&self, name: &str) -> Result<()>;
155
156 async fn rollback_to_savepoint(&self, name: &str) -> Result<()>;
158
159 async fn release_savepoint(&self, name: &str) -> Result<()>;
161}
162
163pub trait RowStream: Send {
165 fn next(&mut self) -> Pin<Box<dyn Future<Output = Result<Option<Row>>> + Send + '_>>;
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
171pub enum IsolationLevel {
172 ReadUncommitted,
174 ReadCommitted,
176 RepeatableRead,
178 Serializable,
180 Snapshot,
182}
183
184impl IsolationLevel {
185 pub fn to_sql(&self) -> &'static str {
187 match self {
188 Self::ReadUncommitted => "READ UNCOMMITTED",
189 Self::ReadCommitted => "READ COMMITTED",
190 Self::RepeatableRead => "REPEATABLE READ",
191 Self::Serializable => "SERIALIZABLE",
192 Self::Snapshot => "SNAPSHOT",
193 }
194 }
195}
196
197impl std::fmt::Display for IsolationLevel {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 write!(f, "{}", self.to_sql())
200 }
201}
202
203#[derive(Clone)]
205pub struct ConnectionConfig {
206 pub url: String,
208 pub connect_timeout_ms: u64,
210 pub query_timeout_ms: u64,
212 pub statement_cache_size: usize,
214 pub application_name: Option<String>,
216 pub properties: std::collections::HashMap<String, String>,
218}
219
220impl std::fmt::Debug for ConnectionConfig {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 let redacted_url = match url::Url::parse(&self.url) {
224 Ok(mut parsed) => {
225 if parsed.password().is_some() {
226 let _ = parsed.set_password(Some("***"));
227 }
228 parsed.to_string()
229 }
230 Err(_) => "***".to_string(),
231 };
232
233 f.debug_struct("ConnectionConfig")
234 .field("url", &redacted_url)
235 .field("connect_timeout_ms", &self.connect_timeout_ms)
236 .field("query_timeout_ms", &self.query_timeout_ms)
237 .field("statement_cache_size", &self.statement_cache_size)
238 .field("application_name", &self.application_name)
239 .field("properties", &self.properties)
240 .finish()
241 }
242}
243
244impl Default for ConnectionConfig {
245 fn default() -> Self {
246 Self {
247 url: String::new(),
248 connect_timeout_ms: 10_000,
249 query_timeout_ms: 30_000,
250 statement_cache_size: 100,
251 application_name: Some("rivven-rdbc".into()),
252 properties: std::collections::HashMap::new(),
253 }
254 }
255}
256
257impl ConnectionConfig {
258 pub fn new(url: impl Into<String>) -> Self {
260 Self {
261 url: url.into(),
262 ..Default::default()
263 }
264 }
265
266 pub fn with_connect_timeout(mut self, ms: u64) -> Self {
268 self.connect_timeout_ms = ms;
269 self
270 }
271
272 pub fn with_query_timeout(mut self, ms: u64) -> Self {
274 self.query_timeout_ms = ms;
275 self
276 }
277
278 pub fn with_statement_cache_size(mut self, size: usize) -> Self {
280 self.statement_cache_size = size;
281 self
282 }
283
284 pub fn with_application_name(mut self, name: impl Into<String>) -> Self {
286 self.application_name = Some(name.into());
287 self
288 }
289
290 pub fn with_property(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
292 self.properties.insert(key.into(), value.into());
293 self
294 }
295}
296
297#[async_trait]
299pub trait ConnectionFactory: Send + Sync {
300 async fn connect(&self, config: &ConnectionConfig) -> Result<Box<dyn Connection>>;
302
303 fn database_type(&self) -> DatabaseType;
305}
306
307#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
309pub enum DatabaseType {
310 PostgreSQL,
312 MySQL,
314 SqlServer,
316 SQLite,
318 Oracle,
320 Unknown,
322}
323
324impl std::fmt::Display for DatabaseType {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 match self {
327 Self::PostgreSQL => write!(f, "PostgreSQL"),
328 Self::MySQL => write!(f, "MySQL"),
329 Self::SqlServer => write!(f, "SQL Server"),
330 Self::SQLite => write!(f, "SQLite"),
331 Self::Oracle => write!(f, "Oracle"),
332 Self::Unknown => write!(f, "Unknown"),
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_isolation_level_to_sql() {
343 assert_eq!(IsolationLevel::ReadCommitted.to_sql(), "READ COMMITTED");
344 assert_eq!(IsolationLevel::Serializable.to_sql(), "SERIALIZABLE");
345 }
346
347 #[test]
348 fn test_connection_config_builder() {
349 let config = ConnectionConfig::new("postgres://localhost/test")
350 .with_connect_timeout(5000)
351 .with_query_timeout(15000)
352 .with_application_name("myapp")
353 .with_property("sslmode", "require");
354
355 assert_eq!(config.url, "postgres://localhost/test");
356 assert_eq!(config.connect_timeout_ms, 5000);
357 assert_eq!(config.query_timeout_ms, 15000);
358 assert_eq!(config.application_name, Some("myapp".into()));
359 assert_eq!(config.properties.get("sslmode"), Some(&"require".into()));
360 }
361
362 #[test]
363 fn test_database_type_display() {
364 assert_eq!(format!("{}", DatabaseType::PostgreSQL), "PostgreSQL");
365 assert_eq!(format!("{}", DatabaseType::MySQL), "MySQL");
366 assert_eq!(format!("{}", DatabaseType::SqlServer), "SQL Server");
367 assert_eq!(format!("{}", DatabaseType::SQLite), "SQLite");
368 assert_eq!(format!("{}", DatabaseType::Oracle), "Oracle");
369 assert_eq!(format!("{}", DatabaseType::Unknown), "Unknown");
370 }
371
372 #[test]
373 fn test_isolation_level_display() {
374 assert_eq!(
375 format!("{}", IsolationLevel::ReadUncommitted),
376 "READ UNCOMMITTED"
377 );
378 assert_eq!(
379 format!("{}", IsolationLevel::ReadCommitted),
380 "READ COMMITTED"
381 );
382 assert_eq!(
383 format!("{}", IsolationLevel::RepeatableRead),
384 "REPEATABLE READ"
385 );
386 assert_eq!(format!("{}", IsolationLevel::Serializable), "SERIALIZABLE");
387 assert_eq!(format!("{}", IsolationLevel::Snapshot), "SNAPSHOT");
388 }
389
390 #[test]
392 fn test_connection_lifecycle_defaults() {
393 let now = Instant::now();
396
397 let age = now.elapsed();
399 assert!(age < Duration::from_secs(1));
400
401 let max_lifetime = Duration::from_secs(1800);
403 assert!(age <= max_lifetime);
404
405 std::thread::sleep(Duration::from_millis(5));
407 let short_lifetime = Duration::from_millis(1);
408 assert!(now.elapsed() > short_lifetime);
409 }
410}