elif_orm/backends/
postgres.rs

1//! PostgreSQL Backend Implementation
2//!
3//! This module provides the PostgreSQL-specific implementation of the database
4//! backend traits using sqlx as the underlying database driver.
5
6use super::core::*;
7use crate::error::{OrmError, OrmResult};
8use async_trait::async_trait;
9use serde_json::Value as JsonValue;
10use sqlx::{postgres::PgPoolOptions, Column, Pool, Postgres, Row as SqlxRow};
11use std::collections::HashMap;
12use std::sync::Arc;
13
14/// PostgreSQL database backend implementation
15#[derive(Debug)]
16pub struct PostgresBackend;
17
18impl PostgresBackend {
19    /// Create a new PostgreSQL backend instance
20    pub fn new() -> Self {
21        Self
22    }
23}
24
25#[async_trait]
26impl DatabaseBackend for PostgresBackend {
27    async fn create_pool(
28        &self,
29        database_url: &str,
30        config: DatabasePoolConfig,
31    ) -> OrmResult<Arc<dyn DatabasePool>> {
32        let mut options = PgPoolOptions::new()
33            .max_connections(config.max_connections)
34            .min_connections(config.min_connections)
35            .acquire_timeout(std::time::Duration::from_secs(
36                config.acquire_timeout_seconds,
37            ))
38            .test_before_acquire(config.test_before_acquire);
39
40        if let Some(idle_timeout) = config.idle_timeout_seconds {
41            options = options.idle_timeout(std::time::Duration::from_secs(idle_timeout));
42        }
43
44        if let Some(max_lifetime) = config.max_lifetime_seconds {
45            options = options.max_lifetime(std::time::Duration::from_secs(max_lifetime));
46        }
47
48        let sqlx_pool = options.connect(database_url).await.map_err(|e| {
49            OrmError::Connection(format!("Failed to create PostgreSQL pool: {}", e))
50        })?;
51
52        Ok(Arc::new(PostgresPool::new(Arc::new(sqlx_pool))))
53    }
54
55    fn sql_dialect(&self) -> SqlDialect {
56        SqlDialect::PostgreSQL
57    }
58
59    fn backend_type(&self) -> crate::backends::DatabaseBackendType {
60        crate::backends::DatabaseBackendType::PostgreSQL
61    }
62
63    fn validate_database_url(&self, url: &str) -> OrmResult<()> {
64        if !url.starts_with("postgresql://") && !url.starts_with("postgres://") {
65            return Err(OrmError::Connection(
66                "Invalid PostgreSQL URL scheme".to_string(),
67            ));
68        }
69        Ok(())
70    }
71
72    fn parse_database_url(&self, url: &str) -> OrmResult<DatabaseConnectionConfig> {
73        // Basic URL parsing for PostgreSQL
74        let parsed = url::Url::parse(url)
75            .map_err(|e| OrmError::Connection(format!("Invalid database URL: {}", e)))?;
76
77        let host = parsed
78            .host_str()
79            .ok_or_else(|| OrmError::Connection("Missing host in database URL".to_string()))?
80            .to_string();
81
82        let port = parsed.port().unwrap_or(5432);
83
84        let database = parsed.path().trim_start_matches('/').to_string();
85        if database.is_empty() {
86            return Err(OrmError::Connection(
87                "Missing database name in URL".to_string(),
88            ));
89        }
90
91        let username = if parsed.username().is_empty() {
92            None
93        } else {
94            Some(parsed.username().to_string())
95        };
96
97        let password = parsed.password().map(|p| p.to_string());
98
99        let mut additional_params = HashMap::new();
100        for (key, value) in parsed.query_pairs() {
101            additional_params.insert(key.to_string(), value.to_string());
102        }
103
104        let ssl_mode = additional_params.get("sslmode").cloned();
105
106        Ok(DatabaseConnectionConfig {
107            host,
108            port,
109            database,
110            username,
111            password,
112            ssl_mode,
113            additional_params,
114        })
115    }
116}
117
118/// PostgreSQL connection pool implementation
119pub struct PostgresPool {
120    pool: Arc<Pool<Postgres>>,
121}
122
123impl PostgresPool {
124    pub fn new(pool: Arc<Pool<Postgres>>) -> Self {
125        Self { pool }
126    }
127}
128
129#[async_trait]
130impl DatabasePool for PostgresPool {
131    async fn acquire(&self) -> OrmResult<Box<dyn DatabaseConnection>> {
132        let conn =
133            self.pool.acquire().await.map_err(|e| {
134                OrmError::Connection(format!("Failed to acquire connection: {}", e))
135            })?;
136
137        Ok(Box::new(PostgresConnection::new(conn)))
138    }
139
140    async fn begin_transaction(&self) -> OrmResult<Box<dyn DatabaseTransaction>> {
141        // For now, we'll use a simpler approach that doesn't require explicit lifetimes
142        Err(OrmError::Query(
143            "Transaction support not yet fully implemented in abstraction layer".to_string(),
144        ))
145    }
146
147    async fn execute(&self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64> {
148        let mut query = sqlx::query(sql);
149
150        for param in params {
151            query = bind_database_value(query, param)?;
152        }
153
154        let result = query
155            .execute(&*self.pool)
156            .await
157            .map_err(|e| OrmError::Query(format!("Query execution failed: {}", e)))?;
158
159        Ok(result.rows_affected())
160    }
161
162    async fn fetch_all(
163        &self,
164        sql: &str,
165        params: &[DatabaseValue],
166    ) -> OrmResult<Vec<Box<dyn DatabaseRow>>> {
167        let mut query = sqlx::query(sql);
168
169        for param in params {
170            query = bind_database_value(query, param)?;
171        }
172
173        let rows = query
174            .fetch_all(&*self.pool)
175            .await
176            .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
177
178        Ok(rows
179            .into_iter()
180            .map(|row| Box::new(PostgresRow::new(row)) as Box<dyn DatabaseRow>)
181            .collect())
182    }
183
184    async fn fetch_optional(
185        &self,
186        sql: &str,
187        params: &[DatabaseValue],
188    ) -> OrmResult<Option<Box<dyn DatabaseRow>>> {
189        let mut query = sqlx::query(sql);
190
191        for param in params {
192            query = bind_database_value(query, param)?;
193        }
194
195        let row = query
196            .fetch_optional(&*self.pool)
197            .await
198            .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
199
200        Ok(row.map(|r| Box::new(PostgresRow::new(r)) as Box<dyn DatabaseRow>))
201    }
202
203    async fn close(&self) -> OrmResult<()> {
204        self.pool.close().await;
205        Ok(())
206    }
207
208    fn stats(&self) -> DatabasePoolStats {
209        let total = self.pool.size();
210        let idle = self.pool.num_idle() as u32;
211        let active = total.saturating_sub(idle);
212
213        DatabasePoolStats {
214            total_connections: total,
215            idle_connections: idle,
216            active_connections: active,
217        }
218    }
219
220    async fn health_check(&self) -> OrmResult<std::time::Duration> {
221        let start = std::time::Instant::now();
222
223        sqlx::query("SELECT 1")
224            .execute(&*self.pool)
225            .await
226            .map_err(|e| OrmError::Connection(format!("Health check failed: {}", e)))?;
227
228        Ok(start.elapsed())
229    }
230}
231
232/// PostgreSQL connection implementation
233pub struct PostgresConnection {
234    conn: sqlx::pool::PoolConnection<Postgres>,
235}
236
237impl PostgresConnection {
238    pub fn new(conn: sqlx::pool::PoolConnection<Postgres>) -> Self {
239        Self { conn }
240    }
241}
242
243#[async_trait]
244impl DatabaseConnection for PostgresConnection {
245    async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64> {
246        let mut query = sqlx::query(sql);
247
248        for param in params {
249            query = bind_database_value(query, param)?;
250        }
251
252        let result = query
253            .execute(&mut *self.conn)
254            .await
255            .map_err(|e| OrmError::Query(format!("Query execution failed: {}", e)))?;
256
257        Ok(result.rows_affected())
258    }
259
260    async fn fetch_all(
261        &mut self,
262        sql: &str,
263        params: &[DatabaseValue],
264    ) -> OrmResult<Vec<Box<dyn DatabaseRow>>> {
265        let mut query = sqlx::query(sql);
266
267        for param in params {
268            query = bind_database_value(query, param)?;
269        }
270
271        let rows = query
272            .fetch_all(&mut *self.conn)
273            .await
274            .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
275
276        Ok(rows
277            .into_iter()
278            .map(|row| Box::new(PostgresRow::new(row)) as Box<dyn DatabaseRow>)
279            .collect())
280    }
281
282    async fn fetch_optional(
283        &mut self,
284        sql: &str,
285        params: &[DatabaseValue],
286    ) -> OrmResult<Option<Box<dyn DatabaseRow>>> {
287        let mut query = sqlx::query(sql);
288
289        for param in params {
290            query = bind_database_value(query, param)?;
291        }
292
293        let row = query
294            .fetch_optional(&mut *self.conn)
295            .await
296            .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
297
298        Ok(row.map(|r| Box::new(PostgresRow::new(r)) as Box<dyn DatabaseRow>))
299    }
300
301    async fn begin_transaction(&mut self) -> OrmResult<Box<dyn DatabaseTransaction>> {
302        // For now, we'll use a simpler approach that doesn't require explicit lifetimes
303        Err(OrmError::Query(
304            "Transaction support not yet fully implemented in abstraction layer".to_string(),
305        ))
306    }
307
308    async fn close(&mut self) -> OrmResult<()> {
309        // Connection will be returned to pool automatically when dropped
310        Ok(())
311    }
312}
313
314/// PostgreSQL transaction implementation
315pub struct PostgresTransaction<'c> {
316    tx: Option<sqlx::Transaction<'c, Postgres>>,
317}
318
319impl<'c> PostgresTransaction<'c> {
320    pub fn new(tx: sqlx::Transaction<'c, Postgres>) -> Self {
321        Self { tx: Some(tx) }
322    }
323}
324
325#[async_trait]
326impl<'c> DatabaseTransaction for PostgresTransaction<'c> {
327    async fn execute(&mut self, sql: &str, params: &[DatabaseValue]) -> OrmResult<u64> {
328        let tx = self
329            .tx
330            .as_mut()
331            .ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
332
333        let mut query = sqlx::query(sql);
334
335        for param in params {
336            query = bind_database_value(query, param)?;
337        }
338
339        let result = query
340            .execute(&mut **tx)
341            .await
342            .map_err(|e| OrmError::Query(format!("Query execution failed: {}", e)))?;
343
344        Ok(result.rows_affected())
345    }
346
347    async fn fetch_all(
348        &mut self,
349        sql: &str,
350        params: &[DatabaseValue],
351    ) -> OrmResult<Vec<Box<dyn DatabaseRow>>> {
352        let tx = self
353            .tx
354            .as_mut()
355            .ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
356
357        let mut query = sqlx::query(sql);
358
359        for param in params {
360            query = bind_database_value(query, param)?;
361        }
362
363        let rows = query
364            .fetch_all(&mut **tx)
365            .await
366            .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
367
368        Ok(rows
369            .into_iter()
370            .map(|row| Box::new(PostgresRow::new(row)) as Box<dyn DatabaseRow>)
371            .collect())
372    }
373
374    async fn fetch_optional(
375        &mut self,
376        sql: &str,
377        params: &[DatabaseValue],
378    ) -> OrmResult<Option<Box<dyn DatabaseRow>>> {
379        let tx = self
380            .tx
381            .as_mut()
382            .ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
383
384        let mut query = sqlx::query(sql);
385
386        for param in params {
387            query = bind_database_value(query, param)?;
388        }
389
390        let row = query
391            .fetch_optional(&mut **tx)
392            .await
393            .map_err(|e| OrmError::Query(format!("Query fetch failed: {}", e)))?;
394
395        Ok(row.map(|r| Box::new(PostgresRow::new(r)) as Box<dyn DatabaseRow>))
396    }
397
398    async fn commit(mut self: Box<Self>) -> OrmResult<()> {
399        let tx = self
400            .tx
401            .take()
402            .ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
403
404        tx.commit()
405            .await
406            .map_err(|e| OrmError::Query(format!("Transaction commit failed: {}", e)))?;
407
408        Ok(())
409    }
410
411    async fn rollback(mut self: Box<Self>) -> OrmResult<()> {
412        let tx = self
413            .tx
414            .take()
415            .ok_or_else(|| OrmError::Query("Transaction already completed".to_string()))?;
416
417        tx.rollback()
418            .await
419            .map_err(|e| OrmError::Query(format!("Transaction rollback failed: {}", e)))?;
420
421        Ok(())
422    }
423}
424
425/// PostgreSQL row implementation
426pub struct PostgresRow {
427    row: sqlx::postgres::PgRow,
428}
429
430impl PostgresRow {
431    pub fn new(row: sqlx::postgres::PgRow) -> Self {
432        Self { row }
433    }
434}
435
436impl DatabaseRow for PostgresRow {
437    fn get_by_index(&self, index: usize) -> OrmResult<DatabaseValue> {
438        postgres_value_to_database_value(&self.row, index)
439    }
440
441    fn get_by_name(&self, name: &str) -> OrmResult<DatabaseValue> {
442        // Find the column index by name
443        let columns = self.row.columns();
444        let index = columns
445            .iter()
446            .position(|col| col.name() == name)
447            .ok_or_else(|| OrmError::Query(format!("Column '{}' not found", name)))?;
448
449        postgres_value_to_database_value(&self.row, index)
450    }
451
452    fn column_count(&self) -> usize {
453        self.row.len()
454    }
455
456    fn column_names(&self) -> Vec<String> {
457        self.row
458            .columns()
459            .iter()
460            .map(|col| col.name().to_string())
461            .collect()
462    }
463
464    fn to_json(&self) -> OrmResult<JsonValue> {
465        let mut map = serde_json::Map::new();
466
467        for (i, column) in self.row.columns().iter().enumerate() {
468            let value = self.get_by_index(i)?;
469            map.insert(column.name().to_string(), value.to_json());
470        }
471
472        Ok(JsonValue::Object(map))
473    }
474
475    fn to_map(&self) -> OrmResult<HashMap<String, DatabaseValue>> {
476        let mut map = HashMap::new();
477
478        for (i, column) in self.row.columns().iter().enumerate() {
479            let value = self.get_by_index(i)?;
480            map.insert(column.name().to_string(), value);
481        }
482
483        Ok(map)
484    }
485}
486
487/// Bind a DatabaseValue to a sqlx query
488fn bind_database_value<'a>(
489    query: sqlx::query::Query<'a, Postgres, sqlx::postgres::PgArguments>,
490    value: &DatabaseValue,
491) -> OrmResult<sqlx::query::Query<'a, Postgres, sqlx::postgres::PgArguments>> {
492    match value {
493        DatabaseValue::Null => Ok(query.bind(Option::<String>::None)),
494        DatabaseValue::Bool(b) => Ok(query.bind(*b)),
495        DatabaseValue::Int32(i) => Ok(query.bind(*i)),
496        DatabaseValue::Int64(i) => Ok(query.bind(*i)),
497        DatabaseValue::Float32(f) => Ok(query.bind(*f)),
498        DatabaseValue::Float64(f) => Ok(query.bind(*f)),
499        DatabaseValue::String(s) => Ok(query.bind(s.clone())),
500        DatabaseValue::Bytes(b) => Ok(query.bind(b.clone())),
501        DatabaseValue::Uuid(u) => Ok(query.bind(*u)),
502        DatabaseValue::DateTime(dt) => Ok(query.bind(*dt)),
503        DatabaseValue::Date(d) => Ok(query.bind(*d)),
504        DatabaseValue::Time(t) => Ok(query.bind(*t)),
505        DatabaseValue::Json(j) => Ok(query.bind(j.clone())),
506        DatabaseValue::Array(_) => Err(OrmError::Query(
507            "Array binding not yet implemented for PostgreSQL".to_string(),
508        )),
509    }
510}
511
512/// Convert a PostgreSQL column value to DatabaseValue
513fn postgres_value_to_database_value(
514    row: &sqlx::postgres::PgRow,
515    index: usize,
516) -> OrmResult<DatabaseValue> {
517    use sqlx::{Column, Row, TypeInfo};
518
519    let column = &row.columns()[index];
520    let type_name = column.type_info().name();
521
522    // Handle null values
523    if let Ok(Some(value)) = row.try_get::<Option<String>, _>(index) {
524        return Ok(DatabaseValue::String(value));
525    }
526
527    match type_name {
528        "BOOL" => {
529            let value: bool = row
530                .try_get(index)
531                .map_err(|e| OrmError::Query(format!("Failed to get bool value: {}", e)))?;
532            Ok(DatabaseValue::Bool(value))
533        }
534        "INT2" => {
535            let value: i16 = row
536                .try_get(index)
537                .map_err(|e| OrmError::Query(format!("Failed to get int16 value: {}", e)))?;
538            Ok(DatabaseValue::Int32(value as i32))
539        }
540        "INT4" => {
541            let value: i32 = row
542                .try_get(index)
543                .map_err(|e| OrmError::Query(format!("Failed to get int32 value: {}", e)))?;
544            Ok(DatabaseValue::Int32(value))
545        }
546        "INT8" => {
547            let value: i64 = row
548                .try_get(index)
549                .map_err(|e| OrmError::Query(format!("Failed to get int64 value: {}", e)))?;
550            Ok(DatabaseValue::Int64(value))
551        }
552        "FLOAT4" => {
553            let value: f32 = row
554                .try_get(index)
555                .map_err(|e| OrmError::Query(format!("Failed to get float32 value: {}", e)))?;
556            Ok(DatabaseValue::Float32(value))
557        }
558        "FLOAT8" => {
559            let value: f64 = row
560                .try_get(index)
561                .map_err(|e| OrmError::Query(format!("Failed to get float64 value: {}", e)))?;
562            Ok(DatabaseValue::Float64(value))
563        }
564        "TEXT" | "VARCHAR" => {
565            let value: String = row
566                .try_get(index)
567                .map_err(|e| OrmError::Query(format!("Failed to get string value: {}", e)))?;
568            Ok(DatabaseValue::String(value))
569        }
570        "BYTEA" => {
571            let value: Vec<u8> = row
572                .try_get(index)
573                .map_err(|e| OrmError::Query(format!("Failed to get bytes value: {}", e)))?;
574            Ok(DatabaseValue::Bytes(value))
575        }
576        "UUID" => {
577            let value: uuid::Uuid = row
578                .try_get(index)
579                .map_err(|e| OrmError::Query(format!("Failed to get UUID value: {}", e)))?;
580            Ok(DatabaseValue::Uuid(value))
581        }
582        "TIMESTAMPTZ" | "TIMESTAMP" => {
583            let value: chrono::DateTime<chrono::Utc> = row
584                .try_get(index)
585                .map_err(|e| OrmError::Query(format!("Failed to get datetime value: {}", e)))?;
586            Ok(DatabaseValue::DateTime(value))
587        }
588        "DATE" => {
589            let value: chrono::NaiveDate = row
590                .try_get(index)
591                .map_err(|e| OrmError::Query(format!("Failed to get date value: {}", e)))?;
592            Ok(DatabaseValue::Date(value))
593        }
594        "TIME" => {
595            let value: chrono::NaiveTime = row
596                .try_get(index)
597                .map_err(|e| OrmError::Query(format!("Failed to get time value: {}", e)))?;
598            Ok(DatabaseValue::Time(value))
599        }
600        "JSON" | "JSONB" => {
601            let value: JsonValue = row
602                .try_get(index)
603                .map_err(|e| OrmError::Query(format!("Failed to get JSON value: {}", e)))?;
604            Ok(DatabaseValue::Json(value))
605        }
606        _ => {
607            // Fallback: try to get as string
608            let value: String = row.try_get(index).map_err(|e| {
609                OrmError::Query(format!(
610                    "Failed to get value as string for unknown type '{}': {}",
611                    type_name, e
612                ))
613            })?;
614            Ok(DatabaseValue::String(value))
615        }
616    }
617}
618
619impl Default for PostgresBackend {
620    fn default() -> Self {
621        Self::new()
622    }
623}