1use crate::error::{OrmError, OrmResult};
8use async_trait::async_trait;
9use serde_json::Value as JsonValue;
10use std::collections::HashMap;
11use std::sync::Arc;
12
13#[async_trait]
15pub trait DatabaseConnection: Send + Sync {
16 async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
18
19 async fn fetch_all(
21 &mut self,
22 sql: &str,
23 params: &[DatabaseValue],
24 ) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
25
26 async fn fetch_optional(
28 &mut self,
29 sql: &str,
30 params: &[DatabaseValue],
31 ) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
32
33 async fn begin_transaction(&mut self) -> OrmResult<Box<dyn DatabaseTransaction>>;
35
36 async fn close(&mut self) -> OrmResult<()>;
38}
39
40#[async_trait]
42pub trait DatabaseTransaction: Send + Sync {
43 async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
45
46 async fn fetch_all(
48 &mut self,
49 sql: &str,
50 params: &[DatabaseValue],
51 ) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
52
53 async fn fetch_optional(
55 &mut self,
56 sql: &str,
57 params: &[DatabaseValue],
58 ) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
59
60 async fn commit(self: Box<Self>) -> OrmResult<()>;
62
63 async fn rollback(self: Box<Self>) -> OrmResult<()>;
65}
66
67#[async_trait]
69pub trait DatabasePool: Send + Sync {
70 async fn acquire(&self) -> OrmResult<Box<dyn DatabaseConnection>>;
72
73 async fn begin_transaction(&self) -> OrmResult<Box<dyn DatabaseTransaction>>;
75
76 async fn execute(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
78
79 async fn fetch_all(
81 &self,
82 sql: &str,
83 params: &[DatabaseValue],
84 ) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
85
86 async fn fetch_optional(
88 &self,
89 sql: &str,
90 params: &[DatabaseValue],
91 ) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
92
93 async fn close(&self) -> OrmResult<()>;
95
96 fn stats(&self) -> DatabasePoolStats;
98
99 async fn health_check(&self) -> OrmResult<std::time::Duration>;
101}
102
103#[derive(Debug, Clone)]
105pub struct DatabasePoolStats {
106 pub total_connections: u32,
107 pub idle_connections: u32,
108 pub active_connections: u32,
109}
110
111pub trait DatabaseRow: Send + Sync {
113 fn get_by_index(&self, index: usize) -> OrmResult<DatabaseValue>;
115
116 fn get_by_name(&self, name: &str) -> OrmResult<DatabaseValue>;
118
119 fn column_count(&self) -> usize;
121
122 fn column_names(&self) -> Vec<String>;
124
125 fn to_json(&self) -> OrmResult<JsonValue>;
127
128 fn to_map(&self) -> OrmResult<HashMap<String, DatabaseValue>>;
130}
131
132pub trait DatabaseRowExt {
134 fn get<T>(&self, column: &str) -> Result<T, crate::error::ModelError>
136 where
137 T: for<'de> serde::Deserialize<'de>;
138
139 fn try_get<T>(&self, column: &str) -> Result<Option<T>, crate::error::ModelError>
141 where
142 T: for<'de> serde::Deserialize<'de>;
143}
144
145impl<R: DatabaseRow + ?Sized> DatabaseRowExt for R {
146 fn get<T>(&self, column: &str) -> Result<T, crate::error::ModelError>
147 where
148 T: for<'de> serde::Deserialize<'de>,
149 {
150 let db_value = self.get_by_name(column)?;
151
152 let json_value = db_value.to_json();
153 serde_json::from_value(json_value).map_err(|e| {
154 crate::error::ModelError::Serialization(format!(
155 "Failed to deserialize column '{}': {}",
156 column, e
157 ))
158 })
159 }
160
161 fn try_get<T>(&self, column: &str) -> Result<Option<T>, crate::error::ModelError>
162 where
163 T: for<'de> serde::Deserialize<'de>,
164 {
165 match self.get_by_name(column) {
166 Ok(db_value) => {
167 if db_value.is_null() {
168 Ok(None)
169 } else {
170 let json_value = db_value.to_json();
171 let parsed: T = serde_json::from_value(json_value).map_err(|e| {
172 crate::error::ModelError::Serialization(format!(
173 "Failed to deserialize column '{}': {}",
174 column, e
175 ))
176 })?;
177 Ok(Some(parsed))
178 }
179 }
180 Err(crate::error::ModelError::ColumnNotFound(_)) => Ok(None),
181 Err(e) => Err(e), }
183 }
184}
185
186#[derive(Debug, Clone, PartialEq)]
188pub enum DatabaseValue {
189 Null,
190 Bool(bool),
191 Int32(i32),
192 Int64(i64),
193 Float32(f32),
194 Float64(f64),
195 String(String),
196 Bytes(Vec<u8>),
197 Uuid(uuid::Uuid),
198 DateTime(chrono::DateTime<chrono::Utc>),
199 Date(chrono::NaiveDate),
200 Time(chrono::NaiveTime),
201 Json(JsonValue),
202 Array(Vec<DatabaseValue>),
203}
204
205impl DatabaseValue {
206 pub fn is_null(&self) -> bool {
208 matches!(self, DatabaseValue::Null)
209 }
210
211 pub fn to_json(&self) -> JsonValue {
213 match self {
214 DatabaseValue::Null => JsonValue::Null,
215 DatabaseValue::Bool(b) => JsonValue::Bool(*b),
216 DatabaseValue::Int32(i) => JsonValue::Number(serde_json::Number::from(*i)),
217 DatabaseValue::Int64(i) => JsonValue::Number(serde_json::Number::from(*i)),
218 DatabaseValue::Float32(f) => JsonValue::Number(
219 serde_json::Number::from_f64(*f as f64)
220 .unwrap_or_else(|| serde_json::Number::from(0)),
221 ),
222 DatabaseValue::Float64(f) => serde_json::Number::from_f64(*f)
223 .map(JsonValue::Number)
224 .unwrap_or(JsonValue::Null),
225 DatabaseValue::String(s) => JsonValue::String(s.clone()),
226 DatabaseValue::Bytes(b) => JsonValue::Array(
227 b.iter()
228 .map(|&x| JsonValue::Number(serde_json::Number::from(x)))
229 .collect(),
230 ),
231 DatabaseValue::Uuid(u) => JsonValue::String(u.to_string()),
232 DatabaseValue::DateTime(dt) => JsonValue::String(dt.to_rfc3339()),
233 DatabaseValue::Date(d) => JsonValue::String(d.to_string()),
234 DatabaseValue::Time(t) => JsonValue::String(t.to_string()),
235 DatabaseValue::Json(j) => j.clone(),
236 DatabaseValue::Array(arr) => {
237 JsonValue::Array(arr.iter().map(|v| v.to_json()).collect())
238 }
239 }
240 }
241
242 pub fn from_json(json: JsonValue) -> Self {
244 match json {
245 JsonValue::Null => DatabaseValue::Null,
246 JsonValue::Bool(b) => DatabaseValue::Bool(b),
247 JsonValue::Number(n) => {
248 if let Some(i) = n.as_i64() {
249 if i >= i32::MIN as i64 && i <= i32::MAX as i64 {
250 DatabaseValue::Int32(i as i32)
251 } else {
252 DatabaseValue::Int64(i)
253 }
254 } else if let Some(f) = n.as_f64() {
255 DatabaseValue::Float64(f)
256 } else {
257 DatabaseValue::Null
258 }
259 }
260 JsonValue::String(s) => {
261 if let Ok(uuid) = uuid::Uuid::parse_str(&s) {
263 DatabaseValue::Uuid(uuid)
264 } else if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(&s) {
265 DatabaseValue::DateTime(dt.with_timezone(&chrono::Utc))
266 } else {
267 DatabaseValue::String(s)
268 }
269 }
270 JsonValue::Array(arr) => {
271 let db_values: Vec<DatabaseValue> =
272 arr.into_iter().map(DatabaseValue::from_json).collect();
273 DatabaseValue::Array(db_values)
274 }
275 JsonValue::Object(_) => DatabaseValue::Json(json),
276 }
277 }
278}
279
280impl From<bool> for DatabaseValue {
281 fn from(value: bool) -> Self {
282 DatabaseValue::Bool(value)
283 }
284}
285
286impl From<i32> for DatabaseValue {
287 fn from(value: i32) -> Self {
288 DatabaseValue::Int32(value)
289 }
290}
291
292impl From<i64> for DatabaseValue {
293 fn from(value: i64) -> Self {
294 DatabaseValue::Int64(value)
295 }
296}
297
298impl From<f32> for DatabaseValue {
299 fn from(value: f32) -> Self {
300 DatabaseValue::Float32(value)
301 }
302}
303
304impl From<f64> for DatabaseValue {
305 fn from(value: f64) -> Self {
306 DatabaseValue::Float64(value)
307 }
308}
309
310impl From<String> for DatabaseValue {
311 fn from(value: String) -> Self {
312 DatabaseValue::String(value)
313 }
314}
315
316impl From<&str> for DatabaseValue {
317 fn from(value: &str) -> Self {
318 DatabaseValue::String(value.to_string())
319 }
320}
321
322impl From<Vec<u8>> for DatabaseValue {
323 fn from(value: Vec<u8>) -> Self {
324 DatabaseValue::Bytes(value)
325 }
326}
327
328impl From<uuid::Uuid> for DatabaseValue {
329 fn from(value: uuid::Uuid) -> Self {
330 DatabaseValue::Uuid(value)
331 }
332}
333
334impl From<chrono::DateTime<chrono::Utc>> for DatabaseValue {
335 fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
336 DatabaseValue::DateTime(value)
337 }
338}
339
340impl From<chrono::NaiveDate> for DatabaseValue {
341 fn from(value: chrono::NaiveDate) -> Self {
342 DatabaseValue::Date(value)
343 }
344}
345
346impl From<chrono::NaiveTime> for DatabaseValue {
347 fn from(value: chrono::NaiveTime) -> Self {
348 DatabaseValue::Time(value)
349 }
350}
351
352impl From<JsonValue> for DatabaseValue {
353 fn from(value: JsonValue) -> Self {
354 DatabaseValue::Json(value)
355 }
356}
357
358impl<T> From<Option<T>> for DatabaseValue
359where
360 T: Into<DatabaseValue>,
361{
362 fn from(value: Option<T>) -> Self {
363 match value {
364 Some(v) => v.into(),
365 None => DatabaseValue::Null,
366 }
367 }
368}
369
370#[derive(Debug, Clone, PartialEq, Eq)]
372pub enum SqlDialect {
373 PostgreSQL,
374 MySQL,
375 SQLite,
376}
377
378impl SqlDialect {
379 pub fn parameter_placeholder(&self, index: usize) -> String {
381 match self {
382 SqlDialect::PostgreSQL => format!("${}", index + 1),
383 SqlDialect::MySQL | SqlDialect::SQLite => "?".to_string(),
384 }
385 }
386
387 pub fn identifier_quote(&self) -> char {
389 match self {
390 SqlDialect::PostgreSQL => '"',
391 SqlDialect::MySQL => '`',
392 SqlDialect::SQLite => '"',
393 }
394 }
395
396 pub fn supports_boolean(&self) -> bool {
398 match self {
399 SqlDialect::PostgreSQL | SqlDialect::SQLite => true,
400 SqlDialect::MySQL => false,
401 }
402 }
403
404 pub fn supports_json(&self) -> bool {
406 match self {
407 SqlDialect::PostgreSQL | SqlDialect::MySQL => true,
408 SqlDialect::SQLite => false,
409 }
410 }
411
412 pub fn current_timestamp(&self) -> &'static str {
414 match self {
415 SqlDialect::PostgreSQL => "NOW()",
416 SqlDialect::MySQL => "CURRENT_TIMESTAMP",
417 SqlDialect::SQLite => "datetime('now')",
418 }
419 }
420
421 pub fn auto_increment(&self) -> &'static str {
423 match self {
424 SqlDialect::PostgreSQL => "SERIAL",
425 SqlDialect::MySQL => "AUTO_INCREMENT",
426 SqlDialect::SQLite => "AUTOINCREMENT",
427 }
428 }
429}
430
431#[async_trait]
433pub trait DatabaseBackend: Send + Sync {
434 async fn create_pool(
436 &self,
437 database_url: &str,
438 config: DatabasePoolConfig,
439 ) -> OrmResult<Arc<dyn DatabasePool>>;
440
441 fn sql_dialect(&self) -> SqlDialect;
443
444 fn backend_type(&self) -> crate::backends::DatabaseBackendType;
446
447 fn validate_database_url(&self, url: &str) -> OrmResult<()>;
449
450 fn parse_database_url(&self, url: &str) -> OrmResult<DatabaseConnectionConfig>;
452}
453
454#[derive(Debug, Clone)]
456pub struct DatabasePoolConfig {
457 pub max_connections: u32,
458 pub min_connections: u32,
459 pub acquire_timeout_seconds: u64,
460 pub idle_timeout_seconds: Option<u64>,
461 pub max_lifetime_seconds: Option<u64>,
462 pub test_before_acquire: bool,
463}
464
465impl Default for DatabasePoolConfig {
466 fn default() -> Self {
467 Self {
468 max_connections: 10,
469 min_connections: 1,
470 acquire_timeout_seconds: 30,
471 idle_timeout_seconds: Some(600), max_lifetime_seconds: Some(1800), test_before_acquire: true,
474 }
475 }
476}
477
478#[derive(Debug, Clone)]
480pub struct DatabaseConnectionConfig {
481 pub host: String,
482 pub port: u16,
483 pub database: String,
484 pub username: Option<String>,
485 pub password: Option<String>,
486 pub ssl_mode: Option<String>,
487 pub additional_params: HashMap<String, String>,
488}
489
490pub struct DatabaseBackendRegistry {
492 backends: HashMap<crate::backends::DatabaseBackendType, Arc<dyn DatabaseBackend>>,
493}
494
495impl DatabaseBackendRegistry {
496 pub fn new() -> Self {
498 Self {
499 backends: HashMap::new(),
500 }
501 }
502
503 pub fn register(
505 &mut self,
506 backend_type: crate::backends::DatabaseBackendType,
507 backend: Arc<dyn DatabaseBackend>,
508 ) {
509 self.backends.insert(backend_type, backend);
510 }
511
512 pub fn get(
514 &self,
515 backend_type: &crate::backends::DatabaseBackendType,
516 ) -> Option<Arc<dyn DatabaseBackend>> {
517 self.backends.get(backend_type).cloned()
518 }
519
520 pub async fn create_pool(
522 &self,
523 database_url: &str,
524 config: DatabasePoolConfig,
525 ) -> OrmResult<Arc<dyn DatabasePool>> {
526 let backend_type = self.detect_backend_from_url(database_url)?;
527 let backend = self.get(&backend_type).ok_or_else(|| {
528 OrmError::Connection(format!("No backend registered for {}", backend_type))
529 })?;
530
531 backend.create_pool(database_url, config).await
532 }
533
534 fn detect_backend_from_url(
536 &self,
537 url: &str,
538 ) -> OrmResult<crate::backends::DatabaseBackendType> {
539 if url.starts_with("postgresql://") || url.starts_with("postgres://") {
540 Ok(crate::backends::DatabaseBackendType::PostgreSQL)
541 } else if url.starts_with("mysql://") {
542 Ok(crate::backends::DatabaseBackendType::MySQL)
543 } else if url.starts_with("sqlite://") || url.starts_with("file:") {
544 Ok(crate::backends::DatabaseBackendType::SQLite)
545 } else {
546 Err(OrmError::Connection(format!(
547 "Unable to detect database backend from URL: {}",
548 url
549 )))
550 }
551 }
552
553 pub fn registered_backends(&self) -> Vec<crate::backends::DatabaseBackendType> {
555 self.backends.keys().cloned().collect()
556 }
557}
558
559impl Default for DatabaseBackendRegistry {
560 fn default() -> Self {
561 Self::new()
562 }
563}