use super::error::DatabaseError;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DatabaseType {
Postgres,
Sqlite,
Mysql,
}
impl DatabaseType {
pub fn supports_transactional_ddl(&self) -> bool {
matches!(self, DatabaseType::Postgres | DatabaseType::Sqlite)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum QueryValue {
Null,
Bool(bool),
Int(i64),
Float(f64),
String(String),
Bytes(Vec<u8>),
Timestamp(chrono::DateTime<chrono::Utc>),
Uuid(Uuid),
Now,
}
impl From<&str> for QueryValue {
fn from(s: &str) -> Self {
QueryValue::String(s.to_string())
}
}
impl From<String> for QueryValue {
fn from(s: String) -> Self {
QueryValue::String(s)
}
}
impl From<i64> for QueryValue {
fn from(i: i64) -> Self {
QueryValue::Int(i)
}
}
impl From<i32> for QueryValue {
fn from(i: i32) -> Self {
QueryValue::Int(i as i64)
}
}
impl From<f64> for QueryValue {
fn from(f: f64) -> Self {
QueryValue::Float(f)
}
}
impl From<bool> for QueryValue {
fn from(b: bool) -> Self {
QueryValue::Bool(b)
}
}
impl From<chrono::DateTime<chrono::Utc>> for QueryValue {
fn from(dt: chrono::DateTime<chrono::Utc>) -> Self {
QueryValue::Timestamp(dt)
}
}
impl From<Uuid> for QueryValue {
fn from(u: Uuid) -> Self {
QueryValue::Uuid(u)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct QueryResult {
pub rows_affected: u64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Row {
pub data: HashMap<String, QueryValue>,
}
impl Row {
pub fn new() -> Self {
Self {
data: HashMap::new(),
}
}
pub fn insert(&mut self, key: String, value: QueryValue) {
self.data.insert(key, value);
}
pub fn get<T: TryFrom<QueryValue>>(&self, key: &str) -> std::result::Result<T, DatabaseError>
where
DatabaseError: From<<T as TryFrom<QueryValue>>::Error>,
{
self.data
.get(key)
.cloned()
.ok_or_else(|| DatabaseError::ColumnNotFound(key.to_string()))
.and_then(|v| v.try_into().map_err(Into::into))
}
}
impl Default for Row {
fn default() -> Self {
Self::new()
}
}
impl TryFrom<QueryValue> for i64 {
type Error = DatabaseError;
fn try_from(value: QueryValue) -> std::result::Result<Self, Self::Error> {
match value {
QueryValue::Int(i) => Ok(i),
_ => Err(DatabaseError::TypeError(format!(
"Cannot convert {:?} to i64",
value
))),
}
}
}
impl TryFrom<QueryValue> for i32 {
type Error = DatabaseError;
fn try_from(value: QueryValue) -> std::result::Result<Self, Self::Error> {
match value {
QueryValue::Int(i) => i32::try_from(i)
.map_err(|_| DatabaseError::TypeError(format!("Value {} out of range for i32", i))),
_ => Err(DatabaseError::TypeError(format!(
"Cannot convert {:?} to i32",
value
))),
}
}
}
impl TryFrom<QueryValue> for u64 {
type Error = DatabaseError;
fn try_from(value: QueryValue) -> std::result::Result<Self, Self::Error> {
match value {
QueryValue::Int(i) => u64::try_from(i)
.map_err(|_| DatabaseError::TypeError(format!("Value {} out of range for u64", i))),
_ => Err(DatabaseError::TypeError(format!(
"Cannot convert {:?} to u64",
value
))),
}
}
}
impl TryFrom<QueryValue> for u32 {
type Error = DatabaseError;
fn try_from(value: QueryValue) -> std::result::Result<Self, Self::Error> {
match value {
QueryValue::Int(i) => u32::try_from(i)
.map_err(|_| DatabaseError::TypeError(format!("Value {} out of range for u32", i))),
_ => Err(DatabaseError::TypeError(format!(
"Cannot convert {:?} to u32",
value
))),
}
}
}
impl TryFrom<QueryValue> for String {
type Error = DatabaseError;
fn try_from(value: QueryValue) -> std::result::Result<Self, Self::Error> {
match value {
QueryValue::String(s) => Ok(s),
_ => Err(DatabaseError::TypeError(format!(
"Cannot convert {:?} to String",
value
))),
}
}
}
impl TryFrom<QueryValue> for bool {
type Error = DatabaseError;
fn try_from(value: QueryValue) -> std::result::Result<Self, Self::Error> {
match value {
QueryValue::Bool(b) => Ok(b),
_ => Err(DatabaseError::TypeError(format!(
"Cannot convert {:?} to bool",
value
))),
}
}
}
impl TryFrom<QueryValue> for f64 {
type Error = DatabaseError;
fn try_from(value: QueryValue) -> std::result::Result<Self, Self::Error> {
match value {
QueryValue::Float(f) => Ok(f),
_ => Err(DatabaseError::TypeError(format!(
"Cannot convert {:?} to f64",
value
))),
}
}
}
impl TryFrom<QueryValue> for chrono::DateTime<chrono::Utc> {
type Error = DatabaseError;
fn try_from(value: QueryValue) -> std::result::Result<Self, Self::Error> {
match value {
QueryValue::Timestamp(dt) => Ok(dt),
_ => Err(DatabaseError::TypeError(format!(
"Cannot convert {:?} to DateTime<Utc>",
value
))),
}
}
}
impl TryFrom<QueryValue> for Uuid {
type Error = DatabaseError;
fn try_from(value: QueryValue) -> std::result::Result<Self, Self::Error> {
match value {
QueryValue::Uuid(u) => Ok(u),
QueryValue::String(s) => Uuid::parse_str(&s)
.map_err(|_| DatabaseError::TypeError(format!("Invalid UUID string: {}", s))),
_ => Err(DatabaseError::TypeError(format!(
"Cannot convert {:?} to Uuid",
value
))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IsolationLevel {
ReadUncommitted,
#[default]
ReadCommitted,
RepeatableRead,
Serializable,
}
impl IsolationLevel {
pub fn to_sql(&self, db_type: DatabaseType) -> &'static str {
match (self, db_type) {
(IsolationLevel::ReadUncommitted, _) => "READ UNCOMMITTED",
(IsolationLevel::ReadCommitted, _) => "READ COMMITTED",
(IsolationLevel::RepeatableRead, _) => "REPEATABLE READ",
(IsolationLevel::Serializable, _) => "SERIALIZABLE",
}
}
pub fn begin_transaction_sql(&self, db_type: DatabaseType) -> String {
match db_type {
DatabaseType::Postgres => {
format!("BEGIN ISOLATION LEVEL {}", self.to_sql(db_type))
}
DatabaseType::Mysql => {
format!(
"SET TRANSACTION ISOLATION LEVEL {}; START TRANSACTION",
self.to_sql(db_type)
)
}
DatabaseType::Sqlite => {
match self {
IsolationLevel::Serializable => "BEGIN EXCLUSIVE".to_string(),
_ => "BEGIN".to_string(),
}
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Savepoint {
name: String,
}
impl Savepoint {
pub fn new(name: impl Into<String>) -> Self {
let name = name.into();
validate_savepoint_name(&name).unwrap_or_else(|e| panic!("Invalid savepoint name: {}", e));
Self { name }
}
pub fn name(&self) -> &str {
&self.name
}
pub fn to_sql(&self) -> String {
format!("SAVEPOINT \"{}\"", self.name.replace('"', "\"\""))
}
pub fn release_sql(&self) -> String {
format!("RELEASE SAVEPOINT \"{}\"", self.name.replace('"', "\"\""))
}
pub fn rollback_sql(&self) -> String {
format!(
"ROLLBACK TO SAVEPOINT \"{}\"",
self.name.replace('"', "\"\"")
)
}
}
fn validate_savepoint_name(name: &str) -> Result<(), String> {
if name.is_empty() {
return Err("Savepoint name cannot be empty".to_string());
}
if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(format!(
"Savepoint name '{}' contains invalid characters. Only alphanumeric characters and underscores are allowed",
name
));
}
if let Some(first_char) = name.chars().next()
&& first_char.is_numeric()
{
return Err(format!(
"Savepoint name '{}' cannot start with a number",
name
));
}
Ok(())
}
#[async_trait::async_trait]
pub trait TransactionExecutor: Send + Sync {
async fn execute(
&mut self,
sql: &str,
params: Vec<QueryValue>,
) -> super::error::Result<QueryResult>;
async fn fetch_one(&mut self, sql: &str, params: Vec<QueryValue>) -> super::error::Result<Row>;
async fn fetch_all(
&mut self,
sql: &str,
params: Vec<QueryValue>,
) -> super::error::Result<Vec<Row>>;
async fn fetch_optional(
&mut self,
sql: &str,
params: Vec<QueryValue>,
) -> super::error::Result<Option<Row>>;
async fn commit(self: Box<Self>) -> super::error::Result<()>;
async fn rollback(self: Box<Self>) -> super::error::Result<()>;
async fn savepoint(&mut self, name: &str) -> super::error::Result<()> {
let _ = name;
Err(super::error::DatabaseError::NotSupported(
"Savepoints are not supported by this backend".to_string(),
))
}
async fn release_savepoint(&mut self, name: &str) -> super::error::Result<()> {
let _ = name;
Err(super::error::DatabaseError::NotSupported(
"Savepoints are not supported by this backend".to_string(),
))
}
async fn rollback_to_savepoint(&mut self, name: &str) -> super::error::Result<()> {
let _ = name;
Err(super::error::DatabaseError::NotSupported(
"Savepoints are not supported by this backend".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
fn test_savepoint_valid_name() {
let sp = Savepoint::new("sp1");
assert_eq!(sp.name(), "sp1");
assert_eq!(sp.to_sql(), "SAVEPOINT \"sp1\"");
assert_eq!(sp.release_sql(), "RELEASE SAVEPOINT \"sp1\"");
assert_eq!(sp.rollback_sql(), "ROLLBACK TO SAVEPOINT \"sp1\"");
}
#[rstest]
fn test_savepoint_valid_underscore_name() {
let sp = Savepoint::new("my_savepoint_1");
assert_eq!(sp.to_sql(), "SAVEPOINT \"my_savepoint_1\"");
}
#[rstest]
#[should_panic(expected = "Invalid savepoint name")]
fn test_savepoint_rejects_sql_injection_semicolon() {
Savepoint::new("sp1; DROP TABLE users; --");
}
#[rstest]
#[should_panic(expected = "Invalid savepoint name")]
fn test_savepoint_rejects_sql_injection_quotes() {
Savepoint::new("sp1\" ; DROP TABLE users; --");
}
#[rstest]
#[should_panic(expected = "Invalid savepoint name")]
fn test_savepoint_rejects_empty_name() {
Savepoint::new("");
}
#[rstest]
#[should_panic(expected = "Invalid savepoint name")]
fn test_savepoint_rejects_name_starting_with_number() {
Savepoint::new("1invalid");
}
#[rstest]
#[should_panic(expected = "Invalid savepoint name")]
fn test_savepoint_rejects_spaces() {
Savepoint::new("sp 1");
}
#[rstest]
fn test_validate_savepoint_name_valid() {
assert!(validate_savepoint_name("sp1").is_ok());
assert!(validate_savepoint_name("my_savepoint").is_ok());
assert!(validate_savepoint_name("_internal").is_ok());
}
#[rstest]
fn test_validate_savepoint_name_rejects_injection() {
assert!(validate_savepoint_name("sp; DROP TABLE").is_err());
assert!(validate_savepoint_name("sp\"injection").is_err());
assert!(validate_savepoint_name("sp' OR '1'='1").is_err());
assert!(validate_savepoint_name("").is_err());
}
}