Skip to main content

dataprof_db/connectors/
common.rs

1//! Common utilities and shared logic for database connectors
2//!
3//! This module provides reusable functions to reduce code duplication across
4//! PostgreSQL, MySQL, and SQLite connectors.
5
6use crate::DataProfilerError;
7use crate::security::{validate_base_query, validate_sql_identifier};
8
9/// Generate "not connected to database" error
10#[allow(dead_code)]
11pub fn not_connected_error() -> DataProfilerError {
12    DataProfilerError::database_connection("Not connected to database")
13}
14
15/// Generate feature-not-enabled error for a specific database
16#[allow(dead_code)]
17pub fn feature_not_enabled_error(db_name: &str, feature: &str) -> DataProfilerError {
18    DataProfilerError::database_feature_disabled(db_name, feature)
19}
20
21/// Macro to generate the streaming batch loop for profiling queries.
22#[macro_export]
23macro_rules! streaming_profile_loop {
24    ($pool:expr, $query:expr, $batch_size:expr, $total_rows:expr, $db_name:literal) => {{
25        use sqlx::{Column, Row};
26        use $crate::connectors::common::build_batch_query;
27        use $crate::streaming::{StreamingProgress, merge_column_batches};
28
29        let mut progress = StreamingProgress::new(Some($total_rows as u64));
30        let mut all_batches: Vec<std::collections::HashMap<String, Vec<String>>> = Vec::new();
31        let mut offset = 0usize;
32
33        loop {
34            let batch_query = build_batch_query($query, $batch_size, offset)?;
35            let rows = sqlx::query(&batch_query)
36                .fetch_all($pool)
37                .await
38                .map_err(|e| $crate::DataProfilerError::DatabaseQueryError {
39                    message: format!("Batch query execution failed: {}", e),
40                })?;
41
42            if rows.is_empty() {
43                break;
44            }
45
46            let columns = rows[0].columns();
47            let mut batch_result: std::collections::HashMap<String, Vec<String>> =
48                std::collections::HashMap::with_capacity(columns.len());
49
50            for col in columns {
51                batch_result.insert(col.name().to_string(), Vec::with_capacity(rows.len()));
52            }
53
54            for row in &rows {
55                for (i, col) in columns.iter().enumerate() {
56                    let value: Option<String> = row.try_get(i).ok();
57                    if let Some(column_data) = batch_result.get_mut(col.name()) {
58                        column_data.push(value.unwrap_or_default());
59                    }
60                }
61            }
62
63            let batch_size_actual = rows.len();
64            all_batches.push(batch_result);
65            progress.update(batch_size_actual as u64);
66
67            if let Some(percentage) = progress.percentage() {
68                log::info!(
69                    "{} streaming progress: {:.1}% ({}/{} rows)",
70                    $db_name,
71                    percentage,
72                    progress.processed_rows,
73                    $total_rows
74                );
75            }
76
77            offset += $batch_size;
78            if batch_size_actual < $batch_size {
79                break;
80            }
81        }
82
83        merge_column_batches(all_batches)
84    }};
85}
86
87/// Macro to process rows into column-oriented HashMap.
88#[macro_export]
89macro_rules! process_rows_to_columns {
90    ($rows:expr) => {{
91        use sqlx::{Column, Row};
92
93        if $rows.is_empty() {
94            std::collections::HashMap::new()
95        } else {
96            let columns = $rows[0].columns();
97            let mut result: std::collections::HashMap<String, Vec<String>> =
98                std::collections::HashMap::with_capacity(columns.len());
99
100            for col in columns {
101                result.insert(col.name().to_string(), Vec::with_capacity($rows.len()));
102            }
103
104            for row in &$rows {
105                for (i, col) in columns.iter().enumerate() {
106                    let value: Option<String> = row.try_get(i).ok();
107                    if let Some(column_data) = result.get_mut(col.name()) {
108                        column_data.push(value.unwrap_or_default());
109                    }
110                }
111            }
112
113            result
114        }
115    }};
116}
117
118/// Build a count query for a given table or query
119#[allow(dead_code)]
120pub fn build_count_query(query: &str) -> Result<String, DataProfilerError> {
121    if query.trim().to_uppercase().starts_with("SELECT") {
122        let validated_query = validate_base_query(query)?;
123        Ok(format!(
124            "SELECT COUNT(*) FROM ({}) as count_subquery",
125            validated_query
126        ))
127    } else {
128        validate_sql_identifier(query)?;
129        Ok(format!("SELECT COUNT(*) FROM {}", query))
130    }
131}
132
133/// Build a batch query with LIMIT and OFFSET
134#[allow(dead_code)]
135pub fn build_batch_query(
136    query: &str,
137    batch_size: usize,
138    offset: usize,
139) -> Result<String, DataProfilerError> {
140    let validated_query = if query.trim().to_uppercase().starts_with("SELECT") {
141        validate_base_query(query)?
142    } else {
143        validate_sql_identifier(query)?;
144        format!("SELECT * FROM {}", query)
145    };
146    Ok(format!(
147        "{} LIMIT {} OFFSET {}",
148        validated_query, batch_size, offset
149    ))
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_build_count_query_table() {
158        let result = build_count_query("users").unwrap();
159        assert_eq!(result, "SELECT COUNT(*) FROM users");
160    }
161
162    #[test]
163    fn test_build_count_query_select() {
164        let result = build_count_query("SELECT * FROM users WHERE active = true").unwrap();
165        assert!(result.contains("SELECT COUNT(*) FROM"));
166        assert!(result.contains("count_subquery"));
167    }
168
169    #[test]
170    fn test_build_batch_query() {
171        let result = build_batch_query("users", 100, 0).unwrap();
172        assert_eq!(result, "SELECT * FROM users LIMIT 100 OFFSET 0");
173    }
174}