1use std::collections::HashMap;
8use std::sync::Arc;
9use async_trait::async_trait;
10use serde_json::Value as JsonValue;
11use crate::error::{OrmResult, OrmError};
12
13#[async_trait]
15pub trait DatabaseConnection: Send + Sync {
16 async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
18
19 async fn fetch_all(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
21
22 async fn fetch_optional(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
24
25 async fn begin_transaction(&mut self) -> OrmResult<Box<dyn DatabaseTransaction>>;
27
28 async fn close(&mut self) -> OrmResult<()>;
30}
31
32#[async_trait]
34pub trait DatabaseTransaction: Send + Sync {
35 async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
37
38 async fn fetch_all(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
40
41 async fn fetch_optional(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
43
44 async fn commit(self: Box<Self>) -> OrmResult<()>;
46
47 async fn rollback(self: Box<Self>) -> OrmResult<()>;
49}
50
51#[async_trait]
53pub trait DatabasePool: Send + Sync {
54 async fn acquire(&self) -> OrmResult<Box<dyn DatabaseConnection>>;
56
57 async fn begin_transaction(&self) -> OrmResult<Box<dyn DatabaseTransaction>>;
59
60 async fn execute(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
62
63 async fn fetch_all(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
65
66 async fn fetch_optional(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
68
69 async fn close(&self) -> OrmResult<()>;
71
72 fn stats(&self) -> DatabasePoolStats;
74
75 async fn health_check(&self) -> OrmResult<std::time::Duration>;
77}
78
79#[derive(Debug, Clone)]
81pub struct DatabasePoolStats {
82 pub total_connections: u32,
83 pub idle_connections: u32,
84 pub active_connections: u32,
85}
86
87pub trait DatabaseRow: Send + Sync {
89 fn get_by_index(&self, index: usize) -> OrmResult<DatabaseValue>;
91
92 fn get_by_name(&self, name: &str) -> OrmResult<DatabaseValue>;
94
95 fn column_count(&self) -> usize;
97
98 fn column_names(&self) -> Vec<String>;
100
101 fn to_json(&self) -> OrmResult<JsonValue>;
103
104 fn to_map(&self) -> OrmResult<HashMap<String, DatabaseValue>>;
106}
107
108pub trait DatabaseRowExt {
110 fn get<T>(&self, column: &str) -> Result<T, crate::error::ModelError>
112 where
113 T: for<'de> serde::Deserialize<'de>;
114
115 fn try_get<T>(&self, column: &str) -> Result<Option<T>, crate::error::ModelError>
117 where
118 T: for<'de> serde::Deserialize<'de>;
119}
120
121impl<R: DatabaseRow + ?Sized> DatabaseRowExt for R {
122 fn get<T>(&self, column: &str) -> Result<T, crate::error::ModelError>
123 where
124 T: for<'de> serde::Deserialize<'de>,
125 {
126 let db_value = self.get_by_name(column)?;
127
128 let json_value = db_value.to_json();
129 serde_json::from_value(json_value)
130 .map_err(|e| crate::error::ModelError::Serialization(format!("Failed to deserialize column '{}': {}", column, e)))
131 }
132
133 fn try_get<T>(&self, column: &str) -> Result<Option<T>, crate::error::ModelError>
134 where
135 T: for<'de> serde::Deserialize<'de>,
136 {
137 match self.get_by_name(column) {
138 Ok(db_value) => {
139 if db_value.is_null() {
140 Ok(None)
141 } else {
142 let json_value = db_value.to_json();
143 let parsed: T = serde_json::from_value(json_value)
144 .map_err(|e| crate::error::ModelError::Serialization(format!("Failed to deserialize column '{}': {}", column, e)))?;
145 Ok(Some(parsed))
146 }
147 },
148 Err(crate::error::ModelError::ColumnNotFound(_)) => Ok(None),
149 Err(e) => Err(e), }
151 }
152}
153
154#[derive(Debug, Clone, PartialEq)]
156pub enum DatabaseValue {
157 Null,
158 Bool(bool),
159 Int32(i32),
160 Int64(i64),
161 Float32(f32),
162 Float64(f64),
163 String(String),
164 Bytes(Vec<u8>),
165 Uuid(uuid::Uuid),
166 DateTime(chrono::DateTime<chrono::Utc>),
167 Date(chrono::NaiveDate),
168 Time(chrono::NaiveTime),
169 Json(JsonValue),
170 Array(Vec<DatabaseValue>),
171}
172
173impl DatabaseValue {
174 pub fn is_null(&self) -> bool {
176 matches!(self, DatabaseValue::Null)
177 }
178
179 pub fn to_json(&self) -> JsonValue {
181 match self {
182 DatabaseValue::Null => JsonValue::Null,
183 DatabaseValue::Bool(b) => JsonValue::Bool(*b),
184 DatabaseValue::Int32(i) => JsonValue::Number(serde_json::Number::from(*i)),
185 DatabaseValue::Int64(i) => JsonValue::Number(serde_json::Number::from(*i)),
186 DatabaseValue::Float32(f) => {
187 JsonValue::Number(serde_json::Number::from_f64(*f as f64).unwrap_or_else(|| serde_json::Number::from(0)))
188 },
189 DatabaseValue::Float64(f) => {
190 serde_json::Number::from_f64(*f)
191 .map(JsonValue::Number)
192 .unwrap_or(JsonValue::Null)
193 },
194 DatabaseValue::String(s) => JsonValue::String(s.clone()),
195 DatabaseValue::Bytes(b) => JsonValue::Array(b.iter().map(|&x| JsonValue::Number(serde_json::Number::from(x))).collect()),
196 DatabaseValue::Uuid(u) => JsonValue::String(u.to_string()),
197 DatabaseValue::DateTime(dt) => JsonValue::String(dt.to_rfc3339()),
198 DatabaseValue::Date(d) => JsonValue::String(d.to_string()),
199 DatabaseValue::Time(t) => JsonValue::String(t.to_string()),
200 DatabaseValue::Json(j) => j.clone(),
201 DatabaseValue::Array(arr) => JsonValue::Array(arr.iter().map(|v| v.to_json()).collect()),
202 }
203 }
204
205 pub fn from_json(json: JsonValue) -> Self {
207 match json {
208 JsonValue::Null => DatabaseValue::Null,
209 JsonValue::Bool(b) => DatabaseValue::Bool(b),
210 JsonValue::Number(n) => {
211 if let Some(i) = n.as_i64() {
212 if i >= i32::MIN as i64 && i <= i32::MAX as i64 {
213 DatabaseValue::Int32(i as i32)
214 } else {
215 DatabaseValue::Int64(i)
216 }
217 } else if let Some(f) = n.as_f64() {
218 DatabaseValue::Float64(f)
219 } else {
220 DatabaseValue::Null
221 }
222 },
223 JsonValue::String(s) => {
224 if let Ok(uuid) = uuid::Uuid::parse_str(&s) {
226 DatabaseValue::Uuid(uuid)
227 } else if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(&s) {
228 DatabaseValue::DateTime(dt.with_timezone(&chrono::Utc))
229 } else {
230 DatabaseValue::String(s)
231 }
232 },
233 JsonValue::Array(arr) => {
234 let db_values: Vec<DatabaseValue> = arr.into_iter().map(DatabaseValue::from_json).collect();
235 DatabaseValue::Array(db_values)
236 },
237 JsonValue::Object(_) => DatabaseValue::Json(json),
238 }
239 }
240}
241
242impl From<bool> for DatabaseValue {
243 fn from(value: bool) -> Self {
244 DatabaseValue::Bool(value)
245 }
246}
247
248impl From<i32> for DatabaseValue {
249 fn from(value: i32) -> Self {
250 DatabaseValue::Int32(value)
251 }
252}
253
254impl From<i64> for DatabaseValue {
255 fn from(value: i64) -> Self {
256 DatabaseValue::Int64(value)
257 }
258}
259
260impl From<f32> for DatabaseValue {
261 fn from(value: f32) -> Self {
262 DatabaseValue::Float32(value)
263 }
264}
265
266impl From<f64> for DatabaseValue {
267 fn from(value: f64) -> Self {
268 DatabaseValue::Float64(value)
269 }
270}
271
272impl From<String> for DatabaseValue {
273 fn from(value: String) -> Self {
274 DatabaseValue::String(value)
275 }
276}
277
278impl From<&str> for DatabaseValue {
279 fn from(value: &str) -> Self {
280 DatabaseValue::String(value.to_string())
281 }
282}
283
284impl From<Vec<u8>> for DatabaseValue {
285 fn from(value: Vec<u8>) -> Self {
286 DatabaseValue::Bytes(value)
287 }
288}
289
290impl From<uuid::Uuid> for DatabaseValue {
291 fn from(value: uuid::Uuid) -> Self {
292 DatabaseValue::Uuid(value)
293 }
294}
295
296impl From<chrono::DateTime<chrono::Utc>> for DatabaseValue {
297 fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
298 DatabaseValue::DateTime(value)
299 }
300}
301
302impl From<chrono::NaiveDate> for DatabaseValue {
303 fn from(value: chrono::NaiveDate) -> Self {
304 DatabaseValue::Date(value)
305 }
306}
307
308impl From<chrono::NaiveTime> for DatabaseValue {
309 fn from(value: chrono::NaiveTime) -> Self {
310 DatabaseValue::Time(value)
311 }
312}
313
314impl From<JsonValue> for DatabaseValue {
315 fn from(value: JsonValue) -> Self {
316 DatabaseValue::Json(value)
317 }
318}
319
320impl<T> From<Option<T>> for DatabaseValue
321where
322 T: Into<DatabaseValue>,
323{
324 fn from(value: Option<T>) -> Self {
325 match value {
326 Some(v) => v.into(),
327 None => DatabaseValue::Null,
328 }
329 }
330}
331
332#[derive(Debug, Clone, PartialEq, Eq)]
334pub enum SqlDialect {
335 PostgreSQL,
336 MySQL,
337 SQLite,
338}
339
340impl SqlDialect {
341 pub fn parameter_placeholder(&self, index: usize) -> String {
343 match self {
344 SqlDialect::PostgreSQL => format!("${}", index + 1),
345 SqlDialect::MySQL | SqlDialect::SQLite => "?".to_string(),
346 }
347 }
348
349 pub fn identifier_quote(&self) -> char {
351 match self {
352 SqlDialect::PostgreSQL => '"',
353 SqlDialect::MySQL => '`',
354 SqlDialect::SQLite => '"',
355 }
356 }
357
358 pub fn supports_boolean(&self) -> bool {
360 match self {
361 SqlDialect::PostgreSQL | SqlDialect::SQLite => true,
362 SqlDialect::MySQL => false,
363 }
364 }
365
366 pub fn supports_json(&self) -> bool {
368 match self {
369 SqlDialect::PostgreSQL | SqlDialect::MySQL => true,
370 SqlDialect::SQLite => false,
371 }
372 }
373
374 pub fn current_timestamp(&self) -> &'static str {
376 match self {
377 SqlDialect::PostgreSQL => "NOW()",
378 SqlDialect::MySQL => "CURRENT_TIMESTAMP",
379 SqlDialect::SQLite => "datetime('now')",
380 }
381 }
382
383 pub fn auto_increment(&self) -> &'static str {
385 match self {
386 SqlDialect::PostgreSQL => "SERIAL",
387 SqlDialect::MySQL => "AUTO_INCREMENT",
388 SqlDialect::SQLite => "AUTOINCREMENT",
389 }
390 }
391}
392
393#[async_trait]
395pub trait DatabaseBackend: Send + Sync {
396 async fn create_pool(&self, database_url: &str, config: DatabasePoolConfig) -> OrmResult<Arc<dyn DatabasePool>>;
398
399 fn sql_dialect(&self) -> SqlDialect;
401
402 fn backend_type(&self) -> crate::backends::DatabaseBackendType;
404
405 fn validate_database_url(&self, url: &str) -> OrmResult<()>;
407
408 fn parse_database_url(&self, url: &str) -> OrmResult<DatabaseConnectionConfig>;
410}
411
412#[derive(Debug, Clone)]
414pub struct DatabasePoolConfig {
415 pub max_connections: u32,
416 pub min_connections: u32,
417 pub acquire_timeout_seconds: u64,
418 pub idle_timeout_seconds: Option<u64>,
419 pub max_lifetime_seconds: Option<u64>,
420 pub test_before_acquire: bool,
421}
422
423impl Default for DatabasePoolConfig {
424 fn default() -> Self {
425 Self {
426 max_connections: 10,
427 min_connections: 1,
428 acquire_timeout_seconds: 30,
429 idle_timeout_seconds: Some(600), max_lifetime_seconds: Some(1800), test_before_acquire: true,
432 }
433 }
434}
435
436#[derive(Debug, Clone)]
438pub struct DatabaseConnectionConfig {
439 pub host: String,
440 pub port: u16,
441 pub database: String,
442 pub username: Option<String>,
443 pub password: Option<String>,
444 pub ssl_mode: Option<String>,
445 pub additional_params: HashMap<String, String>,
446}
447
448pub struct DatabaseBackendRegistry {
450 backends: HashMap<crate::backends::DatabaseBackendType, Arc<dyn DatabaseBackend>>,
451}
452
453impl DatabaseBackendRegistry {
454 pub fn new() -> Self {
456 Self {
457 backends: HashMap::new(),
458 }
459 }
460
461 pub fn register(&mut self, backend_type: crate::backends::DatabaseBackendType, backend: Arc<dyn DatabaseBackend>) {
463 self.backends.insert(backend_type, backend);
464 }
465
466 pub fn get(&self, backend_type: &crate::backends::DatabaseBackendType) -> Option<Arc<dyn DatabaseBackend>> {
468 self.backends.get(backend_type).cloned()
469 }
470
471 pub async fn create_pool(&self, database_url: &str, config: DatabasePoolConfig) -> OrmResult<Arc<dyn DatabasePool>> {
473 let backend_type = self.detect_backend_from_url(database_url)?;
474 let backend = self.get(&backend_type)
475 .ok_or_else(|| OrmError::Connection(format!("No backend registered for {}", backend_type)))?;
476
477 backend.create_pool(database_url, config).await
478 }
479
480 fn detect_backend_from_url(&self, url: &str) -> OrmResult<crate::backends::DatabaseBackendType> {
482 if url.starts_with("postgresql://") || url.starts_with("postgres://") {
483 Ok(crate::backends::DatabaseBackendType::PostgreSQL)
484 } else if url.starts_with("mysql://") {
485 Ok(crate::backends::DatabaseBackendType::MySQL)
486 } else if url.starts_with("sqlite://") || url.starts_with("file:") {
487 Ok(crate::backends::DatabaseBackendType::SQLite)
488 } else {
489 Err(OrmError::Connection(format!("Unable to detect database backend from URL: {}", url)))
490 }
491 }
492
493 pub fn registered_backends(&self) -> Vec<crate::backends::DatabaseBackendType> {
495 self.backends.keys().cloned().collect()
496 }
497}
498
499impl Default for DatabaseBackendRegistry {
500 fn default() -> Self {
501 Self::new()
502 }
503}