pub mod default;
pub mod mysql;
pub mod postgres;
pub mod sqlite;
pub use default::{DefaultDriver, DefaultDriverFactory};
pub use mysql::{MysqlDriver, MysqlDriverFactory};
pub use postgres::{PostgresDriver, PostgresDriverFactory};
pub use sqlite::{SqliteDriver, SqliteDriverFactory};
use crate::error::Error;
#[derive(Debug, Clone)]
pub struct ConnectionConfig {
pub db_type: String,
pub host: String,
pub port: u16,
pub username: String,
pub password: String,
pub database: String,
pub ssl_mode: SslMode,
pub max_connections: usize,
}
impl Default for ConnectionConfig {
fn default() -> Self {
ConnectionConfig {
db_type: "".to_string(),
host: "".to_string(),
port: 0,
username: "".to_string(),
password: "".to_string(),
database: "".to_string(),
ssl_mode: SslMode::Disable,
max_connections: 10,
}
}
}
impl ConnectionConfig {
pub fn new(host: &str, port: u16, username: &str, password: &str, database: &str) -> Self {
ConnectionConfig {
db_type: String::new(),
host: host.to_string(),
port,
username: username.to_string(),
password: password.to_string(),
database: database.to_string(),
ssl_mode: SslMode::Disable,
max_connections: 10,
}
}
pub fn sqlite(db_path: &str) -> Self {
ConnectionConfig {
db_type: "sqlite".to_string(),
host: db_path.to_string(),
port: 0,
username: String::new(),
password: String::new(),
database: db_path.to_string(),
ssl_mode: SslMode::Disable,
max_connections: 10,
}
}
pub fn postgres(host: &str, port: u16, database: &str, username: &str, password: &str) -> Self {
ConnectionConfig {
db_type: "postgres".to_string(),
host: host.to_string(),
port,
username: username.to_string(),
password: password.to_string(),
database: database.to_string(),
ssl_mode: SslMode::Disable,
max_connections: 10,
}
}
pub fn mysql(host: &str, port: u16, database: &str, username: &str, password: &str) -> Self {
ConnectionConfig {
db_type: "mysql".to_string(),
host: host.to_string(),
port,
username: username.to_string(),
password: password.to_string(),
database: database.to_string(),
ssl_mode: SslMode::Disable,
max_connections: 10,
}
}
pub fn with_ssl(mut self, mode: SslMode) -> Self {
self.ssl_mode = mode;
self
}
pub fn with_max_connections(mut self, max: usize) -> Self {
self.max_connections = max;
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SslMode {
Disable,
Prefer,
Require,
VerifyCa,
VerifyFull,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DatabaseType {
Postgresql,
Mysql,
Sqlite,
None,
}
pub struct QueryResult {
pub rows: Vec<Row>,
pub affected_rows: u64,
pub last_insert_id: Option<i64>,
}
pub struct Row {
columns: Vec<Column>,
values: Vec<Option<String>>,
}
impl Row {
pub fn new() -> Self {
Row {
columns: Vec::new(),
values: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Row {
columns: Vec::with_capacity(capacity),
values: Vec::with_capacity(capacity),
}
}
pub fn push(&mut self, column: Column, value: Option<String>) {
self.columns.push(column);
self.values.push(value);
}
pub fn get(&self, idx: usize) -> Option<&str> {
self.values.get(idx).and_then(|v| v.as_deref())
}
pub fn get_by_name(&self, name: &str) -> Option<&str> {
self.columns
.iter()
.position(|c| c.name == name)
.and_then(|idx| self.get(idx))
}
pub fn column_count(&self) -> usize {
self.columns.len()
}
pub fn iter(&self) -> RowIter<'_> {
RowIter { row: self, idx: 0 }
}
}
pub struct RowIter<'a> {
row: &'a Row,
idx: usize,
}
impl<'a> Iterator for RowIter<'a> {
type Item = (&'a Column, Option<&'a str>);
fn next(&mut self) -> Option<Self::Item> {
if self.idx >= self.row.columns.len() {
None
} else {
let col = &self.row.columns[self.idx];
let val = self.row.values[self.idx].as_deref();
self.idx += 1;
Some((col, val))
}
}
}
#[derive(Debug, Clone)]
pub struct Column {
pub name: String,
pub data_type: DataType,
pub nullable: bool,
}
impl Column {
pub fn new(name: &str, data_type: DataType) -> Self {
Column {
name: name.to_string(),
data_type,
nullable: true,
}
}
pub fn not_null(mut self) -> Self {
self.nullable = false;
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum DataType {
Boolean,
Int2,
Int4,
Int8,
Float4,
Float8,
Text,
Varchar(usize),
Char(usize),
Date,
Time,
Timestamp,
Json,
Jsonb,
Uuid,
Bytea,
Array(Box<DataType>),
Custom(String),
}
#[derive(Debug, Clone)]
pub enum Parameter {
Null,
Int(i64),
Float(f64),
String(String),
Bool(bool),
Bytes(Vec<u8>),
}
impl Parameter {
pub fn as_sql_string(&self, _db_type: DatabaseType) -> String {
match self {
Parameter::Null => "NULL".to_string(),
Parameter::Int(v) => v.to_string(),
Parameter::Float(v) => v.to_string(),
Parameter::Bool(v) => v.to_string(),
Parameter::String(v) => {
let escaped = v.replace('\'', "''");
format!("'{}'", escaped)
}
Parameter::Bytes(v) => {
format!("'\\x{}'", hex::encode(v))
}
}
}
}
impl From<i32> for Parameter {
fn from(v: i32) -> Self {
Parameter::Int(v as i64)
}
}
impl From<i64> for Parameter {
fn from(v: i64) -> Self {
Parameter::Int(v)
}
}
impl From<&str> for Parameter {
fn from(v: &str) -> Self {
Parameter::String(v.to_string())
}
}
impl From<String> for Parameter {
fn from(v: String) -> Self {
Parameter::String(v)
}
}
impl From<bool> for Parameter {
fn from(v: bool) -> Self {
Parameter::Bool(v)
}
}
impl From<f64> for Parameter {
fn from(v: f64) -> Self {
Parameter::Float(v)
}
}
pub trait DatabaseDriver: Send + Sync {
fn db_type(&self) -> DatabaseType;
fn connect(&mut self, config: &ConnectionConfig) -> Result<(), Error>;
fn close(&mut self) -> Result<(), Error>;
fn query(&mut self, sql: &str, params: &[Parameter]) -> Result<QueryResult, Error>;
fn execute(&mut self, sql: &str, params: &[Parameter]) -> Result<u64, Error>;
fn prepare(&mut self, name: &str, sql: &str) -> Result<(), Error>;
fn execute_prepared(&mut self, name: &str, params: &[Parameter]) -> Result<QueryResult, Error>;
fn begin(&mut self) -> Result<(), Error>;
fn commit(&mut self) -> Result<(), Error>;
fn rollback(&mut self) -> Result<(), Error>;
fn escape_identifier(&self, ident: &str) -> String;
fn last_insert_id(&mut self) -> Result<Option<i64>, Error>;
fn is_connected(&self) -> bool;
fn version(&mut self) -> Result<String, Error>;
fn limit_offset_clause(&self, limit: Option<usize>, offset: Option<usize>) -> String;
fn placeholder_style(&self) -> PlaceholderStyle;
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PlaceholderStyle {
DollarNumbered,
Positional,
PositionalSqlite,
Named,
}
pub trait DriverFactory: Send + Sync {
fn create(&self) -> Box<dyn DatabaseDriver>;
fn db_type(&self) -> DatabaseType;
}
pub struct DriverRegistry {
drivers: std::collections::HashMap<DatabaseType, Box<dyn DriverFactory>>,
}
impl DriverRegistry {
pub fn new() -> Self {
DriverRegistry {
drivers: std::collections::HashMap::new(),
}
}
pub fn register<F>(&mut self, factory: F)
where
F: DriverFactory + 'static,
{
let db_type = factory.db_type();
self.drivers.insert(db_type, Box::new(factory));
}
pub fn get(&self, db_type: DatabaseType) -> Option<&dyn DriverFactory> {
self.drivers.get(&db_type).map(|f| f.as_ref())
}
pub fn create_driver(&self, db_type: DatabaseType) -> Option<Box<dyn DatabaseDriver>> {
self.get(db_type).map(|f| f.create())
}
}