1pub(crate) use dataprof_core::DataProfilerError;
8use dataprof_core::{DataSource, ExecutionMetadata, QualityDimension, QueryEngine};
9use dataprof_metrics::analyze_column;
10use dataprof_runtime::{ProfileReport, ReportAssembler};
11use std::collections::HashMap;
12
13pub mod connection;
14pub mod connectors;
15pub mod retry;
16pub mod sampling;
17pub mod security;
18pub mod streaming;
19
20pub use connection::*;
21pub use connectors::*;
22pub use retry::*;
23pub use sampling::*;
24pub use security::*;
25
26#[derive(Debug, Clone)]
28pub struct DatabaseConfig {
29 pub connection_string: String,
30 pub batch_size: usize,
31 pub max_connections: Option<u32>,
32 pub connection_timeout: Option<std::time::Duration>,
33 pub retry_config: Option<RetryConfig>,
34 pub sampling_config: Option<SamplingConfig>,
35 pub ssl_config: Option<SslConfig>,
36 pub load_credentials_from_env: bool,
37}
38
39impl Default for DatabaseConfig {
40 fn default() -> Self {
41 Self {
42 connection_string: String::new(),
43 batch_size: 10000,
44 max_connections: Some(10),
45 connection_timeout: Some(std::time::Duration::from_secs(30)),
46 retry_config: Some(RetryConfig::default()),
47 sampling_config: None,
48 ssl_config: Some(SslConfig::default()),
49 load_credentials_from_env: true,
50 }
51 }
52}
53
54#[async_trait::async_trait]
56pub trait DatabaseConnector: Send + Sync {
57 async fn connect(&mut self) -> Result<(), DataProfilerError>;
59
60 async fn disconnect(&mut self) -> Result<(), DataProfilerError>;
62
63 async fn profile_query(
65 &mut self,
66 query: &str,
67 ) -> Result<HashMap<String, Vec<String>>, DataProfilerError>;
68
69 async fn profile_query_streaming(
71 &mut self,
72 query: &str,
73 batch_size: usize,
74 ) -> Result<HashMap<String, Vec<String>>, DataProfilerError>;
75
76 async fn get_table_schema(
78 &mut self,
79 table_name: &str,
80 ) -> Result<Vec<String>, DataProfilerError>;
81
82 async fn count_table_rows(&mut self, table_name: &str) -> Result<u64, DataProfilerError>;
84
85 async fn test_connection(&mut self) -> Result<bool, DataProfilerError>;
87}
88
89pub fn create_connector(
91 mut config: DatabaseConfig,
92) -> Result<Box<dyn DatabaseConnector>, DataProfilerError> {
93 if config.load_credentials_from_env || config.connection_string.is_empty() {
94 config = apply_environment_configuration(config)?;
95 }
96
97 let connection_str = config.connection_string.as_str();
98
99 if connection_str.starts_with("postgresql://") || connection_str.starts_with("postgres://") {
100 Ok(Box::new(connectors::postgres::PostgresConnector::new(
101 config,
102 )?))
103 } else if connection_str.starts_with("mysql://") {
104 Ok(Box::new(connectors::mysql::MySqlConnector::new(config)?))
105 } else if connection_str.starts_with("sqlite://")
106 || connection_str.ends_with(".db")
107 || connection_str.ends_with(".sqlite")
108 || connection_str == ":memory:"
109 {
110 Ok(Box::new(connectors::sqlite::SqliteConnector::new(config)?))
111 } else {
112 Err(DataProfilerError::DatabaseConfigError {
113 message: format!(
114 "Unsupported database connection string: {}. Supported: postgresql://, mysql://, sqlite://",
115 connection_str
116 ),
117 })
118 }
119}
120
121fn apply_environment_configuration(
123 mut config: DatabaseConfig,
124) -> Result<DatabaseConfig, DataProfilerError> {
125 let database_type = if config.connection_string.is_empty() {
126 if std::env::var("POSTGRES_URL").is_ok()
127 || std::env::var("DATABASE_URL")
128 .map(|url| url.starts_with("postgres"))
129 .unwrap_or(false)
130 {
131 "postgresql".to_string()
132 } else if std::env::var("MYSQL_URL").is_ok() {
133 "mysql".to_string()
134 } else {
135 "postgresql".to_string()
136 }
137 } else {
138 let conn_info = ConnectionInfo::parse(&config.connection_string)?;
139 conn_info.database_type().to_string()
140 };
141 let database_type = database_type.as_str();
142
143 if config.connection_string.is_empty() {
144 let (secure_connection_string, ssl_config) = load_secure_database_config(database_type)?;
145 config.connection_string = secure_connection_string;
146 config.ssl_config = Some(ssl_config);
147 } else {
148 if let Some(ssl_config) = &config.ssl_config {
149 config.connection_string = ssl_config
150 .apply_to_connection_string(config.connection_string.clone(), database_type);
151 }
152
153 if config.load_credentials_from_env {
154 let credentials = DatabaseCredentials::from_environment(database_type);
155 config.connection_string =
156 credentials.apply_to_connection_string(&config.connection_string);
157 }
158 }
159
160 Ok(config)
161}
162
163pub async fn analyze_database(
165 config: DatabaseConfig,
166 query: &str,
167 calculate_quality: bool,
168 quality_dimensions: Option<Vec<QualityDimension>>,
169) -> Result<ProfileReport, DataProfilerError> {
170 let mut connector = create_connector(config.clone())?;
171
172 connector.connect().await?;
173
174 let start = std::time::Instant::now();
175
176 let (actual_query, is_table) = if query.trim().to_uppercase().starts_with("SELECT") {
177 let validated_query = security::validate_base_query(query)?;
178 (validated_query, false)
179 } else {
180 security::validate_sql_identifier(query)?;
181 (format!("SELECT * FROM {}", query), true)
182 };
183
184 let total_rows = if is_table {
185 connector.count_table_rows(query).await.unwrap_or(0)
186 } else {
187 0
188 };
189
190 let (final_query, sample_info) = if let Some(sampling_config) = &config.sampling_config {
191 if total_rows > sampling_config.sample_size as u64 {
192 let sampled_query = sampling_config.generate_sample_query(&actual_query, total_rows)?;
193 let info = SampleInfo::new(
194 total_rows,
195 sampling_config.sample_size.min(total_rows as usize) as u64,
196 sampling_config.strategy.clone(),
197 );
198 (sampled_query, Some(info))
199 } else {
200 (actual_query, None)
201 }
202 } else {
203 (actual_query, None)
204 };
205
206 let columns = connector
207 .profile_query_streaming(&final_query, config.batch_size)
208 .await?;
209
210 connector.disconnect().await?;
211
212 let query_engine = detect_query_engine(&config.connection_string);
213
214 if columns.is_empty() {
215 let mut exec = ExecutionMetadata::new(0, 0, start.elapsed().as_millis());
216 if let Some(ref info) = sample_info
217 && info.sampling_ratio < 1.0
218 {
219 exec = exec
220 .with_sampling(info.sampling_ratio)
221 .with_source_exhausted(false);
222 }
223 return Ok(ReportAssembler::new(
224 DataSource::Query {
225 engine: query_engine.clone(),
226 statement: query.to_string(),
227 database: extract_database_name(&config.connection_string),
228 execution_id: None,
229 },
230 exec,
231 )
232 .skip_quality()
233 .build());
234 }
235
236 let mut column_profiles = Vec::new();
237 let actual_rows_processed = columns.values().next().map(|v| v.len()).unwrap_or(0);
238
239 for (name, data) in &columns {
240 let profile = analyze_column(name, data);
241 column_profiles.push(profile);
242 }
243
244 let scan_time_ms = start.elapsed().as_millis();
245 let sampling_ratio = sample_info.map(|s| s.sampling_ratio).unwrap_or(1.0);
246 let num_columns = column_profiles.len();
247
248 let mut execution = ExecutionMetadata::new(actual_rows_processed, num_columns, scan_time_ms);
249 if sampling_ratio < 1.0 {
250 execution = execution
251 .with_sampling(sampling_ratio)
252 .with_source_exhausted(false);
253 }
254
255 let mut assembler = ReportAssembler::new(
256 DataSource::Query {
257 engine: query_engine,
258 statement: query.to_string(),
259 database: extract_database_name(&config.connection_string),
260 execution_id: None,
261 },
262 execution,
263 )
264 .columns(column_profiles);
265
266 if calculate_quality {
267 assembler = assembler.with_quality_data(columns);
268 if let Some(dims) = quality_dimensions {
269 assembler = assembler.with_requested_dimensions(dims);
270 }
271 } else {
272 assembler = assembler.skip_quality();
273 }
274
275 Ok(assembler.build())
276}
277
278fn detect_query_engine(connection_string: &str) -> QueryEngine {
280 let conn = connection_string.to_lowercase();
281 if conn.starts_with("postgres") || conn.starts_with("postgresql") {
282 QueryEngine::Postgres
283 } else if conn.starts_with("mysql") || conn.starts_with("mariadb") {
284 QueryEngine::MySql
285 } else if conn.starts_with("sqlite") {
286 QueryEngine::Sqlite
287 } else {
288 QueryEngine::Custom("unknown".to_string())
289 }
290}
291
292fn extract_database_name(connection_string: &str) -> Option<String> {
294 if let Some(pos) = connection_string.rfind('/') {
295 let db_part = &connection_string[pos + 1..];
296 let db_name = db_part.split('?').next().unwrap_or(db_part);
297 if !db_name.is_empty() {
298 return Some(db_name.to_string());
299 }
300 }
301 None
302}