1use std::collections::HashMap;
8use std::sync::Arc;
9use async_trait::async_trait;
10use serde_json::Value as JsonValue;
11use crate::error::{OrmResult, OrmError};
12
13#[async_trait]
15pub trait DatabaseConnection: Send + Sync {
16 async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
18
19 async fn fetch_all(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
21
22 async fn fetch_optional(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
24
25 async fn begin_transaction(&mut self) -> OrmResult<Box<dyn DatabaseTransaction>>;
27
28 async fn close(&mut self) -> OrmResult<()>;
30}
31
32#[async_trait]
34pub trait DatabaseTransaction: Send + Sync {
35 async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
37
38 async fn fetch_all(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
40
41 async fn fetch_optional(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
43
44 async fn commit(self: Box<Self>) -> OrmResult<()>;
46
47 async fn rollback(self: Box<Self>) -> OrmResult<()>;
49}
50
51#[async_trait]
53pub trait DatabasePool: Send + Sync {
54 async fn acquire(&self) -> OrmResult<Box<dyn DatabaseConnection>>;
56
57 async fn begin_transaction(&self) -> OrmResult<Box<dyn DatabaseTransaction>>;
59
60 async fn execute(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64>;
62
63 async fn fetch_all(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Vec<Box<dyn DatabaseRow>>>;
65
66 async fn fetch_optional(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Option<Box<dyn DatabaseRow>>>;
68
69 async fn close(&self) -> OrmResult<()>;
71
72 fn stats(&self) -> DatabasePoolStats;
74
75 async fn health_check(&self) -> OrmResult<std::time::Duration>;
77}
78
79#[derive(Debug, Clone)]
81pub struct DatabasePoolStats {
82 pub total_connections: u32,
83 pub idle_connections: u32,
84 pub active_connections: u32,
85}
86
87pub trait DatabaseRow: Send + Sync {
89 fn get_by_index(&self, index: usize) -> OrmResult<DatabaseValue>;
91
92 fn get_by_name(&self, name: &str) -> OrmResult<DatabaseValue>;
94
95 fn column_count(&self) -> usize;
97
98 fn column_names(&self) -> Vec<String>;
100
101 fn to_json(&self) -> OrmResult<JsonValue>;
103
104 fn to_map(&self) -> OrmResult<HashMap<String, DatabaseValue>>;
106}
107
108#[derive(Debug, Clone, PartialEq)]
110pub enum DatabaseValue {
111 Null,
112 Bool(bool),
113 Int32(i32),
114 Int64(i64),
115 Float32(f32),
116 Float64(f64),
117 String(String),
118 Bytes(Vec<u8>),
119 Uuid(uuid::Uuid),
120 DateTime(chrono::DateTime<chrono::Utc>),
121 Date(chrono::NaiveDate),
122 Time(chrono::NaiveTime),
123 Json(JsonValue),
124 Array(Vec<DatabaseValue>),
125}
126
127impl DatabaseValue {
128 pub fn is_null(&self) -> bool {
130 matches!(self, DatabaseValue::Null)
131 }
132
133 pub fn to_json(&self) -> JsonValue {
135 match self {
136 DatabaseValue::Null => JsonValue::Null,
137 DatabaseValue::Bool(b) => JsonValue::Bool(*b),
138 DatabaseValue::Int32(i) => JsonValue::Number(serde_json::Number::from(*i)),
139 DatabaseValue::Int64(i) => JsonValue::Number(serde_json::Number::from(*i)),
140 DatabaseValue::Float32(f) => {
141 JsonValue::Number(serde_json::Number::from_f64(*f as f64).unwrap_or_else(|| serde_json::Number::from(0)))
142 },
143 DatabaseValue::Float64(f) => {
144 serde_json::Number::from_f64(*f)
145 .map(JsonValue::Number)
146 .unwrap_or(JsonValue::Null)
147 },
148 DatabaseValue::String(s) => JsonValue::String(s.clone()),
149 DatabaseValue::Bytes(b) => JsonValue::Array(b.iter().map(|&x| JsonValue::Number(serde_json::Number::from(x))).collect()),
150 DatabaseValue::Uuid(u) => JsonValue::String(u.to_string()),
151 DatabaseValue::DateTime(dt) => JsonValue::String(dt.to_rfc3339()),
152 DatabaseValue::Date(d) => JsonValue::String(d.to_string()),
153 DatabaseValue::Time(t) => JsonValue::String(t.to_string()),
154 DatabaseValue::Json(j) => j.clone(),
155 DatabaseValue::Array(arr) => JsonValue::Array(arr.iter().map(|v| v.to_json()).collect()),
156 }
157 }
158}
159
160impl From<bool> for DatabaseValue {
161 fn from(value: bool) -> Self {
162 DatabaseValue::Bool(value)
163 }
164}
165
166impl From<i32> for DatabaseValue {
167 fn from(value: i32) -> Self {
168 DatabaseValue::Int32(value)
169 }
170}
171
172impl From<i64> for DatabaseValue {
173 fn from(value: i64) -> Self {
174 DatabaseValue::Int64(value)
175 }
176}
177
178impl From<f32> for DatabaseValue {
179 fn from(value: f32) -> Self {
180 DatabaseValue::Float32(value)
181 }
182}
183
184impl From<f64> for DatabaseValue {
185 fn from(value: f64) -> Self {
186 DatabaseValue::Float64(value)
187 }
188}
189
190impl From<String> for DatabaseValue {
191 fn from(value: String) -> Self {
192 DatabaseValue::String(value)
193 }
194}
195
196impl From<&str> for DatabaseValue {
197 fn from(value: &str) -> Self {
198 DatabaseValue::String(value.to_string())
199 }
200}
201
202impl From<Vec<u8>> for DatabaseValue {
203 fn from(value: Vec<u8>) -> Self {
204 DatabaseValue::Bytes(value)
205 }
206}
207
208impl From<uuid::Uuid> for DatabaseValue {
209 fn from(value: uuid::Uuid) -> Self {
210 DatabaseValue::Uuid(value)
211 }
212}
213
214impl From<chrono::DateTime<chrono::Utc>> for DatabaseValue {
215 fn from(value: chrono::DateTime<chrono::Utc>) -> Self {
216 DatabaseValue::DateTime(value)
217 }
218}
219
220impl From<chrono::NaiveDate> for DatabaseValue {
221 fn from(value: chrono::NaiveDate) -> Self {
222 DatabaseValue::Date(value)
223 }
224}
225
226impl From<chrono::NaiveTime> for DatabaseValue {
227 fn from(value: chrono::NaiveTime) -> Self {
228 DatabaseValue::Time(value)
229 }
230}
231
232impl From<JsonValue> for DatabaseValue {
233 fn from(value: JsonValue) -> Self {
234 DatabaseValue::Json(value)
235 }
236}
237
238impl<T> From<Option<T>> for DatabaseValue
239where
240 T: Into<DatabaseValue>,
241{
242 fn from(value: Option<T>) -> Self {
243 match value {
244 Some(v) => v.into(),
245 None => DatabaseValue::Null,
246 }
247 }
248}
249
250#[derive(Debug, Clone, PartialEq, Eq)]
252pub enum SqlDialect {
253 PostgreSQL,
254 MySQL,
255 SQLite,
256}
257
258impl SqlDialect {
259 pub fn parameter_placeholder(&self, index: usize) -> String {
261 match self {
262 SqlDialect::PostgreSQL => format!("${}", index + 1),
263 SqlDialect::MySQL | SqlDialect::SQLite => "?".to_string(),
264 }
265 }
266
267 pub fn identifier_quote(&self) -> char {
269 match self {
270 SqlDialect::PostgreSQL => '"',
271 SqlDialect::MySQL => '`',
272 SqlDialect::SQLite => '"',
273 }
274 }
275
276 pub fn supports_boolean(&self) -> bool {
278 match self {
279 SqlDialect::PostgreSQL | SqlDialect::SQLite => true,
280 SqlDialect::MySQL => false,
281 }
282 }
283
284 pub fn supports_json(&self) -> bool {
286 match self {
287 SqlDialect::PostgreSQL | SqlDialect::MySQL => true,
288 SqlDialect::SQLite => false,
289 }
290 }
291
292 pub fn current_timestamp(&self) -> &'static str {
294 match self {
295 SqlDialect::PostgreSQL => "NOW()",
296 SqlDialect::MySQL => "CURRENT_TIMESTAMP",
297 SqlDialect::SQLite => "datetime('now')",
298 }
299 }
300
301 pub fn auto_increment(&self) -> &'static str {
303 match self {
304 SqlDialect::PostgreSQL => "SERIAL",
305 SqlDialect::MySQL => "AUTO_INCREMENT",
306 SqlDialect::SQLite => "AUTOINCREMENT",
307 }
308 }
309}
310
311#[async_trait]
313pub trait DatabaseBackend: Send + Sync {
314 async fn create_pool(&self, database_url: &str, config: DatabasePoolConfig) -> OrmResult<Arc<dyn DatabasePool>>;
316
317 fn sql_dialect(&self) -> SqlDialect;
319
320 fn backend_type(&self) -> crate::backends::DatabaseBackendType;
322
323 fn validate_database_url(&self, url: &str) -> OrmResult<()>;
325
326 fn parse_database_url(&self, url: &str) -> OrmResult<DatabaseConnectionConfig>;
328}
329
330#[derive(Debug, Clone)]
332pub struct DatabasePoolConfig {
333 pub max_connections: u32,
334 pub min_connections: u32,
335 pub acquire_timeout_seconds: u64,
336 pub idle_timeout_seconds: Option<u64>,
337 pub max_lifetime_seconds: Option<u64>,
338 pub test_before_acquire: bool,
339}
340
341impl Default for DatabasePoolConfig {
342 fn default() -> Self {
343 Self {
344 max_connections: 10,
345 min_connections: 1,
346 acquire_timeout_seconds: 30,
347 idle_timeout_seconds: Some(600), max_lifetime_seconds: Some(1800), test_before_acquire: true,
350 }
351 }
352}
353
354#[derive(Debug, Clone)]
356pub struct DatabaseConnectionConfig {
357 pub host: String,
358 pub port: u16,
359 pub database: String,
360 pub username: Option<String>,
361 pub password: Option<String>,
362 pub ssl_mode: Option<String>,
363 pub additional_params: HashMap<String, String>,
364}
365
366pub struct DatabaseBackendRegistry {
368 backends: HashMap<crate::backends::DatabaseBackendType, Arc<dyn DatabaseBackend>>,
369}
370
371impl DatabaseBackendRegistry {
372 pub fn new() -> Self {
374 Self {
375 backends: HashMap::new(),
376 }
377 }
378
379 pub fn register(&mut self, backend_type: crate::backends::DatabaseBackendType, backend: Arc<dyn DatabaseBackend>) {
381 self.backends.insert(backend_type, backend);
382 }
383
384 pub fn get(&self, backend_type: &crate::backends::DatabaseBackendType) -> Option<Arc<dyn DatabaseBackend>> {
386 self.backends.get(backend_type).cloned()
387 }
388
389 pub async fn create_pool(&self, database_url: &str, config: DatabasePoolConfig) -> OrmResult<Arc<dyn DatabasePool>> {
391 let backend_type = self.detect_backend_from_url(database_url)?;
392 let backend = self.get(&backend_type)
393 .ok_or_else(|| OrmError::Connection(format!("No backend registered for {}", backend_type)))?;
394
395 backend.create_pool(database_url, config).await
396 }
397
398 fn detect_backend_from_url(&self, url: &str) -> OrmResult<crate::backends::DatabaseBackendType> {
400 if url.starts_with("postgresql://") || url.starts_with("postgres://") {
401 Ok(crate::backends::DatabaseBackendType::PostgreSQL)
402 } else if url.starts_with("mysql://") {
403 Ok(crate::backends::DatabaseBackendType::MySQL)
404 } else if url.starts_with("sqlite://") || url.starts_with("file:") {
405 Ok(crate::backends::DatabaseBackendType::SQLite)
406 } else {
407 Err(OrmError::Connection(format!("Unable to detect database backend from URL: {}", url)))
408 }
409 }
410
411 pub fn registered_backends(&self) -> Vec<crate::backends::DatabaseBackendType> {
413 self.backends.keys().cloned().collect()
414 }
415}
416
417impl Default for DatabaseBackendRegistry {
418 fn default() -> Self {
419 Self::new()
420 }
421}