dataprof_db/connectors/
postgres.rs1#[cfg(not(feature = "postgres"))]
4use super::common::feature_not_enabled_error;
5#[cfg(feature = "postgres")]
6use super::common::{build_count_query, not_connected_error};
7use crate::connection::ConnectionInfo;
8#[cfg(feature = "postgres")]
9use crate::security::validate_sql_identifier;
10use crate::{DataProfilerError, DatabaseConfig, DatabaseConnector};
11#[cfg(feature = "postgres")]
12use crate::{process_rows_to_columns, streaming_profile_loop};
13use async_trait::async_trait;
14use std::collections::HashMap;
15
16#[cfg(feature = "postgres")]
17use {sqlx::postgres::PgPool, sqlx::postgres::PgPoolOptions};
18
19pub struct PostgresConnector {
21 #[allow(dead_code)]
22 config: DatabaseConfig,
23 #[allow(dead_code)]
24 connection_info: ConnectionInfo,
25 #[cfg(feature = "postgres")]
26 pool: Option<PgPool>,
27 #[cfg(not(feature = "postgres"))]
28 #[allow(dead_code)]
29 pool: Option<()>,
30}
31
32impl PostgresConnector {
33 pub fn new(config: DatabaseConfig) -> Result<Self, DataProfilerError> {
35 let connection_info = ConnectionInfo::parse(&config.connection_string)?;
36
37 if connection_info.database_type() != "postgresql" {
38 return Err(DataProfilerError::DatabaseConfigError {
39 message: format!(
40 "Invalid connection string for PostgreSQL: {}",
41 config.connection_string
42 ),
43 });
44 }
45
46 Ok(Self {
47 config,
48 connection_info,
49 pool: None,
50 })
51 }
52}
53
54#[async_trait]
55impl DatabaseConnector for PostgresConnector {
56 async fn connect(&mut self) -> Result<(), DataProfilerError> {
57 #[cfg(feature = "postgres")]
58 {
59 let connection_string = self.connection_info.to_connection_string("sqlx");
60
61 let pool = PgPoolOptions::new()
62 .max_connections(self.config.max_connections.unwrap_or(10))
63 .acquire_timeout(
64 self.config
65 .connection_timeout
66 .unwrap_or(std::time::Duration::from_secs(30)),
67 )
68 .connect(&connection_string)
69 .await
70 .map_err(|e| {
71 DataProfilerError::database_connection(&format!(
72 "Failed to connect to PostgreSQL: {}",
73 e
74 ))
75 })?;
76
77 self.pool = Some(pool);
78 Ok(())
79 }
80
81 #[cfg(not(feature = "postgres"))]
82 {
83 Err(DataProfilerError::database_feature_disabled(
84 "PostgreSQL",
85 "postgres",
86 ))
87 }
88 }
89
90 async fn disconnect(&mut self) -> Result<(), DataProfilerError> {
91 #[cfg(feature = "postgres")]
92 {
93 if let Some(pool) = &self.pool {
94 pool.close().await;
95 self.pool = None;
96 }
97 }
98 Ok(())
99 }
100
101 #[allow(unused_variables)]
102 async fn profile_query(
103 &mut self,
104 query: &str,
105 ) -> Result<HashMap<String, Vec<String>>, DataProfilerError> {
106 #[cfg(feature = "postgres")]
107 {
108 let pool = self.pool.as_ref().ok_or_else(not_connected_error)?;
109
110 let rows = sqlx::query(query).fetch_all(pool).await.map_err(|e| {
111 DataProfilerError::database_query(&format!("Query execution failed: {}", e))
112 })?;
113
114 Ok(process_rows_to_columns!(rows))
115 }
116
117 #[cfg(not(feature = "postgres"))]
118 Err(feature_not_enabled_error("PostgreSQL", "postgres"))
119 }
120
121 #[allow(unused_variables)]
122 async fn profile_query_streaming(
123 &mut self,
124 query: &str,
125 batch_size: usize,
126 ) -> Result<HashMap<String, Vec<String>>, DataProfilerError> {
127 #[cfg(feature = "postgres")]
128 {
129 let pool = self.pool.as_ref().ok_or_else(not_connected_error)?;
130
131 let count_query = build_count_query(query)?;
132 let total_rows: i64 = sqlx::query_scalar(&count_query)
133 .fetch_one(pool)
134 .await
135 .map_err(|e| {
136 DataProfilerError::database_query(&format!("Failed to count rows: {}", e))
137 })?;
138
139 streaming_profile_loop!(pool, query, batch_size, total_rows, "PostgreSQL")
140 }
141
142 #[cfg(not(feature = "postgres"))]
143 Err(feature_not_enabled_error("PostgreSQL", "postgres"))
144 }
145
146 #[allow(unused_variables)]
147 async fn get_table_schema(
148 &mut self,
149 table_name: &str,
150 ) -> Result<Vec<String>, DataProfilerError> {
151 #[cfg(feature = "postgres")]
152 {
153 use sqlx::Row;
154
155 let pool = self.pool.as_ref().ok_or_else(not_connected_error)?;
156
157 let query = r#"
158 SELECT column_name
159 FROM information_schema.columns
160 WHERE table_name = $1
161 ORDER BY ordinal_position
162 "#;
163
164 let rows = sqlx::query(query)
165 .bind(table_name)
166 .fetch_all(pool)
167 .await
168 .map_err(|e| {
169 DataProfilerError::database_query(&format!("Failed to get table schema: {}", e))
170 })?;
171
172 let mut columns = Vec::new();
173 for row in rows {
174 let column_name: String = row.try_get(0).map_err(|e| {
175 DataProfilerError::database_query(&format!("Failed to read column name: {}", e))
176 })?;
177 columns.push(column_name);
178 }
179
180 Ok(columns)
181 }
182
183 #[cfg(not(feature = "postgres"))]
184 Err(feature_not_enabled_error("PostgreSQL", "postgres"))
185 }
186
187 #[allow(unused_variables)]
188 async fn count_table_rows(&mut self, table_name: &str) -> Result<u64, DataProfilerError> {
189 #[cfg(feature = "postgres")]
190 {
191 let pool = self.pool.as_ref().ok_or_else(not_connected_error)?;
192
193 validate_sql_identifier(table_name)?;
194 let query = format!("SELECT COUNT(*) FROM {}", table_name);
195 let count: i64 = sqlx::query_scalar(&query)
196 .fetch_one(pool)
197 .await
198 .map_err(|e| {
199 DataProfilerError::database_query(&format!("Failed to count rows: {}", e))
200 })?;
201
202 Ok(count as u64)
203 }
204
205 #[cfg(not(feature = "postgres"))]
206 Err(feature_not_enabled_error("PostgreSQL", "postgres"))
207 }
208
209 async fn test_connection(&mut self) -> Result<bool, DataProfilerError> {
210 #[cfg(feature = "postgres")]
211 {
212 let pool = self.pool.as_ref().ok_or_else(not_connected_error)?;
213
214 let result: i32 = sqlx::query_scalar("SELECT 1")
215 .fetch_one(pool)
216 .await
217 .map_err(|e| {
218 DataProfilerError::database_query(&format!("Connection test failed: {}", e))
219 })?;
220
221 Ok(result == 1)
222 }
223
224 #[cfg(not(feature = "postgres"))]
225 Err(feature_not_enabled_error("PostgreSQL", "postgres"))
226 }
227}