use async_trait::async_trait;
use std::future::Future;
use std::pin::Pin;
use std::time::{Duration, Instant};
use crate::error::Result;
use crate::types::{Row, Value};
#[async_trait]
pub trait ConnectionLifecycle: Send + Sync {
fn created_at(&self) -> Instant;
fn age(&self) -> Duration {
self.created_at().elapsed()
}
fn is_expired(&self, max_lifetime: Duration) -> bool {
self.age() > max_lifetime
}
async fn idle_time(&self) -> Duration;
async fn is_idle_expired(&self, idle_timeout: Duration) -> bool {
self.idle_time().await > idle_timeout
}
async fn touch(&self);
}
#[async_trait]
pub trait Connection: Send + Sync {
async fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>>;
async fn execute(&self, sql: &str, params: &[Value]) -> Result<u64>;
async fn execute_batch(&self, statements: &[(&str, &[Value])]) -> Result<Vec<u64>> {
let mut results = Vec::with_capacity(statements.len());
for (sql, params) in statements {
results.push(self.execute(sql, params).await?);
}
Ok(results)
}
async fn prepare(&self, sql: &str) -> Result<Box<dyn PreparedStatement>>;
async fn begin(&self) -> Result<Box<dyn Transaction>>;
async fn begin_with_isolation(
&self,
isolation: IsolationLevel,
) -> Result<Box<dyn Transaction>> {
let tx = self.begin().await?;
tx.set_isolation_level(isolation).await?;
Ok(tx)
}
async fn query_stream(&self, sql: &str, params: &[Value]) -> Result<Pin<Box<dyn RowStream>>>;
async fn query_one(&self, sql: &str, params: &[Value]) -> Result<Option<Row>> {
let rows = self.query(sql, params).await?;
Ok(rows.into_iter().next())
}
async fn is_valid(&self) -> bool;
async fn close(&self) -> Result<()>;
}
#[async_trait]
pub trait PreparedStatement: Send + Sync {
async fn execute(&self, params: &[Value]) -> Result<u64>;
async fn query(&self, params: &[Value]) -> Result<Vec<Row>>;
fn sql(&self) -> &str;
}
#[async_trait]
pub trait Transaction: Send + Sync {
async fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>>;
async fn execute(&self, sql: &str, params: &[Value]) -> Result<u64>;
async fn execute_batch(&self, statements: &[(&str, &[Value])]) -> Result<Vec<u64>> {
let mut results = Vec::with_capacity(statements.len());
for (sql, params) in statements {
results.push(self.execute(sql, params).await?);
}
Ok(results)
}
async fn commit(self: Box<Self>) -> Result<()>;
async fn rollback(self: Box<Self>) -> Result<()>;
async fn set_isolation_level(&self, level: IsolationLevel) -> Result<()>;
async fn savepoint(&self, name: &str) -> Result<()>;
async fn rollback_to_savepoint(&self, name: &str) -> Result<()>;
async fn release_savepoint(&self, name: &str) -> Result<()>;
}
pub trait RowStream: Send {
fn next(&mut self) -> Pin<Box<dyn Future<Output = Result<Option<Row>>> + Send + '_>>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
RepeatableRead,
Serializable,
Snapshot,
}
impl IsolationLevel {
pub fn to_sql(&self) -> &'static str {
match self {
Self::ReadUncommitted => "READ UNCOMMITTED",
Self::ReadCommitted => "READ COMMITTED",
Self::RepeatableRead => "REPEATABLE READ",
Self::Serializable => "SERIALIZABLE",
Self::Snapshot => "SNAPSHOT",
}
}
}
impl std::fmt::Display for IsolationLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_sql())
}
}
#[derive(Clone)]
pub struct ConnectionConfig {
pub url: String,
pub connect_timeout_ms: u64,
pub query_timeout_ms: u64,
pub statement_cache_size: usize,
pub application_name: Option<String>,
pub properties: std::collections::HashMap<String, String>,
}
impl std::fmt::Debug for ConnectionConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let redacted_url = match url::Url::parse(&self.url) {
Ok(mut parsed) => {
if parsed.password().is_some() {
let _ = parsed.set_password(Some("***"));
}
parsed.to_string()
}
Err(_) => "***".to_string(),
};
f.debug_struct("ConnectionConfig")
.field("url", &redacted_url)
.field("connect_timeout_ms", &self.connect_timeout_ms)
.field("query_timeout_ms", &self.query_timeout_ms)
.field("statement_cache_size", &self.statement_cache_size)
.field("application_name", &self.application_name)
.field("properties", &self.properties)
.finish()
}
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
url: String::new(),
connect_timeout_ms: 10_000,
query_timeout_ms: 30_000,
statement_cache_size: 100,
application_name: Some("rivven-rdbc".into()),
properties: std::collections::HashMap::new(),
}
}
}
impl ConnectionConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn with_connect_timeout(mut self, ms: u64) -> Self {
self.connect_timeout_ms = ms;
self
}
pub fn with_query_timeout(mut self, ms: u64) -> Self {
self.query_timeout_ms = ms;
self
}
pub fn with_statement_cache_size(mut self, size: usize) -> Self {
self.statement_cache_size = size;
self
}
pub fn with_application_name(mut self, name: impl Into<String>) -> Self {
self.application_name = Some(name.into());
self
}
pub fn with_property(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.properties.insert(key.into(), value.into());
self
}
}
#[async_trait]
pub trait ConnectionFactory: Send + Sync {
async fn connect(&self, config: &ConnectionConfig) -> Result<Box<dyn Connection>>;
fn database_type(&self) -> DatabaseType;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DatabaseType {
PostgreSQL,
MySQL,
SqlServer,
SQLite,
Oracle,
Unknown,
}
impl std::fmt::Display for DatabaseType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::PostgreSQL => write!(f, "PostgreSQL"),
Self::MySQL => write!(f, "MySQL"),
Self::SqlServer => write!(f, "SQL Server"),
Self::SQLite => write!(f, "SQLite"),
Self::Oracle => write!(f, "Oracle"),
Self::Unknown => write!(f, "Unknown"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_isolation_level_to_sql() {
assert_eq!(IsolationLevel::ReadCommitted.to_sql(), "READ COMMITTED");
assert_eq!(IsolationLevel::Serializable.to_sql(), "SERIALIZABLE");
}
#[test]
fn test_connection_config_builder() {
let config = ConnectionConfig::new("postgres://localhost/test")
.with_connect_timeout(5000)
.with_query_timeout(15000)
.with_application_name("myapp")
.with_property("sslmode", "require");
assert_eq!(config.url, "postgres://localhost/test");
assert_eq!(config.connect_timeout_ms, 5000);
assert_eq!(config.query_timeout_ms, 15000);
assert_eq!(config.application_name, Some("myapp".into()));
assert_eq!(config.properties.get("sslmode"), Some(&"require".into()));
}
#[test]
fn test_database_type_display() {
assert_eq!(format!("{}", DatabaseType::PostgreSQL), "PostgreSQL");
assert_eq!(format!("{}", DatabaseType::MySQL), "MySQL");
assert_eq!(format!("{}", DatabaseType::SqlServer), "SQL Server");
assert_eq!(format!("{}", DatabaseType::SQLite), "SQLite");
assert_eq!(format!("{}", DatabaseType::Oracle), "Oracle");
assert_eq!(format!("{}", DatabaseType::Unknown), "Unknown");
}
#[test]
fn test_isolation_level_display() {
assert_eq!(
format!("{}", IsolationLevel::ReadUncommitted),
"READ UNCOMMITTED"
);
assert_eq!(
format!("{}", IsolationLevel::ReadCommitted),
"READ COMMITTED"
);
assert_eq!(
format!("{}", IsolationLevel::RepeatableRead),
"REPEATABLE READ"
);
assert_eq!(format!("{}", IsolationLevel::Serializable), "SERIALIZABLE");
assert_eq!(format!("{}", IsolationLevel::Snapshot), "SNAPSHOT");
}
#[test]
fn test_connection_lifecycle_defaults() {
let now = Instant::now();
let age = now.elapsed();
assert!(age < Duration::from_secs(1));
let max_lifetime = Duration::from_secs(1800);
assert!(age <= max_lifetime);
std::thread::sleep(Duration::from_millis(5));
let short_lifetime = Duration::from_millis(1);
assert!(now.elapsed() > short_lifetime);
}
}