1use super::core::*;
7use crate::error::{OrmError, OrmResult};
8use async_trait::async_trait;
9use serde_json::Value as JsonValue;
10use sqlx::{postgres::PgPoolOptions, Column, Pool, Postgres, Row as SqlxRow};
11use std::collections::HashMap;
12use std::sync::Arc;
13
14#[derive(Debug)]
16pub struct PostgresBackend;
17
18impl PostgresBackend {
19 pub fn new() -> Self {
21 Self
22 }
23}
24
25#[async_trait]
26impl DatabaseBackend for PostgresBackend {
27 async fn create_pool(
28 &self,
29 database_url: &str,
30 config: DatabasePoolConfig,
31 ) -> OrmResult<Arc<dyn DatabasePool>> {
32 let mut options = PgPoolOptions::new()
33 .max_connections(config.max_connections)
34 .min_connections(config.min_connections)
35 .acquire_timeout(std::time::Duration::from_secs(
36 config.acquire_timeout_seconds,
37 ))
38 .test_before_acquire(config.test_before_acquire);
39
40 if let Some(idle_timeout) = config.idle_timeout_seconds {
41 options = options.idle_timeout(std::time::Duration::from_secs(idle_timeout));
42 }
43
44 if let Some(max_lifetime) = config.max_lifetime_seconds {
45 options = options.max_lifetime(std::time::Duration::from_secs(max_lifetime));
46 }
47
48 let sqlx_pool = options.connect(database_url).await.map_err(|e| {
49 OrmError::Connection(format!("Failed to create PostgreSQL pool: {}", e))
50 })?;
51
52 Ok(Arc::new(PostgresPool::new(Arc::new(sqlx_pool))))
53 }
54
55 fn sql_dialect(&self) -> SqlDialect {
56 SqlDialect::PostgreSQL
57 }
58
59 fn backend_type(&self) -> crate::backends::DatabaseBackendType {
60 crate::backends::DatabaseBackendType::PostgreSQL
61 }
62
63 fn validate_database_url(&self, url: &str) -> OrmResult<()> {
64 if !url.starts_with("postgresql://") && !url.starts_with("postgres://") {
65 return Err(OrmError::Connection(
66 "Invalid PostgreSQL URL scheme".to_string(),
67 ));
68 }
69 Ok(())
70 }
71
72 fn parse_database_url(&self, url: &str) -> OrmResult<DatabaseConnectionConfig> {
73 let parsed = url::Url::parse(url)
75 .map_err(|e| OrmError::Connection(format!("Invalid database URL: {}", e)))?;
76
77 let host = parsed
78 .host_str()
79 .ok_or_else(|| OrmError::Connection("Missing host in database URL".to_string()))?
80 .to_string();
81
82 let port = parsed.port().unwrap_or(5432);
83
84 let database = parsed.path().trim_start_matches('/').to_string();
85 if database.is_empty() {
86 return Err(OrmError::Connection(
87 "Missing database name in URL".to_string(),
88 ));
89 }
90
91 let username = if parsed.username().is_empty() {
92 None
93 } else {
94 Some(parsed.username().to_string())
95 };
96
97 let password = parsed.password().map(|p| p.to_string());
98
99 let mut additional_params = HashMap::new();
100 for (key, value) in parsed.query_pairs() {
101 additional_params.insert(key.to_string(), value.to_string());
102 }
103
104 let ssl_mode = additional_params.get("sslmode").cloned();
105
106 Ok(DatabaseConnectionConfig {
107 host,
108 port,
109 database,
110 username,
111 password,
112 ssl_mode,
113 additional_params,
114 })
115 }
116}
117
118pub struct PostgresPool {
120 pool: Arc<Pool<Postgres>>,
121}
122
123impl PostgresPool {
124 pub fn new(pool: Arc<Pool<Postgres>>) -> Self {
125 Self { pool }
126 }
127}
128
129#[async_trait]
130impl DatabasePool for PostgresPool {
131 async fn acquire(&self) -> OrmResult<Box<dyn DatabaseConnection>> {
132 let conn =
133 self.pool.acquire().await.map_err(|e| {
134 OrmError::Connection(format!("Failed to acquire connection: {}", e))
135 })?;
136
137 Ok(Box::new(PostgresConnection::new(conn)))
138 }
139
140 async fn begin_transaction(&self) -> OrmResult<Box<dyn DatabaseTransaction>> {
141 Err(OrmError::Query(
143 "Transaction support not yet fully implemented in abstraction layer".to_string(),
144 ))
145 }
146
147 async fn execute(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64> {
148 let mut query = sqlx::query(sql);
149
150 for param in params {
151 query = bind_database_value(query, param)?;
152 }
153
154 let result = query
155 .execute(&*self.pool)
156 .await
157 .map_err(|e| OrmError::Query(format!("Query execution failed: {}", e)))?;
158
159 Ok(result.rows_affected())
160 }
161
162 async fn fetch_all(
163 &self,
164 sql: &str,
165 params: &[DatabaseValue],
166 ) -> OrmResult<Vec<Box<dyn DatabaseRow>>> {
167 let mut query = sqlx::query(sql);
168
169 for param in params {
170 query = bind_database_value(query, param)?;
171 }
172
173 let rows = query
174 .fetch_all(&*self.pool)
175 .await
176 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
177
178 Ok(rows
179 .into_iter()
180 .map(|row| Box::new(PostgresRow::new(row)) as Box<dyn DatabaseRow>)
181 .collect())
182 }
183
184 async fn fetch_optional(
185 &self,
186 sql: &str,
187 params: &[DatabaseValue],
188 ) -> OrmResult<Option<Box<dyn DatabaseRow>>> {
189 let mut query = sqlx::query(sql);
190
191 for param in params {
192 query = bind_database_value(query, param)?;
193 }
194
195 let row = query
196 .fetch_optional(&*self.pool)
197 .await
198 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
199
200 Ok(row.map(|r| Box::new(PostgresRow::new(r)) as Box<dyn DatabaseRow>))
201 }
202
203 async fn close(&self) -> OrmResult<()> {
204 self.pool.close().await;
205 Ok(())
206 }
207
208 fn stats(&self) -> DatabasePoolStats {
209 let total = self.pool.size();
210 let idle = self.pool.num_idle() as u32;
211 let active = total.saturating_sub(idle);
212
213 DatabasePoolStats {
214 total_connections: total,
215 idle_connections: idle,
216 active_connections: active,
217 }
218 }
219
220 async fn health_check(&self) -> OrmResult<std::time::Duration> {
221 let start = std::time::Instant::now();
222
223 sqlx::query("SELECT 1")
224 .execute(&*self.pool)
225 .await
226 .map_err(|e| OrmError::Connection(format!("Health check failed: {}", e)))?;
227
228 Ok(start.elapsed())
229 }
230}
231
232pub struct PostgresConnection {
234 conn: sqlx::pool::PoolConnection<Postgres>,
235}
236
237impl PostgresConnection {
238 pub fn new(conn: sqlx::pool::PoolConnection<Postgres>) -> Self {
239 Self { conn }
240 }
241}
242
243#[async_trait]
244impl DatabaseConnection for PostgresConnection {
245 async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64> {
246 let mut query = sqlx::query(sql);
247
248 for param in params {
249 query = bind_database_value(query, param)?;
250 }
251
252 let result = query
253 .execute(&mut *self.conn)
254 .await
255 .map_err(|e| OrmError::Query(format!("Query execution failed: {}", e)))?;
256
257 Ok(result.rows_affected())
258 }
259
260 async fn fetch_all(
261 &mut self,
262 sql: &str,
263 params: &[DatabaseValue],
264 ) -> OrmResult<Vec<Box<dyn DatabaseRow>>> {
265 let mut query = sqlx::query(sql);
266
267 for param in params {
268 query = bind_database_value(query, param)?;
269 }
270
271 let rows = query
272 .fetch_all(&mut *self.conn)
273 .await
274 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
275
276 Ok(rows
277 .into_iter()
278 .map(|row| Box::new(PostgresRow::new(row)) as Box<dyn DatabaseRow>)
279 .collect())
280 }
281
282 async fn fetch_optional(
283 &mut self,
284 sql: &str,
285 params: &[DatabaseValue],
286 ) -> OrmResult<Option<Box<dyn DatabaseRow>>> {
287 let mut query = sqlx::query(sql);
288
289 for param in params {
290 query = bind_database_value(query, param)?;
291 }
292
293 let row = query
294 .fetch_optional(&mut *self.conn)
295 .await
296 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
297
298 Ok(row.map(|r| Box::new(PostgresRow::new(r)) as Box<dyn DatabaseRow>))
299 }
300
301 async fn begin_transaction(&mut self) -> OrmResult<Box<dyn DatabaseTransaction>> {
302 Err(OrmError::Query(
304 "Transaction support not yet fully implemented in abstraction layer".to_string(),
305 ))
306 }
307
308 async fn close(&mut self) -> OrmResult<()> {
309 Ok(())
311 }
312}
313
314pub struct PostgresTransaction<'c> {
316 tx: Option<sqlx::Transaction<'c, Postgres>>,
317}
318
319impl<'c> PostgresTransaction<'c> {
320 pub fn new(tx: sqlx::Transaction<'c, Postgres>) -> Self {
321 Self { tx: Some(tx) }
322 }
323}
324
325#[async_trait]
326impl<'c> DatabaseTransaction for PostgresTransaction<'c> {
327 async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64> {
328 let tx = self
329 .tx
330 .as_mut()
331 .ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
332
333 let mut query = sqlx::query(sql);
334
335 for param in params {
336 query = bind_database_value(query, param)?;
337 }
338
339 let result = query
340 .execute(&mut **tx)
341 .await
342 .map_err(|e| OrmError::Query(format!("Query execution failed: {}", e)))?;
343
344 Ok(result.rows_affected())
345 }
346
347 async fn fetch_all(
348 &mut self,
349 sql: &str,
350 params: &[DatabaseValue],
351 ) -> OrmResult<Vec<Box<dyn DatabaseRow>>> {
352 let tx = self
353 .tx
354 .as_mut()
355 .ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
356
357 let mut query = sqlx::query(sql);
358
359 for param in params {
360 query = bind_database_value(query, param)?;
361 }
362
363 let rows = query
364 .fetch_all(&mut **tx)
365 .await
366 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
367
368 Ok(rows
369 .into_iter()
370 .map(|row| Box::new(PostgresRow::new(row)) as Box<dyn DatabaseRow>)
371 .collect())
372 }
373
374 async fn fetch_optional(
375 &mut self,
376 sql: &str,
377 params: &[DatabaseValue],
378 ) -> OrmResult<Option<Box<dyn DatabaseRow>>> {
379 let tx = self
380 .tx
381 .as_mut()
382 .ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
383
384 let mut query = sqlx::query(sql);
385
386 for param in params {
387 query = bind_database_value(query, param)?;
388 }
389
390 let row = query
391 .fetch_optional(&mut **tx)
392 .await
393 .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
394
395 Ok(row.map(|r| Box::new(PostgresRow::new(r)) as Box<dyn DatabaseRow>))
396 }
397
398 async fn commit(mut self: Box<Self>) -> OrmResult<()> {
399 let tx = self
400 .tx
401 .take()
402 .ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
403
404 tx.commit()
405 .await
406 .map_err(|e| OrmError::Query(format!("Transaction commit failed: {}", e)))?;
407
408 Ok(())
409 }
410
411 async fn rollback(mut self: Box<Self>) -> OrmResult<()> {
412 let tx = self
413 .tx
414 .take()
415 .ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
416
417 tx.rollback()
418 .await
419 .map_err(|e| OrmError::Query(format!("Transaction rollback failed: {}", e)))?;
420
421 Ok(())
422 }
423}
424
425pub struct PostgresRow {
427 row: sqlx::postgres::PgRow,
428}
429
430impl PostgresRow {
431 pub fn new(row: sqlx::postgres::PgRow) -> Self {
432 Self { row }
433 }
434}
435
436impl DatabaseRow for PostgresRow {
437 fn get_by_index(&self, index: usize) -> OrmResult<DatabaseValue> {
438 postgres_value_to_database_value(&self.row, index)
439 }
440
441 fn get_by_name(&self, name: &str) -> OrmResult<DatabaseValue> {
442 let columns = self.row.columns();
444 let index = columns
445 .iter()
446 .position(|col| col.name() == name)
447 .ok_or_else(|| OrmError::Query(format!("Column '{}' not found", name)))?;
448
449 postgres_value_to_database_value(&self.row, index)
450 }
451
452 fn column_count(&self) -> usize {
453 self.row.len()
454 }
455
456 fn column_names(&self) -> Vec<String> {
457 self.row
458 .columns()
459 .iter()
460 .map(|col| col.name().to_string())
461 .collect()
462 }
463
464 fn to_json(&self) -> OrmResult<JsonValue> {
465 let mut map = serde_json::Map::new();
466
467 for (i, column) in self.row.columns().iter().enumerate() {
468 let value = self.get_by_index(i)?;
469 map.insert(column.name().to_string(), value.to_json());
470 }
471
472 Ok(JsonValue::Object(map))
473 }
474
475 fn to_map(&self) -> OrmResult<HashMap<String, DatabaseValue>> {
476 let mut map = HashMap::new();
477
478 for (i, column) in self.row.columns().iter().enumerate() {
479 let value = self.get_by_index(i)?;
480 map.insert(column.name().to_string(), value);
481 }
482
483 Ok(map)
484 }
485}
486
487fn bind_database_value<'a>(
489 query: sqlx::query::Query<'a, Postgres, sqlx::postgres::PgArguments>,
490 value: &DatabaseValue,
491) -> OrmResult<sqlx::query::Query<'a, Postgres, sqlx::postgres::PgArguments>> {
492 match value {
493 DatabaseValue::Null => Ok(query.bind(Option::<String>::None)),
494 DatabaseValue::Bool(b) => Ok(query.bind(*b)),
495 DatabaseValue::Int32(i) => Ok(query.bind(*i)),
496 DatabaseValue::Int64(i) => Ok(query.bind(*i)),
497 DatabaseValue::Float32(f) => Ok(query.bind(*f)),
498 DatabaseValue::Float64(f) => Ok(query.bind(*f)),
499 DatabaseValue::String(s) => Ok(query.bind(s.clone())),
500 DatabaseValue::Bytes(b) => Ok(query.bind(b.clone())),
501 DatabaseValue::Uuid(u) => Ok(query.bind(*u)),
502 DatabaseValue::DateTime(dt) => Ok(query.bind(*dt)),
503 DatabaseValue::Date(d) => Ok(query.bind(*d)),
504 DatabaseValue::Time(t) => Ok(query.bind(*t)),
505 DatabaseValue::Json(j) => Ok(query.bind(j.clone())),
506 DatabaseValue::Array(_) => Err(OrmError::Query(
507 "Array binding not yet implemented for PostgreSQL".to_string(),
508 )),
509 }
510}
511
512fn postgres_value_to_database_value(
514 row: &sqlx::postgres::PgRow,
515 index: usize,
516) -> OrmResult<DatabaseValue> {
517 use sqlx::{Column, Row, TypeInfo};
518
519 let column = &row.columns()[index];
520 let type_name = column.type_info().name();
521
522 if let Ok(Some(value)) = row.try_get::<Option<String>, _>(index) {
524 return Ok(DatabaseValue::String(value));
525 }
526
527 match type_name {
528 "BOOL" => {
529 let value: bool = row
530 .try_get(index)
531 .map_err(|e| OrmError::Query(format!("Failed to get bool value: {}", e)))?;
532 Ok(DatabaseValue::Bool(value))
533 }
534 "INT2" => {
535 let value: i16 = row
536 .try_get(index)
537 .map_err(|e| OrmError::Query(format!("Failed to get int16 value: {}", e)))?;
538 Ok(DatabaseValue::Int32(value as i32))
539 }
540 "INT4" => {
541 let value: i32 = row
542 .try_get(index)
543 .map_err(|e| OrmError::Query(format!("Failed to get int32 value: {}", e)))?;
544 Ok(DatabaseValue::Int32(value))
545 }
546 "INT8" => {
547 let value: i64 = row
548 .try_get(index)
549 .map_err(|e| OrmError::Query(format!("Failed to get int64 value: {}", e)))?;
550 Ok(DatabaseValue::Int64(value))
551 }
552 "FLOAT4" => {
553 let value: f32 = row
554 .try_get(index)
555 .map_err(|e| OrmError::Query(format!("Failed to get float32 value: {}", e)))?;
556 Ok(DatabaseValue::Float32(value))
557 }
558 "FLOAT8" => {
559 let value: f64 = row
560 .try_get(index)
561 .map_err(|e| OrmError::Query(format!("Failed to get float64 value: {}", e)))?;
562 Ok(DatabaseValue::Float64(value))
563 }
564 "TEXT" | "VARCHAR" => {
565 let value: String = row
566 .try_get(index)
567 .map_err(|e| OrmError::Query(format!("Failed to get string value: {}", e)))?;
568 Ok(DatabaseValue::String(value))
569 }
570 "BYTEA" => {
571 let value: Vec<u8> = row
572 .try_get(index)
573 .map_err(|e| OrmError::Query(format!("Failed to get bytes value: {}", e)))?;
574 Ok(DatabaseValue::Bytes(value))
575 }
576 "UUID" => {
577 let value: uuid::Uuid = row
578 .try_get(index)
579 .map_err(|e| OrmError::Query(format!("Failed to get UUID value: {}", e)))?;
580 Ok(DatabaseValue::Uuid(value))
581 }
582 "TIMESTAMPTZ" | "TIMESTAMP" => {
583 let value: chrono::DateTime<chrono::Utc> = row
584 .try_get(index)
585 .map_err(|e| OrmError::Query(format!("Failed to get datetime value: {}", e)))?;
586 Ok(DatabaseValue::DateTime(value))
587 }
588 "DATE" => {
589 let value: chrono::NaiveDate = row
590 .try_get(index)
591 .map_err(|e| OrmError::Query(format!("Failed to get date value: {}", e)))?;
592 Ok(DatabaseValue::Date(value))
593 }
594 "TIME" => {
595 let value: chrono::NaiveTime = row
596 .try_get(index)
597 .map_err(|e| OrmError::Query(format!("Failed to get time value: {}", e)))?;
598 Ok(DatabaseValue::Time(value))
599 }
600 "JSON" | "JSONB" => {
601 let value: JsonValue = row
602 .try_get(index)
603 .map_err(|e| OrmError::Query(format!("Failed to get JSON value: {}", e)))?;
604 Ok(DatabaseValue::Json(value))
605 }
606 _ => {
607 let value: String = row.try_get(index).map_err(|e| {
609 OrmError::Query(format!(
610 "Failed to get value as string for unknown type '{}': {}",
611 type_name, e
612 ))
613 })?;
614 Ok(DatabaseValue::String(value))
615 }
616 }
617}
618
619impl Default for PostgresBackend {
620 fn default() -> Self {
621 Self::new()
622 }
623}