use crate::error::Result;
use crate::row::Row;
use crate::value::Value;
use asupersync::{Cx, Outcome};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IsolationLevel {
ReadUncommitted,
#[default]
ReadCommitted,
RepeatableRead,
Serializable,
}
impl IsolationLevel {
#[must_use]
pub const fn as_sql(&self) -> &'static str {
match self {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
}
}
}
#[derive(Debug, Clone)]
pub struct PreparedStatement {
id: u64,
sql: String,
param_count: usize,
columns: Option<Vec<String>>,
}
impl PreparedStatement {
#[must_use]
pub fn new(id: u64, sql: String, param_count: usize) -> Self {
Self {
id,
sql,
param_count,
columns: None,
}
}
#[must_use]
pub fn with_columns(id: u64, sql: String, param_count: usize, columns: Vec<String>) -> Self {
Self {
id,
sql,
param_count,
columns: Some(columns),
}
}
#[must_use]
pub const fn id(&self) -> u64 {
self.id
}
#[must_use]
pub fn sql(&self) -> &str {
&self.sql
}
#[must_use]
pub const fn param_count(&self) -> usize {
self.param_count
}
#[must_use]
pub fn columns(&self) -> Option<&[String]> {
self.columns.as_deref()
}
#[must_use]
pub fn validate_params(&self, params: &[Value]) -> bool {
params.len() == self.param_count
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Dialect {
#[default]
Postgres,
Sqlite,
Mysql,
}
impl Dialect {
pub fn placeholder(self, index: usize) -> String {
match self {
Dialect::Postgres => format!("${index}"),
Dialect::Sqlite => format!("?{index}"),
Dialect::Mysql => "?".to_string(),
}
}
pub const fn concat_op(self) -> &'static str {
match self {
Dialect::Postgres | Dialect::Sqlite => "||",
Dialect::Mysql => "", }
}
pub const fn supports_ilike(self) -> bool {
matches!(self, Dialect::Postgres)
}
pub fn quote_identifier(self, name: &str) -> String {
match self {
Dialect::Postgres | Dialect::Sqlite => {
let escaped = name.replace('"', "\"\"");
format!("\"{escaped}\"")
}
Dialect::Mysql => {
let escaped = name.replace('`', "``");
format!("`{escaped}`")
}
}
}
}
pub trait Connection: Send + Sync {
type Tx<'conn>: TransactionOps
where
Self: 'conn;
fn dialect(&self) -> Dialect {
Dialect::Postgres
}
fn query(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<Vec<Row>, crate::Error>> + Send;
fn query_one(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<Option<Row>, crate::Error>> + Send;
fn execute(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<u64, crate::Error>> + Send;
fn insert(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<i64, crate::Error>> + Send;
fn batch(
&self,
cx: &Cx,
statements: &[(String, Vec<Value>)],
) -> impl Future<Output = Outcome<Vec<u64>, crate::Error>> + Send;
fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, crate::Error>> + Send;
fn begin_with(
&self,
cx: &Cx,
isolation: IsolationLevel,
) -> impl Future<Output = Outcome<Self::Tx<'_>, crate::Error>> + Send;
fn prepare(
&self,
cx: &Cx,
sql: &str,
) -> impl Future<Output = Outcome<PreparedStatement, crate::Error>> + Send;
fn query_prepared(
&self,
cx: &Cx,
stmt: &PreparedStatement,
params: &[Value],
) -> impl Future<Output = Outcome<Vec<Row>, crate::Error>> + Send;
fn execute_prepared(
&self,
cx: &Cx,
stmt: &PreparedStatement,
params: &[Value],
) -> impl Future<Output = Outcome<u64, crate::Error>> + Send;
fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), crate::Error>> + Send;
fn is_valid(&self, cx: &Cx) -> impl Future<Output = bool> + Send {
async {
match self.ping(cx).await {
Outcome::Ok(()) => true,
Outcome::Err(_) | Outcome::Cancelled(_) | Outcome::Panicked(_) => false,
}
}
}
fn close(self, cx: &Cx) -> impl Future<Output = Result<()>> + Send;
}
pub trait TransactionOps: Send {
fn query(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<Vec<Row>, crate::Error>> + Send;
fn query_one(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<Option<Row>, crate::Error>> + Send;
fn execute(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<u64, crate::Error>> + Send;
fn savepoint(
&self,
cx: &Cx,
name: &str,
) -> impl Future<Output = Outcome<(), crate::Error>> + Send;
fn rollback_to(
&self,
cx: &Cx,
name: &str,
) -> impl Future<Output = Outcome<(), crate::Error>> + Send;
fn release(
&self,
cx: &Cx,
name: &str,
) -> impl Future<Output = Outcome<(), crate::Error>> + Send;
fn commit(self, cx: &Cx) -> impl Future<Output = Outcome<(), crate::Error>> + Send;
fn rollback(self, cx: &Cx) -> impl Future<Output = Outcome<(), crate::Error>> + Send;
}
pub struct Transaction<'conn> {
conn: &'conn dyn TransactionInternal,
finalized: bool,
}
pub trait TransactionInternal: Send + Sync {
fn query_internal(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> std::pin::Pin<Box<dyn Future<Output = Outcome<Vec<Row>, crate::Error>> + Send + '_>>;
fn query_one_internal(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> std::pin::Pin<Box<dyn Future<Output = Outcome<Option<Row>, crate::Error>> + Send + '_>>;
fn execute_internal(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> std::pin::Pin<Box<dyn Future<Output = Outcome<u64, crate::Error>> + Send + '_>>;
fn savepoint_internal(
&self,
cx: &Cx,
name: &str,
) -> std::pin::Pin<Box<dyn Future<Output = Outcome<(), crate::Error>> + Send + '_>>;
fn rollback_to_internal(
&self,
cx: &Cx,
name: &str,
) -> std::pin::Pin<Box<dyn Future<Output = Outcome<(), crate::Error>> + Send + '_>>;
fn release_internal(
&self,
cx: &Cx,
name: &str,
) -> std::pin::Pin<Box<dyn Future<Output = Outcome<(), crate::Error>> + Send + '_>>;
fn commit_internal(
&self,
cx: &Cx,
) -> std::pin::Pin<Box<dyn Future<Output = Outcome<(), crate::Error>> + Send + '_>>;
fn rollback_internal(
&self,
cx: &Cx,
) -> std::pin::Pin<Box<dyn Future<Output = Outcome<(), crate::Error>> + Send + '_>>;
}
impl<'conn> Transaction<'conn> {
pub fn new(conn: &'conn dyn TransactionInternal) -> Self {
Self {
conn,
finalized: false,
}
}
#[must_use]
pub const fn is_finalized(&self) -> bool {
self.finalized
}
}
impl TransactionOps for Transaction<'_> {
fn query(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<Vec<Row>, crate::Error>> + Send {
self.conn.query_internal(cx, sql, params)
}
fn query_one(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<Option<Row>, crate::Error>> + Send {
self.conn.query_one_internal(cx, sql, params)
}
fn execute(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<u64, crate::Error>> + Send {
self.conn.execute_internal(cx, sql, params)
}
fn savepoint(
&self,
cx: &Cx,
name: &str,
) -> impl Future<Output = Outcome<(), crate::Error>> + Send {
self.conn.savepoint_internal(cx, name)
}
fn rollback_to(
&self,
cx: &Cx,
name: &str,
) -> impl Future<Output = Outcome<(), crate::Error>> + Send {
self.conn.rollback_to_internal(cx, name)
}
fn release(
&self,
cx: &Cx,
name: &str,
) -> impl Future<Output = Outcome<(), crate::Error>> + Send {
self.conn.release_internal(cx, name)
}
async fn commit(mut self, cx: &Cx) -> Outcome<(), crate::Error> {
self.finalized = true;
self.conn.commit_internal(cx).await
}
async fn rollback(mut self, cx: &Cx) -> Outcome<(), crate::Error> {
self.finalized = true;
self.conn.rollback_internal(cx).await
}
}
use std::future::Future;
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if !self.finalized {
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionConfig {
pub url: String,
pub connect_timeout_ms: u64,
pub query_timeout_ms: u64,
pub ssl_mode: SslMode,
pub application_name: Option<String>,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum SslMode {
Disable,
#[default]
Prefer,
Require,
VerifyCa,
VerifyFull,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
url: String::new(),
connect_timeout_ms: 30_000,
query_timeout_ms: 30_000,
ssl_mode: SslMode::default(),
application_name: None,
}
}
}
impl ConnectionConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn connect_timeout(mut self, ms: u64) -> Self {
self.connect_timeout_ms = ms;
self
}
pub fn query_timeout(mut self, ms: u64) -> Self {
self.query_timeout_ms = ms;
self
}
pub fn ssl_mode(mut self, mode: SslMode) -> Self {
self.ssl_mode = mode;
self
}
pub fn application_name(mut self, name: impl Into<String>) -> Self {
self.application_name = Some(name.into());
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_isolation_level_default() {
let level = IsolationLevel::default();
assert_eq!(level, IsolationLevel::ReadCommitted);
}
#[test]
fn test_isolation_level_as_sql() {
assert_eq!(IsolationLevel::ReadUncommitted.as_sql(), "READ UNCOMMITTED");
assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
assert_eq!(IsolationLevel::RepeatableRead.as_sql(), "REPEATABLE READ");
assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
}
#[test]
fn test_prepared_statement_new() {
let stmt = PreparedStatement::new(1, "SELECT * FROM users WHERE id = $1".to_string(), 1);
assert_eq!(stmt.id(), 1);
assert_eq!(stmt.sql(), "SELECT * FROM users WHERE id = $1");
assert_eq!(stmt.param_count(), 1);
assert!(stmt.columns().is_none());
}
#[test]
fn test_prepared_statement_with_columns() {
let stmt = PreparedStatement::with_columns(
2,
"SELECT id, name FROM users".to_string(),
0,
vec!["id".to_string(), "name".to_string()],
);
assert_eq!(stmt.id(), 2);
assert_eq!(stmt.param_count(), 0);
assert_eq!(
stmt.columns(),
Some(&["id".to_string(), "name".to_string()][..])
);
}
#[test]
fn test_prepared_statement_validate_params() {
let stmt = PreparedStatement::new(1, "SELECT $1, $2".to_string(), 2);
assert!(!stmt.validate_params(&[]));
assert!(!stmt.validate_params(&[Value::Int(1)]));
assert!(stmt.validate_params(&[Value::Int(1), Value::Int(2)]));
assert!(!stmt.validate_params(&[Value::Int(1), Value::Int(2), Value::Int(3)]));
}
#[test]
fn test_ssl_mode_default() {
let mode = SslMode::default();
assert!(matches!(mode, SslMode::Prefer));
}
#[test]
fn test_connection_config_builder() {
let config = ConnectionConfig::new("postgres://localhost/test")
.connect_timeout(5000)
.query_timeout(10000)
.ssl_mode(SslMode::Require)
.application_name("test_app");
assert_eq!(config.url, "postgres://localhost/test");
assert_eq!(config.connect_timeout_ms, 5000);
assert_eq!(config.query_timeout_ms, 10000);
assert!(matches!(config.ssl_mode, SslMode::Require));
assert_eq!(config.application_name, Some("test_app".to_string()));
}
#[test]
fn test_connection_config_default() {
let config = ConnectionConfig::default();
assert_eq!(config.url, "");
assert_eq!(config.connect_timeout_ms, 30_000);
assert_eq!(config.query_timeout_ms, 30_000);
assert!(matches!(config.ssl_mode, SslMode::Prefer));
assert!(config.application_name.is_none());
}
}