1use super::{ConnectionError, ConnectionResult};
4use std::collections::HashMap;
5use tracing::debug;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum Driver {
10 Postgres,
12 MySql,
14 Sqlite,
16}
17
18impl Driver {
19 pub fn default_port(&self) -> Option<u16> {
21 match self {
22 Self::Postgres => Some(5432),
23 Self::MySql => Some(3306),
24 Self::Sqlite => None,
25 }
26 }
27
28 pub fn name(&self) -> &'static str {
30 match self {
31 Self::Postgres => "postgres",
32 Self::MySql => "mysql",
33 Self::Sqlite => "sqlite",
34 }
35 }
36
37 pub fn from_scheme(scheme: &str) -> ConnectionResult<Self> {
39 match scheme.to_lowercase().as_str() {
40 "postgres" | "postgresql" => Ok(Self::Postgres),
41 "mysql" | "mariadb" => Ok(Self::MySql),
42 "sqlite" | "sqlite3" | "file" => Ok(Self::Sqlite),
43 other => Err(ConnectionError::UnknownDriver(other.to_string())),
44 }
45 }
46}
47
48impl std::fmt::Display for Driver {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(f, "{}", self.name())
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct ParsedUrl {
57 pub driver: Driver,
59 pub user: Option<String>,
61 pub password: Option<String>,
63 pub host: Option<String>,
65 pub port: Option<u16>,
67 pub database: Option<String>,
69 pub params: HashMap<String, String>,
71}
72
73impl ParsedUrl {
74 pub fn is_memory(&self) -> bool {
76 self.driver == Driver::Sqlite
77 && self.database.as_ref().map_or(false, |d| {
78 d == ":memory:" || d.is_empty()
79 })
80 }
81
82 pub fn param(&self, key: &str) -> Option<&str> {
84 self.params.get(key).map(|s| s.as_str())
85 }
86
87 pub fn to_url(&self) -> String {
89 let mut url = format!("{}://", self.driver.name());
90
91 if let Some(ref user) = self.user {
93 url.push_str(&url_encode(user));
94 if let Some(ref pass) = self.password {
95 url.push(':');
96 url.push_str(&url_encode(pass));
97 }
98 url.push('@');
99 }
100
101 if let Some(ref host) = self.host {
103 url.push_str(host);
104 if let Some(port) = self.port {
105 url.push(':');
106 url.push_str(&port.to_string());
107 }
108 }
109
110 if let Some(ref db) = self.database {
112 url.push('/');
113 url.push_str(db);
114 }
115
116 if !self.params.is_empty() {
118 url.push('?');
119 let params: Vec<_> = self
120 .params
121 .iter()
122 .map(|(k, v)| format!("{}={}", url_encode(k), url_encode(v)))
123 .collect();
124 url.push_str(¶ms.join("&"));
125 }
126
127 url
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct ConnectionString {
134 parsed: ParsedUrl,
135 original: String,
136}
137
138impl ConnectionString {
139 pub fn parse(url: &str) -> ConnectionResult<Self> {
157 debug!(url_len = url.len(), "ConnectionString::parse()");
158 let original = url.to_string();
159 let parsed = parse_url(url)?;
160 debug!(driver = %parsed.driver, host = ?parsed.host, database = ?parsed.database, "Connection parsed");
161 Ok(Self { parsed, original })
162 }
163
164 pub fn from_env(var: &str) -> ConnectionResult<Self> {
166 let url = std::env::var(var)
167 .map_err(|_| ConnectionError::EnvNotFound(var.to_string()))?;
168 Self::parse(&url)
169 }
170
171 pub fn from_database_url() -> ConnectionResult<Self> {
173 Self::from_env("DATABASE_URL")
174 }
175
176 pub fn as_str(&self) -> &str {
178 &self.original
179 }
180
181 pub fn driver(&self) -> Driver {
183 self.parsed.driver
184 }
185
186 pub fn user(&self) -> Option<&str> {
188 self.parsed.user.as_deref()
189 }
190
191 pub fn password(&self) -> Option<&str> {
193 self.parsed.password.as_deref()
194 }
195
196 pub fn host(&self) -> Option<&str> {
198 self.parsed.host.as_deref()
199 }
200
201 pub fn port(&self) -> Option<u16> {
203 self.parsed.port
204 }
205
206 pub fn port_or_default(&self) -> Option<u16> {
208 self.parsed.port.or_else(|| self.parsed.driver.default_port())
209 }
210
211 pub fn database(&self) -> Option<&str> {
213 self.parsed.database.as_deref()
214 }
215
216 pub fn param(&self, key: &str) -> Option<&str> {
218 self.parsed.param(key)
219 }
220
221 pub fn params(&self) -> &HashMap<String, String> {
223 &self.parsed.params
224 }
225
226 pub fn parsed(&self) -> &ParsedUrl {
228 &self.parsed
229 }
230
231 pub fn is_memory(&self) -> bool {
233 self.parsed.is_memory()
234 }
235
236 pub fn with_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
238 self.parsed.params.insert(key.into(), value.into());
239 self.original = self.parsed.to_url();
240 self
241 }
242
243 pub fn without_param(mut self, key: &str) -> Self {
245 self.parsed.params.remove(key);
246 self.original = self.parsed.to_url();
247 self
248 }
249}
250
251fn parse_url(url: &str) -> ConnectionResult<ParsedUrl> {
253 if url == "sqlite::memory:" || url == ":memory:" {
255 return Ok(ParsedUrl {
256 driver: Driver::Sqlite,
257 user: None,
258 password: None,
259 host: None,
260 port: None,
261 database: Some(":memory:".to_string()),
262 params: HashMap::new(),
263 });
264 }
265
266 let (scheme, rest) = url
268 .split_once("://")
269 .ok_or_else(|| ConnectionError::InvalidUrl("Missing scheme (e.g., postgres://)".to_string()))?;
270
271 let driver = Driver::from_scheme(scheme)?;
272
273 if driver == Driver::Sqlite {
275 return parse_sqlite_url(rest);
276 }
277
278 parse_network_url(driver, rest)
280}
281
282fn parse_sqlite_url(rest: &str) -> ConnectionResult<ParsedUrl> {
283 let (path, params) = parse_query_params(rest);
285
286 let database = if path.is_empty() || path == ":memory:" {
287 Some(":memory:".to_string())
288 } else {
289 Some(url_decode(&path))
290 };
291
292 Ok(ParsedUrl {
293 driver: Driver::Sqlite,
294 user: None,
295 password: None,
296 host: None,
297 port: None,
298 database,
299 params,
300 })
301}
302
303fn parse_network_url(driver: Driver, rest: &str) -> ConnectionResult<ParsedUrl> {
304 let (main, params) = parse_query_params(rest);
306
307 let (creds, host_part) = if let Some(at_pos) = main.rfind('@') {
309 (Some(&main[..at_pos]), &main[at_pos + 1..])
310 } else {
311 (None, main.as_str())
312 };
313
314 let (user, password) = if let Some(creds) = creds {
316 if let Some((u, p)) = creds.split_once(':') {
317 (Some(url_decode(u)), Some(url_decode(p)))
318 } else {
319 (Some(url_decode(creds)), None)
320 }
321 } else {
322 (None, None)
323 };
324
325 let (host_port, database) = if let Some(slash_pos) = host_part.find('/') {
327 (&host_part[..slash_pos], Some(url_decode(&host_part[slash_pos + 1..])))
328 } else {
329 (host_part, None)
330 };
331
332 let (host, port) = if host_port.is_empty() {
334 (None, None)
335 } else if let Some(colon_pos) = host_port.rfind(':') {
336 if host_port.starts_with('[') {
338 if let Some(bracket_pos) = host_port.find(']') {
339 if colon_pos > bracket_pos {
340 let port = host_port[colon_pos + 1..]
342 .parse()
343 .map_err(|_| ConnectionError::InvalidUrl("Invalid port number".to_string()))?;
344 (Some(host_port[..colon_pos].to_string()), Some(port))
345 } else {
346 (Some(host_port.to_string()), None)
348 }
349 } else {
350 return Err(ConnectionError::InvalidUrl("Invalid IPv6 address".to_string()));
351 }
352 } else {
353 let port = host_port[colon_pos + 1..]
355 .parse()
356 .map_err(|_| ConnectionError::InvalidUrl("Invalid port number".to_string()))?;
357 (Some(host_port[..colon_pos].to_string()), Some(port))
358 }
359 } else {
360 (Some(host_port.to_string()), None)
361 };
362
363 Ok(ParsedUrl {
364 driver,
365 user,
366 password,
367 host,
368 port,
369 database,
370 params,
371 })
372}
373
374fn parse_query_params(input: &str) -> (String, HashMap<String, String>) {
375 if let Some((main, query)) = input.split_once('?') {
376 let params = query
377 .split('&')
378 .filter_map(|pair| {
379 let (key, value) = pair.split_once('=')?;
380 Some((url_decode(key), url_decode(value)))
381 })
382 .collect();
383 (main.to_string(), params)
384 } else {
385 (input.to_string(), HashMap::new())
386 }
387}
388
389fn url_decode(s: &str) -> String {
390 let mut result = String::with_capacity(s.len());
392 let mut chars = s.chars().peekable();
393
394 while let Some(c) = chars.next() {
395 if c == '%' {
396 let hex: String = chars.by_ref().take(2).collect();
397 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
398 result.push(byte as char);
399 } else {
400 result.push('%');
401 result.push_str(&hex);
402 }
403 } else if c == '+' {
404 result.push(' ');
405 } else {
406 result.push(c);
407 }
408 }
409
410 result
411}
412
413fn url_encode(s: &str) -> String {
414 let mut result = String::with_capacity(s.len() * 3);
415 for c in s.chars() {
416 match c {
417 'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => result.push(c),
418 _ => {
419 for byte in c.to_string().bytes() {
420 result.push_str(&format!("%{:02X}", byte));
421 }
422 }
423 }
424 }
425 result
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn test_parse_postgres_full() {
434 let conn = ConnectionString::parse("postgres://user:pass@localhost:5432/mydb").unwrap();
435 assert_eq!(conn.driver(), Driver::Postgres);
436 assert_eq!(conn.user(), Some("user"));
437 assert_eq!(conn.password(), Some("pass"));
438 assert_eq!(conn.host(), Some("localhost"));
439 assert_eq!(conn.port(), Some(5432));
440 assert_eq!(conn.database(), Some("mydb"));
441 }
442
443 #[test]
444 fn test_parse_postgres_with_params() {
445 let conn = ConnectionString::parse(
446 "postgres://user:pass@localhost/mydb?sslmode=require&connect_timeout=10"
447 ).unwrap();
448 assert_eq!(conn.param("sslmode"), Some("require"));
449 assert_eq!(conn.param("connect_timeout"), Some("10"));
450 }
451
452 #[test]
453 fn test_parse_postgres_no_password() {
454 let conn = ConnectionString::parse("postgres://user@localhost/mydb").unwrap();
455 assert_eq!(conn.user(), Some("user"));
456 assert_eq!(conn.password(), None);
457 }
458
459 #[test]
460 fn test_parse_mysql() {
461 let conn = ConnectionString::parse("mysql://root:secret@127.0.0.1:3306/testdb").unwrap();
462 assert_eq!(conn.driver(), Driver::MySql);
463 assert_eq!(conn.host(), Some("127.0.0.1"));
464 assert_eq!(conn.port(), Some(3306));
465 }
466
467 #[test]
468 fn test_parse_mariadb() {
469 let conn = ConnectionString::parse("mariadb://user:pass@localhost/db").unwrap();
470 assert_eq!(conn.driver(), Driver::MySql);
471 }
472
473 #[test]
474 fn test_parse_sqlite_file() {
475 let conn = ConnectionString::parse("sqlite://./data/app.db").unwrap();
476 assert_eq!(conn.driver(), Driver::Sqlite);
477 assert_eq!(conn.database(), Some("./data/app.db"));
478 }
479
480 #[test]
481 fn test_parse_sqlite_memory() {
482 let conn = ConnectionString::parse("sqlite::memory:").unwrap();
483 assert_eq!(conn.driver(), Driver::Sqlite);
484 assert!(conn.is_memory());
485
486 let conn = ConnectionString::parse("sqlite://:memory:").unwrap();
487 assert!(conn.is_memory());
488 }
489
490 #[test]
491 fn test_parse_special_characters() {
492 let conn = ConnectionString::parse("postgres://user:p%40ss%3Aword@localhost/db").unwrap();
493 assert_eq!(conn.password(), Some("p@ss:word"));
494 }
495
496 #[test]
497 fn test_default_port() {
498 assert_eq!(Driver::Postgres.default_port(), Some(5432));
499 assert_eq!(Driver::MySql.default_port(), Some(3306));
500 assert_eq!(Driver::Sqlite.default_port(), None);
501 }
502
503 #[test]
504 fn test_port_or_default() {
505 let conn = ConnectionString::parse("postgres://localhost/db").unwrap();
506 assert_eq!(conn.port(), None);
507 assert_eq!(conn.port_or_default(), Some(5432));
508 }
509
510 #[test]
511 fn test_with_param() {
512 let conn = ConnectionString::parse("postgres://localhost/db").unwrap();
513 let conn = conn.with_param("sslmode", "require");
514 assert_eq!(conn.param("sslmode"), Some("require"));
515 }
516
517 #[test]
518 fn test_to_url_roundtrip() {
519 let original = "postgres://user:pass@localhost:5432/mydb?sslmode=require";
520 let conn = ConnectionString::parse(original).unwrap();
521 let rebuilt = conn.parsed().to_url();
522 assert!(rebuilt.contains("postgres://"));
523 assert!(rebuilt.contains("localhost:5432"));
524 assert!(rebuilt.contains("sslmode=require"));
525 }
526
527 #[test]
528 fn test_invalid_url() {
529 assert!(ConnectionString::parse("not-a-url").is_err());
530 assert!(ConnectionString::parse("unknown://localhost").is_err());
531 }
532}
533
534