elif_orm/backends/
core.rs

1//! Core Database Backend Traits
2//!
3//! This module defines the core traits and types for database backend abstraction.
4//! These traits abstract away database-specific implementations and provide a unified
5//! interface for the ORM to work with different database systems.
6
7use crate::error::{OrmError, OrmResult};
8use async_trait::async_trait;
9use serde_json::Value as JsonValue;
10use std::collections::HashMap;
11use std::sync::Arc;
12
13/// Abstract database connection trait
14#[async_trait]
15pub trait DatabaseConnection: Send + Sync {
16    /// Execute a query and return affected rows count
17    async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
18
19    /// Execute a query and return the result rows
20    async fn fetch_all(
21        &mut self,
22        sql: &str,
23        params: &[DatabaseValue],
24    ) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
25
26    /// Execute a query and return the first result row
27    async fn fetch_optional(
28        &mut self,
29        sql: &str,
30        params: &[DatabaseValue],
31    ) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
32
33    /// Begin a transaction
34    async fn begin_transaction(&mut self) -> OrmResult<Box<dyn DatabaseTransaction>>;
35
36    /// Close the connection
37    async fn close(&mut self) -> OrmResult<()>;
38}
39
40/// Abstract database transaction trait
41#[async_trait]
42pub trait DatabaseTransaction: Send + Sync {
43    /// Execute a query within the transaction
44    async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
45
46    /// Execute a query and return result rows within the transaction
47    async fn fetch_all(
48        &mut self,
49        sql: &str,
50        params: &[DatabaseValue],
51    ) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
52
53    /// Execute a query and return the first result row within the transaction
54    async fn fetch_optional(
55        &mut self,
56        sql: &str,
57        params: &[DatabaseValue],
58    ) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
59
60    /// Commit the transaction
61    async fn commit(self: Box<Self>) -> OrmResult<()>;
62
63    /// Rollback the transaction
64    async fn rollback(self: Box<Self>) -> OrmResult<()>;
65}
66
67/// Abstract database connection pool trait
68#[async_trait]
69pub trait DatabasePool: Send + Sync {
70    /// Acquire a connection from the pool
71    async fn acquire(&self) -> OrmResult<Box<dyn DatabaseConnection>>;
72
73    /// Begin a transaction from the pool
74    async fn begin_transaction(&self) -> OrmResult<Box<dyn DatabaseTransaction>>;
75
76    /// Execute a query directly on the pool
77    async fn execute(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
78
79    /// Execute a query and return result rows directly on the pool
80    async fn fetch_all(
81        &self,
82        sql: &str,
83        params: &[DatabaseValue],
84    ) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
85
86    /// Execute a query and return the first result row directly on the pool
87    async fn fetch_optional(
88        &self,
89        sql: &str,
90        params: &[DatabaseValue],
91    ) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
92
93    /// Close the pool
94    async fn close(&self) -> OrmResult<()>;
95
96    /// Get pool statistics
97    fn stats(&self) -> DatabasePoolStats;
98
99    /// Perform a health check on the pool
100    async fn health_check(&self) -> OrmResult<std::time::Duration>;
101}
102
103/// Database pool statistics
104#[derive(Debug, Clone)]
105pub struct DatabasePoolStats {
106    pub total_connections: u32,
107    pub idle_connections: u32,
108    pub active_connections: u32,
109}
110
111/// Abstract database row trait
112pub trait DatabaseRow: Send + Sync {
113    /// Get a column value by index
114    fn get_by_index(&self, index: usize) -> OrmResult<DatabaseValue>;
115
116    /// Get a column value by name
117    fn get_by_name(&self, name: &str) -> OrmResult<DatabaseValue>;
118
119    /// Get column count
120    fn column_count(&self) -> usize;
121
122    /// Get column names
123    fn column_names(&self) -> Vec<String>;
124
125    /// Convert row to JSON value
126    fn to_json(&self) -> OrmResult<JsonValue>;
127
128    /// Convert row to HashMap
129    fn to_map(&self) -> OrmResult<HashMap<String, DatabaseValue>>;
130}
131
132/// Extension trait for DatabaseRow to support typed column access for models
133pub trait DatabaseRowExt {
134    /// Get a typed value from a column (for model deserialization)
135    fn get<T>(&self, column: &str) -> Result<T, crate::error::ModelError>
136    where
137        T: for<'de> serde::Deserialize<'de>;
138
139    /// Try to get an optional typed value from a column
140    fn try_get<T>(&self, column: &str) -> Result<Option<T>, crate::error::ModelError>
141    where
142        T: for<'de> serde::Deserialize<'de>;
143}
144
145impl<R: DatabaseRow + ?Sized> DatabaseRowExt for R {
146    fn get<T>(&self, column: &str) -> Result<T, crate::error::ModelError>
147    where
148        T: for<'de> serde::Deserialize<'de>,
149    {
150        let db_value = self.get_by_name(column)?;
151
152        let json_value = db_value.to_json();
153        serde_json::from_value(json_value).map_err(|e| {
154            crate::error::ModelError::Serialization(format!(
155                "Failed to deserialize column '{}': {}",
156                column, e
157            ))
158        })
159    }
160
161    fn try_get<T>(&self, column: &str) -> Result<Option<T>, crate::error::ModelError>
162    where
163        T: for<'de> serde::Deserialize<'de>,
164    {
165        match self.get_by_name(column) {
166            Ok(db_value) => {
167                if db_value.is_null() {
168                    Ok(None)
169                } else {
170                    let json_value = db_value.to_json();
171                    let parsed: T = serde_json::from_value(json_value).map_err(|e| {
172                        crate::error::ModelError::Serialization(format!(
173                            "Failed to deserialize column '{}': {}",
174                            column, e
175                        ))
176                    })?;
177                    Ok(Some(parsed))
178                }
179            }
180            Err(crate::error::ModelError::ColumnNotFound(_)) => Ok(None),
181            Err(e) => Err(e), // Preserve the original error type and information
182        }
183    }
184}
185
186/// Database value enumeration for type-safe parameter binding
187#[derive(Debug, Clone, PartialEq)]
188pub enum DatabaseValue {
189    Null,
190    Bool(bool),
191    Int32(i32),
192    Int64(i64),
193    Float32(f32),
194    Float64(f64),
195    String(String),
196    Bytes(Vec<u8>),
197    Uuid(uuid::Uuid),
198    DateTime(chrono::DateTime<chrono::Utc>),
199    Date(chrono::NaiveDate),
200    Time(chrono::NaiveTime),
201    Json(JsonValue),
202    Array(Vec<DatabaseValue>),
203}
204
205impl DatabaseValue {
206    /// Check if the value is null
207    pub fn is_null(&self) -> bool {
208        matches!(self, DatabaseValue::Null)
209    }
210
211    /// Convert to JSON value
212    pub fn to_json(&self) -> JsonValue {
213        match self {
214            DatabaseValue::Null => JsonValue::Null,
215            DatabaseValue::Bool(b) => JsonValue::Bool(*b),
216            DatabaseValue::Int32(i) => JsonValue::Number(serde_json::Number::from(*i)),
217            DatabaseValue::Int64(i) => JsonValue::Number(serde_json::Number::from(*i)),
218            DatabaseValue::Float32(f) => JsonValue::Number(
219                serde_json::Number::from_f64(*f as f64)
220                    .unwrap_or_else(|| serde_json::Number::from(0)),
221            ),
222            DatabaseValue::Float64(f) => serde_json::Number::from_f64(*f)
223                .map(JsonValue::Number)
224                .unwrap_or(JsonValue::Null),
225            DatabaseValue::String(s) => JsonValue::String(s.clone()),
226            DatabaseValue::Bytes(b) => JsonValue::Array(
227                b.iter()
228                    .map(|&x| JsonValue::Number(serde_json::Number::from(x)))
229                    .collect(),
230            ),
231            DatabaseValue::Uuid(u) => JsonValue::String(u.to_string()),
232            DatabaseValue::DateTime(dt) => JsonValue::String(dt.to_rfc3339()),
233            DatabaseValue::Date(d) => JsonValue::String(d.to_string()),
234            DatabaseValue::Time(t) => JsonValue::String(t.to_string()),
235            DatabaseValue::Json(j) => j.clone(),
236            DatabaseValue::Array(arr) => {
237                JsonValue::Array(arr.iter().map(|v| v.to_json()).collect())
238            }
239        }
240    }
241
242    /// Create DatabaseValue from JSON value
243    pub fn from_json(json: JsonValue) -> Self {
244        match json {
245            JsonValue::Null => DatabaseValue::Null,
246            JsonValue::Bool(b) => DatabaseValue::Bool(b),
247            JsonValue::Number(n) => {
248                if let Some(i) = n.as_i64() {
249                    if i >= i32::MIN as i64 && i <= i32::MAX as i64 {
250                        DatabaseValue::Int32(i as i32)
251                    } else {
252                        DatabaseValue::Int64(i)
253                    }
254                } else if let Some(f) = n.as_f64() {
255                    DatabaseValue::Float64(f)
256                } else {
257                    DatabaseValue::Null
258                }
259            }
260            JsonValue::String(s) => {
261                // Try to parse as UUID first
262                if let Ok(uuid) = uuid::Uuid::parse_str(&s) {
263                    DatabaseValue::Uuid(uuid)
264                } else if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(&s) {
265                    DatabaseValue::DateTime(dt.with_timezone(&chrono::Utc))
266                } else {
267                    DatabaseValue::String(s)
268                }
269            }
270            JsonValue::Array(arr) => {
271                let db_values: Vec<DatabaseValue> =
272                    arr.into_iter().map(DatabaseValue::from_json).collect();
273                DatabaseValue::Array(db_values)
274            }
275            JsonValue::Object(_) => DatabaseValue::Json(json),
276        }
277    }
278}
279
280impl From<bool> for DatabaseValue {
281    fn from(value: bool) -> Self {
282        DatabaseValue::Bool(value)
283    }
284}
285
286impl From<i32> for DatabaseValue {
287    fn from(value: i32) -> Self {
288        DatabaseValue::Int32(value)
289    }
290}
291
292impl From<i64> for DatabaseValue {
293    fn from(value: i64) -> Self {
294        DatabaseValue::Int64(value)
295    }
296}
297
298impl From<f32> for DatabaseValue {
299    fn from(value: f32) -> Self {
300        DatabaseValue::Float32(value)
301    }
302}
303
304impl From<f64> for DatabaseValue {
305    fn from(value: f64) -> Self {
306        DatabaseValue::Float64(value)
307    }
308}
309
310impl From<String> for DatabaseValue {
311    fn from(value: String) -> Self {
312        DatabaseValue::String(value)
313    }
314}
315
316impl From<&str> for DatabaseValue {
317    fn from(value: &str) -> Self {
318        DatabaseValue::String(value.to_string())
319    }
320}
321
322impl From<Vec<u8>> for DatabaseValue {
323    fn from(value: Vec<u8>) -> Self {
324        DatabaseValue::Bytes(value)
325    }
326}
327
328impl From<uuid::Uuid> for DatabaseValue {
329    fn from(value: uuid::Uuid) -> Self {
330        DatabaseValue::Uuid(value)
331    }
332}
333
334impl From<chrono::DateTime<chrono::Utc>> for DatabaseValue {
335    fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
336        DatabaseValue::DateTime(value)
337    }
338}
339
340impl From<chrono::NaiveDate> for DatabaseValue {
341    fn from(value: chrono::NaiveDate) -> Self {
342        DatabaseValue::Date(value)
343    }
344}
345
346impl From<chrono::NaiveTime> for DatabaseValue {
347    fn from(value: chrono::NaiveTime) -> Self {
348        DatabaseValue::Time(value)
349    }
350}
351
352impl From<JsonValue> for DatabaseValue {
353    fn from(value: JsonValue) -> Self {
354        DatabaseValue::Json(value)
355    }
356}
357
358impl<T> From<Option<T>> for DatabaseValue
359where
360    T: Into<DatabaseValue>,
361{
362    fn from(value: Option<T>) -> Self {
363        match value {
364            Some(v) => v.into(),
365            None => DatabaseValue::Null,
366        }
367    }
368}
369
370/// SQL dialect enumeration for generating database-specific SQL
371#[derive(Debug, Clone, PartialEq, Eq)]
372pub enum SqlDialect {
373    PostgreSQL,
374    MySQL,
375    SQLite,
376}
377
378impl SqlDialect {
379    /// Get the parameter placeholder style for this dialect
380    pub fn parameter_placeholder(&self, index: usize) -> String {
381        match self {
382            SqlDialect::PostgreSQL => format!("${}", index + 1),
383            SqlDialect::MySQL | SqlDialect::SQLite => "?".to_string(),
384        }
385    }
386
387    /// Get the quote character for identifiers in this dialect
388    pub fn identifier_quote(&self) -> char {
389        match self {
390            SqlDialect::PostgreSQL => '"',
391            SqlDialect::MySQL => '`',
392            SqlDialect::SQLite => '"',
393        }
394    }
395
396    /// Check if this dialect supports boolean types
397    pub fn supports_boolean(&self) -> bool {
398        match self {
399            SqlDialect::PostgreSQL | SqlDialect::SQLite => true,
400            SqlDialect::MySQL => false,
401        }
402    }
403
404    /// Check if this dialect supports JSON types
405    pub fn supports_json(&self) -> bool {
406        match self {
407            SqlDialect::PostgreSQL | SqlDialect::MySQL => true,
408            SqlDialect::SQLite => false,
409        }
410    }
411
412    /// Get the current timestamp function for this dialect
413    pub fn current_timestamp(&self) -> &'static str {
414        match self {
415            SqlDialect::PostgreSQL => "NOW()",
416            SqlDialect::MySQL => "CURRENT_TIMESTAMP",
417            SqlDialect::SQLite => "datetime('now')",
418        }
419    }
420
421    /// Get the auto-increment column definition for this dialect
422    pub fn auto_increment(&self) -> &'static str {
423        match self {
424            SqlDialect::PostgreSQL => "SERIAL",
425            SqlDialect::MySQL => "AUTO_INCREMENT",
426            SqlDialect::SQLite => "AUTOINCREMENT",
427        }
428    }
429}
430
431/// Database backend trait that provides database-specific implementations
432#[async_trait]
433pub trait DatabaseBackend: Send + Sync {
434    /// Create a connection pool from a database URL
435    async fn create_pool(
436        &self,
437        database_url: &str,
438        config: DatabasePoolConfig,
439    ) -> OrmResult<Arc<dyn DatabasePool>>;
440
441    /// Get the SQL dialect used by this backend
442    fn sql_dialect(&self) -> SqlDialect;
443
444    /// Get the backend type
445    fn backend_type(&self) -> crate::backends::DatabaseBackendType;
446
447    /// Validate a database URL for this backend
448    fn validate_database_url(&self, url: &str) -> OrmResult<()>;
449
450    /// Parse connection parameters from a database URL
451    fn parse_database_url(&self, url: &str) -> OrmResult<DatabaseConnectionConfig>;
452}
453
454/// Database pool configuration
455#[derive(Debug, Clone)]
456pub struct DatabasePoolConfig {
457    pub max_connections: u32,
458    pub min_connections: u32,
459    pub acquire_timeout_seconds: u64,
460    pub idle_timeout_seconds: Option<u64>,
461    pub max_lifetime_seconds: Option<u64>,
462    pub test_before_acquire: bool,
463}
464
465impl Default for DatabasePoolConfig {
466    fn default() -> Self {
467        Self {
468            max_connections: 10,
469            min_connections: 1,
470            acquire_timeout_seconds: 30,
471            idle_timeout_seconds: Some(600),  // 10 minutes
472            max_lifetime_seconds: Some(1800), // 30 minutes
473            test_before_acquire: true,
474        }
475    }
476}
477
478/// Database connection configuration parsed from URL
479#[derive(Debug, Clone)]
480pub struct DatabaseConnectionConfig {
481    pub host: String,
482    pub port: u16,
483    pub database: String,
484    pub username: Option<String>,
485    pub password: Option<String>,
486    pub ssl_mode: Option<String>,
487    pub additional_params: HashMap<String, String>,
488}
489
490/// Database backend registry for managing multiple backend implementations
491pub struct DatabaseBackendRegistry {
492    backends: HashMap<crate::backends::DatabaseBackendType, Arc<dyn DatabaseBackend>>,
493}
494
495impl DatabaseBackendRegistry {
496    /// Create a new backend registry
497    pub fn new() -> Self {
498        Self {
499            backends: HashMap::new(),
500        }
501    }
502
503    /// Register a database backend
504    pub fn register(
505        &mut self,
506        backend_type: crate::backends::DatabaseBackendType,
507        backend: Arc<dyn DatabaseBackend>,
508    ) {
509        self.backends.insert(backend_type, backend);
510    }
511
512    /// Get a database backend by type
513    pub fn get(
514        &self,
515        backend_type: &crate::backends::DatabaseBackendType,
516    ) -> Option<Arc<dyn DatabaseBackend>> {
517        self.backends.get(backend_type).cloned()
518    }
519
520    /// Create a connection pool using the appropriate backend for the given URL
521    pub async fn create_pool(
522        &self,
523        database_url: &str,
524        config: DatabasePoolConfig,
525    ) -> OrmResult<Arc<dyn DatabasePool>> {
526        let backend_type = self.detect_backend_from_url(database_url)?;
527        let backend = self.get(&backend_type).ok_or_else(|| {
528            OrmError::Connection(format!("No backend registered for {}", backend_type))
529        })?;
530
531        backend.create_pool(database_url, config).await
532    }
533
534    /// Detect database backend type from URL
535    fn detect_backend_from_url(
536        &self,
537        url: &str,
538    ) -> OrmResult<crate::backends::DatabaseBackendType> {
539        if url.starts_with("postgresql://") || url.starts_with("postgres://") {
540            Ok(crate::backends::DatabaseBackendType::PostgreSQL)
541        } else if url.starts_with("mysql://") {
542            Ok(crate::backends::DatabaseBackendType::MySQL)
543        } else if url.starts_with("sqlite://") || url.starts_with("file:") {
544            Ok(crate::backends::DatabaseBackendType::SQLite)
545        } else {
546            Err(OrmError::Connection(format!(
547                "Unable to detect database backend from URL: {}",
548                url
549            )))
550        }
551    }
552
553    /// List all registered backend types
554    pub fn registered_backends(&self) -> Vec<crate::backends::DatabaseBackendType> {
555        self.backends.keys().cloned().collect()
556    }
557}
558
559impl Default for DatabaseBackendRegistry {
560    fn default() -> Self {
561        Self::new()
562    }
563}