1use percent_encoding::{AsciiSet, NON_ALPHANUMERIC, utf8_percent_encode};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10
11use crate::settings::secret_types::SecretString;
12
13const USERINFO_ENCODE_SET: &AsciiSet = &NON_ALPHANUMERIC
17 .remove(b'-')
18 .remove(b'.')
19 .remove(b'_')
20 .remove(b'~');
21
22#[derive(Clone, Serialize, Deserialize)]
24pub struct DatabaseConfig {
25 pub engine: String,
27
28 pub name: String,
30
31 pub user: Option<String>,
33
34 pub password: Option<SecretString>,
36
37 pub host: Option<String>,
39
40 pub port: Option<u16>,
42
43 pub options: HashMap<String, String>,
45}
46
47impl fmt::Debug for DatabaseConfig {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 f.debug_struct("DatabaseConfig")
50 .field("engine", &self.engine)
51 .field("name", &self.name)
52 .field("user", &self.user)
53 .field("password", &self.password.as_ref().map(|_| "[REDACTED]"))
54 .field("host", &self.host)
55 .field("port", &self.port)
56 .field("options", &self.options)
57 .finish()
58 }
59}
60
61impl DatabaseConfig {
62 pub fn sqlite(name: impl Into<String>) -> Self {
77 Self {
78 engine: "reinhardt.db.backends.sqlite3".to_string(),
79 name: name.into(),
80 user: None,
81 password: None,
82 host: None,
83 port: None,
84 options: HashMap::new(),
85 }
86 }
87 pub fn postgresql(
104 name: impl Into<String>,
105 user: impl Into<String>,
106 password: impl Into<String>,
107 host: impl Into<String>,
108 port: u16,
109 ) -> Self {
110 Self {
111 engine: "reinhardt.db.backends.postgresql".to_string(),
112 name: name.into(),
113 user: Some(user.into()),
114 password: Some(SecretString::new(password.into())),
115 host: Some(host.into()),
116 port: Some(port),
117 options: HashMap::new(),
118 }
119 }
120 pub fn mysql(
137 name: impl Into<String>,
138 user: impl Into<String>,
139 password: impl Into<String>,
140 host: impl Into<String>,
141 port: u16,
142 ) -> Self {
143 Self {
144 engine: "reinhardt.db.backends.mysql".to_string(),
145 name: name.into(),
146 user: Some(user.into()),
147 password: Some(SecretString::new(password.into())),
148 host: Some(host.into()),
149 port: Some(port),
150 options: HashMap::new(),
151 }
152 }
153
154 pub fn to_url(&self) -> String {
171 let scheme = if self.engine == "sqlite" || self.engine.contains("sqlite") {
174 "sqlite"
175 } else if self.engine == "postgresql"
176 || self.engine == "postgres"
177 || self.engine.contains("postgresql")
178 || self.engine.contains("postgres")
179 {
180 "postgresql"
181 } else if self.engine == "mysql" || self.engine.contains("mysql") {
182 "mysql"
183 } else {
184 "sqlite"
186 };
187
188 match scheme {
189 "sqlite" => {
190 if self.name == ":memory:" {
191 "sqlite::memory:".to_string()
192 } else {
193 use std::path::Path;
196 let path = Path::new(&self.name);
197 if path.is_absolute() {
198 format!("sqlite:///{}", self.name)
200 } else {
201 format!("sqlite:{}", self.name)
203 }
204 }
205 }
206 "postgresql" | "mysql" => {
207 let mut url = format!("{}://", scheme);
208
209 if let Some(user) = &self.user {
211 let encoded_user = utf8_percent_encode(user, USERINFO_ENCODE_SET).to_string();
212 url.push_str(&encoded_user);
213 if let Some(password) = &self.password {
214 url.push(':');
215 let encoded_password =
216 utf8_percent_encode(password.expose_secret(), USERINFO_ENCODE_SET)
217 .to_string();
218 url.push_str(&encoded_password);
219 }
220 url.push('@');
221 }
222
223 let host = self.host.as_deref().unwrap_or("localhost");
225 url.push_str(host);
226
227 if let Some(port) = self.port {
229 url.push(':');
230 url.push_str(&port.to_string());
231 }
232
233 url.push('/');
235 url.push_str(&self.name);
236
237 if !self.options.is_empty() {
239 let mut query_parts = Vec::new();
240 for (key, value) in &self.options {
241 let encoded_key = utf8_percent_encode(key, USERINFO_ENCODE_SET).to_string();
242 let encoded_value =
243 utf8_percent_encode(value, USERINFO_ENCODE_SET).to_string();
244 query_parts.push(format!("{}={}", encoded_key, encoded_value));
245 }
246 url.push('?');
247 url.push_str(&query_parts.join("&"));
248 }
249
250 url
251 }
252 _ => format!("sqlite://{}", self.name),
253 }
254 }
255}
256
257impl fmt::Display for DatabaseConfig {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 let scheme = if self.engine.contains("sqlite") {
261 "sqlite"
262 } else if self.engine.contains("postgresql") || self.engine.contains("postgres") {
263 "postgresql"
264 } else if self.engine.contains("mysql") {
265 "mysql"
266 } else {
267 "unknown"
268 };
269
270 match scheme {
271 "sqlite" => write!(f, "sqlite:{}", self.name),
272 _ => {
273 write!(f, "{}://", scheme)?;
274 if self.user.is_some() || self.password.is_some() {
275 write!(f, "***@")?;
276 }
277 if let Some(host) = &self.host {
278 write!(f, "{}", host)?;
279 }
280 if let Some(port) = self.port {
281 write!(f, ":{}", port)?;
282 }
283 write!(f, "/{}", self.name)
284 }
285 }
286 }
287}
288
289impl Default for DatabaseConfig {
290 fn default() -> Self {
291 Self::sqlite("db.sqlite3".to_string())
292 }
293}
294
295pub(crate) const VALID_DATABASE_SCHEMES: &[&str] = &[
297 "postgres://",
298 "postgresql://",
299 "sqlite://",
300 "sqlite:",
301 "mysql://",
302 "mariadb://",
303];
304
305pub(crate) fn validate_database_url_scheme(url: &str) -> Result<(), String> {
310 if VALID_DATABASE_SCHEMES.iter().any(|s| url.starts_with(s)) {
311 Ok(())
312 } else {
313 Err(format!(
314 "Invalid database URL: unrecognized scheme. Expected one of: {}",
315 VALID_DATABASE_SCHEMES.join(", ")
316 ))
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use rstest::rstest;
324
325 #[rstest]
326 fn test_settings_db_config_sqlite() {
327 let db = DatabaseConfig::sqlite("test.db");
329
330 assert_eq!(db.engine, "reinhardt.db.backends.sqlite3");
332 assert_eq!(db.name, "test.db");
333 assert!(db.user.is_none());
334 assert!(db.password.is_none());
335 }
336
337 #[rstest]
338 fn test_settings_db_config_postgresql() {
339 let db = DatabaseConfig::postgresql("testdb", "user", "pass", "localhost", 5432);
341
342 assert_eq!(db.engine, "reinhardt.db.backends.postgresql");
344 assert_eq!(db.name, "testdb");
345 assert_eq!(db.user, Some("user".to_string()));
346 assert_eq!(
347 db.password.as_ref().map(|p| p.expose_secret()),
348 Some("pass")
349 );
350 assert_eq!(db.port, Some(5432));
351 }
352
353 #[rstest]
354 fn test_debug_output_redacts_password() {
355 let db = DatabaseConfig::postgresql("testdb", "user", "s3cr3t!", "localhost", 5432);
357
358 let debug_output = format!("{:?}", db);
360
361 assert!(!debug_output.contains("s3cr3t!"));
363 assert!(debug_output.contains("[REDACTED]"));
364 }
365
366 #[rstest]
367 fn test_debug_output_without_password() {
368 let db = DatabaseConfig::sqlite("test.db");
370
371 let debug_output = format!("{:?}", db);
373
374 assert!(debug_output.contains("None"));
376 assert!(debug_output.contains("DatabaseConfig"));
377 }
378
379 #[rstest]
380 fn test_to_url_encodes_special_chars_in_username() {
381 let mut db = DatabaseConfig::postgresql("mydb", "user@domain", "pass", "localhost", 5432);
383 db.user = Some("user@domain".to_string());
384
385 let url = db.to_url();
387
388 assert!(url.contains("user%40domain"));
390 assert!(!url.contains("user@domain:"));
391 }
392
393 #[rstest]
394 fn test_to_url_encodes_special_chars_in_password() {
395 let db = DatabaseConfig::postgresql("mydb", "user", "p@ss:w/rd#", "localhost", 5432);
397
398 let url = db.to_url();
400
401 assert!(url.contains("p%40ss%3Aw%2Frd%23"));
403 assert!(!url.contains("p@ss:w/rd#"));
404 }
405
406 #[rstest]
407 fn test_to_url_prevents_host_injection() {
408 let db = DatabaseConfig::postgresql(
410 "mydb",
411 "admin@evil.com:9999/fake",
412 "pass",
413 "localhost",
414 5432,
415 );
416
417 let url = db.to_url();
419
420 assert!(url.contains("admin%40evil.com%3A9999%2Ffake"));
422 assert!(url.contains("@localhost:5432"));
423 }
424
425 #[rstest]
426 fn test_to_url_encodes_query_parameter_values() {
427 let mut db = DatabaseConfig::postgresql("mydb", "user", "pass", "localhost", 5432);
429 db.options
430 .insert("sslmode".to_string(), "require&inject=true".to_string());
431
432 let url = db.to_url();
434
435 assert!(url.contains("require%26inject%3Dtrue"));
437 assert!(!url.contains("require&inject=true"));
438 }
439
440 #[rstest]
441 fn test_to_url_simple_credentials() {
442 let db = DatabaseConfig::postgresql("mydb", "user", "pass", "localhost", 5432);
444
445 let url = db.to_url();
447
448 assert_eq!(url, "postgresql://user:pass@localhost:5432/mydb");
450 }
451
452 #[rstest]
453 fn test_display_output_masks_credentials() {
454 let db = DatabaseConfig::postgresql("mydb", "admin", "s3cr3t!", "db.example.com", 5432);
456
457 let display_output = format!("{}", db);
459
460 assert!(!display_output.contains("admin"));
462 assert!(!display_output.contains("s3cr3t!"));
463 assert!(display_output.contains("***@"));
464 assert!(display_output.contains("db.example.com"));
465 assert!(display_output.contains("mydb"));
466 }
467
468 #[rstest]
469 fn test_display_output_sqlite() {
470 let db = DatabaseConfig::sqlite("app.db");
472
473 let display_output = format!("{}", db);
475
476 assert_eq!(display_output, "sqlite:app.db");
478 }
479
480 #[rstest]
481 fn test_password_stored_as_secret_string() {
482 let db = DatabaseConfig::postgresql("mydb", "user", "my-secret-pw", "localhost", 5432);
484
485 let password = db.password.as_ref().unwrap();
487
488 assert_eq!(password.expose_secret(), "my-secret-pw");
490 assert_eq!(format!("{}", password), "[REDACTED]");
492 }
493
494 #[rstest]
495 #[case("postgres://localhost/db")]
496 #[case("postgresql://user:pass@localhost:5432/db")]
497 #[case("sqlite::memory:")]
498 #[case("sqlite:///path/to/db")]
499 #[case("mysql://root@localhost/db")]
500 #[case("mariadb://root@localhost/db")]
501 fn test_valid_database_url_schemes(#[case] url: &str) {
502 assert!(validate_database_url_scheme(url).is_ok());
504 }
505
506 #[rstest]
507 #[case("http://localhost/db")]
508 #[case("ftp://localhost/db")]
509 #[case("redis://localhost")]
510 #[case("")]
511 #[case("not-a-url")]
512 fn test_invalid_database_url_schemes(#[case] url: &str) {
513 let result = validate_database_url_scheme(url);
515
516 assert!(result.is_err());
518 assert!(result.unwrap_err().contains("Invalid database URL"));
519 }
520}