1use std::collections::HashMap;
6use std::env;
7use std::path::Path;
8
9use base64::Engine;
10
11use crate::types::QualifiedIdentifier;
12
13use super::error::ConfigError;
14use super::jwt::parse_js_path;
15use super::types::{AppConfig, IsolationLevel, LogLevel, OpenApiMode};
16
17pub async fn load_config(
35 file_path: Option<&Path>,
36 db_settings: HashMap<String, String>,
37) -> Result<AppConfig, ConfigError> {
38 let mut config = AppConfig::default();
39
40 if let Some(path) = file_path {
42 let file_contents = tokio::fs::read_to_string(path).await?;
43 parse_config_file(&file_contents, &mut config)?;
44 config.config_file_path = Some(path.to_path_buf());
45 }
46
47 apply_env_overrides(&mut config)?;
49
50 for (key, value) in db_settings {
52 let _ = apply_config_value(&mut config, &key, &value);
53 }
54
55 post_process_config(&mut config)?;
57
58 validate_config(&config)?;
60
61 Ok(config)
62}
63
64fn parse_config_file(contents: &str, config: &mut AppConfig) -> Result<(), ConfigError> {
68 for (line_num, line) in contents.lines().enumerate() {
69 let line = line.trim();
70
71 if line.is_empty() || line.starts_with('#') || line.starts_with("--") {
73 continue;
74 }
75
76 if let Some((key, value)) = parse_key_value(line) {
78 apply_config_value(config, &key, &value).map_err(|e| match e {
79 ConfigError::InvalidValue { .. } => ConfigError::Parse {
80 line: Some(line_num + 1),
81 message: e.to_string(),
82 },
83 _ => e,
84 })?;
85 }
86 }
87
88 Ok(())
89}
90
91fn parse_key_value(line: &str) -> Option<(String, String)> {
93 let mut parts = line.splitn(2, '=');
94 let key = parts.next()?.trim().to_string();
95 let value = parts.next()?.trim();
96
97 let value = value
99 .trim_start_matches('"')
100 .trim_end_matches('"')
101 .trim_start_matches('\'')
102 .trim_end_matches('\'')
103 .to_string();
104
105 if key.is_empty() {
106 return None;
107 }
108
109 Some((key, value))
110}
111
112fn apply_env_overrides(config: &mut AppConfig) -> Result<(), ConfigError> {
117 for (key, value) in env::vars() {
118 if let Some(config_key) = key.strip_prefix("DBRST_") {
119 let config_key = config_key.to_lowercase().replace('_', "-");
120 let _ = apply_config_value(config, &config_key, &value);
122 }
123 }
124 Ok(())
125}
126
127pub fn apply_config_value(
129 config: &mut AppConfig,
130 key: &str,
131 value: &str,
132) -> Result<(), ConfigError> {
133 match key {
134 "db-uri" => config.db_uri = value.to_string(),
136 "db-schemas" | "db-schema" => {
137 config.db_schemas = parse_comma_list(value);
138 }
139 "db-anon-role" => {
140 config.db_anon_role = if value.is_empty() {
141 None
142 } else {
143 Some(value.to_string())
144 };
145 }
146 "db-pool" => {
147 config.db_pool_size = parse_int(key, value)?;
148 }
149 "db-pool-acquisition-timeout" => {
150 config.db_pool_acquisition_timeout = parse_int(key, value)?;
151 }
152 "db-pool-max-lifetime" => {
153 config.db_pool_max_lifetime = parse_int(key, value)?;
154 }
155 "db-pool-max-idletime" | "db-pool-timeout" => {
156 config.db_pool_max_idletime = parse_int(key, value)?;
157 }
158 "db-pool-automatic-recovery" => {
159 config.db_pool_automatic_recovery = parse_bool(value)?;
160 }
161 "db-prepared-statements" => {
162 config.db_prepared_statements = parse_bool(value)?;
163 }
164 "db-pre-request" | "pre-request" => {
165 config.db_pre_request = if value.is_empty() {
166 None
167 } else {
168 Some(parse_qualified_identifier(key, value)?)
169 };
170 }
171 "db-root-spec" | "root-spec" => {
172 config.db_root_spec = if value.is_empty() {
173 None
174 } else {
175 Some(parse_qualified_identifier(key, value)?)
176 };
177 }
178 "db-extra-search-path" => {
179 config.db_extra_search_path = parse_comma_list(value);
180 }
181 "db-hoisted-tx-settings" => {
182 config.db_hoisted_tx_settings = parse_comma_list(value);
183 }
184 "db-max-rows" | "max-rows" => {
185 config.db_max_rows = if value.is_empty() {
186 None
187 } else {
188 Some(parse_int(key, value)?)
189 };
190 }
191 "db-plan-enabled" => {
192 config.db_plan_enabled = parse_bool(value)?;
193 }
194 "db-tx-end" => match value {
195 "commit" => {
196 config.db_tx_rollback_all = false;
197 config.db_tx_allow_override = false;
198 }
199 "commit-allow-override" => {
200 config.db_tx_rollback_all = false;
201 config.db_tx_allow_override = true;
202 }
203 "rollback" => {
204 config.db_tx_rollback_all = true;
205 config.db_tx_allow_override = false;
206 }
207 "rollback-allow-override" => {
208 config.db_tx_rollback_all = true;
209 config.db_tx_allow_override = true;
210 }
211 _ => {
212 return Err(ConfigError::InvalidValue {
213 key: key.to_string(),
214 value: value.to_string(),
215 expected: Some(
216 "commit, commit-allow-override, rollback, rollback-allow-override"
217 .to_string(),
218 ),
219 });
220 }
221 },
222 "db-tx-read-isolation" => {
223 config.db_tx_read_isolation = parse_isolation_level(value)?;
224 }
225 "db-tx-write-isolation" => {
226 config.db_tx_write_isolation = parse_isolation_level(value)?;
227 }
228 "db-aggregates-enabled" => {
229 config.db_aggregates_enabled = parse_bool(value)?;
230 }
231 "db-config" => {
232 config.db_config = parse_bool(value)?;
233 }
234 "db-pre-config" => {
235 config.db_pre_config = if value.is_empty() {
236 None
237 } else {
238 Some(parse_qualified_identifier(key, value)?)
239 };
240 }
241 "db-channel" => {
242 config.db_channel = value.to_string();
243 }
244 "db-channel-enabled" => {
245 config.db_channel_enabled = parse_bool(value)?;
246 }
247
248 "server-host" => config.server_host = value.to_string(),
250 "server-port" => {
251 config.server_port = parse_int(key, value)?;
252 }
253 "server-unix-socket" => {
254 config.server_unix_socket = if value.is_empty() {
255 None
256 } else {
257 Some(value.into())
258 };
259 }
260 "server-unix-socket-mode" => {
261 config.server_unix_socket_mode =
262 u32::from_str_radix(value, 8).map_err(|_| ConfigError::InvalidValue {
263 key: key.to_string(),
264 value: value.to_string(),
265 expected: Some("octal number (e.g., 660)".to_string()),
266 })?;
267 }
268 "server-cors-allowed-origins" => {
269 config.server_cors_allowed_origins = if value.is_empty() {
270 None
271 } else {
272 Some(parse_comma_list(value))
273 };
274 }
275 "server-trace-header" => {
276 config.server_trace_header = if value.is_empty() {
277 None
278 } else {
279 Some(value.to_string())
280 };
281 }
282 "server-timing-enabled" => {
283 config.server_timing_enabled = parse_bool(value)?;
284 }
285 "server-max-body-size" => {
286 config.server_max_body_size = parse_int(key, value)?;
287 }
288
289 "admin-server-host" => config.admin_server_host = value.to_string(),
291 "admin-server-port" => {
292 config.admin_server_port = if value.is_empty() {
293 None
294 } else {
295 Some(parse_int(key, value)?)
296 };
297 }
298
299 "jwt-secret" => {
301 config.jwt_secret = if value.is_empty() {
302 None
303 } else {
304 Some(value.to_string())
305 };
306 }
307 "jwt-secret-is-base64" | "secret-is-base64" => {
308 config.jwt_secret_is_base64 = parse_bool(value)?;
309 }
310 "jwt-aud" => {
311 config.jwt_aud = if value.is_empty() {
312 None
313 } else {
314 Some(value.to_string())
315 };
316 }
317 "jwt-role-claim-key" | "role-claim-key" => {
318 config.jwt_role_claim_key = parse_js_path(value)?;
319 }
320 "jwt-cache-max-entries" => {
321 config.jwt_cache_max_entries = parse_int(key, value)?;
322 }
323
324 "log-level" => {
326 config.log_level = LogLevel::parse(value).ok_or_else(|| ConfigError::InvalidValue {
327 key: key.to_string(),
328 value: value.to_string(),
329 expected: Some("crit, error, warn, info, debug".to_string()),
330 })?;
331 }
332 "log-query" => {
333 config.log_query = parse_bool(value)?;
334 }
335
336 "openapi-mode" => {
338 config.openapi_mode =
339 OpenApiMode::parse(value).ok_or_else(|| ConfigError::InvalidValue {
340 key: key.to_string(),
341 value: value.to_string(),
342 expected: Some("follow-privileges, ignore-privileges, disabled".to_string()),
343 })?;
344 }
345 "openapi-security-active" => {
346 config.openapi_security_active = parse_bool(value)?;
347 }
348 "openapi-server-proxy-uri" => {
349 config.openapi_server_proxy_uri = if value.is_empty() {
350 None
351 } else {
352 Some(value.to_string())
353 };
354 }
355
356 "server-streaming-enabled" => {
358 config.server_streaming_enabled = parse_bool(value)?;
359 }
360 "server-streaming-threshold" => {
361 config.server_streaming_threshold =
362 value
363 .parse::<u64>()
364 .map_err(|_| ConfigError::InvalidValue {
365 key: key.to_string(),
366 value: value.to_string(),
367 expected: Some("positive integer (bytes)".to_string()),
368 })?;
369 }
370
371 "metrics-enabled" => {
373 config.metrics_enabled = parse_bool(value)?;
374 }
375 "metrics-otlp-endpoint" => config.metrics_otlp_endpoint = value.to_string(),
376 "metrics-otlp-protocol" => config.metrics_otlp_protocol = value.to_string(),
377 "metrics-export-interval-secs" => {
378 config.metrics_export_interval_secs = parse_int(key, value)?;
379 }
380 "metrics-service-name" => config.metrics_service_name = value.to_string(),
381 "tracing-enabled" => {
382 config.tracing_enabled = parse_bool(value)?;
383 }
384 "tracing-sampling-ratio" => {
385 config.tracing_sampling_ratio =
386 value
387 .parse::<f64>()
388 .map_err(|_| ConfigError::InvalidValue {
389 key: key.to_string(),
390 value: value.to_string(),
391 expected: Some("float between 0.0 and 1.0".to_string()),
392 })?;
393 }
394
395 key if key.starts_with("app.settings.") => {
397 if let Some(setting_key) = key.strip_prefix("app.settings.") {
398 config
399 .app_settings
400 .insert(setting_key.to_string(), value.to_string());
401 }
402 }
403
404 _ => {
406 tracing::debug!("Unknown config key: {}", key);
407 }
408 }
409
410 Ok(())
411}
412
413pub fn parse_bool(value: &str) -> Result<bool, ConfigError> {
415 match value.to_lowercase().as_str() {
416 "true" | "yes" | "on" | "1" => Ok(true),
417 "false" | "no" | "off" | "0" => Ok(false),
418 _ => Err(ConfigError::InvalidBool(value.to_string())),
419 }
420}
421
422fn parse_int<T: std::str::FromStr>(key: &str, value: &str) -> Result<T, ConfigError>
424where
425 T::Err: std::fmt::Display,
426{
427 value
428 .parse()
429 .map_err(|e: T::Err| ConfigError::InvalidValue {
430 key: key.to_string(),
431 value: value.to_string(),
432 expected: Some(format!("integer ({})", e)),
433 })
434}
435
436fn parse_isolation_level(value: &str) -> Result<IsolationLevel, ConfigError> {
438 match value.to_lowercase().as_str() {
439 "read-committed" | "readcommitted" => Ok(IsolationLevel::ReadCommitted),
440 "repeatable-read" | "repeatableread" => Ok(IsolationLevel::RepeatableRead),
441 "serializable" => Ok(IsolationLevel::Serializable),
442 _ => Err(ConfigError::InvalidValue {
443 key: "isolation-level".to_string(),
444 value: value.to_string(),
445 expected: Some("read-committed, repeatable-read, serializable".to_string()),
446 }),
447 }
448}
449
450fn parse_comma_list(value: &str) -> Vec<String> {
452 if value.is_empty() {
453 vec![]
454 } else {
455 value.split(',').map(|s| s.trim().to_string()).collect()
456 }
457}
458
459fn parse_qualified_identifier(key: &str, value: &str) -> Result<QualifiedIdentifier, ConfigError> {
461 QualifiedIdentifier::parse(value).map_err(|_| ConfigError::InvalidValue {
462 key: key.to_string(),
463 value: value.to_string(),
464 expected: Some("qualified identifier (schema.name or name)".to_string()),
465 })
466}
467
468fn post_process_config(config: &mut AppConfig) -> Result<(), ConfigError> {
470 if config.jwt_secret_is_base64
472 && let Some(ref secret) = config.jwt_secret
473 {
474 let decoded = base64::engine::general_purpose::STANDARD.decode(secret)?;
475 config.jwt_secret = Some(String::from_utf8(decoded)?);
476 }
477
478 if !config.db_uri.contains("application_name") {
480 let separator = if config.db_uri.contains('?') {
481 "&"
482 } else {
483 "?"
484 };
485 config.db_uri = format!(
486 "{}{}fallback_application_name=dbrest",
487 config.db_uri, separator
488 );
489 }
490
491 Ok(())
492}
493
494pub fn validate_config(config: &AppConfig) -> Result<(), ConfigError> {
496 if config.db_schemas.is_empty() {
498 return Err(ConfigError::Validation(
499 "db-schemas cannot be empty".to_string(),
500 ));
501 }
502
503 for schema in &config.db_schemas {
505 if schema == "pg_catalog" || schema == "information_schema" {
506 return Err(ConfigError::Validation(format!(
507 "db-schemas cannot include system schema: '{}'",
508 schema
509 )));
510 }
511 }
512
513 if let Some(admin_port) = config.admin_server_port
515 && admin_port == config.server_port
516 {
517 return Err(ConfigError::Validation(
518 "admin-server-port cannot be the same as server-port".to_string(),
519 ));
520 }
521
522 if let Some(ref secret) = config.jwt_secret {
524 let is_jwks = secret.trim().starts_with('{');
526 if !is_jwks && secret.len() < 32 {
527 return Err(ConfigError::Validation(
528 "jwt-secret must be at least 32 characters long".to_string(),
529 ));
530 }
531 }
532
533 if config.db_pool_size == 0 {
535 return Err(ConfigError::Validation(
536 "db-pool must be greater than 0".to_string(),
537 ));
538 }
539
540 Ok(())
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 #[test]
548 fn test_parse_key_value() {
549 assert_eq!(
550 parse_key_value("key=value"),
551 Some(("key".to_string(), "value".to_string()))
552 );
553 assert_eq!(
554 parse_key_value("key = value"),
555 Some(("key".to_string(), "value".to_string()))
556 );
557 assert_eq!(
558 parse_key_value("key=\"value\""),
559 Some(("key".to_string(), "value".to_string()))
560 );
561 assert_eq!(
562 parse_key_value("key='value'"),
563 Some(("key".to_string(), "value".to_string()))
564 );
565 assert_eq!(parse_key_value("no_equals"), None);
566 assert_eq!(parse_key_value("=value"), None);
567 }
568
569 #[test]
570 fn test_parse_bool() {
571 assert!(parse_bool("true").unwrap());
572 assert!(parse_bool("TRUE").unwrap());
573 assert!(parse_bool("yes").unwrap());
574 assert!(parse_bool("on").unwrap());
575 assert!(parse_bool("1").unwrap());
576
577 assert!(!parse_bool("false").unwrap());
578 assert!(!parse_bool("FALSE").unwrap());
579 assert!(!parse_bool("no").unwrap());
580 assert!(!parse_bool("off").unwrap());
581 assert!(!parse_bool("0").unwrap());
582
583 assert!(parse_bool("maybe").is_err());
584 }
585
586 #[test]
587 fn test_parse_comma_list() {
588 assert_eq!(parse_comma_list("a,b,c"), vec!["a", "b", "c"]);
589 assert_eq!(parse_comma_list("a, b, c"), vec!["a", "b", "c"]);
590 assert_eq!(parse_comma_list("single"), vec!["single"]);
591 assert!(parse_comma_list("").is_empty());
592 }
593
594 #[test]
595 fn test_apply_config_value() {
596 let mut config = AppConfig::default();
597
598 apply_config_value(&mut config, "server-port", "8080").unwrap();
599 assert_eq!(config.server_port, 8080);
600
601 apply_config_value(&mut config, "db-schemas", "api,public").unwrap();
602 assert_eq!(config.db_schemas, vec!["api", "public"]);
603
604 apply_config_value(&mut config, "db-pool", "20").unwrap();
605 assert_eq!(config.db_pool_size, 20);
606
607 apply_config_value(&mut config, "log-level", "debug").unwrap();
608 assert_eq!(config.log_level, LogLevel::Debug);
609 }
610
611 #[test]
612 fn test_apply_config_tx_end() {
613 let mut config = AppConfig::default();
614
615 apply_config_value(&mut config, "db-tx-end", "commit").unwrap();
616 assert!(!config.db_tx_rollback_all);
617 assert!(!config.db_tx_allow_override);
618
619 apply_config_value(&mut config, "db-tx-end", "rollback-allow-override").unwrap();
620 assert!(config.db_tx_rollback_all);
621 assert!(config.db_tx_allow_override);
622 }
623
624 #[test]
625 fn test_apply_config_app_settings() {
626 let mut config = AppConfig::default();
627
628 apply_config_value(&mut config, "app.settings.my-key", "my-value").unwrap();
629 assert_eq!(
630 config.app_settings.get("my-key"),
631 Some(&"my-value".to_string())
632 );
633 }
634
635 #[test]
636 fn test_validate_config_empty_schemas() {
637 let mut config = AppConfig::default();
638 config.db_schemas = vec![];
639 assert!(validate_config(&config).is_err());
640 }
641
642 #[test]
643 fn test_validate_config_system_schema() {
644 let mut config = AppConfig::default();
645 config.db_schemas = vec!["pg_catalog".to_string()];
646 assert!(validate_config(&config).is_err());
647
648 config.db_schemas = vec!["information_schema".to_string()];
649 assert!(validate_config(&config).is_err());
650 }
651
652 #[test]
653 fn test_validate_config_same_ports() {
654 let mut config = AppConfig::default();
655 config.server_port = 3000;
656 config.admin_server_port = Some(3000);
657 assert!(validate_config(&config).is_err());
658 }
659
660 #[test]
661 fn test_validate_config_short_jwt_secret() {
662 let mut config = AppConfig::default();
663 config.jwt_secret = Some("short".to_string());
664 assert!(validate_config(&config).is_err());
665
666 config.jwt_secret = Some("a".repeat(32));
667 assert!(validate_config(&config).is_ok());
668 }
669
670 #[test]
671 fn test_validate_config_jwks_bypass() {
672 let mut config = AppConfig::default();
673 config.jwt_secret = Some("{\"keys\":[]}".to_string());
675 assert!(validate_config(&config).is_ok());
676 }
677
678 #[test]
679 fn test_parse_config_file() {
680 let contents = r#"
681# Comment line
682server-port = 8080
683db-schemas = api, public
684log-level = debug
685
686-- Another comment style
687db-pool = 25
688"#;
689
690 let mut config = AppConfig::default();
691 parse_config_file(contents, &mut config).unwrap();
692
693 assert_eq!(config.server_port, 8080);
694 assert_eq!(config.db_schemas, vec!["api", "public"]);
695 assert_eq!(config.log_level, LogLevel::Debug);
696 assert_eq!(config.db_pool_size, 25);
697 }
698
699 #[tokio::test]
700 async fn test_load_config_defaults() {
701 let config = load_config(None, HashMap::new()).await.unwrap();
702 assert_eq!(config.server_port, 3000);
703 assert_eq!(config.db_schemas, vec!["public"]);
704 }
705}