Skip to main content

dataprof_db/connectors/
postgres.rs

1//! PostgreSQL database connector with connection pooling
2
3#[cfg(not(feature = "postgres"))]
4use super::common::feature_not_enabled_error;
5#[cfg(feature = "postgres")]
6use super::common::{build_count_query, not_connected_error};
7use crate::connection::ConnectionInfo;
8#[cfg(feature = "postgres")]
9use crate::security::validate_sql_identifier;
10use crate::{DataProfilerError, DatabaseConfig, DatabaseConnector};
11#[cfg(feature = "postgres")]
12use crate::{process_rows_to_columns, streaming_profile_loop};
13use async_trait::async_trait;
14use std::collections::HashMap;
15
16#[cfg(feature = "postgres")]
17use {sqlx::postgres::PgPool, sqlx::postgres::PgPoolOptions};
18
19/// PostgreSQL connector with connection pooling support
20pub struct PostgresConnector {
21    #[allow(dead_code)]
22    config: DatabaseConfig,
23    #[allow(dead_code)]
24    connection_info: ConnectionInfo,
25    #[cfg(feature = "postgres")]
26    pool: Option<PgPool>,
27    #[cfg(not(feature = "postgres"))]
28    #[allow(dead_code)]
29    pool: Option<()>,
30}
31
32impl PostgresConnector {
33    /// Create a new PostgreSQL connector
34    pub fn new(config: DatabaseConfig) -> Result<Self, DataProfilerError> {
35        let connection_info = ConnectionInfo::parse(&config.connection_string)?;
36
37        if connection_info.database_type() != "postgresql" {
38            return Err(DataProfilerError::DatabaseConfigError {
39                message: format!(
40                    "Invalid connection string for PostgreSQL: {}",
41                    config.connection_string
42                ),
43            });
44        }
45
46        Ok(Self {
47            config,
48            connection_info,
49            pool: None,
50        })
51    }
52}
53
54#[async_trait]
55impl DatabaseConnector for PostgresConnector {
56    async fn connect(&mut self) -> Result<(), DataProfilerError> {
57        #[cfg(feature = "postgres")]
58        {
59            let connection_string = self.connection_info.to_connection_string("sqlx");
60
61            let pool = PgPoolOptions::new()
62                .max_connections(self.config.max_connections.unwrap_or(10))
63                .acquire_timeout(
64                    self.config
65                        .connection_timeout
66                        .unwrap_or(std::time::Duration::from_secs(30)),
67                )
68                .connect(&connection_string)
69                .await
70                .map_err(|e| {
71                    DataProfilerError::database_connection(&format!(
72                        "Failed to connect to PostgreSQL: {}",
73                        e
74                    ))
75                })?;
76
77            self.pool = Some(pool);
78            Ok(())
79        }
80
81        #[cfg(not(feature = "postgres"))]
82        {
83            Err(DataProfilerError::database_feature_disabled(
84                "PostgreSQL",
85                "postgres",
86            ))
87        }
88    }
89
90    async fn disconnect(&mut self) -> Result<(), DataProfilerError> {
91        #[cfg(feature = "postgres")]
92        {
93            if let Some(pool) = &self.pool {
94                pool.close().await;
95                self.pool = None;
96            }
97        }
98        Ok(())
99    }
100
101    #[allow(unused_variables)]
102    async fn profile_query(
103        &mut self,
104        query: &str,
105    ) -> Result<HashMap<String, Vec<String>>, DataProfilerError> {
106        #[cfg(feature = "postgres")]
107        {
108            let pool = self.pool.as_ref().ok_or_else(not_connected_error)?;
109
110            let rows = sqlx::query(query).fetch_all(pool).await.map_err(|e| {
111                DataProfilerError::database_query(&format!("Query execution failed: {}", e))
112            })?;
113
114            Ok(process_rows_to_columns!(rows))
115        }
116
117        #[cfg(not(feature = "postgres"))]
118        Err(feature_not_enabled_error("PostgreSQL", "postgres"))
119    }
120
121    #[allow(unused_variables)]
122    async fn profile_query_streaming(
123        &mut self,
124        query: &str,
125        batch_size: usize,
126    ) -> Result<HashMap<String, Vec<String>>, DataProfilerError> {
127        #[cfg(feature = "postgres")]
128        {
129            let pool = self.pool.as_ref().ok_or_else(not_connected_error)?;
130
131            let count_query = build_count_query(query)?;
132            let total_rows: i64 = sqlx::query_scalar(&count_query)
133                .fetch_one(pool)
134                .await
135                .map_err(|e| {
136                    DataProfilerError::database_query(&format!("Failed to count rows: {}", e))
137                })?;
138
139            streaming_profile_loop!(pool, query, batch_size, total_rows, "PostgreSQL")
140        }
141
142        #[cfg(not(feature = "postgres"))]
143        Err(feature_not_enabled_error("PostgreSQL", "postgres"))
144    }
145
146    #[allow(unused_variables)]
147    async fn get_table_schema(
148        &mut self,
149        table_name: &str,
150    ) -> Result<Vec<String>, DataProfilerError> {
151        #[cfg(feature = "postgres")]
152        {
153            use sqlx::Row;
154
155            let pool = self.pool.as_ref().ok_or_else(not_connected_error)?;
156
157            let query = r#"
158                SELECT column_name
159                FROM information_schema.columns
160                WHERE table_name = $1
161                ORDER BY ordinal_position
162            "#;
163
164            let rows = sqlx::query(query)
165                .bind(table_name)
166                .fetch_all(pool)
167                .await
168                .map_err(|e| {
169                    DataProfilerError::database_query(&format!("Failed to get table schema: {}", e))
170                })?;
171
172            let mut columns = Vec::new();
173            for row in rows {
174                let column_name: String = row.try_get(0).map_err(|e| {
175                    DataProfilerError::database_query(&format!("Failed to read column name: {}", e))
176                })?;
177                columns.push(column_name);
178            }
179
180            Ok(columns)
181        }
182
183        #[cfg(not(feature = "postgres"))]
184        Err(feature_not_enabled_error("PostgreSQL", "postgres"))
185    }
186
187    #[allow(unused_variables)]
188    async fn count_table_rows(&mut self, table_name: &str) -> Result<u64, DataProfilerError> {
189        #[cfg(feature = "postgres")]
190        {
191            let pool = self.pool.as_ref().ok_or_else(not_connected_error)?;
192
193            validate_sql_identifier(table_name)?;
194            let query = format!("SELECT COUNT(*) FROM {}", table_name);
195            let count: i64 = sqlx::query_scalar(&query)
196                .fetch_one(pool)
197                .await
198                .map_err(|e| {
199                    DataProfilerError::database_query(&format!("Failed to count rows: {}", e))
200                })?;
201
202            Ok(count as u64)
203        }
204
205        #[cfg(not(feature = "postgres"))]
206        Err(feature_not_enabled_error("PostgreSQL", "postgres"))
207    }
208
209    async fn test_connection(&mut self) -> Result<bool, DataProfilerError> {
210        #[cfg(feature = "postgres")]
211        {
212            let pool = self.pool.as_ref().ok_or_else(not_connected_error)?;
213
214            let result: i32 = sqlx::query_scalar("SELECT 1")
215                .fetch_one(pool)
216                .await
217                .map_err(|e| {
218                    DataProfilerError::database_query(&format!("Connection test failed: {}", e))
219                })?;
220
221            Ok(result == 1)
222        }
223
224        #[cfg(not(feature = "postgres"))]
225        Err(feature_not_enabled_error("PostgreSQL", "postgres"))
226    }
227}