prax_query/connection/
parser.rs

1//! Connection string parser.
2
3use super::{ConnectionError, ConnectionResult};
4use std::collections::HashMap;
5use tracing::debug;
6
7/// Database driver type.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum Driver {
10    /// PostgreSQL
11    Postgres,
12    /// MySQL / MariaDB
13    MySql,
14    /// SQLite
15    Sqlite,
16}
17
18impl Driver {
19    /// Get the default port for this driver.
20    pub fn default_port(&self) -> Option<u16> {
21        match self {
22            Self::Postgres => Some(5432),
23            Self::MySql => Some(3306),
24            Self::Sqlite => None,
25        }
26    }
27
28    /// Get the driver name.
29    pub fn name(&self) -> &'static str {
30        match self {
31            Self::Postgres => "postgres",
32            Self::MySql => "mysql",
33            Self::Sqlite => "sqlite",
34        }
35    }
36
37    /// Parse driver from URL scheme.
38    pub fn from_scheme(scheme: &str) -> ConnectionResult<Self> {
39        match scheme.to_lowercase().as_str() {
40            "postgres" | "postgresql" => Ok(Self::Postgres),
41            "mysql" | "mariadb" => Ok(Self::MySql),
42            "sqlite" | "sqlite3" | "file" => Ok(Self::Sqlite),
43            other => Err(ConnectionError::UnknownDriver(other.to_string())),
44        }
45    }
46}
47
48impl std::fmt::Display for Driver {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        write!(f, "{}", self.name())
51    }
52}
53
54/// A parsed database URL.
55#[derive(Debug, Clone)]
56pub struct ParsedUrl {
57    /// Database driver.
58    pub driver: Driver,
59    /// Username (if any).
60    pub user: Option<String>,
61    /// Password (if any).
62    pub password: Option<String>,
63    /// Host (for network databases).
64    pub host: Option<String>,
65    /// Port (for network databases).
66    pub port: Option<u16>,
67    /// Database name or file path.
68    pub database: Option<String>,
69    /// Query parameters.
70    pub params: HashMap<String, String>,
71}
72
73impl ParsedUrl {
74    /// Check if this is an in-memory SQLite database.
75    pub fn is_memory(&self) -> bool {
76        self.driver == Driver::Sqlite
77            && self
78                .database
79                .as_ref()
80                .is_some_and(|d| d == ":memory:" || d.is_empty())
81    }
82
83    /// Get a query parameter.
84    pub fn param(&self, key: &str) -> Option<&str> {
85        self.params.get(key).map(|s| s.as_str())
86    }
87
88    /// Convert back to a URL string.
89    pub fn to_url(&self) -> String {
90        let mut url = format!("{}://", self.driver.name());
91
92        // Add credentials
93        if let Some(ref user) = self.user {
94            url.push_str(&url_encode(user));
95            if let Some(ref pass) = self.password {
96                url.push(':');
97                url.push_str(&url_encode(pass));
98            }
99            url.push('@');
100        }
101
102        // Add host/port
103        if let Some(ref host) = self.host {
104            url.push_str(host);
105            if let Some(port) = self.port {
106                url.push(':');
107                url.push_str(&port.to_string());
108            }
109        }
110
111        // Add database
112        if let Some(ref db) = self.database {
113            url.push('/');
114            url.push_str(db);
115        }
116
117        // Add query params
118        if !self.params.is_empty() {
119            url.push('?');
120            let params: Vec<_> = self
121                .params
122                .iter()
123                .map(|(k, v)| format!("{}={}", url_encode(k), url_encode(v)))
124                .collect();
125            url.push_str(&params.join("&"));
126        }
127
128        url
129    }
130}
131
132/// Connection string parser.
133#[derive(Debug, Clone)]
134pub struct ConnectionString {
135    parsed: ParsedUrl,
136    original: String,
137}
138
139impl ConnectionString {
140    /// Parse a connection URL.
141    ///
142    /// # Examples
143    ///
144    /// ```rust
145    /// use prax_query::connection::ConnectionString;
146    ///
147    /// // PostgreSQL
148    /// let conn = ConnectionString::parse("postgres://user:pass@localhost:5432/mydb").unwrap();
149    ///
150    /// // MySQL
151    /// let conn = ConnectionString::parse("mysql://user:pass@localhost/mydb").unwrap();
152    ///
153    /// // SQLite
154    /// let conn = ConnectionString::parse("sqlite://./data.db").unwrap();
155    /// let conn = ConnectionString::parse("sqlite::memory:").unwrap();
156    /// ```
157    pub fn parse(url: &str) -> ConnectionResult<Self> {
158        debug!(url_len = url.len(), "ConnectionString::parse()");
159        let original = url.to_string();
160        let parsed = parse_url(url)?;
161        debug!(driver = %parsed.driver, host = ?parsed.host, database = ?parsed.database, "Connection parsed");
162        Ok(Self { parsed, original })
163    }
164
165    /// Parse from environment variable.
166    pub fn from_env(var: &str) -> ConnectionResult<Self> {
167        let url = std::env::var(var).map_err(|_| ConnectionError::EnvNotFound(var.to_string()))?;
168        Self::parse(&url)
169    }
170
171    /// Parse from DATABASE_URL environment variable.
172    pub fn from_database_url() -> ConnectionResult<Self> {
173        Self::from_env("DATABASE_URL")
174    }
175
176    /// Get the original URL string.
177    pub fn as_str(&self) -> &str {
178        &self.original
179    }
180
181    /// Get the database driver.
182    pub fn driver(&self) -> Driver {
183        self.parsed.driver
184    }
185
186    /// Get the username.
187    pub fn user(&self) -> Option<&str> {
188        self.parsed.user.as_deref()
189    }
190
191    /// Get the password.
192    pub fn password(&self) -> Option<&str> {
193        self.parsed.password.as_deref()
194    }
195
196    /// Get the host.
197    pub fn host(&self) -> Option<&str> {
198        self.parsed.host.as_deref()
199    }
200
201    /// Get the port.
202    pub fn port(&self) -> Option<u16> {
203        self.parsed.port
204    }
205
206    /// Get the port or the default for the driver.
207    pub fn port_or_default(&self) -> Option<u16> {
208        self.parsed
209            .port
210            .or_else(|| self.parsed.driver.default_port())
211    }
212
213    /// Get the database name.
214    pub fn database(&self) -> Option<&str> {
215        self.parsed.database.as_deref()
216    }
217
218    /// Get a query parameter.
219    pub fn param(&self, key: &str) -> Option<&str> {
220        self.parsed.param(key)
221    }
222
223    /// Get all query parameters.
224    pub fn params(&self) -> &HashMap<String, String> {
225        &self.parsed.params
226    }
227
228    /// Get the parsed URL.
229    pub fn parsed(&self) -> &ParsedUrl {
230        &self.parsed
231    }
232
233    /// Check if this is an in-memory SQLite database.
234    pub fn is_memory(&self) -> bool {
235        self.parsed.is_memory()
236    }
237
238    /// Build a new URL with modified parameters.
239    pub fn with_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
240        self.parsed.params.insert(key.into(), value.into());
241        self.original = self.parsed.to_url();
242        self
243    }
244
245    /// Build a new URL without a specific parameter.
246    pub fn without_param(mut self, key: &str) -> Self {
247        self.parsed.params.remove(key);
248        self.original = self.parsed.to_url();
249        self
250    }
251}
252
253/// Parse a database URL into its components.
254fn parse_url(url: &str) -> ConnectionResult<ParsedUrl> {
255    // Handle SQLite memory shorthand
256    if url == "sqlite::memory:" || url == ":memory:" {
257        return Ok(ParsedUrl {
258            driver: Driver::Sqlite,
259            user: None,
260            password: None,
261            host: None,
262            port: None,
263            database: Some(":memory:".to_string()),
264            params: HashMap::new(),
265        });
266    }
267
268    // Find scheme
269    let (scheme, rest) = url.split_once("://").ok_or_else(|| {
270        ConnectionError::InvalidUrl("Missing scheme (e.g., postgres://)".to_string())
271    })?;
272
273    let driver = Driver::from_scheme(scheme)?;
274
275    // Handle SQLite specially (path-based)
276    if driver == Driver::Sqlite {
277        return parse_sqlite_url(rest);
278    }
279
280    // Parse network URL
281    parse_network_url(driver, rest)
282}
283
284fn parse_sqlite_url(rest: &str) -> ConnectionResult<ParsedUrl> {
285    // Split off query params
286    let (path, params) = parse_query_params(rest);
287
288    let database = if path.is_empty() || path == ":memory:" {
289        Some(":memory:".to_string())
290    } else {
291        Some(url_decode(&path))
292    };
293
294    Ok(ParsedUrl {
295        driver: Driver::Sqlite,
296        user: None,
297        password: None,
298        host: None,
299        port: None,
300        database,
301        params,
302    })
303}
304
305fn parse_network_url(driver: Driver, rest: &str) -> ConnectionResult<ParsedUrl> {
306    // Split off query params
307    let (main, params) = parse_query_params(rest);
308
309    // Split credentials from host
310    let (creds, host_part) = if let Some(at_pos) = main.rfind('@') {
311        (Some(&main[..at_pos]), &main[at_pos + 1..])
312    } else {
313        (None, main.as_str())
314    };
315
316    // Parse credentials
317    let (user, password) = if let Some(creds) = creds {
318        if let Some((u, p)) = creds.split_once(':') {
319            (Some(url_decode(u)), Some(url_decode(p)))
320        } else {
321            (Some(url_decode(creds)), None)
322        }
323    } else {
324        (None, None)
325    };
326
327    // Split host from database
328    let (host_port, database) = if let Some(slash_pos) = host_part.find('/') {
329        (
330            &host_part[..slash_pos],
331            Some(url_decode(&host_part[slash_pos + 1..])),
332        )
333    } else {
334        (host_part, None)
335    };
336
337    // Parse host and port
338    let (host, port) = if host_port.is_empty() {
339        (None, None)
340    } else if let Some(colon_pos) = host_port.rfind(':') {
341        // Check if it's IPv6 address [::1]
342        if host_port.starts_with('[') {
343            if let Some(bracket_pos) = host_port.find(']') {
344                if colon_pos > bracket_pos {
345                    // Port after IPv6 address
346                    let port = host_port[colon_pos + 1..].parse().map_err(|_| {
347                        ConnectionError::InvalidUrl("Invalid port number".to_string())
348                    })?;
349                    (Some(host_port[..colon_pos].to_string()), Some(port))
350                } else {
351                    // No port, just IPv6 address
352                    (Some(host_port.to_string()), None)
353                }
354            } else {
355                return Err(ConnectionError::InvalidUrl(
356                    "Invalid IPv6 address".to_string(),
357                ));
358            }
359        } else {
360            // Regular host:port
361            let port = host_port[colon_pos + 1..]
362                .parse()
363                .map_err(|_| ConnectionError::InvalidUrl("Invalid port number".to_string()))?;
364            (Some(host_port[..colon_pos].to_string()), Some(port))
365        }
366    } else {
367        (Some(host_port.to_string()), None)
368    };
369
370    Ok(ParsedUrl {
371        driver,
372        user,
373        password,
374        host,
375        port,
376        database,
377        params,
378    })
379}
380
381fn parse_query_params(input: &str) -> (String, HashMap<String, String>) {
382    if let Some((main, query)) = input.split_once('?') {
383        let params = query
384            .split('&')
385            .filter_map(|pair| {
386                let (key, value) = pair.split_once('=')?;
387                Some((url_decode(key), url_decode(value)))
388            })
389            .collect();
390        (main.to_string(), params)
391    } else {
392        (input.to_string(), HashMap::new())
393    }
394}
395
396fn url_decode(s: &str) -> String {
397    // Simple percent decoding
398    let mut result = String::with_capacity(s.len());
399    let mut chars = s.chars().peekable();
400
401    while let Some(c) = chars.next() {
402        if c == '%' {
403            let hex: String = chars.by_ref().take(2).collect();
404            if let Ok(byte) = u8::from_str_radix(&hex, 16) {
405                result.push(byte as char);
406            } else {
407                result.push('%');
408                result.push_str(&hex);
409            }
410        } else if c == '+' {
411            result.push(' ');
412        } else {
413            result.push(c);
414        }
415    }
416
417    result
418}
419
420fn url_encode(s: &str) -> String {
421    let mut result = String::with_capacity(s.len() * 3);
422    for c in s.chars() {
423        match c {
424            'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => result.push(c),
425            _ => {
426                for byte in c.to_string().bytes() {
427                    result.push_str(&format!("%{:02X}", byte));
428                }
429            }
430        }
431    }
432    result
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_parse_postgres_full() {
441        let conn = ConnectionString::parse("postgres://user:pass@localhost:5432/mydb").unwrap();
442        assert_eq!(conn.driver(), Driver::Postgres);
443        assert_eq!(conn.user(), Some("user"));
444        assert_eq!(conn.password(), Some("pass"));
445        assert_eq!(conn.host(), Some("localhost"));
446        assert_eq!(conn.port(), Some(5432));
447        assert_eq!(conn.database(), Some("mydb"));
448    }
449
450    #[test]
451    fn test_parse_postgres_with_params() {
452        let conn = ConnectionString::parse(
453            "postgres://user:pass@localhost/mydb?sslmode=require&connect_timeout=10",
454        )
455        .unwrap();
456        assert_eq!(conn.param("sslmode"), Some("require"));
457        assert_eq!(conn.param("connect_timeout"), Some("10"));
458    }
459
460    #[test]
461    fn test_parse_postgres_no_password() {
462        let conn = ConnectionString::parse("postgres://user@localhost/mydb").unwrap();
463        assert_eq!(conn.user(), Some("user"));
464        assert_eq!(conn.password(), None);
465    }
466
467    #[test]
468    fn test_parse_mysql() {
469        let conn = ConnectionString::parse("mysql://root:secret@127.0.0.1:3306/testdb").unwrap();
470        assert_eq!(conn.driver(), Driver::MySql);
471        assert_eq!(conn.host(), Some("127.0.0.1"));
472        assert_eq!(conn.port(), Some(3306));
473    }
474
475    #[test]
476    fn test_parse_mariadb() {
477        let conn = ConnectionString::parse("mariadb://user:pass@localhost/db").unwrap();
478        assert_eq!(conn.driver(), Driver::MySql);
479    }
480
481    #[test]
482    fn test_parse_sqlite_file() {
483        let conn = ConnectionString::parse("sqlite://./data/app.db").unwrap();
484        assert_eq!(conn.driver(), Driver::Sqlite);
485        assert_eq!(conn.database(), Some("./data/app.db"));
486    }
487
488    #[test]
489    fn test_parse_sqlite_memory() {
490        let conn = ConnectionString::parse("sqlite::memory:").unwrap();
491        assert_eq!(conn.driver(), Driver::Sqlite);
492        assert!(conn.is_memory());
493
494        let conn = ConnectionString::parse("sqlite://:memory:").unwrap();
495        assert!(conn.is_memory());
496    }
497
498    #[test]
499    fn test_parse_special_characters() {
500        let conn = ConnectionString::parse("postgres://user:p%40ss%3Aword@localhost/db").unwrap();
501        assert_eq!(conn.password(), Some("p@ss:word"));
502    }
503
504    #[test]
505    fn test_default_port() {
506        assert_eq!(Driver::Postgres.default_port(), Some(5432));
507        assert_eq!(Driver::MySql.default_port(), Some(3306));
508        assert_eq!(Driver::Sqlite.default_port(), None);
509    }
510
511    #[test]
512    fn test_port_or_default() {
513        let conn = ConnectionString::parse("postgres://localhost/db").unwrap();
514        assert_eq!(conn.port(), None);
515        assert_eq!(conn.port_or_default(), Some(5432));
516    }
517
518    #[test]
519    fn test_with_param() {
520        let conn = ConnectionString::parse("postgres://localhost/db").unwrap();
521        let conn = conn.with_param("sslmode", "require");
522        assert_eq!(conn.param("sslmode"), Some("require"));
523    }
524
525    #[test]
526    fn test_to_url_roundtrip() {
527        let original = "postgres://user:pass@localhost:5432/mydb?sslmode=require";
528        let conn = ConnectionString::parse(original).unwrap();
529        let rebuilt = conn.parsed().to_url();
530        assert!(rebuilt.contains("postgres://"));
531        assert!(rebuilt.contains("localhost:5432"));
532        assert!(rebuilt.contains("sslmode=require"));
533    }
534
535    #[test]
536    fn test_invalid_url() {
537        assert!(ConnectionString::parse("not-a-url").is_err());
538        assert!(ConnectionString::parse("unknown://localhost").is_err());
539    }
540}