1use std::collections::HashMap;
7use std::sync::Arc;
8use async_trait::async_trait;
9use sqlx::{Pool, Postgres, Row as SqlxRow, postgres::PgPoolOptions, Column};
10use serde_json::Value as JsonValue;
11use crate::error::{OrmResult, OrmError};
12use super::core::*;
13
14#[derive(Debug)]
16pub struct PostgresBackend;
17
18impl PostgresBackend {
19 pub fn new() -> Self {
21 Self
22 }
23}
24
25#[async_trait]
26impl DatabaseBackend for PostgresBackend {
27 async fn create_pool(&self, database_url: &str, config: DatabasePoolConfig) -> OrmResult<Arc<dyn DatabasePool>> {
28 let mut options = PgPoolOptions::new()
29 .max_connections(config.max_connections)
30 .min_connections(config.min_connections)
31 .acquire_timeout(std::time::Duration::from_secs(config.acquire_timeout_seconds))
32 .test_before_acquire(config.test_before_acquire);
33
34 if let Some(idle_timeout) = config.idle_timeout_seconds {
35 options = options.idle_timeout(std::time::Duration::from_secs(idle_timeout));
36 }
37
38 if let Some(max_lifetime) = config.max_lifetime_seconds {
39 options = options.max_lifetime(std::time::Duration::from_secs(max_lifetime));
40 }
41
42 let sqlx_pool = options.connect(database_url)
43 .await
44 .map_err(|e| OrmError::Connection(format!("Failed to create PostgreSQL pool: {}", e)))?;
45
46 Ok(Arc::new(PostgresPool::new(Arc::new(sqlx_pool))))
47 }
48
49 fn sql_dialect(&self) -> SqlDialect {
50 SqlDialect::PostgreSQL
51 }
52
53 fn backend_type(&self) -> crate::backends::DatabaseBackendType {
54 crate::backends::DatabaseBackendType::PostgreSQL
55 }
56
57 fn validate_database_url(&self, url: &str) -> OrmResult<()> {
58 if !url.starts_with("postgresql://") && !url.starts_with("postgres://") {
59 return Err(OrmError::Connection("Invalid PostgreSQL URL scheme".to_string()));
60 }
61 Ok(())
62 }
63
64 fn parse_database_url(&self, url: &str) -> OrmResult<DatabaseConnectionConfig> {
65 let parsed = url::Url::parse(url)
67 .map_err(|e| OrmError::Connection(format!("Invalid database URL: {}", e)))?;
68
69 let host = parsed.host_str()
70 .ok_or_else(|| OrmError::Connection("Missing host in database URL".to_string()))?
71 .to_string();
72
73 let port = parsed.port().unwrap_or(5432);
74
75 let database = parsed.path().trim_start_matches('/').to_string();
76 if database.is_empty() {
77 return Err(OrmError::Connection("Missing database name in URL".to_string()));
78 }
79
80 let username = if parsed.username().is_empty() {
81 None
82 } else {
83 Some(parsed.username().to_string())
84 };
85
86 let password = parsed.password().map(|p| p.to_string());
87
88 let mut additional_params = HashMap::new();
89 for (key, value) in parsed.query_pairs() {
90 additional_params.insert(key.to_string(), value.to_string());
91 }
92
93 let ssl_mode = additional_params.get("sslmode").cloned();
94
95 Ok(DatabaseConnectionConfig {
96 host,
97 port,
98 database,
99 username,
100 password,
101 ssl_mode,
102 additional_params,
103 })
104 }
105}
106
107pub struct PostgresPool {
109 pool: Arc<Pool<Postgres>>,
110}
111
112impl PostgresPool {
113 pub fn new(pool: Arc<Pool<Postgres>>) -> Self {
114 Self { pool }
115 }
116}
117
118#[async_trait]
119impl DatabasePool for PostgresPool {
120 async fn acquire(&self) -> OrmResult<Box<dyn DatabaseConnection>> {
121 let conn = self.pool.acquire()
122 .await
123 .map_err(|e| OrmError::Connection(format!("Failed to acquire connection: {}", e)))?;
124
125 Ok(Box::new(PostgresConnection::new(conn)))
126 }
127
128 async fn begin_transaction(&self) -> OrmResult<Box<dyn DatabaseTransaction>> {
129 Err(OrmError::Query("Transaction support not yet fully implemented in abstraction layer".to_string()))
131 }
132
133 async fn execute(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64> {
134 let mut query = sqlx::query(sql);
135
136 for param in params {
137 query = bind_database_value(query, param)?;
138 }
139
140 let result = query.execute(&*self.pool)
141 .await
142 .map_err(|e| OrmError::Query(format!("Query execution failed: {}", e)))?;
143
144 Ok(result.rows_affected())
145 }
146
147 async fn fetch_all(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Vec<Box<dyn DatabaseRow>>> {
148 let mut query = sqlx::query(sql);
149
150 for param in params {
151 query = bind_database_value(query, param)?;
152 }
153
154 let rows = query.fetch_all(&*self.pool)
155 .await
156 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
157
158 Ok(rows.into_iter().map(|row| Box::new(PostgresRow::new(row)) as Box<dyn DatabaseRow>).collect())
159 }
160
161 async fn fetch_optional(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Option<Box<dyn DatabaseRow>>> {
162 let mut query = sqlx::query(sql);
163
164 for param in params {
165 query = bind_database_value(query, param)?;
166 }
167
168 let row = query.fetch_optional(&*self.pool)
169 .await
170 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
171
172 Ok(row.map(|r| Box::new(PostgresRow::new(r)) as Box<dyn DatabaseRow>))
173 }
174
175 async fn close(&self) -> OrmResult<()> {
176 self.pool.close().await;
177 Ok(())
178 }
179
180 fn stats(&self) -> DatabasePoolStats {
181 let total = self.pool.size() as u32;
182 let idle = self.pool.num_idle() as u32;
183 let active = total.saturating_sub(idle);
184
185 DatabasePoolStats {
186 total_connections: total,
187 idle_connections: idle,
188 active_connections: active,
189 }
190 }
191
192 async fn health_check(&self) -> OrmResult<std::time::Duration> {
193 let start = std::time::Instant::now();
194
195 sqlx::query("SELECT 1")
196 .execute(&*self.pool)
197 .await
198 .map_err(|e| OrmError::Connection(format!("Health check failed: {}", e)))?;
199
200 Ok(start.elapsed())
201 }
202}
203
204pub struct PostgresConnection {
206 conn: sqlx::pool::PoolConnection<Postgres>,
207}
208
209impl PostgresConnection {
210 pub fn new(conn: sqlx::pool::PoolConnection<Postgres>) -> Self {
211 Self { conn }
212 }
213}
214
215#[async_trait]
216impl DatabaseConnection for PostgresConnection {
217 async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64> {
218 let mut query = sqlx::query(sql);
219
220 for param in params {
221 query = bind_database_value(query, param)?;
222 }
223
224 let result = query.execute(&mut *self.conn)
225 .await
226 .map_err(|e| OrmError::Query(format!("Query execution failed: {}", e)))?;
227
228 Ok(result.rows_affected())
229 }
230
231 async fn fetch_all(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Vec<Box<dyn DatabaseRow>>> {
232 let mut query = sqlx::query(sql);
233
234 for param in params {
235 query = bind_database_value(query, param)?;
236 }
237
238 let rows = query.fetch_all(&mut *self.conn)
239 .await
240 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
241
242 Ok(rows.into_iter().map(|row| Box::new(PostgresRow::new(row)) as Box<dyn DatabaseRow>).collect())
243 }
244
245 async fn fetch_optional(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Option<Box<dyn DatabaseRow>>> {
246 let mut query = sqlx::query(sql);
247
248 for param in params {
249 query = bind_database_value(query, param)?;
250 }
251
252 let row = query.fetch_optional(&mut *self.conn)
253 .await
254 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
255
256 Ok(row.map(|r| Box::new(PostgresRow::new(r)) as Box<dyn DatabaseRow>))
257 }
258
259 async fn begin_transaction(&mut self) -> OrmResult<Box<dyn DatabaseTransaction>> {
260 Err(OrmError::Query("Transaction support not yet fully implemented in abstraction layer".to_string()))
262 }
263
264 async fn close(&mut self) -> OrmResult<()> {
265 Ok(())
267 }
268}
269
270pub struct PostgresTransaction<'c> {
272 tx: Option<sqlx::Transaction<'c, Postgres>>,
273}
274
275impl<'c> PostgresTransaction<'c> {
276 pub fn new(tx: sqlx::Transaction<'c, Postgres>) -> Self {
277 Self { tx: Some(tx) }
278 }
279}
280
281#[async_trait]
282impl<'c> DatabaseTransaction for PostgresTransaction<'c> {
283 async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64> {
284 let tx = self.tx.as_mut().ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
285
286 let mut query = sqlx::query(sql);
287
288 for param in params {
289 query = bind_database_value(query, param)?;
290 }
291
292 let result = query.execute(&mut **tx)
293 .await
294 .map_err(|e| OrmError::Query(format!("Query execution failed: {}", e)))?;
295
296 Ok(result.rows_affected())
297 }
298
299 async fn fetch_all(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Vec<Box<dyn DatabaseRow>>> {
300 let tx = self.tx.as_mut().ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
301
302 let mut query = sqlx::query(sql);
303
304 for param in params {
305 query = bind_database_value(query, param)?;
306 }
307
308 let rows = query.fetch_all(&mut **tx)
309 .await
310 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
311
312 Ok(rows.into_iter().map(|row| Box::new(PostgresRow::new(row)) as Box<dyn DatabaseRow>).collect())
313 }
314
315 async fn fetch_optional(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<Option<Box<dyn DatabaseRow>>> {
316 let tx = self.tx.as_mut().ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
317
318 let mut query = sqlx::query(sql);
319
320 for param in params {
321 query = bind_database_value(query, param)?;
322 }
323
324 let row = query.fetch_optional(&mut **tx)
325 .await
326 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
327
328 Ok(row.map(|r| Box::new(PostgresRow::new(r)) as Box<dyn DatabaseRow>))
329 }
330
331 async fn commit(mut self: Box<Self>) -> OrmResult<()> {
332 let tx = self.tx.take().ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
333
334 tx.commit()
335 .await
336 .map_err(|e| OrmError::Query(format!("Transaction commit failed: {}", e)))?;
337
338 Ok(())
339 }
340
341 async fn rollback(mut self: Box<Self>) -> OrmResult<()> {
342 let tx = self.tx.take().ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
343
344 tx.rollback()
345 .await
346 .map_err(|e| OrmError::Query(format!("Transaction rollback failed: {}", e)))?;
347
348 Ok(())
349 }
350}
351
352pub struct PostgresRow {
354 row: sqlx::postgres::PgRow,
355}
356
357impl PostgresRow {
358 pub fn new(row: sqlx::postgres::PgRow) -> Self {
359 Self { row }
360 }
361}
362
363impl DatabaseRow for PostgresRow {
364 fn get_by_index(&self, index: usize) -> OrmResult<DatabaseValue> {
365 postgres_value_to_database_value(&self.row, index)
366 }
367
368 fn get_by_name(&self, name: &str) -> OrmResult<DatabaseValue> {
369 let columns = self.row.columns();
371 let index = columns.iter().position(|col| col.name() == name)
372 .ok_or_else(|| OrmError::Query(format!("Column '{}' not found", name)))?;
373
374 postgres_value_to_database_value(&self.row, index)
375 }
376
377 fn column_count(&self) -> usize {
378 self.row.len()
379 }
380
381 fn column_names(&self) -> Vec<String> {
382 self.row.columns().iter().map(|col| col.name().to_string()).collect()
383 }
384
385 fn to_json(&self) -> OrmResult<JsonValue> {
386 let mut map = serde_json::Map::new();
387
388 for (i, column) in self.row.columns().iter().enumerate() {
389 let value = self.get_by_index(i)?;
390 map.insert(column.name().to_string(), value.to_json());
391 }
392
393 Ok(JsonValue::Object(map))
394 }
395
396 fn to_map(&self) -> OrmResult<HashMap<String, DatabaseValue>> {
397 let mut map = HashMap::new();
398
399 for (i, column) in self.row.columns().iter().enumerate() {
400 let value = self.get_by_index(i)?;
401 map.insert(column.name().to_string(), value);
402 }
403
404 Ok(map)
405 }
406}
407
408fn bind_database_value<'a>(
410 query: sqlx::query::Query<'a, Postgres, sqlx::postgres::PgArguments>,
411 value: &DatabaseValue
412) -> OrmResult<sqlx::query::Query<'a, Postgres, sqlx::postgres::PgArguments>> {
413 match value {
414 DatabaseValue::Null => Ok(query.bind(Option::<String>::None)),
415 DatabaseValue::Bool(b) => Ok(query.bind(*b)),
416 DatabaseValue::Int32(i) => Ok(query.bind(*i)),
417 DatabaseValue::Int64(i) => Ok(query.bind(*i)),
418 DatabaseValue::Float32(f) => Ok(query.bind(*f)),
419 DatabaseValue::Float64(f) => Ok(query.bind(*f)),
420 DatabaseValue::String(s) => Ok(query.bind(s.clone())),
421 DatabaseValue::Bytes(b) => Ok(query.bind(b.clone())),
422 DatabaseValue::Uuid(u) => Ok(query.bind(*u)),
423 DatabaseValue::DateTime(dt) => Ok(query.bind(*dt)),
424 DatabaseValue::Date(d) => Ok(query.bind(*d)),
425 DatabaseValue::Time(t) => Ok(query.bind(*t)),
426 DatabaseValue::Json(j) => Ok(query.bind(j.clone())),
427 DatabaseValue::Array(_) => Err(OrmError::Query("Array binding not yet implemented for PostgreSQL".to_string())),
428 }
429}
430
431fn postgres_value_to_database_value(row: &sqlx::postgres::PgRow, index: usize) -> OrmResult<DatabaseValue> {
433 use sqlx::{Row, Column, TypeInfo};
434
435 let column = &row.columns()[index];
436 let type_name = column.type_info().name();
437
438 if let Ok(Some(value)) = row.try_get::<Option<String>, _>(index) {
440 return Ok(DatabaseValue::String(value));
441 }
442
443 match type_name {
444 "BOOL" => {
445 let value: bool = row.try_get(index)
446 .map_err(|e| OrmError::Query(format!("Failed to get bool value: {}", e)))?;
447 Ok(DatabaseValue::Bool(value))
448 },
449 "INT2" => {
450 let value: i16 = row.try_get(index)
451 .map_err(|e| OrmError::Query(format!("Failed to get int16 value: {}", e)))?;
452 Ok(DatabaseValue::Int32(value as i32))
453 },
454 "INT4" => {
455 let value: i32 = row.try_get(index)
456 .map_err(|e| OrmError::Query(format!("Failed to get int32 value: {}", e)))?;
457 Ok(DatabaseValue::Int32(value))
458 },
459 "INT8" => {
460 let value: i64 = row.try_get(index)
461 .map_err(|e| OrmError::Query(format!("Failed to get int64 value: {}", e)))?;
462 Ok(DatabaseValue::Int64(value))
463 },
464 "FLOAT4" => {
465 let value: f32 = row.try_get(index)
466 .map_err(|e| OrmError::Query(format!("Failed to get float32 value: {}", e)))?;
467 Ok(DatabaseValue::Float32(value))
468 },
469 "FLOAT8" => {
470 let value: f64 = row.try_get(index)
471 .map_err(|e| OrmError::Query(format!("Failed to get float64 value: {}", e)))?;
472 Ok(DatabaseValue::Float64(value))
473 },
474 "TEXT" | "VARCHAR" => {
475 let value: String = row.try_get(index)
476 .map_err(|e| OrmError::Query(format!("Failed to get string value: {}", e)))?;
477 Ok(DatabaseValue::String(value))
478 },
479 "BYTEA" => {
480 let value: Vec<u8> = row.try_get(index)
481 .map_err(|e| OrmError::Query(format!("Failed to get bytes value: {}", e)))?;
482 Ok(DatabaseValue::Bytes(value))
483 },
484 "UUID" => {
485 let value: uuid::Uuid = row.try_get(index)
486 .map_err(|e| OrmError::Query(format!("Failed to get UUID value: {}", e)))?;
487 Ok(DatabaseValue::Uuid(value))
488 },
489 "TIMESTAMPTZ" | "TIMESTAMP" => {
490 let value: chrono::DateTime<chrono::Utc> = row.try_get(index)
491 .map_err(|e| OrmError::Query(format!("Failed to get datetime value: {}", e)))?;
492 Ok(DatabaseValue::DateTime(value))
493 },
494 "DATE" => {
495 let value: chrono::NaiveDate = row.try_get(index)
496 .map_err(|e| OrmError::Query(format!("Failed to get date value: {}", e)))?;
497 Ok(DatabaseValue::Date(value))
498 },
499 "TIME" => {
500 let value: chrono::NaiveTime = row.try_get(index)
501 .map_err(|e| OrmError::Query(format!("Failed to get time value: {}", e)))?;
502 Ok(DatabaseValue::Time(value))
503 },
504 "JSON" | "JSONB" => {
505 let value: JsonValue = row.try_get(index)
506 .map_err(|e| OrmError::Query(format!("Failed to get JSON value: {}", e)))?;
507 Ok(DatabaseValue::Json(value))
508 },
509 _ => {
510 let value: String = row.try_get(index)
512 .map_err(|e| OrmError::Query(format!("Failed to get value as string for unknown type '{}': {}", type_name, e)))?;
513 Ok(DatabaseValue::String(value))
514 }
515 }
516}
517
518impl Default for PostgresBackend {
519 fn default() -> Self {
520 Self::new()
521 }
522}