1use super::{ConnectionError, ConnectionResult};
4use std::collections::HashMap;
5use tracing::debug;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum Driver {
10 Postgres,
12 MySql,
14 Sqlite,
16}
17
18impl Driver {
19 pub fn default_port(&self) -> Option<u16> {
21 match self {
22 Self::Postgres => Some(5432),
23 Self::MySql => Some(3306),
24 Self::Sqlite => None,
25 }
26 }
27
28 pub fn name(&self) -> &'static str {
30 match self {
31 Self::Postgres => "postgres",
32 Self::MySql => "mysql",
33 Self::Sqlite => "sqlite",
34 }
35 }
36
37 pub fn from_scheme(scheme: &str) -> ConnectionResult<Self> {
39 match scheme.to_lowercase().as_str() {
40 "postgres" | "postgresql" => Ok(Self::Postgres),
41 "mysql" | "mariadb" => Ok(Self::MySql),
42 "sqlite" | "sqlite3" | "file" => Ok(Self::Sqlite),
43 other => Err(ConnectionError::UnknownDriver(other.to_string())),
44 }
45 }
46}
47
48impl std::fmt::Display for Driver {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(f, "{}", self.name())
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct ParsedUrl {
57 pub driver: Driver,
59 pub user: Option<String>,
61 pub password: Option<String>,
63 pub host: Option<String>,
65 pub port: Option<u16>,
67 pub database: Option<String>,
69 pub params: HashMap<String, String>,
71}
72
73impl ParsedUrl {
74 pub fn is_memory(&self) -> bool {
76 self.driver == Driver::Sqlite
77 && self
78 .database
79 .as_ref()
80 .is_some_and(|d| d == ":memory:" || d.is_empty())
81 }
82
83 pub fn param(&self, key: &str) -> Option<&str> {
85 self.params.get(key).map(|s| s.as_str())
86 }
87
88 pub fn to_url(&self) -> String {
90 let mut url = format!("{}://", self.driver.name());
91
92 if let Some(ref user) = self.user {
94 url.push_str(&url_encode(user));
95 if let Some(ref pass) = self.password {
96 url.push(':');
97 url.push_str(&url_encode(pass));
98 }
99 url.push('@');
100 }
101
102 if let Some(ref host) = self.host {
104 url.push_str(host);
105 if let Some(port) = self.port {
106 url.push(':');
107 url.push_str(&port.to_string());
108 }
109 }
110
111 if let Some(ref db) = self.database {
113 url.push('/');
114 url.push_str(db);
115 }
116
117 if !self.params.is_empty() {
119 url.push('?');
120 let params: Vec<_> = self
121 .params
122 .iter()
123 .map(|(k, v)| format!("{}={}", url_encode(k), url_encode(v)))
124 .collect();
125 url.push_str(¶ms.join("&"));
126 }
127
128 url
129 }
130}
131
132#[derive(Debug, Clone)]
134pub struct ConnectionString {
135 parsed: ParsedUrl,
136 original: String,
137}
138
139impl ConnectionString {
140 pub fn parse(url: &str) -> ConnectionResult<Self> {
158 debug!(url_len = url.len(), "ConnectionString::parse()");
159 let original = url.to_string();
160 let parsed = parse_url(url)?;
161 debug!(driver = %parsed.driver, host = ?parsed.host, database = ?parsed.database, "Connection parsed");
162 Ok(Self { parsed, original })
163 }
164
165 pub fn from_env(var: &str) -> ConnectionResult<Self> {
167 let url = std::env::var(var).map_err(|_| ConnectionError::EnvNotFound(var.to_string()))?;
168 Self::parse(&url)
169 }
170
171 pub fn from_database_url() -> ConnectionResult<Self> {
173 Self::from_env("DATABASE_URL")
174 }
175
176 pub fn as_str(&self) -> &str {
178 &self.original
179 }
180
181 pub fn driver(&self) -> Driver {
183 self.parsed.driver
184 }
185
186 pub fn user(&self) -> Option<&str> {
188 self.parsed.user.as_deref()
189 }
190
191 pub fn password(&self) -> Option<&str> {
193 self.parsed.password.as_deref()
194 }
195
196 pub fn host(&self) -> Option<&str> {
198 self.parsed.host.as_deref()
199 }
200
201 pub fn port(&self) -> Option<u16> {
203 self.parsed.port
204 }
205
206 pub fn port_or_default(&self) -> Option<u16> {
208 self.parsed
209 .port
210 .or_else(|| self.parsed.driver.default_port())
211 }
212
213 pub fn database(&self) -> Option<&str> {
215 self.parsed.database.as_deref()
216 }
217
218 pub fn param(&self, key: &str) -> Option<&str> {
220 self.parsed.param(key)
221 }
222
223 pub fn params(&self) -> &HashMap<String, String> {
225 &self.parsed.params
226 }
227
228 pub fn parsed(&self) -> &ParsedUrl {
230 &self.parsed
231 }
232
233 pub fn is_memory(&self) -> bool {
235 self.parsed.is_memory()
236 }
237
238 pub fn with_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
240 self.parsed.params.insert(key.into(), value.into());
241 self.original = self.parsed.to_url();
242 self
243 }
244
245 pub fn without_param(mut self, key: &str) -> Self {
247 self.parsed.params.remove(key);
248 self.original = self.parsed.to_url();
249 self
250 }
251}
252
253fn parse_url(url: &str) -> ConnectionResult<ParsedUrl> {
255 if url == "sqlite::memory:" || url == ":memory:" {
257 return Ok(ParsedUrl {
258 driver: Driver::Sqlite,
259 user: None,
260 password: None,
261 host: None,
262 port: None,
263 database: Some(":memory:".to_string()),
264 params: HashMap::new(),
265 });
266 }
267
268 let (scheme, rest) = url.split_once("://").ok_or_else(|| {
270 ConnectionError::InvalidUrl("Missing scheme (e.g., postgres://)".to_string())
271 })?;
272
273 let driver = Driver::from_scheme(scheme)?;
274
275 if driver == Driver::Sqlite {
277 return parse_sqlite_url(rest);
278 }
279
280 parse_network_url(driver, rest)
282}
283
284fn parse_sqlite_url(rest: &str) -> ConnectionResult<ParsedUrl> {
285 let (path, params) = parse_query_params(rest);
287
288 let database = if path.is_empty() || path == ":memory:" {
289 Some(":memory:".to_string())
290 } else {
291 Some(url_decode(&path))
292 };
293
294 Ok(ParsedUrl {
295 driver: Driver::Sqlite,
296 user: None,
297 password: None,
298 host: None,
299 port: None,
300 database,
301 params,
302 })
303}
304
305fn parse_network_url(driver: Driver, rest: &str) -> ConnectionResult<ParsedUrl> {
306 let (main, params) = parse_query_params(rest);
308
309 let (creds, host_part) = if let Some(at_pos) = main.rfind('@') {
311 (Some(&main[..at_pos]), &main[at_pos + 1..])
312 } else {
313 (None, main.as_str())
314 };
315
316 let (user, password) = if let Some(creds) = creds {
318 if let Some((u, p)) = creds.split_once(':') {
319 (Some(url_decode(u)), Some(url_decode(p)))
320 } else {
321 (Some(url_decode(creds)), None)
322 }
323 } else {
324 (None, None)
325 };
326
327 let (host_port, database) = if let Some(slash_pos) = host_part.find('/') {
329 (
330 &host_part[..slash_pos],
331 Some(url_decode(&host_part[slash_pos + 1..])),
332 )
333 } else {
334 (host_part, None)
335 };
336
337 let (host, port) = if host_port.is_empty() {
339 (None, None)
340 } else if let Some(colon_pos) = host_port.rfind(':') {
341 if host_port.starts_with('[') {
343 if let Some(bracket_pos) = host_port.find(']') {
344 if colon_pos > bracket_pos {
345 let port = host_port[colon_pos + 1..].parse().map_err(|_| {
347 ConnectionError::InvalidUrl("Invalid port number".to_string())
348 })?;
349 (Some(host_port[..colon_pos].to_string()), Some(port))
350 } else {
351 (Some(host_port.to_string()), None)
353 }
354 } else {
355 return Err(ConnectionError::InvalidUrl(
356 "Invalid IPv6 address".to_string(),
357 ));
358 }
359 } else {
360 let port = host_port[colon_pos + 1..]
362 .parse()
363 .map_err(|_| ConnectionError::InvalidUrl("Invalid port number".to_string()))?;
364 (Some(host_port[..colon_pos].to_string()), Some(port))
365 }
366 } else {
367 (Some(host_port.to_string()), None)
368 };
369
370 Ok(ParsedUrl {
371 driver,
372 user,
373 password,
374 host,
375 port,
376 database,
377 params,
378 })
379}
380
381fn parse_query_params(input: &str) -> (String, HashMap<String, String>) {
382 if let Some((main, query)) = input.split_once('?') {
383 let params = query
384 .split('&')
385 .filter_map(|pair| {
386 let (key, value) = pair.split_once('=')?;
387 Some((url_decode(key), url_decode(value)))
388 })
389 .collect();
390 (main.to_string(), params)
391 } else {
392 (input.to_string(), HashMap::new())
393 }
394}
395
396fn url_decode(s: &str) -> String {
397 let mut result = String::with_capacity(s.len());
399 let mut chars = s.chars().peekable();
400
401 while let Some(c) = chars.next() {
402 if c == '%' {
403 let hex: String = chars.by_ref().take(2).collect();
404 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
405 result.push(byte as char);
406 } else {
407 result.push('%');
408 result.push_str(&hex);
409 }
410 } else if c == '+' {
411 result.push(' ');
412 } else {
413 result.push(c);
414 }
415 }
416
417 result
418}
419
420fn url_encode(s: &str) -> String {
421 let mut result = String::with_capacity(s.len() * 3);
422 for c in s.chars() {
423 match c {
424 'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => result.push(c),
425 _ => {
426 for byte in c.to_string().bytes() {
427 result.push_str(&format!("%{:02X}", byte));
428 }
429 }
430 }
431 }
432 result
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_parse_postgres_full() {
441 let conn = ConnectionString::parse("postgres://user:pass@localhost:5432/mydb").unwrap();
442 assert_eq!(conn.driver(), Driver::Postgres);
443 assert_eq!(conn.user(), Some("user"));
444 assert_eq!(conn.password(), Some("pass"));
445 assert_eq!(conn.host(), Some("localhost"));
446 assert_eq!(conn.port(), Some(5432));
447 assert_eq!(conn.database(), Some("mydb"));
448 }
449
450 #[test]
451 fn test_parse_postgres_with_params() {
452 let conn = ConnectionString::parse(
453 "postgres://user:pass@localhost/mydb?sslmode=require&connect_timeout=10",
454 )
455 .unwrap();
456 assert_eq!(conn.param("sslmode"), Some("require"));
457 assert_eq!(conn.param("connect_timeout"), Some("10"));
458 }
459
460 #[test]
461 fn test_parse_postgres_no_password() {
462 let conn = ConnectionString::parse("postgres://user@localhost/mydb").unwrap();
463 assert_eq!(conn.user(), Some("user"));
464 assert_eq!(conn.password(), None);
465 }
466
467 #[test]
468 fn test_parse_mysql() {
469 let conn = ConnectionString::parse("mysql://root:secret@127.0.0.1:3306/testdb").unwrap();
470 assert_eq!(conn.driver(), Driver::MySql);
471 assert_eq!(conn.host(), Some("127.0.0.1"));
472 assert_eq!(conn.port(), Some(3306));
473 }
474
475 #[test]
476 fn test_parse_mariadb() {
477 let conn = ConnectionString::parse("mariadb://user:pass@localhost/db").unwrap();
478 assert_eq!(conn.driver(), Driver::MySql);
479 }
480
481 #[test]
482 fn test_parse_sqlite_file() {
483 let conn = ConnectionString::parse("sqlite://./data/app.db").unwrap();
484 assert_eq!(conn.driver(), Driver::Sqlite);
485 assert_eq!(conn.database(), Some("./data/app.db"));
486 }
487
488 #[test]
489 fn test_parse_sqlite_memory() {
490 let conn = ConnectionString::parse("sqlite::memory:").unwrap();
491 assert_eq!(conn.driver(), Driver::Sqlite);
492 assert!(conn.is_memory());
493
494 let conn = ConnectionString::parse("sqlite://:memory:").unwrap();
495 assert!(conn.is_memory());
496 }
497
498 #[test]
499 fn test_parse_special_characters() {
500 let conn = ConnectionString::parse("postgres://user:p%40ss%3Aword@localhost/db").unwrap();
501 assert_eq!(conn.password(), Some("p@ss:word"));
502 }
503
504 #[test]
505 fn test_default_port() {
506 assert_eq!(Driver::Postgres.default_port(), Some(5432));
507 assert_eq!(Driver::MySql.default_port(), Some(3306));
508 assert_eq!(Driver::Sqlite.default_port(), None);
509 }
510
511 #[test]
512 fn test_port_or_default() {
513 let conn = ConnectionString::parse("postgres://localhost/db").unwrap();
514 assert_eq!(conn.port(), None);
515 assert_eq!(conn.port_or_default(), Some(5432));
516 }
517
518 #[test]
519 fn test_with_param() {
520 let conn = ConnectionString::parse("postgres://localhost/db").unwrap();
521 let conn = conn.with_param("sslmode", "require");
522 assert_eq!(conn.param("sslmode"), Some("require"));
523 }
524
525 #[test]
526 fn test_to_url_roundtrip() {
527 let original = "postgres://user:pass@localhost:5432/mydb?sslmode=require";
528 let conn = ConnectionString::parse(original).unwrap();
529 let rebuilt = conn.parsed().to_url();
530 assert!(rebuilt.contains("postgres://"));
531 assert!(rebuilt.contains("localhost:5432"));
532 assert!(rebuilt.contains("sslmode=require"));
533 }
534
535 #[test]
536 fn test_invalid_url() {
537 assert!(ConnectionString::parse("not-a-url").is_err());
538 assert!(ConnectionString::parse("unknown://localhost").is_err());
539 }
540}