Skip to main content

dataprof_db/
connection.rs

1//! Connection management and utilities for database connectors
2
3use crate::DataProfilerError;
4use url::Url;
5
6/// Parse database connection string and extract components
7#[derive(Clone)]
8pub struct ConnectionInfo {
9    pub scheme: String,
10    pub host: Option<String>,
11    pub port: Option<u16>,
12    pub username: Option<String>,
13    pub password: Option<String>,
14    pub database: Option<String>,
15    pub path: Option<String>,
16    pub query_params: std::collections::HashMap<String, String>,
17}
18
19impl std::fmt::Debug for ConnectionInfo {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        let redacted_password: Option<&'static str> = match self.password {
22            Some(_) => Some("<REDACTED>"),
23            None => None,
24        };
25        f.debug_struct("ConnectionInfo")
26            .field("scheme", &self.scheme)
27            .field("host", &self.host)
28            .field("port", &self.port)
29            .field("username", &self.username)
30            .field("password", &redacted_password)
31            .field("database", &self.database)
32            .field("path", &self.path)
33            .field("query_params", &self.query_params)
34            .finish()
35    }
36}
37
38impl ConnectionInfo {
39    /// Parse a connection string into its components
40    pub fn parse(connection_string: &str) -> Result<Self, DataProfilerError> {
41        if !connection_string.contains("://") {
42            return Ok(ConnectionInfo {
43                scheme: "file".to_string(),
44                host: None,
45                port: None,
46                username: None,
47                password: None,
48                database: None,
49                path: Some(connection_string.to_string()),
50                query_params: std::collections::HashMap::new(),
51            });
52        }
53
54        let url =
55            Url::parse(connection_string).map_err(|e| DataProfilerError::DatabaseConfigError {
56                message: format!("Invalid connection string: {}", e),
57            })?;
58
59        let mut query_params = std::collections::HashMap::new();
60        for (key, value) in url.query_pairs() {
61            query_params.insert(key.to_string(), value.to_string());
62        }
63
64        Ok(ConnectionInfo {
65            scheme: url.scheme().to_string(),
66            host: url.host_str().map(|s| s.to_string()),
67            port: url.port(),
68            username: if url.username().is_empty() {
69                None
70            } else {
71                Some(url.username().to_string())
72            },
73            password: url.password().map(|s| s.to_string()),
74            database: if url.path().len() > 1 {
75                Some(url.path().trim_start_matches('/').to_string())
76            } else {
77                None
78            },
79            path: if url.scheme() == "file" {
80                Some(url.path().to_string())
81            } else {
82                None
83            },
84            query_params,
85        })
86    }
87
88    /// Get the database type from the scheme
89    pub fn database_type(&self) -> &str {
90        match self.scheme.as_str() {
91            "postgresql" | "postgres" => "postgresql",
92            "mysql" => "mysql",
93            "sqlite" | "file" => "sqlite",
94            _ => "unknown",
95        }
96    }
97
98    /// Build a connection string for specific database libraries
99    pub fn to_connection_string(&self, target_format: &str) -> String {
100        match target_format {
101            "sqlx" => match self.scheme.as_str() {
102                "postgresql" | "postgres" => {
103                    let mut parts = vec![format!("{}://", self.scheme)];
104
105                    if let (Some(user), Some(pass)) = (&self.username, &self.password) {
106                        parts.push(format!("{}:{}@", user, pass));
107                    } else if let Some(user) = &self.username {
108                        parts.push(format!("{}@", user));
109                    }
110
111                    if let Some(host) = &self.host {
112                        parts.push(host.clone());
113                        if let Some(port) = self.port {
114                            parts.push(format!(":{}", port));
115                        }
116                    }
117
118                    if let Some(db) = &self.database {
119                        parts.push(format!("/{}", db));
120                    }
121
122                    parts.join("")
123                }
124                "mysql" => {
125                    let mut parts = vec!["mysql://".to_string()];
126
127                    if let (Some(user), Some(pass)) = (&self.username, &self.password) {
128                        parts.push(format!("{}:{}@", user, pass));
129                    } else if let Some(user) = &self.username {
130                        parts.push(format!("{}@", user));
131                    }
132
133                    if let Some(host) = &self.host {
134                        parts.push(host.clone());
135                        if let Some(port) = self.port {
136                            parts.push(format!(":{}", port));
137                        }
138                    }
139
140                    if let Some(db) = &self.database {
141                        parts.push(format!("/{}", db));
142                    }
143
144                    parts.join("")
145                }
146                "sqlite" | "file" => {
147                    if let Some(path) = &self.path {
148                        format!("sqlite://{}", path)
149                    } else {
150                        "sqlite://memory:".to_string()
151                    }
152                }
153                _ => self.to_original_string(),
154            },
155            _ => self.to_original_string(),
156        }
157    }
158
159    /// Reconstruct the original connection string
160    pub fn to_original_string(&self) -> String {
161        if let Some(path) = &self.path {
162            return path.clone();
163        }
164
165        let mut parts = vec![format!("{}://", self.scheme)];
166
167        if let (Some(user), Some(pass)) = (&self.username, &self.password) {
168            parts.push(format!("{}:{}@", user, pass));
169        } else if let Some(user) = &self.username {
170            parts.push(format!("{}@", user));
171        }
172
173        if let Some(host) = &self.host {
174            parts.push(host.clone());
175            if let Some(port) = self.port {
176                parts.push(format!(":{}", port));
177            }
178        }
179
180        if let Some(db) = &self.database {
181            parts.push(format!("/{}", db));
182        }
183
184        if !self.query_params.is_empty() {
185            let query: Vec<String> = self
186                .query_params
187                .iter()
188                .map(|(k, v)| format!("{}={}", k, v))
189                .collect();
190            parts.push(format!("?{}", query.join("&")));
191        }
192
193        parts.join("")
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_parse_postgresql_connection() {
203        let conn_str = "postgresql://user:pass@localhost:5432/mydb";
204        let info = ConnectionInfo::parse(conn_str).expect("Failed to parse connection string");
205
206        assert_eq!(info.scheme, "postgresql");
207        assert_eq!(info.host, Some("localhost".to_string()));
208        assert_eq!(info.port, Some(5432));
209        assert_eq!(info.username, Some("user".to_string()));
210        assert_eq!(info.password, Some("pass".to_string()));
211        assert_eq!(info.database, Some("mydb".to_string()));
212        assert_eq!(info.database_type(), "postgresql");
213    }
214
215    #[test]
216    fn test_parse_mysql_connection() {
217        let conn_str = "mysql://root:password@127.0.0.1:3306/testdb";
218        let info = ConnectionInfo::parse(conn_str).expect("Failed to parse connection string");
219
220        assert_eq!(info.scheme, "mysql");
221        assert_eq!(info.host, Some("127.0.0.1".to_string()));
222        assert_eq!(info.port, Some(3306));
223        assert_eq!(info.username, Some("root".to_string()));
224        assert_eq!(info.password, Some("password".to_string()));
225        assert_eq!(info.database, Some("testdb".to_string()));
226        assert_eq!(info.database_type(), "mysql");
227    }
228
229    #[test]
230    fn test_parse_sqlite_connection() {
231        let conn_str = "sqlite:///path/to/db.sqlite";
232        let info = ConnectionInfo::parse(conn_str).expect("Failed to parse connection string");
233
234        assert_eq!(info.scheme, "sqlite");
235        assert_eq!(info.database_type(), "sqlite");
236    }
237
238    #[test]
239    fn test_parse_file_path() {
240        let conn_str = "/path/to/database.db";
241        let info = ConnectionInfo::parse(conn_str).expect("Failed to parse connection string");
242
243        assert_eq!(info.scheme, "file");
244        assert_eq!(info.path, Some("/path/to/database.db".to_string()));
245        assert_eq!(info.database_type(), "sqlite");
246    }
247}