prax_query/connection/
parser.rs

1//! Connection string parser.
2
3use super::{ConnectionError, ConnectionResult};
4use std::collections::HashMap;
5use tracing::debug;
6
7/// Database driver type.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum Driver {
10    /// PostgreSQL
11    Postgres,
12    /// MySQL / MariaDB
13    MySql,
14    /// SQLite
15    Sqlite,
16}
17
18impl Driver {
19    /// Get the default port for this driver.
20    pub fn default_port(&self) -> Option<u16> {
21        match self {
22            Self::Postgres => Some(5432),
23            Self::MySql => Some(3306),
24            Self::Sqlite => None,
25        }
26    }
27
28    /// Get the driver name.
29    pub fn name(&self) -> &'static str {
30        match self {
31            Self::Postgres => "postgres",
32            Self::MySql => "mysql",
33            Self::Sqlite => "sqlite",
34        }
35    }
36
37    /// Parse driver from URL scheme.
38    pub fn from_scheme(scheme: &str) -> ConnectionResult<Self> {
39        match scheme.to_lowercase().as_str() {
40            "postgres" | "postgresql" => Ok(Self::Postgres),
41            "mysql" | "mariadb" => Ok(Self::MySql),
42            "sqlite" | "sqlite3" | "file" => Ok(Self::Sqlite),
43            other => Err(ConnectionError::UnknownDriver(other.to_string())),
44        }
45    }
46}
47
48impl std::fmt::Display for Driver {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        write!(f, "{}", self.name())
51    }
52}
53
54/// A parsed database URL.
55#[derive(Debug, Clone)]
56pub struct ParsedUrl {
57    /// Database driver.
58    pub driver: Driver,
59    /// Username (if any).
60    pub user: Option<String>,
61    /// Password (if any).
62    pub password: Option<String>,
63    /// Host (for network databases).
64    pub host: Option<String>,
65    /// Port (for network databases).
66    pub port: Option<u16>,
67    /// Database name or file path.
68    pub database: Option<String>,
69    /// Query parameters.
70    pub params: HashMap<String, String>,
71}
72
73impl ParsedUrl {
74    /// Check if this is an in-memory SQLite database.
75    pub fn is_memory(&self) -> bool {
76        self.driver == Driver::Sqlite
77            && self.database.as_ref().map_or(false, |d| {
78                d == ":memory:" || d.is_empty()
79            })
80    }
81
82    /// Get a query parameter.
83    pub fn param(&self, key: &str) -> Option<&str> {
84        self.params.get(key).map(|s| s.as_str())
85    }
86
87    /// Convert back to a URL string.
88    pub fn to_url(&self) -> String {
89        let mut url = format!("{}://", self.driver.name());
90
91        // Add credentials
92        if let Some(ref user) = self.user {
93            url.push_str(&url_encode(user));
94            if let Some(ref pass) = self.password {
95                url.push(':');
96                url.push_str(&url_encode(pass));
97            }
98            url.push('@');
99        }
100
101        // Add host/port
102        if let Some(ref host) = self.host {
103            url.push_str(host);
104            if let Some(port) = self.port {
105                url.push(':');
106                url.push_str(&port.to_string());
107            }
108        }
109
110        // Add database
111        if let Some(ref db) = self.database {
112            url.push('/');
113            url.push_str(db);
114        }
115
116        // Add query params
117        if !self.params.is_empty() {
118            url.push('?');
119            let params: Vec<_> = self
120                .params
121                .iter()
122                .map(|(k, v)| format!("{}={}", url_encode(k), url_encode(v)))
123                .collect();
124            url.push_str(&params.join("&"));
125        }
126
127        url
128    }
129}
130
131/// Connection string parser.
132#[derive(Debug, Clone)]
133pub struct ConnectionString {
134    parsed: ParsedUrl,
135    original: String,
136}
137
138impl ConnectionString {
139    /// Parse a connection URL.
140    ///
141    /// # Examples
142    ///
143    /// ```rust
144    /// use prax_query::connection::ConnectionString;
145    ///
146    /// // PostgreSQL
147    /// let conn = ConnectionString::parse("postgres://user:pass@localhost:5432/mydb").unwrap();
148    ///
149    /// // MySQL
150    /// let conn = ConnectionString::parse("mysql://user:pass@localhost/mydb").unwrap();
151    ///
152    /// // SQLite
153    /// let conn = ConnectionString::parse("sqlite://./data.db").unwrap();
154    /// let conn = ConnectionString::parse("sqlite::memory:").unwrap();
155    /// ```
156    pub fn parse(url: &str) -> ConnectionResult<Self> {
157        debug!(url_len = url.len(), "ConnectionString::parse()");
158        let original = url.to_string();
159        let parsed = parse_url(url)?;
160        debug!(driver = %parsed.driver, host = ?parsed.host, database = ?parsed.database, "Connection parsed");
161        Ok(Self { parsed, original })
162    }
163
164    /// Parse from environment variable.
165    pub fn from_env(var: &str) -> ConnectionResult<Self> {
166        let url = std::env::var(var)
167            .map_err(|_| ConnectionError::EnvNotFound(var.to_string()))?;
168        Self::parse(&url)
169    }
170
171    /// Parse from DATABASE_URL environment variable.
172    pub fn from_database_url() -> ConnectionResult<Self> {
173        Self::from_env("DATABASE_URL")
174    }
175
176    /// Get the original URL string.
177    pub fn as_str(&self) -> &str {
178        &self.original
179    }
180
181    /// Get the database driver.
182    pub fn driver(&self) -> Driver {
183        self.parsed.driver
184    }
185
186    /// Get the username.
187    pub fn user(&self) -> Option<&str> {
188        self.parsed.user.as_deref()
189    }
190
191    /// Get the password.
192    pub fn password(&self) -> Option<&str> {
193        self.parsed.password.as_deref()
194    }
195
196    /// Get the host.
197    pub fn host(&self) -> Option<&str> {
198        self.parsed.host.as_deref()
199    }
200
201    /// Get the port.
202    pub fn port(&self) -> Option<u16> {
203        self.parsed.port
204    }
205
206    /// Get the port or the default for the driver.
207    pub fn port_or_default(&self) -> Option<u16> {
208        self.parsed.port.or_else(|| self.parsed.driver.default_port())
209    }
210
211    /// Get the database name.
212    pub fn database(&self) -> Option<&str> {
213        self.parsed.database.as_deref()
214    }
215
216    /// Get a query parameter.
217    pub fn param(&self, key: &str) -> Option<&str> {
218        self.parsed.param(key)
219    }
220
221    /// Get all query parameters.
222    pub fn params(&self) -> &HashMap<String, String> {
223        &self.parsed.params
224    }
225
226    /// Get the parsed URL.
227    pub fn parsed(&self) -> &ParsedUrl {
228        &self.parsed
229    }
230
231    /// Check if this is an in-memory SQLite database.
232    pub fn is_memory(&self) -> bool {
233        self.parsed.is_memory()
234    }
235
236    /// Build a new URL with modified parameters.
237    pub fn with_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
238        self.parsed.params.insert(key.into(), value.into());
239        self.original = self.parsed.to_url();
240        self
241    }
242
243    /// Build a new URL without a specific parameter.
244    pub fn without_param(mut self, key: &str) -> Self {
245        self.parsed.params.remove(key);
246        self.original = self.parsed.to_url();
247        self
248    }
249}
250
251/// Parse a database URL into its components.
252fn parse_url(url: &str) -> ConnectionResult<ParsedUrl> {
253    // Handle SQLite memory shorthand
254    if url == "sqlite::memory:" || url == ":memory:" {
255        return Ok(ParsedUrl {
256            driver: Driver::Sqlite,
257            user: None,
258            password: None,
259            host: None,
260            port: None,
261            database: Some(":memory:".to_string()),
262            params: HashMap::new(),
263        });
264    }
265
266    // Find scheme
267    let (scheme, rest) = url
268        .split_once("://")
269        .ok_or_else(|| ConnectionError::InvalidUrl("Missing scheme (e.g., postgres://)".to_string()))?;
270
271    let driver = Driver::from_scheme(scheme)?;
272
273    // Handle SQLite specially (path-based)
274    if driver == Driver::Sqlite {
275        return parse_sqlite_url(rest);
276    }
277
278    // Parse network URL
279    parse_network_url(driver, rest)
280}
281
282fn parse_sqlite_url(rest: &str) -> ConnectionResult<ParsedUrl> {
283    // Split off query params
284    let (path, params) = parse_query_params(rest);
285
286    let database = if path.is_empty() || path == ":memory:" {
287        Some(":memory:".to_string())
288    } else {
289        Some(url_decode(&path))
290    };
291
292    Ok(ParsedUrl {
293        driver: Driver::Sqlite,
294        user: None,
295        password: None,
296        host: None,
297        port: None,
298        database,
299        params,
300    })
301}
302
303fn parse_network_url(driver: Driver, rest: &str) -> ConnectionResult<ParsedUrl> {
304    // Split off query params
305    let (main, params) = parse_query_params(rest);
306
307    // Split credentials from host
308    let (creds, host_part) = if let Some(at_pos) = main.rfind('@') {
309        (Some(&main[..at_pos]), &main[at_pos + 1..])
310    } else {
311        (None, main.as_str())
312    };
313
314    // Parse credentials
315    let (user, password) = if let Some(creds) = creds {
316        if let Some((u, p)) = creds.split_once(':') {
317            (Some(url_decode(u)), Some(url_decode(p)))
318        } else {
319            (Some(url_decode(creds)), None)
320        }
321    } else {
322        (None, None)
323    };
324
325    // Split host from database
326    let (host_port, database) = if let Some(slash_pos) = host_part.find('/') {
327        (&host_part[..slash_pos], Some(url_decode(&host_part[slash_pos + 1..])))
328    } else {
329        (host_part, None)
330    };
331
332    // Parse host and port
333    let (host, port) = if host_port.is_empty() {
334        (None, None)
335    } else if let Some(colon_pos) = host_port.rfind(':') {
336        // Check if it's IPv6 address [::1]
337        if host_port.starts_with('[') {
338            if let Some(bracket_pos) = host_port.find(']') {
339                if colon_pos > bracket_pos {
340                    // Port after IPv6 address
341                    let port = host_port[colon_pos + 1..]
342                        .parse()
343                        .map_err(|_| ConnectionError::InvalidUrl("Invalid port number".to_string()))?;
344                    (Some(host_port[..colon_pos].to_string()), Some(port))
345                } else {
346                    // No port, just IPv6 address
347                    (Some(host_port.to_string()), None)
348                }
349            } else {
350                return Err(ConnectionError::InvalidUrl("Invalid IPv6 address".to_string()));
351            }
352        } else {
353            // Regular host:port
354            let port = host_port[colon_pos + 1..]
355                .parse()
356                .map_err(|_| ConnectionError::InvalidUrl("Invalid port number".to_string()))?;
357            (Some(host_port[..colon_pos].to_string()), Some(port))
358        }
359    } else {
360        (Some(host_port.to_string()), None)
361    };
362
363    Ok(ParsedUrl {
364        driver,
365        user,
366        password,
367        host,
368        port,
369        database,
370        params,
371    })
372}
373
374fn parse_query_params(input: &str) -> (String, HashMap<String, String>) {
375    if let Some((main, query)) = input.split_once('?') {
376        let params = query
377            .split('&')
378            .filter_map(|pair| {
379                let (key, value) = pair.split_once('=')?;
380                Some((url_decode(key), url_decode(value)))
381            })
382            .collect();
383        (main.to_string(), params)
384    } else {
385        (input.to_string(), HashMap::new())
386    }
387}
388
389fn url_decode(s: &str) -> String {
390    // Simple percent decoding
391    let mut result = String::with_capacity(s.len());
392    let mut chars = s.chars().peekable();
393
394    while let Some(c) = chars.next() {
395        if c == '%' {
396            let hex: String = chars.by_ref().take(2).collect();
397            if let Ok(byte) = u8::from_str_radix(&hex, 16) {
398                result.push(byte as char);
399            } else {
400                result.push('%');
401                result.push_str(&hex);
402            }
403        } else if c == '+' {
404            result.push(' ');
405        } else {
406            result.push(c);
407        }
408    }
409
410    result
411}
412
413fn url_encode(s: &str) -> String {
414    let mut result = String::with_capacity(s.len() * 3);
415    for c in s.chars() {
416        match c {
417            'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => result.push(c),
418            _ => {
419                for byte in c.to_string().bytes() {
420                    result.push_str(&format!("%{:02X}", byte));
421                }
422            }
423        }
424    }
425    result
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn test_parse_postgres_full() {
434        let conn = ConnectionString::parse("postgres://user:pass@localhost:5432/mydb").unwrap();
435        assert_eq!(conn.driver(), Driver::Postgres);
436        assert_eq!(conn.user(), Some("user"));
437        assert_eq!(conn.password(), Some("pass"));
438        assert_eq!(conn.host(), Some("localhost"));
439        assert_eq!(conn.port(), Some(5432));
440        assert_eq!(conn.database(), Some("mydb"));
441    }
442
443    #[test]
444    fn test_parse_postgres_with_params() {
445        let conn = ConnectionString::parse(
446            "postgres://user:pass@localhost/mydb?sslmode=require&connect_timeout=10"
447        ).unwrap();
448        assert_eq!(conn.param("sslmode"), Some("require"));
449        assert_eq!(conn.param("connect_timeout"), Some("10"));
450    }
451
452    #[test]
453    fn test_parse_postgres_no_password() {
454        let conn = ConnectionString::parse("postgres://user@localhost/mydb").unwrap();
455        assert_eq!(conn.user(), Some("user"));
456        assert_eq!(conn.password(), None);
457    }
458
459    #[test]
460    fn test_parse_mysql() {
461        let conn = ConnectionString::parse("mysql://root:secret@127.0.0.1:3306/testdb").unwrap();
462        assert_eq!(conn.driver(), Driver::MySql);
463        assert_eq!(conn.host(), Some("127.0.0.1"));
464        assert_eq!(conn.port(), Some(3306));
465    }
466
467    #[test]
468    fn test_parse_mariadb() {
469        let conn = ConnectionString::parse("mariadb://user:pass@localhost/db").unwrap();
470        assert_eq!(conn.driver(), Driver::MySql);
471    }
472
473    #[test]
474    fn test_parse_sqlite_file() {
475        let conn = ConnectionString::parse("sqlite://./data/app.db").unwrap();
476        assert_eq!(conn.driver(), Driver::Sqlite);
477        assert_eq!(conn.database(), Some("./data/app.db"));
478    }
479
480    #[test]
481    fn test_parse_sqlite_memory() {
482        let conn = ConnectionString::parse("sqlite::memory:").unwrap();
483        assert_eq!(conn.driver(), Driver::Sqlite);
484        assert!(conn.is_memory());
485
486        let conn = ConnectionString::parse("sqlite://:memory:").unwrap();
487        assert!(conn.is_memory());
488    }
489
490    #[test]
491    fn test_parse_special_characters() {
492        let conn = ConnectionString::parse("postgres://user:p%40ss%3Aword@localhost/db").unwrap();
493        assert_eq!(conn.password(), Some("p@ss:word"));
494    }
495
496    #[test]
497    fn test_default_port() {
498        assert_eq!(Driver::Postgres.default_port(), Some(5432));
499        assert_eq!(Driver::MySql.default_port(), Some(3306));
500        assert_eq!(Driver::Sqlite.default_port(), None);
501    }
502
503    #[test]
504    fn test_port_or_default() {
505        let conn = ConnectionString::parse("postgres://localhost/db").unwrap();
506        assert_eq!(conn.port(), None);
507        assert_eq!(conn.port_or_default(), Some(5432));
508    }
509
510    #[test]
511    fn test_with_param() {
512        let conn = ConnectionString::parse("postgres://localhost/db").unwrap();
513        let conn = conn.with_param("sslmode", "require");
514        assert_eq!(conn.param("sslmode"), Some("require"));
515    }
516
517    #[test]
518    fn test_to_url_roundtrip() {
519        let original = "postgres://user:pass@localhost:5432/mydb?sslmode=require";
520        let conn = ConnectionString::parse(original).unwrap();
521        let rebuilt = conn.parsed().to_url();
522        assert!(rebuilt.contains("postgres://"));
523        assert!(rebuilt.contains("localhost:5432"));
524        assert!(rebuilt.contains("sslmode=require"));
525    }
526
527    #[test]
528    fn test_invalid_url() {
529        assert!(ConnectionString::parse("not-a-url").is_err());
530        assert!(ConnectionString::parse("unknown://localhost").is_err());
531    }
532}
533
534