1use serde::de::DeserializeOwned;
23use serde_yaml::Value;
24use std::env;
25use std::fs;
26use std::path::{Path, PathBuf};
27use std::str::FromStr;
28use thiserror::Error;
29
30#[derive(Debug, Error)]
31pub enum YmlError {
32 #[error("failed to read yaml file at {path}: {source}")]
33 ReadFile {
34 path: PathBuf,
35 #[source]
36 source: std::io::Error,
37 },
38 #[error("failed to parse yaml file at {path}: {source}")]
39 ParseFile {
40 path: PathBuf,
41 #[source]
42 source: serde_yaml::Error,
43 },
44 #[error("yaml path is invalid: {0}")]
45 InvalidPath(String),
46 #[error("failed to read current directory: {0}")]
47 CurrentDir(#[source] std::io::Error),
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub enum YamlPathSegment {
52 Key(String),
53 Index(usize),
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct YamlPath {
58 segments: Vec<YamlPathSegment>,
59}
60
61impl YamlPath {
62 pub fn parse(path: impl AsRef<str>) -> Result<Self, YmlError> {
63 let input = path.as_ref().trim();
64 if input.is_empty() {
65 return Err(YmlError::InvalidPath("path cannot be empty".to_owned()));
66 }
67
68 let chars: Vec<char> = input.chars().collect();
69 let mut index = 0usize;
70 let mut segments = Vec::new();
71
72 while index < chars.len() {
73 skip_whitespace(&chars, &mut index);
74
75 if index < chars.len() && chars[index] == '.' {
76 index += 1;
77 skip_whitespace(&chars, &mut index);
78 if index >= chars.len() {
79 return Err(YmlError::InvalidPath(format!(
80 "path `{input}` cannot end with `.`"
81 )));
82 }
83 }
84
85 if index >= chars.len() {
86 break;
87 }
88
89 if chars[index] == '[' {
90 segments.push(parse_bracket_segment(&chars, &mut index, input)?);
91 } else {
92 segments.push(parse_bare_segment(&chars, &mut index, input)?);
93 }
94
95 skip_whitespace(&chars, &mut index);
96 if index < chars.len() && chars[index] != '.' && chars[index] != '[' {
97 return Err(YmlError::InvalidPath(format!(
98 "unexpected character `{}` in path `{input}`",
99 chars[index]
100 )));
101 }
102 }
103
104 Ok(Self { segments })
105 }
106
107 pub fn segments(&self) -> &[YamlPathSegment] {
108 &self.segments
109 }
110}
111
112impl FromStr for YamlPath {
113 type Err = YmlError;
114
115 fn from_str(value: &str) -> Result<Self, Self::Err> {
116 Self::parse(value)
117 }
118}
119
120#[derive(Debug, Clone, PartialEq)]
121pub struct YamlDoc {
122 value: Value,
123}
124
125impl YamlDoc {
126 pub fn from_value(value: Value) -> Self {
127 Self { value }
128 }
129
130 pub fn as_value(&self) -> &Value {
131 &self.value
132 }
133
134 pub fn into_inner(self) -> Value {
135 self.value
136 }
137
138 pub fn get_path(&self, path: &YamlPath) -> Option<&Value> {
139 lookup_value(&self.value, path)
140 }
141
142 pub fn get(&self, path: &str) -> Result<Option<&Value>, YmlError> {
143 let parsed = YamlPath::parse(path)?;
144 Ok(self.get_path(&parsed))
145 }
146
147 pub fn get_string(&self, path: &str) -> Result<Option<String>, YmlError> {
148 let parsed = YamlPath::parse(path)?;
149 Ok(self.get_string_at(&parsed))
150 }
151
152 pub fn get_string_at(&self, path: &YamlPath) -> Option<String> {
153 self.get_path(path)
154 .and_then(stringify_scalar)
155 .map(|value| env_subst(&value))
156 }
157}
158
159pub trait YamlLookup {
160 fn yaml_lookup(&self, path: &YamlPath) -> Option<&Value>;
161}
162
163impl YamlLookup for YamlDoc {
164 fn yaml_lookup(&self, path: &YamlPath) -> Option<&Value> {
165 self.get_path(path)
166 }
167}
168
169impl YamlLookup for Value {
170 fn yaml_lookup(&self, path: &YamlPath) -> Option<&Value> {
171 lookup_value(self, path)
172 }
173}
174
175impl<T> YamlLookup for &T
176where
177 T: YamlLookup + ?Sized,
178{
179 fn yaml_lookup(&self, path: &YamlPath) -> Option<&Value> {
180 (*self).yaml_lookup(path)
181 }
182}
183
184pub fn get_yaml_path_value<'a, T>(doc: &'a T, path: &YamlPath) -> Option<&'a Value>
185where
186 T: YamlLookup + ?Sized,
187{
188 doc.yaml_lookup(path)
189}
190
191pub fn load_yaml<T, P>(path: P) -> Result<T, YmlError>
192where
193 T: DeserializeOwned,
194 P: AsRef<Path>,
195{
196 let path = path.as_ref();
197 let content = fs::read_to_string(path).map_err(|source| YmlError::ReadFile {
198 path: path.to_path_buf(),
199 source,
200 })?;
201 serde_yaml::from_str::<T>(&content).map_err(|source| YmlError::ParseFile {
202 path: path.to_path_buf(),
203 source,
204 })
205}
206
207pub fn load_yaml_value<P>(path: P) -> Result<YamlDoc, YmlError>
208where
209 P: AsRef<Path>,
210{
211 load_yaml::<Value, _>(path).map(YamlDoc::from_value)
212}
213
214pub fn env_subst(input: impl AsRef<str>) -> String {
215 let source = input.as_ref();
216 let mut result = String::with_capacity(source.len());
217 let mut cursor = 0usize;
218
219 while let Some(relative_start) = source[cursor..].find("${") {
220 let start = cursor + relative_start;
221 result.push_str(&source[cursor..start]);
222
223 let placeholder = &source[start + 2..];
224 if let Some(relative_end) = placeholder.find('}') {
225 let end = start + 2 + relative_end;
226 let body = &source[start + 2..end];
227 let (name, default_value) = body.split_once(':').unwrap_or((body, ""));
228 let value = env::var(name)
229 .ok()
230 .filter(|candidate| !candidate.trim().is_empty())
231 .unwrap_or_else(|| default_value.to_owned());
232 result.push_str(&value);
233 cursor = end + 1;
234 } else {
235 result.push_str(&source[start..]);
236 cursor = source.len();
237 break;
238 }
239 }
240
241 if cursor < source.len() {
242 result.push_str(&source[cursor..]);
243 }
244
245 result
246}
247
248#[derive(Debug, Clone, PartialEq, Eq)]
249pub struct SpringYaml {
250 root: PathBuf,
251}
252
253impl SpringYaml {
254 pub fn from_dir(path: impl Into<PathBuf>) -> Self {
255 Self { root: path.into() }
256 }
257
258 pub fn from_current_dir() -> Result<Self, YmlError> {
259 let root = env::current_dir().map_err(YmlError::CurrentDir)?;
260 Ok(Self { root })
261 }
262
263 pub fn root(&self) -> &Path {
264 &self.root
265 }
266
267 pub fn resolve_resource(&self, resource_name: &str) -> PathBuf {
268 let base_name = resource_name
269 .strip_suffix(".yml")
270 .or_else(|| resource_name.strip_suffix(".yaml"))
271 .unwrap_or(resource_name);
272
273 let extensions = if resource_name.contains('.') {
274 ["", ".yml", ".yaml"]
275 } else {
276 [".yml", ".yaml", ""]
277 };
278
279 for extension in extensions {
280 if extension.is_empty() && !resource_name.contains('.') {
281 continue;
282 }
283
284 let candidate = self.root.join(format!("{base_name}{extension}"));
285 if candidate.exists() {
286 return candidate;
287 }
288 }
289
290 let fallback_extension = if resource_name.contains('.') {
291 ""
292 } else {
293 ".yml"
294 };
295 self.root.join(format!("{base_name}{fallback_extension}"))
296 }
297
298 pub fn get_yml_content(&self, resource_name: &str) -> Result<String, YmlError> {
299 let path = self.resolve_resource(resource_name);
300 fs::read_to_string(&path).map_err(|source| YmlError::ReadFile { path, source })
301 }
302
303 pub fn load_named(&self, resource_name: &str) -> Result<YamlDoc, YmlError> {
304 let path = self.resolve_resource(resource_name);
305 load_yaml_value(path)
306 }
307
308 pub fn load_active(&self) -> Result<YamlDoc, YmlError> {
309 let primary = self.load_named("application")?;
310 let profile = primary
311 .get_string("spring.profiles.active")?
312 .filter(|value| !value.trim().is_empty());
313
314 if let Some(profile_name) = profile {
315 let active_path = self.resolve_resource(&format!("application-{profile_name}"));
316 if active_path.exists() {
317 return load_yaml_value(active_path);
318 }
319 }
320
321 Ok(primary)
322 }
323}
324
325#[derive(Clone, PartialEq, Eq)]
326pub struct DatabaseConfig {
327 pub jdbc_url: String,
328 pub jdbc_username: Option<String>,
329 pub jdbc_password: Option<String>,
330}
331
332impl std::fmt::Debug for DatabaseConfig {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 const REDACTED: &str = "***REDACTED***";
335
336 f.debug_struct("DatabaseConfig")
337 .field("jdbc_url", &self.jdbc_url)
338 .field("jdbc_username", &self.jdbc_username)
339 .field(
340 "jdbc_password",
341 &self.jdbc_password.as_ref().map(|_| REDACTED),
342 )
343 .finish()
344 }
345}
346
347#[derive(Debug, Default, Clone, Copy)]
348pub struct DatabaseConfigReader;
349
350impl DatabaseConfigReader {
351 pub fn read(
352 path: impl AsRef<Path>,
353 prefer_data_source_name: Option<&str>,
354 ) -> Result<Option<DatabaseConfig>, YmlError> {
355 let spring_yaml = SpringYaml::from_dir(path.as_ref().to_path_buf());
356 let active = spring_yaml.load_active()?;
357 Self::read_from_doc(&active, prefer_data_source_name)
358 }
359
360 pub fn read_from_doc(
361 doc: &YamlDoc,
362 prefer_data_source_name: Option<&str>,
363 ) -> Result<Option<DatabaseConfig>, YmlError> {
364 if let Some(name) = prefer_data_source_name {
365 if let Some(config) = Self::read_named_data_source(doc, name)? {
366 return Ok(Some(config));
367 }
368 }
369
370 for url_path in SINGLE_DATASOURCE_PATHS {
371 if let Some(url) = read_non_blank_property(doc, url_path)? {
372 let base_path = extract_base_path(url_path);
373 let username = read_non_blank_property(doc, &format!("{base_path}.username"))?
374 .or(read_non_blank_property(doc, "spring.datasource.username")?)
375 .or(read_non_blank_property(doc, "spring.r2dbc.username")?);
376 let password = read_non_blank_property(doc, &format!("{base_path}.password"))?
377 .or(read_non_blank_property(doc, "spring.datasource.password")?)
378 .or(read_non_blank_property(doc, "spring.r2dbc.password")?);
379
380 return Ok(Some(DatabaseConfig {
381 jdbc_url: url,
382 jdbc_username: username,
383 jdbc_password: password,
384 }));
385 }
386 }
387
388 for data_source_name in COMMON_DATA_SOURCE_NAMES {
389 if let Some(config) = Self::read_named_data_source(doc, data_source_name)? {
390 return Ok(Some(config));
391 }
392 }
393
394 Ok(None)
395 }
396
397 fn read_named_data_source(
398 doc: &YamlDoc,
399 data_source_name: &str,
400 ) -> Result<Option<DatabaseConfig>, YmlError> {
401 let url_paths = [
402 format!("spring.datasource.{data_source_name}.url"),
403 format!("spring.datasource.{data_source_name}.jdbc-url"),
404 format!("spring.datasource.dynamic.datasource.{data_source_name}.url"),
405 format!("spring.datasource.mp.datasource.{data_source_name}.url"),
406 ];
407
408 for url_path in url_paths {
409 if let Some(url) = read_non_blank_property(doc, &url_path)? {
410 let base_path = extract_base_path(&url_path);
411 let username = read_non_blank_property(doc, &format!("{base_path}.username"))?;
412 let password = read_non_blank_property(doc, &format!("{base_path}.password"))?;
413
414 return Ok(Some(DatabaseConfig {
415 jdbc_url: url,
416 jdbc_username: username,
417 jdbc_password: password,
418 }));
419 }
420 }
421
422 Ok(None)
423 }
424}
425
426const SINGLE_DATASOURCE_PATHS: &[&str] = &[
427 "spring.datasource.url",
428 "spring.datasource.jdbc-url",
429 "spring.r2dbc.url",
430 "spring.datasource.primary.url",
431 "spring.datasource.master.url",
432 "spring.datasource.default.url",
433 "spring.data.jdbc.url",
434];
435
436const COMMON_DATA_SOURCE_NAMES: &[&str] = &["master", "primary", "default", "main", "slave"];
437
438fn read_non_blank_property(doc: &YamlDoc, path: &str) -> Result<Option<String>, YmlError> {
439 Ok(doc
440 .get_string(path)?
441 .filter(|value| !value.trim().is_empty()))
442}
443
444fn extract_base_path(url_path: &str) -> String {
445 url_path
446 .rsplit_once('.')
447 .map(|(base, _)| {
448 base.trim_end_matches(".jdbc")
449 .trim_end_matches(".r2dbc")
450 .to_owned()
451 })
452 .unwrap_or_else(|| url_path.to_owned())
453}
454
455fn skip_whitespace(chars: &[char], index: &mut usize) {
456 while *index < chars.len() && chars[*index].is_whitespace() {
457 *index += 1;
458 }
459}
460
461fn parse_bare_segment(
462 chars: &[char],
463 index: &mut usize,
464 original: &str,
465) -> Result<YamlPathSegment, YmlError> {
466 let mut segment = String::new();
467 while *index < chars.len() && chars[*index] != '.' && chars[*index] != '[' {
468 segment.push(chars[*index]);
469 *index += 1;
470 }
471
472 let trimmed = segment.trim();
473 if trimmed.is_empty() {
474 return Err(YmlError::InvalidPath(format!(
475 "empty segment in path `{original}`"
476 )));
477 }
478
479 Ok(YamlPathSegment::Key(trimmed.to_owned()))
480}
481
482fn parse_bracket_segment(
483 chars: &[char],
484 index: &mut usize,
485 original: &str,
486) -> Result<YamlPathSegment, YmlError> {
487 *index += 1;
488 skip_whitespace(chars, index);
489
490 if *index >= chars.len() {
491 return Err(YmlError::InvalidPath(format!(
492 "unclosed bracket in path `{original}`"
493 )));
494 }
495
496 let segment = if matches!(chars[*index], '"' | '\'') {
497 let quote = chars[*index];
498 *index += 1;
499 let mut value = String::new();
500 let mut closed = false;
501
502 while *index < chars.len() {
503 let current = chars[*index];
504 if current == '\\' {
505 *index += 1;
506 if *index < chars.len() {
507 value.push(chars[*index]);
508 *index += 1;
509 }
510 continue;
511 }
512
513 if current == quote {
514 *index += 1;
515 closed = true;
516 break;
517 }
518
519 value.push(current);
520 *index += 1;
521 }
522
523 if !closed {
524 return Err(YmlError::InvalidPath(format!(
525 "unclosed quoted segment in path `{original}`"
526 )));
527 }
528
529 YamlPathSegment::Key(value)
530 } else {
531 let mut raw = String::new();
532 while *index < chars.len() && chars[*index] != ']' {
533 raw.push(chars[*index]);
534 *index += 1;
535 }
536
537 let trimmed = raw.trim();
538 if trimmed.is_empty() {
539 return Err(YmlError::InvalidPath(format!(
540 "empty bracket segment in path `{original}`"
541 )));
542 }
543
544 if trimmed.chars().all(|character| character.is_ascii_digit()) {
545 let value = trimmed.parse::<usize>().map_err(|_| {
546 YmlError::InvalidPath(format!(
547 "invalid sequence index `{trimmed}` in `{original}`"
548 ))
549 })?;
550 YamlPathSegment::Index(value)
551 } else {
552 YamlPathSegment::Key(trimmed.to_owned())
553 }
554 };
555
556 skip_whitespace(chars, index);
557 if *index >= chars.len() || chars[*index] != ']' {
558 return Err(YmlError::InvalidPath(format!(
559 "missing closing `]` in path `{original}`"
560 )));
561 }
562 *index += 1;
563
564 Ok(segment)
565}
566
567fn lookup_value<'a>(root: &'a Value, path: &YamlPath) -> Option<&'a Value> {
568 let mut current = root;
569
570 for segment in path.segments() {
571 current = match segment {
572 YamlPathSegment::Key(key) => {
573 let mapping = current.as_mapping()?;
574 let key = Value::String(key.clone());
575 mapping.get(&key)?
576 }
577 YamlPathSegment::Index(index) => current.as_sequence()?.get(*index)?,
578 };
579 }
580
581 Some(current)
582}
583
584fn stringify_scalar(value: &Value) -> Option<String> {
585 match value {
586 Value::String(inner) => Some(inner.clone()),
587 Value::Number(inner) => Some(inner.to_string()),
588 Value::Bool(inner) => Some(inner.to_string()),
589 _ => None,
590 }
591}
592
593#[macro_export]
594macro_rules! yaml_path {
595 ($path:literal) => {{
596 <$crate::YamlPath as ::std::str::FromStr>::from_str($path)
597 .expect("yaml_path!: invalid path literal")
598 }};
599 ($($path:tt)+) => {{
600 <$crate::YamlPath as ::std::str::FromStr>::from_str(::core::stringify!($($path)+))
601 .expect("yaml_path!: invalid path tokens")
602 }};
603}
604
605#[macro_export]
606macro_rules! yaml_get {
607 ($doc:expr, $path:literal) => {{
608 let __path = $crate::yaml_path!($path);
609 $crate::get_yaml_path_value(&$doc, &__path)
610 }};
611 ($doc:expr, $($path:tt)+) => {{
612 let __path = $crate::yaml_path!($($path)+);
613 $crate::get_yaml_path_value(&$doc, &__path)
614 }};
615}
616
617#[cfg(test)]
618mod debug_redaction_tests {
619 use super::DatabaseConfig;
620
621 #[test]
622 fn database_config_debug_redacts_password() {
623 let x = "demo";
624
625 let string = x.to_owned();
626 let config = DatabaseConfig {
627 jdbc_url: "jdbc:postgresql://localhost/app".to_owned(),
628 jdbc_username: Some(string),
629 jdbc_password: Some("super-secret".to_owned()),
630 };
631
632 let output = format!("{config:?}");
633 assert!(output.contains("***REDACTED***"));
634 assert!(!output.contains("super-secret"));
635 }
636}