Skip to main content

dataprof_db/
lib.rs

1//! Database connectivity module for dataprof.
2//!
3//! This crate owns the database profiling surface, including connection
4//! handling, secure configuration helpers, sampling, streaming, and the
5//! feature-gated sqlx-based connectors.
6
7pub(crate) use dataprof_core::DataProfilerError;
8use dataprof_core::{DataSource, ExecutionMetadata, QualityDimension, QueryEngine};
9use dataprof_metrics::analyze_column;
10use dataprof_runtime::{ProfileReport, ReportAssembler};
11use std::collections::HashMap;
12
13pub mod connection;
14pub mod connectors;
15pub mod retry;
16pub mod sampling;
17pub mod security;
18pub mod streaming;
19
20pub use connection::*;
21pub use connectors::*;
22pub use retry::*;
23pub use sampling::*;
24pub use security::*;
25
26/// Database configuration for connection strings and settings
27#[derive(Debug, Clone)]
28pub struct DatabaseConfig {
29    pub connection_string: String,
30    pub batch_size: usize,
31    pub max_connections: Option<u32>,
32    pub connection_timeout: Option<std::time::Duration>,
33    pub retry_config: Option<RetryConfig>,
34    pub sampling_config: Option<SamplingConfig>,
35    pub ssl_config: Option<SslConfig>,
36    pub load_credentials_from_env: bool,
37}
38
39impl Default for DatabaseConfig {
40    fn default() -> Self {
41        Self {
42            connection_string: String::new(),
43            batch_size: 10000,
44            max_connections: Some(10),
45            connection_timeout: Some(std::time::Duration::from_secs(30)),
46            retry_config: Some(RetryConfig::default()),
47            sampling_config: None,
48            ssl_config: Some(SslConfig::default()),
49            load_credentials_from_env: true,
50        }
51    }
52}
53
54/// Trait that all database connectors must implement
55#[async_trait::async_trait]
56pub trait DatabaseConnector: Send + Sync {
57    /// Connect to the database
58    async fn connect(&mut self) -> Result<(), DataProfilerError>;
59
60    /// Disconnect from the database
61    async fn disconnect(&mut self) -> Result<(), DataProfilerError>;
62
63    /// Execute a query and get column data for profiling
64    async fn profile_query(
65        &mut self,
66        query: &str,
67    ) -> Result<HashMap<String, Vec<String>>, DataProfilerError>;
68
69    /// Execute a query with streaming for large result sets
70    async fn profile_query_streaming(
71        &mut self,
72        query: &str,
73        batch_size: usize,
74    ) -> Result<HashMap<String, Vec<String>>, DataProfilerError>;
75
76    /// Get table schema information
77    async fn get_table_schema(
78        &mut self,
79        table_name: &str,
80    ) -> Result<Vec<String>, DataProfilerError>;
81
82    /// Count total rows in table (for progress tracking)
83    async fn count_table_rows(&mut self, table_name: &str) -> Result<u64, DataProfilerError>;
84
85    /// Test connection
86    async fn test_connection(&mut self) -> Result<bool, DataProfilerError>;
87}
88
89/// Factory function to create appropriate database connector
90pub fn create_connector(
91    mut config: DatabaseConfig,
92) -> Result<Box<dyn DatabaseConnector>, DataProfilerError> {
93    if config.load_credentials_from_env || config.connection_string.is_empty() {
94        config = apply_environment_configuration(config)?;
95    }
96
97    let connection_str = config.connection_string.as_str();
98
99    if connection_str.starts_with("postgresql://") || connection_str.starts_with("postgres://") {
100        Ok(Box::new(connectors::postgres::PostgresConnector::new(
101            config,
102        )?))
103    } else if connection_str.starts_with("mysql://") {
104        Ok(Box::new(connectors::mysql::MySqlConnector::new(config)?))
105    } else if connection_str.starts_with("sqlite://")
106        || connection_str.ends_with(".db")
107        || connection_str.ends_with(".sqlite")
108        || connection_str == ":memory:"
109    {
110        Ok(Box::new(connectors::sqlite::SqliteConnector::new(config)?))
111    } else {
112        Err(DataProfilerError::DatabaseConfigError {
113            message: format!(
114                "Unsupported database connection string: {}. Supported: postgresql://, mysql://, sqlite://",
115                connection_str
116            ),
117        })
118    }
119}
120
121/// Apply environment configuration to database config
122fn apply_environment_configuration(
123    mut config: DatabaseConfig,
124) -> Result<DatabaseConfig, DataProfilerError> {
125    let database_type = if config.connection_string.is_empty() {
126        if std::env::var("POSTGRES_URL").is_ok()
127            || std::env::var("DATABASE_URL")
128                .map(|url| url.starts_with("postgres"))
129                .unwrap_or(false)
130        {
131            "postgresql".to_string()
132        } else if std::env::var("MYSQL_URL").is_ok() {
133            "mysql".to_string()
134        } else {
135            "postgresql".to_string()
136        }
137    } else {
138        let conn_info = ConnectionInfo::parse(&config.connection_string)?;
139        conn_info.database_type().to_string()
140    };
141    let database_type = database_type.as_str();
142
143    if config.connection_string.is_empty() {
144        let (secure_connection_string, ssl_config) = load_secure_database_config(database_type)?;
145        config.connection_string = secure_connection_string;
146        config.ssl_config = Some(ssl_config);
147    } else {
148        if let Some(ssl_config) = &config.ssl_config {
149            config.connection_string = ssl_config
150                .apply_to_connection_string(config.connection_string.clone(), database_type);
151        }
152
153        if config.load_credentials_from_env {
154            let credentials = DatabaseCredentials::from_environment(database_type);
155            config.connection_string =
156                credentials.apply_to_connection_string(&config.connection_string);
157        }
158    }
159
160    Ok(config)
161}
162
163/// High-level function to analyze a database table or query.
164pub async fn analyze_database(
165    config: DatabaseConfig,
166    query: &str,
167    calculate_quality: bool,
168    quality_dimensions: Option<Vec<QualityDimension>>,
169) -> Result<ProfileReport, DataProfilerError> {
170    let mut connector = create_connector(config.clone())?;
171
172    connector.connect().await?;
173
174    let start = std::time::Instant::now();
175
176    let (actual_query, is_table) = if query.trim().to_uppercase().starts_with("SELECT") {
177        let validated_query = security::validate_base_query(query)?;
178        (validated_query, false)
179    } else {
180        security::validate_sql_identifier(query)?;
181        (format!("SELECT * FROM {}", query), true)
182    };
183
184    let total_rows = if is_table {
185        connector.count_table_rows(query).await.unwrap_or(0)
186    } else {
187        0
188    };
189
190    let (final_query, sample_info) = if let Some(sampling_config) = &config.sampling_config {
191        if total_rows > sampling_config.sample_size as u64 {
192            let sampled_query = sampling_config.generate_sample_query(&actual_query, total_rows)?;
193            let info = SampleInfo::new(
194                total_rows,
195                sampling_config.sample_size.min(total_rows as usize) as u64,
196                sampling_config.strategy.clone(),
197            );
198            (sampled_query, Some(info))
199        } else {
200            (actual_query, None)
201        }
202    } else {
203        (actual_query, None)
204    };
205
206    let columns = connector
207        .profile_query_streaming(&final_query, config.batch_size)
208        .await?;
209
210    connector.disconnect().await?;
211
212    let query_engine = detect_query_engine(&config.connection_string);
213
214    if columns.is_empty() {
215        let mut exec = ExecutionMetadata::new(0, 0, start.elapsed().as_millis());
216        if let Some(ref info) = sample_info
217            && info.sampling_ratio < 1.0
218        {
219            exec = exec
220                .with_sampling(info.sampling_ratio)
221                .with_source_exhausted(false);
222        }
223        return Ok(ReportAssembler::new(
224            DataSource::Query {
225                engine: query_engine.clone(),
226                statement: query.to_string(),
227                database: extract_database_name(&config.connection_string),
228                execution_id: None,
229            },
230            exec,
231        )
232        .skip_quality()
233        .build());
234    }
235
236    let mut column_profiles = Vec::new();
237    let actual_rows_processed = columns.values().next().map(|v| v.len()).unwrap_or(0);
238
239    for (name, data) in &columns {
240        let profile = analyze_column(name, data);
241        column_profiles.push(profile);
242    }
243
244    let scan_time_ms = start.elapsed().as_millis();
245    let sampling_ratio = sample_info.map(|s| s.sampling_ratio).unwrap_or(1.0);
246    let num_columns = column_profiles.len();
247
248    let mut execution = ExecutionMetadata::new(actual_rows_processed, num_columns, scan_time_ms);
249    if sampling_ratio < 1.0 {
250        execution = execution
251            .with_sampling(sampling_ratio)
252            .with_source_exhausted(false);
253    }
254
255    let mut assembler = ReportAssembler::new(
256        DataSource::Query {
257            engine: query_engine,
258            statement: query.to_string(),
259            database: extract_database_name(&config.connection_string),
260            execution_id: None,
261        },
262        execution,
263    )
264    .columns(column_profiles);
265
266    if calculate_quality {
267        assembler = assembler.with_quality_data(columns);
268        if let Some(dims) = quality_dimensions {
269            assembler = assembler.with_requested_dimensions(dims);
270        }
271    } else {
272        assembler = assembler.skip_quality();
273    }
274
275    Ok(assembler.build())
276}
277
278/// Detect query engine from connection string
279fn detect_query_engine(connection_string: &str) -> QueryEngine {
280    let conn = connection_string.to_lowercase();
281    if conn.starts_with("postgres") || conn.starts_with("postgresql") {
282        QueryEngine::Postgres
283    } else if conn.starts_with("mysql") || conn.starts_with("mariadb") {
284        QueryEngine::MySql
285    } else if conn.starts_with("sqlite") {
286        QueryEngine::Sqlite
287    } else {
288        QueryEngine::Custom("unknown".to_string())
289    }
290}
291
292/// Extract database name from connection string
293fn extract_database_name(connection_string: &str) -> Option<String> {
294    if let Some(pos) = connection_string.rfind('/') {
295        let db_part = &connection_string[pos + 1..];
296        let db_name = db_part.split('?').next().unwrap_or(db_part);
297        if !db_name.is_empty() {
298            return Some(db_name.to_string());
299        }
300    }
301    None
302}