1use percent_encoding::{AsciiSet, NON_ALPHANUMERIC, utf8_percent_encode};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10
11use crate::settings::secret_types::SecretString;
12
13const USERINFO_ENCODE_SET: &AsciiSet = &NON_ALPHANUMERIC
17 .remove(b'-')
18 .remove(b'.')
19 .remove(b'_')
20 .remove(b'~');
21
22#[non_exhaustive]
24#[derive(Clone, Serialize, Deserialize)]
25pub struct DatabaseConfig {
26 pub engine: String,
28
29 pub name: String,
31
32 pub user: Option<String>,
34
35 pub password: Option<SecretString>,
37
38 pub host: Option<String>,
40
41 pub port: Option<u16>,
43
44 pub options: HashMap<String, String>,
46}
47
48impl fmt::Debug for DatabaseConfig {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("DatabaseConfig")
51 .field("engine", &self.engine)
52 .field("name", &self.name)
53 .field("user", &self.user)
54 .field("password", &self.password.as_ref().map(|_| "[REDACTED]"))
55 .field("host", &self.host)
56 .field("port", &self.port)
57 .field("options", &self.options)
58 .finish()
59 }
60}
61
62impl DatabaseConfig {
63 pub fn new(engine: impl Into<String>, name: impl Into<String>) -> Self {
75 Self {
76 engine: engine.into(),
77 name: name.into(),
78 user: None,
79 password: None,
80 host: None,
81 port: None,
82 options: HashMap::new(),
83 }
84 }
85
86 pub fn with_user(mut self, user: impl Into<String>) -> Self {
88 self.user = Some(user.into());
89 self
90 }
91
92 pub fn with_password(mut self, password: impl Into<String>) -> Self {
94 self.password = Some(SecretString::new(password.into()));
95 self
96 }
97
98 pub fn with_host(mut self, host: impl Into<String>) -> Self {
100 self.host = Some(host.into());
101 self
102 }
103
104 pub fn with_port(mut self, port: u16) -> Self {
106 self.port = Some(port);
107 self
108 }
109
110 pub fn sqlite(name: impl Into<String>) -> Self {
125 Self {
126 engine: "reinhardt.db.backends.sqlite3".to_string(),
127 name: name.into(),
128 user: None,
129 password: None,
130 host: None,
131 port: None,
132 options: HashMap::new(),
133 }
134 }
135 pub fn postgresql(
152 name: impl Into<String>,
153 user: impl Into<String>,
154 password: impl Into<String>,
155 host: impl Into<String>,
156 port: u16,
157 ) -> Self {
158 Self {
159 engine: "reinhardt.db.backends.postgresql".to_string(),
160 name: name.into(),
161 user: Some(user.into()),
162 password: Some(SecretString::new(password.into())),
163 host: Some(host.into()),
164 port: Some(port),
165 options: HashMap::new(),
166 }
167 }
168 pub fn mysql(
185 name: impl Into<String>,
186 user: impl Into<String>,
187 password: impl Into<String>,
188 host: impl Into<String>,
189 port: u16,
190 ) -> Self {
191 Self {
192 engine: "reinhardt.db.backends.mysql".to_string(),
193 name: name.into(),
194 user: Some(user.into()),
195 password: Some(SecretString::new(password.into())),
196 host: Some(host.into()),
197 port: Some(port),
198 options: HashMap::new(),
199 }
200 }
201
202 pub fn to_url(&self) -> String {
219 let scheme = if self.engine == "sqlite" || self.engine.contains("sqlite") {
222 "sqlite"
223 } else if self.engine == "postgresql"
224 || self.engine == "postgres"
225 || self.engine.contains("postgresql")
226 || self.engine.contains("postgres")
227 {
228 "postgresql"
229 } else if self.engine == "mysql" || self.engine.contains("mysql") {
230 "mysql"
231 } else {
232 "sqlite"
234 };
235
236 match scheme {
237 "sqlite" => {
238 if self.name == ":memory:" {
239 "sqlite::memory:".to_string()
240 } else {
241 use std::path::Path;
244 let path = Path::new(&self.name);
245 if path.is_absolute() {
246 format!("sqlite:///{}", self.name)
248 } else {
249 format!("sqlite:{}", self.name)
251 }
252 }
253 }
254 "postgresql" | "mysql" => {
255 let mut url = format!("{}://", scheme);
256
257 if let Some(user) = &self.user {
259 let encoded_user = utf8_percent_encode(user, USERINFO_ENCODE_SET).to_string();
260 url.push_str(&encoded_user);
261 if let Some(password) = &self.password {
262 url.push(':');
263 let encoded_password =
264 utf8_percent_encode(password.expose_secret(), USERINFO_ENCODE_SET)
265 .to_string();
266 url.push_str(&encoded_password);
267 }
268 url.push('@');
269 }
270
271 let host = self.host.as_deref().unwrap_or("localhost");
273 url.push_str(host);
274
275 if let Some(port) = self.port {
277 url.push(':');
278 url.push_str(&port.to_string());
279 }
280
281 url.push('/');
283 url.push_str(&self.name);
284
285 if !self.options.is_empty() {
287 let mut query_parts = Vec::new();
288 for (key, value) in &self.options {
289 let encoded_key = utf8_percent_encode(key, USERINFO_ENCODE_SET).to_string();
290 let encoded_value =
291 utf8_percent_encode(value, USERINFO_ENCODE_SET).to_string();
292 query_parts.push(format!("{}={}", encoded_key, encoded_value));
293 }
294 url.push('?');
295 url.push_str(&query_parts.join("&"));
296 }
297
298 url
299 }
300 _ => format!("sqlite://{}", self.name),
301 }
302 }
303}
304
305impl fmt::Display for DatabaseConfig {
306 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307 let scheme = if self.engine.contains("sqlite") {
309 "sqlite"
310 } else if self.engine.contains("postgresql") || self.engine.contains("postgres") {
311 "postgresql"
312 } else if self.engine.contains("mysql") {
313 "mysql"
314 } else {
315 "unknown"
316 };
317
318 match scheme {
319 "sqlite" => write!(f, "sqlite:{}", self.name),
320 _ => {
321 write!(f, "{}://", scheme)?;
322 if self.user.is_some() || self.password.is_some() {
323 write!(f, "***@")?;
324 }
325 if let Some(host) = &self.host {
326 write!(f, "{}", host)?;
327 }
328 if let Some(port) = self.port {
329 write!(f, ":{}", port)?;
330 }
331 write!(f, "/{}", self.name)
332 }
333 }
334 }
335}
336
337impl Default for DatabaseConfig {
338 fn default() -> Self {
339 Self::sqlite("db.sqlite3".to_string())
340 }
341}
342
343pub(crate) const VALID_DATABASE_SCHEMES: &[&str] = &[
345 "postgres://",
346 "postgresql://",
347 "sqlite://",
348 "sqlite:",
349 "mysql://",
350 "mariadb://",
351];
352
353pub(crate) fn validate_database_url_scheme(url: &str) -> Result<(), String> {
358 if VALID_DATABASE_SCHEMES.iter().any(|s| url.starts_with(s)) {
359 Ok(())
360 } else {
361 Err(format!(
362 "Invalid database URL: unrecognized scheme. Expected one of: {}",
363 VALID_DATABASE_SCHEMES.join(", ")
364 ))
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use rstest::rstest;
372
373 #[rstest]
374 fn test_settings_db_config_sqlite() {
375 let db = DatabaseConfig::sqlite("test.db");
377
378 assert_eq!(db.engine, "reinhardt.db.backends.sqlite3");
380 assert_eq!(db.name, "test.db");
381 assert!(db.user.is_none());
382 assert!(db.password.is_none());
383 }
384
385 #[rstest]
386 fn test_settings_db_config_postgresql() {
387 let db = DatabaseConfig::postgresql("testdb", "user", "pass", "localhost", 5432);
389
390 assert_eq!(db.engine, "reinhardt.db.backends.postgresql");
392 assert_eq!(db.name, "testdb");
393 assert_eq!(db.user, Some("user".to_string()));
394 assert_eq!(
395 db.password.as_ref().map(|p| p.expose_secret()),
396 Some("pass")
397 );
398 assert_eq!(db.port, Some(5432));
399 }
400
401 #[rstest]
402 fn test_debug_output_redacts_password() {
403 let db = DatabaseConfig::postgresql("testdb", "user", "s3cr3t!", "localhost", 5432);
405
406 let debug_output = format!("{:?}", db);
408
409 assert!(!debug_output.contains("s3cr3t!"));
411 assert!(debug_output.contains("[REDACTED]"));
412 }
413
414 #[rstest]
415 fn test_debug_output_without_password() {
416 let db = DatabaseConfig::sqlite("test.db");
418
419 let debug_output = format!("{:?}", db);
421
422 assert!(debug_output.contains("None"));
424 assert!(debug_output.contains("DatabaseConfig"));
425 }
426
427 #[rstest]
428 fn test_to_url_encodes_special_chars_in_username() {
429 let mut db = DatabaseConfig::postgresql("mydb", "user@domain", "pass", "localhost", 5432);
431 db.user = Some("user@domain".to_string());
432
433 let url = db.to_url();
435
436 assert!(url.contains("user%40domain"));
438 assert!(!url.contains("user@domain:"));
439 }
440
441 #[rstest]
442 fn test_to_url_encodes_special_chars_in_password() {
443 let db = DatabaseConfig::postgresql("mydb", "user", "p@ss:w/rd#", "localhost", 5432);
445
446 let url = db.to_url();
448
449 assert!(url.contains("p%40ss%3Aw%2Frd%23"));
451 assert!(!url.contains("p@ss:w/rd#"));
452 }
453
454 #[rstest]
455 fn test_to_url_prevents_host_injection() {
456 let db = DatabaseConfig::postgresql(
458 "mydb",
459 "admin@evil.com:9999/fake",
460 "pass",
461 "localhost",
462 5432,
463 );
464
465 let url = db.to_url();
467
468 assert!(url.contains("admin%40evil.com%3A9999%2Ffake"));
470 assert!(url.contains("@localhost:5432"));
471 }
472
473 #[rstest]
474 fn test_to_url_encodes_query_parameter_values() {
475 let mut db = DatabaseConfig::postgresql("mydb", "user", "pass", "localhost", 5432);
477 db.options
478 .insert("sslmode".to_string(), "require&inject=true".to_string());
479
480 let url = db.to_url();
482
483 assert!(url.contains("require%26inject%3Dtrue"));
485 assert!(!url.contains("require&inject=true"));
486 }
487
488 #[rstest]
489 fn test_to_url_simple_credentials() {
490 let db = DatabaseConfig::postgresql("mydb", "user", "pass", "localhost", 5432);
492
493 let url = db.to_url();
495
496 assert_eq!(url, "postgresql://user:pass@localhost:5432/mydb");
498 }
499
500 #[rstest]
501 fn test_display_output_masks_credentials() {
502 let db = DatabaseConfig::postgresql("mydb", "admin", "s3cr3t!", "db.example.com", 5432);
504
505 let display_output = format!("{}", db);
507
508 assert!(!display_output.contains("admin"));
510 assert!(!display_output.contains("s3cr3t!"));
511 assert!(display_output.contains("***@"));
512 assert!(display_output.contains("db.example.com"));
513 assert!(display_output.contains("mydb"));
514 }
515
516 #[rstest]
517 fn test_display_output_sqlite() {
518 let db = DatabaseConfig::sqlite("app.db");
520
521 let display_output = format!("{}", db);
523
524 assert_eq!(display_output, "sqlite:app.db");
526 }
527
528 #[rstest]
529 fn test_password_stored_as_secret_string() {
530 let db = DatabaseConfig::postgresql("mydb", "user", "my-secret-pw", "localhost", 5432);
532
533 let password = db.password.as_ref().unwrap();
535
536 assert_eq!(password.expose_secret(), "my-secret-pw");
538 assert_eq!(format!("{}", password), "[REDACTED]");
540 }
541
542 #[rstest]
543 #[case("postgres://localhost/db")]
544 #[case("postgresql://user:pass@localhost:5432/db")]
545 #[case("sqlite::memory:")]
546 #[case("sqlite:///path/to/db")]
547 #[case("mysql://root@localhost/db")]
548 #[case("mariadb://root@localhost/db")]
549 fn test_valid_database_url_schemes(#[case] url: &str) {
550 assert!(validate_database_url_scheme(url).is_ok());
552 }
553
554 #[rstest]
555 #[case("http://localhost/db")]
556 #[case("ftp://localhost/db")]
557 #[case("redis://localhost")]
558 #[case("")]
559 #[case("not-a-url")]
560 fn test_invalid_database_url_schemes(#[case] url: &str) {
561 let result = validate_database_url_scheme(url);
563
564 assert!(result.is_err());
566 assert!(result.unwrap_err().contains("Invalid database URL"));
567 }
568}