1use std::collections::HashMap;
6use std::env;
7use std::path::Path;
8
9use base64::Engine;
10
11use crate::types::QualifiedIdentifier;
12
13use super::error::ConfigError;
14use super::jwt::parse_js_path;
15use super::types::{AppConfig, IsolationLevel, LogLevel, OpenApiMode};
16
17pub async fn load_config(
35 file_path: Option<&Path>,
36 db_settings: HashMap<String, String>,
37) -> Result<AppConfig, ConfigError> {
38 let mut config = AppConfig::default();
39
40 if let Some(path) = file_path {
42 let file_contents = tokio::fs::read_to_string(path).await?;
43 parse_config_file(&file_contents, &mut config)?;
44 config.config_file_path = Some(path.to_path_buf());
45 }
46
47 apply_env_overrides(&mut config)?;
49
50 for (key, value) in db_settings {
52 let _ = apply_config_value(&mut config, &key, &value);
53 }
54
55 post_process_config(&mut config)?;
57
58 validate_config(&config)?;
60
61 Ok(config)
62}
63
64fn parse_config_file(contents: &str, config: &mut AppConfig) -> Result<(), ConfigError> {
68 for (line_num, line) in contents.lines().enumerate() {
69 let line = line.trim();
70
71 if line.is_empty() || line.starts_with('#') || line.starts_with("--") {
73 continue;
74 }
75
76 if let Some((key, value)) = parse_key_value(line) {
78 apply_config_value(config, &key, &value).map_err(|e| match e {
79 ConfigError::InvalidValue { .. } => ConfigError::Parse {
80 line: Some(line_num + 1),
81 message: e.to_string(),
82 },
83 _ => e,
84 })?;
85 }
86 }
87
88 Ok(())
89}
90
91fn parse_key_value(line: &str) -> Option<(String, String)> {
93 let mut parts = line.splitn(2, '=');
94 let key = parts.next()?.trim().to_string();
95 let value = parts.next()?.trim();
96
97 let value = value
99 .trim_start_matches('"')
100 .trim_end_matches('"')
101 .trim_start_matches('\'')
102 .trim_end_matches('\'')
103 .to_string();
104
105 if key.is_empty() {
106 return None;
107 }
108
109 Some((key, value))
110}
111
112fn apply_env_overrides(config: &mut AppConfig) -> Result<(), ConfigError> {
117 for (key, value) in env::vars() {
118 if let Some(config_key) = key.strip_prefix("DBRST_") {
119 let config_key = config_key.to_lowercase().replace('_', "-");
120 let _ = apply_config_value(config, &config_key, &value);
122 }
123 }
124 Ok(())
125}
126
127pub fn apply_config_value(
129 config: &mut AppConfig,
130 key: &str,
131 value: &str,
132) -> Result<(), ConfigError> {
133 match key {
134 "db-uri" => config.db_uri = value.to_string(),
136 "db-schemas" | "db-schema" => {
137 config.db_schemas = parse_comma_list(value);
138 }
139 "db-anon-role" => {
140 config.db_anon_role = if value.is_empty() {
141 None
142 } else {
143 Some(value.to_string())
144 };
145 }
146 "db-pool" => {
147 config.db_pool_size = parse_int(key, value)?;
148 }
149 "db-pool-acquisition-timeout" => {
150 config.db_pool_acquisition_timeout = parse_int(key, value)?;
151 }
152 "db-pool-max-lifetime" => {
153 config.db_pool_max_lifetime = parse_int(key, value)?;
154 }
155 "db-pool-max-idletime" | "db-pool-timeout" => {
156 config.db_pool_max_idletime = parse_int(key, value)?;
157 }
158 "db-pool-automatic-recovery" => {
159 config.db_pool_automatic_recovery = parse_bool(value)?;
160 }
161 "db-prepared-statements" => {
162 config.db_prepared_statements = parse_bool(value)?;
163 }
164 "db-pre-request" | "pre-request" => {
165 config.db_pre_request = if value.is_empty() {
166 None
167 } else {
168 Some(parse_qualified_identifier(key, value)?)
169 };
170 }
171 "db-root-spec" | "root-spec" => {
172 config.db_root_spec = if value.is_empty() {
173 None
174 } else {
175 Some(parse_qualified_identifier(key, value)?)
176 };
177 }
178 "db-extra-search-path" => {
179 config.db_extra_search_path = parse_comma_list(value);
180 }
181 "db-hoisted-tx-settings" => {
182 config.db_hoisted_tx_settings = parse_comma_list(value);
183 }
184 "db-max-rows" | "max-rows" => {
185 config.db_max_rows = if value.is_empty() {
186 None
187 } else {
188 Some(parse_int(key, value)?)
189 };
190 }
191 "db-plan-enabled" => {
192 config.db_plan_enabled = parse_bool(value)?;
193 }
194 "db-tx-end" => match value {
195 "commit" => {
196 config.db_tx_rollback_all = false;
197 config.db_tx_allow_override = false;
198 }
199 "commit-allow-override" => {
200 config.db_tx_rollback_all = false;
201 config.db_tx_allow_override = true;
202 }
203 "rollback" => {
204 config.db_tx_rollback_all = true;
205 config.db_tx_allow_override = false;
206 }
207 "rollback-allow-override" => {
208 config.db_tx_rollback_all = true;
209 config.db_tx_allow_override = true;
210 }
211 _ => {
212 return Err(ConfigError::InvalidValue {
213 key: key.to_string(),
214 value: value.to_string(),
215 expected: Some(
216 "commit, commit-allow-override, rollback, rollback-allow-override"
217 .to_string(),
218 ),
219 });
220 }
221 },
222 "db-tx-read-isolation" => {
223 config.db_tx_read_isolation = parse_isolation_level(value)?;
224 }
225 "db-tx-write-isolation" => {
226 config.db_tx_write_isolation = parse_isolation_level(value)?;
227 }
228 "db-aggregates-enabled" => {
229 config.db_aggregates_enabled = parse_bool(value)?;
230 }
231 "db-config" => {
232 config.db_config = parse_bool(value)?;
233 }
234 "db-pre-config" => {
235 config.db_pre_config = if value.is_empty() {
236 None
237 } else {
238 Some(parse_qualified_identifier(key, value)?)
239 };
240 }
241 "db-channel" => {
242 config.db_channel = value.to_string();
243 }
244 "db-channel-enabled" => {
245 config.db_channel_enabled = parse_bool(value)?;
246 }
247
248 "server-host" => config.server_host = value.to_string(),
250 "server-port" => {
251 config.server_port = parse_int(key, value)?;
252 }
253 "server-unix-socket" => {
254 config.server_unix_socket = if value.is_empty() {
255 None
256 } else {
257 Some(value.into())
258 };
259 }
260 "server-unix-socket-mode" => {
261 config.server_unix_socket_mode =
262 u32::from_str_radix(value, 8).map_err(|_| ConfigError::InvalidValue {
263 key: key.to_string(),
264 value: value.to_string(),
265 expected: Some("octal number (e.g., 660)".to_string()),
266 })?;
267 }
268 "server-cors-allowed-origins" => {
269 config.server_cors_allowed_origins = if value.is_empty() {
270 None
271 } else {
272 Some(parse_comma_list(value))
273 };
274 }
275 "server-trace-header" => {
276 config.server_trace_header = if value.is_empty() {
277 None
278 } else {
279 Some(value.to_string())
280 };
281 }
282 "server-timing-enabled" => {
283 config.server_timing_enabled = parse_bool(value)?;
284 }
285 "server-max-body-size" => {
286 config.server_max_body_size = parse_int(key, value)?;
287 }
288
289 "admin-server-host" => config.admin_server_host = value.to_string(),
291 "admin-server-port" => {
292 config.admin_server_port = if value.is_empty() {
293 None
294 } else {
295 Some(parse_int(key, value)?)
296 };
297 }
298
299 "jwt-secret" => {
301 config.jwt_secret = if value.is_empty() {
302 None
303 } else {
304 Some(value.to_string())
305 };
306 }
307 "jwt-secret-is-base64" | "secret-is-base64" => {
308 config.jwt_secret_is_base64 = parse_bool(value)?;
309 }
310 "jwt-aud" => {
311 config.jwt_aud = if value.is_empty() {
312 None
313 } else {
314 Some(value.to_string())
315 };
316 }
317 "jwt-role-claim-key" | "role-claim-key" => {
318 config.jwt_role_claim_key = parse_js_path(value)?;
319 }
320 "jwt-cache-max-entries" => {
321 config.jwt_cache_max_entries = parse_int(key, value)?;
322 }
323
324 "log-level" => {
326 config.log_level = LogLevel::parse(value).ok_or_else(|| ConfigError::InvalidValue {
327 key: key.to_string(),
328 value: value.to_string(),
329 expected: Some("crit, error, warn, info, debug".to_string()),
330 })?;
331 }
332 "log-query" => {
333 config.log_query = parse_bool(value)?;
334 }
335
336 "openapi-mode" => {
338 config.openapi_mode =
339 OpenApiMode::parse(value).ok_or_else(|| ConfigError::InvalidValue {
340 key: key.to_string(),
341 value: value.to_string(),
342 expected: Some("follow-privileges, ignore-privileges, disabled".to_string()),
343 })?;
344 }
345 "openapi-security-active" => {
346 config.openapi_security_active = parse_bool(value)?;
347 }
348 "openapi-server-proxy-uri" => {
349 config.openapi_server_proxy_uri = if value.is_empty() {
350 None
351 } else {
352 Some(value.to_string())
353 };
354 }
355
356 "server-streaming-enabled" => {
358 config.server_streaming_enabled = parse_bool(value)?;
359 }
360 "server-streaming-threshold" => {
361 config.server_streaming_threshold =
362 value
363 .parse::<u64>()
364 .map_err(|_| ConfigError::InvalidValue {
365 key: key.to_string(),
366 value: value.to_string(),
367 expected: Some("positive integer (bytes)".to_string()),
368 })?;
369 }
370
371 key if key.starts_with("app.settings.") => {
373 if let Some(setting_key) = key.strip_prefix("app.settings.") {
374 config
375 .app_settings
376 .insert(setting_key.to_string(), value.to_string());
377 }
378 }
379
380 _ => {
382 tracing::debug!("Unknown config key: {}", key);
383 }
384 }
385
386 Ok(())
387}
388
389pub fn parse_bool(value: &str) -> Result<bool, ConfigError> {
391 match value.to_lowercase().as_str() {
392 "true" | "yes" | "on" | "1" => Ok(true),
393 "false" | "no" | "off" | "0" => Ok(false),
394 _ => Err(ConfigError::InvalidBool(value.to_string())),
395 }
396}
397
398fn parse_int<T: std::str::FromStr>(key: &str, value: &str) -> Result<T, ConfigError>
400where
401 T::Err: std::fmt::Display,
402{
403 value
404 .parse()
405 .map_err(|e: T::Err| ConfigError::InvalidValue {
406 key: key.to_string(),
407 value: value.to_string(),
408 expected: Some(format!("integer ({})", e)),
409 })
410}
411
412fn parse_isolation_level(value: &str) -> Result<IsolationLevel, ConfigError> {
414 match value.to_lowercase().as_str() {
415 "read-committed" | "readcommitted" => Ok(IsolationLevel::ReadCommitted),
416 "repeatable-read" | "repeatableread" => Ok(IsolationLevel::RepeatableRead),
417 "serializable" => Ok(IsolationLevel::Serializable),
418 _ => Err(ConfigError::InvalidValue {
419 key: "isolation-level".to_string(),
420 value: value.to_string(),
421 expected: Some("read-committed, repeatable-read, serializable".to_string()),
422 }),
423 }
424}
425
426fn parse_comma_list(value: &str) -> Vec<String> {
428 if value.is_empty() {
429 vec![]
430 } else {
431 value.split(',').map(|s| s.trim().to_string()).collect()
432 }
433}
434
435fn parse_qualified_identifier(key: &str, value: &str) -> Result<QualifiedIdentifier, ConfigError> {
437 QualifiedIdentifier::parse(value).map_err(|_| ConfigError::InvalidValue {
438 key: key.to_string(),
439 value: value.to_string(),
440 expected: Some("qualified identifier (schema.name or name)".to_string()),
441 })
442}
443
444fn post_process_config(config: &mut AppConfig) -> Result<(), ConfigError> {
446 if config.jwt_secret_is_base64
448 && let Some(ref secret) = config.jwt_secret
449 {
450 let decoded = base64::engine::general_purpose::STANDARD.decode(secret)?;
451 config.jwt_secret = Some(String::from_utf8(decoded)?);
452 }
453
454 if !config.db_uri.contains("application_name") {
456 let separator = if config.db_uri.contains('?') {
457 "&"
458 } else {
459 "?"
460 };
461 config.db_uri = format!(
462 "{}{}fallback_application_name=dbrest",
463 config.db_uri, separator
464 );
465 }
466
467 Ok(())
468}
469
470pub fn validate_config(config: &AppConfig) -> Result<(), ConfigError> {
472 if config.db_schemas.is_empty() {
474 return Err(ConfigError::Validation(
475 "db-schemas cannot be empty".to_string(),
476 ));
477 }
478
479 for schema in &config.db_schemas {
481 if schema == "pg_catalog" || schema == "information_schema" {
482 return Err(ConfigError::Validation(format!(
483 "db-schemas cannot include system schema: '{}'",
484 schema
485 )));
486 }
487 }
488
489 if let Some(admin_port) = config.admin_server_port
491 && admin_port == config.server_port
492 {
493 return Err(ConfigError::Validation(
494 "admin-server-port cannot be the same as server-port".to_string(),
495 ));
496 }
497
498 if let Some(ref secret) = config.jwt_secret {
500 let is_jwks = secret.trim().starts_with('{');
502 if !is_jwks && secret.len() < 32 {
503 return Err(ConfigError::Validation(
504 "jwt-secret must be at least 32 characters long".to_string(),
505 ));
506 }
507 }
508
509 if config.db_pool_size == 0 {
511 return Err(ConfigError::Validation(
512 "db-pool must be greater than 0".to_string(),
513 ));
514 }
515
516 Ok(())
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn test_parse_key_value() {
525 assert_eq!(
526 parse_key_value("key=value"),
527 Some(("key".to_string(), "value".to_string()))
528 );
529 assert_eq!(
530 parse_key_value("key = value"),
531 Some(("key".to_string(), "value".to_string()))
532 );
533 assert_eq!(
534 parse_key_value("key=\"value\""),
535 Some(("key".to_string(), "value".to_string()))
536 );
537 assert_eq!(
538 parse_key_value("key='value'"),
539 Some(("key".to_string(), "value".to_string()))
540 );
541 assert_eq!(parse_key_value("no_equals"), None);
542 assert_eq!(parse_key_value("=value"), None);
543 }
544
545 #[test]
546 fn test_parse_bool() {
547 assert!(parse_bool("true").unwrap());
548 assert!(parse_bool("TRUE").unwrap());
549 assert!(parse_bool("yes").unwrap());
550 assert!(parse_bool("on").unwrap());
551 assert!(parse_bool("1").unwrap());
552
553 assert!(!parse_bool("false").unwrap());
554 assert!(!parse_bool("FALSE").unwrap());
555 assert!(!parse_bool("no").unwrap());
556 assert!(!parse_bool("off").unwrap());
557 assert!(!parse_bool("0").unwrap());
558
559 assert!(parse_bool("maybe").is_err());
560 }
561
562 #[test]
563 fn test_parse_comma_list() {
564 assert_eq!(parse_comma_list("a,b,c"), vec!["a", "b", "c"]);
565 assert_eq!(parse_comma_list("a, b, c"), vec!["a", "b", "c"]);
566 assert_eq!(parse_comma_list("single"), vec!["single"]);
567 assert!(parse_comma_list("").is_empty());
568 }
569
570 #[test]
571 fn test_apply_config_value() {
572 let mut config = AppConfig::default();
573
574 apply_config_value(&mut config, "server-port", "8080").unwrap();
575 assert_eq!(config.server_port, 8080);
576
577 apply_config_value(&mut config, "db-schemas", "api,public").unwrap();
578 assert_eq!(config.db_schemas, vec!["api", "public"]);
579
580 apply_config_value(&mut config, "db-pool", "20").unwrap();
581 assert_eq!(config.db_pool_size, 20);
582
583 apply_config_value(&mut config, "log-level", "debug").unwrap();
584 assert_eq!(config.log_level, LogLevel::Debug);
585 }
586
587 #[test]
588 fn test_apply_config_tx_end() {
589 let mut config = AppConfig::default();
590
591 apply_config_value(&mut config, "db-tx-end", "commit").unwrap();
592 assert!(!config.db_tx_rollback_all);
593 assert!(!config.db_tx_allow_override);
594
595 apply_config_value(&mut config, "db-tx-end", "rollback-allow-override").unwrap();
596 assert!(config.db_tx_rollback_all);
597 assert!(config.db_tx_allow_override);
598 }
599
600 #[test]
601 fn test_apply_config_app_settings() {
602 let mut config = AppConfig::default();
603
604 apply_config_value(&mut config, "app.settings.my-key", "my-value").unwrap();
605 assert_eq!(
606 config.app_settings.get("my-key"),
607 Some(&"my-value".to_string())
608 );
609 }
610
611 #[test]
612 fn test_validate_config_empty_schemas() {
613 let mut config = AppConfig::default();
614 config.db_schemas = vec![];
615 assert!(validate_config(&config).is_err());
616 }
617
618 #[test]
619 fn test_validate_config_system_schema() {
620 let mut config = AppConfig::default();
621 config.db_schemas = vec!["pg_catalog".to_string()];
622 assert!(validate_config(&config).is_err());
623
624 config.db_schemas = vec!["information_schema".to_string()];
625 assert!(validate_config(&config).is_err());
626 }
627
628 #[test]
629 fn test_validate_config_same_ports() {
630 let mut config = AppConfig::default();
631 config.server_port = 3000;
632 config.admin_server_port = Some(3000);
633 assert!(validate_config(&config).is_err());
634 }
635
636 #[test]
637 fn test_validate_config_short_jwt_secret() {
638 let mut config = AppConfig::default();
639 config.jwt_secret = Some("short".to_string());
640 assert!(validate_config(&config).is_err());
641
642 config.jwt_secret = Some("a".repeat(32));
643 assert!(validate_config(&config).is_ok());
644 }
645
646 #[test]
647 fn test_validate_config_jwks_bypass() {
648 let mut config = AppConfig::default();
649 config.jwt_secret = Some("{\"keys\":[]}".to_string());
651 assert!(validate_config(&config).is_ok());
652 }
653
654 #[test]
655 fn test_parse_config_file() {
656 let contents = r#"
657# Comment line
658server-port = 8080
659db-schemas = api, public
660log-level = debug
661
662-- Another comment style
663db-pool = 25
664"#;
665
666 let mut config = AppConfig::default();
667 parse_config_file(contents, &mut config).unwrap();
668
669 assert_eq!(config.server_port, 8080);
670 assert_eq!(config.db_schemas, vec!["api", "public"]);
671 assert_eq!(config.log_level, LogLevel::Debug);
672 assert_eq!(config.db_pool_size, 25);
673 }
674
675 #[tokio::test]
676 async fn test_load_config_defaults() {
677 let config = load_config(None, HashMap::new()).await.unwrap();
678 assert_eq!(config.server_port, 3000);
679 assert_eq!(config.db_schemas, vec!["public"]);
680 }
681}