1use percent_encoding::{AsciiSet, NON_ALPHANUMERIC, utf8_percent_encode};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10
11use crate::settings::secret_types::SecretString;
12
13const USERINFO_ENCODE_SET: &AsciiSet = &NON_ALPHANUMERIC
17 .remove(b'-')
18 .remove(b'.')
19 .remove(b'_')
20 .remove(b'~');
21
22#[non_exhaustive]
24#[derive(Clone, Serialize, Deserialize)]
25pub struct DatabaseConfig {
26 pub engine: String,
28
29 pub name: String,
31
32 pub user: Option<String>,
34
35 pub password: Option<SecretString>,
37
38 pub host: Option<String>,
40
41 pub port: Option<u16>,
43
44 #[serde(default)]
46 pub options: HashMap<String, String>,
47}
48
49impl fmt::Debug for DatabaseConfig {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 f.debug_struct("DatabaseConfig")
52 .field("engine", &self.engine)
53 .field("name", &self.name)
54 .field("user", &self.user)
55 .field("password", &self.password.as_ref().map(|_| "[REDACTED]"))
56 .field("host", &self.host)
57 .field("port", &self.port)
58 .field("options", &self.options)
59 .finish()
60 }
61}
62
63impl DatabaseConfig {
64 pub fn new(engine: impl Into<String>, name: impl Into<String>) -> Self {
76 Self {
77 engine: engine.into(),
78 name: name.into(),
79 user: None,
80 password: None,
81 host: None,
82 port: None,
83 options: HashMap::new(),
84 }
85 }
86
87 pub fn with_user(mut self, user: impl Into<String>) -> Self {
89 self.user = Some(user.into());
90 self
91 }
92
93 pub fn with_password(mut self, password: impl Into<String>) -> Self {
95 self.password = Some(SecretString::new(password.into()));
96 self
97 }
98
99 pub fn with_host(mut self, host: impl Into<String>) -> Self {
101 self.host = Some(host.into());
102 self
103 }
104
105 pub fn with_port(mut self, port: u16) -> Self {
107 self.port = Some(port);
108 self
109 }
110
111 pub fn sqlite(name: impl Into<String>) -> Self {
126 Self {
127 engine: "reinhardt.db.backends.sqlite3".to_string(),
128 name: name.into(),
129 user: None,
130 password: None,
131 host: None,
132 port: None,
133 options: HashMap::new(),
134 }
135 }
136 pub fn postgresql(
153 name: impl Into<String>,
154 user: impl Into<String>,
155 password: impl Into<String>,
156 host: impl Into<String>,
157 port: u16,
158 ) -> Self {
159 Self {
160 engine: "reinhardt.db.backends.postgresql".to_string(),
161 name: name.into(),
162 user: Some(user.into()),
163 password: Some(SecretString::new(password.into())),
164 host: Some(host.into()),
165 port: Some(port),
166 options: HashMap::new(),
167 }
168 }
169 pub fn mysql(
186 name: impl Into<String>,
187 user: impl Into<String>,
188 password: impl Into<String>,
189 host: impl Into<String>,
190 port: u16,
191 ) -> Self {
192 Self {
193 engine: "reinhardt.db.backends.mysql".to_string(),
194 name: name.into(),
195 user: Some(user.into()),
196 password: Some(SecretString::new(password.into())),
197 host: Some(host.into()),
198 port: Some(port),
199 options: HashMap::new(),
200 }
201 }
202
203 pub fn to_url(&self) -> String {
220 let scheme = if self.engine == "sqlite" || self.engine.contains("sqlite") {
223 "sqlite"
224 } else if self.engine == "postgresql"
225 || self.engine == "postgres"
226 || self.engine.contains("postgresql")
227 || self.engine.contains("postgres")
228 {
229 "postgresql"
230 } else if self.engine == "mysql" || self.engine.contains("mysql") {
231 "mysql"
232 } else {
233 "sqlite"
235 };
236
237 match scheme {
238 "sqlite" => {
239 if self.name == ":memory:" {
240 "sqlite::memory:".to_string()
241 } else {
242 use std::path::Path;
245 let path = Path::new(&self.name);
246 if path.is_absolute() {
247 format!("sqlite:///{}", self.name)
249 } else {
250 format!("sqlite:{}", self.name)
252 }
253 }
254 }
255 "postgresql" | "mysql" => {
256 let mut url = format!("{}://", scheme);
257
258 if let Some(user) = &self.user {
260 let encoded_user = utf8_percent_encode(user, USERINFO_ENCODE_SET).to_string();
261 url.push_str(&encoded_user);
262 if let Some(password) = &self.password {
263 url.push(':');
264 let encoded_password =
265 utf8_percent_encode(password.expose_secret(), USERINFO_ENCODE_SET)
266 .to_string();
267 url.push_str(&encoded_password);
268 }
269 url.push('@');
270 }
271
272 let host = self.host.as_deref().unwrap_or("localhost");
274 url.push_str(host);
275
276 if let Some(port) = self.port {
278 url.push(':');
279 url.push_str(&port.to_string());
280 }
281
282 url.push('/');
284 url.push_str(&self.name);
285
286 if !self.options.is_empty() {
288 let mut query_parts = Vec::new();
289 for (key, value) in &self.options {
290 let encoded_key = utf8_percent_encode(key, USERINFO_ENCODE_SET).to_string();
291 let encoded_value =
292 utf8_percent_encode(value, USERINFO_ENCODE_SET).to_string();
293 query_parts.push(format!("{}={}", encoded_key, encoded_value));
294 }
295 url.push('?');
296 url.push_str(&query_parts.join("&"));
297 }
298
299 url
300 }
301 _ => format!("sqlite://{}", self.name),
302 }
303 }
304}
305
306impl fmt::Display for DatabaseConfig {
307 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308 let scheme = if self.engine.contains("sqlite") {
310 "sqlite"
311 } else if self.engine.contains("postgresql") || self.engine.contains("postgres") {
312 "postgresql"
313 } else if self.engine.contains("mysql") {
314 "mysql"
315 } else {
316 "unknown"
317 };
318
319 match scheme {
320 "sqlite" => write!(f, "sqlite:{}", self.name),
321 _ => {
322 write!(f, "{}://", scheme)?;
323 if self.user.is_some() || self.password.is_some() {
324 write!(f, "***@")?;
325 }
326 if let Some(host) = &self.host {
327 write!(f, "{}", host)?;
328 }
329 if let Some(port) = self.port {
330 write!(f, ":{}", port)?;
331 }
332 write!(f, "/{}", self.name)
333 }
334 }
335 }
336}
337
338impl Default for DatabaseConfig {
339 fn default() -> Self {
340 Self::sqlite("db.sqlite3".to_string())
341 }
342}
343
344pub const VALID_DATABASE_SCHEMES: &[&str] = &[
346 "postgres://",
347 "postgresql://",
348 "sqlite://",
349 "sqlite:",
350 "mysql://",
351 "mariadb://",
352];
353
354pub fn validate_database_url_scheme(url: &str) -> Result<(), String> {
359 if VALID_DATABASE_SCHEMES.iter().any(|s| url.starts_with(s)) {
360 Ok(())
361 } else {
362 Err(format!(
363 "Invalid database URL: unrecognized scheme. Expected one of: {}",
364 VALID_DATABASE_SCHEMES.join(", ")
365 ))
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use rstest::rstest;
373
374 #[rstest]
375 fn test_settings_db_config_sqlite() {
376 let db = DatabaseConfig::sqlite("test.db");
378
379 assert_eq!(db.engine, "reinhardt.db.backends.sqlite3");
381 assert_eq!(db.name, "test.db");
382 assert!(db.user.is_none());
383 assert!(db.password.is_none());
384 }
385
386 #[rstest]
387 fn test_settings_db_config_postgresql() {
388 let db = DatabaseConfig::postgresql("testdb", "user", "pass", "localhost", 5432);
390
391 assert_eq!(db.engine, "reinhardt.db.backends.postgresql");
393 assert_eq!(db.name, "testdb");
394 assert_eq!(db.user, Some("user".to_string()));
395 assert_eq!(
396 db.password.as_ref().map(|p| p.expose_secret()),
397 Some("pass")
398 );
399 assert_eq!(db.port, Some(5432));
400 }
401
402 #[rstest]
403 fn test_debug_output_redacts_password() {
404 let db = DatabaseConfig::postgresql("testdb", "user", "s3cr3t!", "localhost", 5432);
406
407 let debug_output = format!("{:?}", db);
409
410 assert!(!debug_output.contains("s3cr3t!"));
412 assert!(debug_output.contains("[REDACTED]"));
413 }
414
415 #[rstest]
416 fn test_debug_output_without_password() {
417 let db = DatabaseConfig::sqlite("test.db");
419
420 let debug_output = format!("{:?}", db);
422
423 assert!(debug_output.contains("None"));
425 assert!(debug_output.contains("DatabaseConfig"));
426 }
427
428 #[rstest]
429 fn test_to_url_encodes_special_chars_in_username() {
430 let mut db = DatabaseConfig::postgresql("mydb", "user@domain", "pass", "localhost", 5432);
432 db.user = Some("user@domain".to_string());
433
434 let url = db.to_url();
436
437 assert!(url.contains("user%40domain"));
439 assert!(!url.contains("user@domain:"));
440 }
441
442 #[rstest]
443 fn test_to_url_encodes_special_chars_in_password() {
444 let db = DatabaseConfig::postgresql("mydb", "user", "p@ss:w/rd#", "localhost", 5432);
446
447 let url = db.to_url();
449
450 assert!(url.contains("p%40ss%3Aw%2Frd%23"));
452 assert!(!url.contains("p@ss:w/rd#"));
453 }
454
455 #[rstest]
456 fn test_to_url_prevents_host_injection() {
457 let db = DatabaseConfig::postgresql(
459 "mydb",
460 "admin@evil.com:9999/fake",
461 "pass",
462 "localhost",
463 5432,
464 );
465
466 let url = db.to_url();
468
469 assert!(url.contains("admin%40evil.com%3A9999%2Ffake"));
471 assert!(url.contains("@localhost:5432"));
472 }
473
474 #[rstest]
475 fn test_to_url_encodes_query_parameter_values() {
476 let mut db = DatabaseConfig::postgresql("mydb", "user", "pass", "localhost", 5432);
478 db.options
479 .insert("sslmode".to_string(), "require&inject=true".to_string());
480
481 let url = db.to_url();
483
484 assert!(url.contains("require%26inject%3Dtrue"));
486 assert!(!url.contains("require&inject=true"));
487 }
488
489 #[rstest]
490 fn test_to_url_simple_credentials() {
491 let db = DatabaseConfig::postgresql("mydb", "user", "pass", "localhost", 5432);
493
494 let url = db.to_url();
496
497 assert_eq!(url, "postgresql://user:pass@localhost:5432/mydb");
499 }
500
501 #[rstest]
502 fn test_display_output_masks_credentials() {
503 let db = DatabaseConfig::postgresql("mydb", "admin", "s3cr3t!", "db.example.com", 5432);
505
506 let display_output = format!("{}", db);
508
509 assert!(!display_output.contains("admin"));
511 assert!(!display_output.contains("s3cr3t!"));
512 assert!(display_output.contains("***@"));
513 assert!(display_output.contains("db.example.com"));
514 assert!(display_output.contains("mydb"));
515 }
516
517 #[rstest]
518 fn test_display_output_sqlite() {
519 let db = DatabaseConfig::sqlite("app.db");
521
522 let display_output = format!("{}", db);
524
525 assert_eq!(display_output, "sqlite:app.db");
527 }
528
529 #[rstest]
530 fn test_password_stored_as_secret_string() {
531 let db = DatabaseConfig::postgresql("mydb", "user", "my-secret-pw", "localhost", 5432);
533
534 let password = db.password.as_ref().unwrap();
536
537 assert_eq!(password.expose_secret(), "my-secret-pw");
539 assert_eq!(format!("{}", password), "[REDACTED]");
541 }
542
543 #[rstest]
544 #[case("postgres://localhost/db")]
545 #[case("postgresql://user:pass@localhost:5432/db")]
546 #[case("sqlite::memory:")]
547 #[case("sqlite:///path/to/db")]
548 #[case("mysql://root@localhost/db")]
549 #[case("mariadb://root@localhost/db")]
550 fn test_valid_database_url_schemes(#[case] url: &str) {
551 assert!(validate_database_url_scheme(url).is_ok());
553 }
554
555 #[rstest]
556 #[case("http://localhost/db")]
557 #[case("ftp://localhost/db")]
558 #[case("redis://localhost")]
559 #[case("")]
560 #[case("not-a-url")]
561 fn test_invalid_database_url_schemes(#[case] url: &str) {
562 let result = validate_database_url_scheme(url);
564
565 assert!(result.is_err());
567 assert!(result.unwrap_err().contains("Invalid database URL"));
568 }
569}