1use std::collections::BTreeMap;
43use std::sync::Arc;
44
45use serde::Serialize;
46
47use crate::env::ConfigEnv;
48use crate::error::{ConfigError, ConfigErrors, SourceErrorKind, SourceLocation};
49use crate::source::{ConfigValues, Source};
50use crate::value::{ConfigValue, Value};
51
52enum DefaultsSource<T> {
54 Value(T),
56 Fn(Arc<dyn Fn() -> T + Send + Sync>),
58}
59
60impl<T: Clone> Clone for DefaultsSource<T> {
61 fn clone(&self) -> Self {
62 match self {
63 DefaultsSource::Value(v) => DefaultsSource::Value(v.clone()),
64 DefaultsSource::Fn(f) => DefaultsSource::Fn(Arc::clone(f)),
65 }
66 }
67}
68
69pub struct Defaults<T> {
94 source: DefaultsSource<T>,
95}
96
97impl<T: Serialize + Clone + Send + Sync + 'static> Defaults<T> {
98 pub fn from(value: T) -> Self {
108 Self {
109 source: DefaultsSource::Value(value),
110 }
111 }
112
113 pub fn from_fn<F>(f: F) -> Self
129 where
130 F: Fn() -> T + Send + Sync + 'static,
131 {
132 Self {
133 source: DefaultsSource::Fn(Arc::new(f)),
134 }
135 }
136}
137
138impl<T: Clone> Clone for Defaults<T> {
139 fn clone(&self) -> Self {
140 Self {
141 source: self.source.clone(),
142 }
143 }
144}
145
146impl Defaults<()> {
147 pub fn partial() -> PartialDefaults {
163 PartialDefaults::new()
164 }
165}
166
167#[cfg(feature = "watch")]
168use std::path::PathBuf;
169
170impl<T: Serialize + Clone + Send + Sync + 'static> Source for Defaults<T> {
171 fn load(&self, _env: &dyn ConfigEnv) -> Result<ConfigValues, ConfigErrors> {
178 let value = match &self.source {
179 DefaultsSource::Value(v) => v.clone(),
180 DefaultsSource::Fn(f) => f(),
181 };
182
183 serialize_to_config_values(&value, "defaults")
184 }
185
186 fn name(&self) -> &str {
187 "defaults"
188 }
189
190 #[cfg(feature = "watch")]
191 fn watch_path(&self) -> Option<PathBuf> {
192 None
194 }
195
196 #[cfg(feature = "watch")]
197 fn clone_box(&self) -> Box<dyn Source> {
198 Box::new(self.clone())
199 }
200}
201
202#[derive(Debug, Clone, Default)]
218pub struct PartialDefaults {
219 values: BTreeMap<String, Value>,
220}
221
222impl PartialDefaults {
223 pub fn new() -> Self {
225 Self {
226 values: BTreeMap::new(),
227 }
228 }
229
230 pub fn set<V: Into<Value>>(mut self, path: impl Into<String>, value: V) -> Self {
242 self.values.insert(path.into(), value.into());
243 self
244 }
245
246 pub fn set_many<I, K, V>(mut self, iter: I) -> Self
261 where
262 I: IntoIterator<Item = (K, V)>,
263 K: Into<String>,
264 V: Into<Value>,
265 {
266 for (path, value) in iter {
267 self.values.insert(path.into(), value.into());
268 }
269 self
270 }
271}
272
273impl Source for PartialDefaults {
274 fn load(&self, _env: &dyn ConfigEnv) -> Result<ConfigValues, ConfigErrors> {
278 let mut config_values = ConfigValues::empty();
279
280 for (path, value) in &self.values {
281 config_values.insert(
282 path.clone(),
283 ConfigValue::new(
284 value.clone(),
285 SourceLocation::new(format!("defaults:{}", path)),
286 ),
287 );
288 }
289
290 Ok(config_values)
291 }
292
293 fn name(&self) -> &str {
294 "defaults"
295 }
296
297 #[cfg(feature = "watch")]
298 fn watch_path(&self) -> Option<PathBuf> {
299 None
301 }
302
303 #[cfg(feature = "watch")]
304 fn clone_box(&self) -> Box<dyn Source> {
305 Box::new(self.clone())
306 }
307}
308
309fn serialize_to_config_values<T: Serialize>(
311 value: &T,
312 source_name: &str,
313) -> Result<ConfigValues, ConfigErrors> {
314 let json = serde_json::to_value(value).map_err(|e| {
316 ConfigErrors::single(ConfigError::SourceError {
317 source_name: source_name.to_string(),
318 kind: SourceErrorKind::Other {
319 message: format!("Failed to serialize defaults: {}", e),
320 },
321 })
322 })?;
323
324 let mut values = ConfigValues::empty();
325 flatten_json(&json, "", source_name, &mut values);
326 Ok(values)
327}
328
329fn flatten_json(
331 value: &serde_json::Value,
332 prefix: &str,
333 source_name: &str,
334 values: &mut ConfigValues,
335) {
336 match value {
337 serde_json::Value::Object(map) => {
338 for (key, val) in map {
339 let path = if prefix.is_empty() {
340 key.clone()
341 } else {
342 format!("{}.{}", prefix, key)
343 };
344 flatten_json(val, &path, source_name, values);
345 }
346 }
347 serde_json::Value::Array(arr) => {
348 for (i, val) in arr.iter().enumerate() {
349 let path = format!("{}[{}]", prefix, i);
350 flatten_json(val, &path, source_name, values);
351 }
352 values.insert(
354 format!("{}.__len", prefix),
355 ConfigValue::new(
356 Value::Integer(arr.len() as i64),
357 SourceLocation::new(source_name),
358 ),
359 );
360 }
361 _ => {
362 values.insert(
363 prefix.to_string(),
364 ConfigValue::new(json_to_value(value), SourceLocation::new(source_name)),
365 );
366 }
367 }
368}
369
370fn json_to_value(json: &serde_json::Value) -> Value {
372 match json {
373 serde_json::Value::Null => Value::Null,
374 serde_json::Value::Bool(b) => Value::Bool(*b),
375 serde_json::Value::Number(n) => {
376 if let Some(i) = n.as_i64() {
377 Value::Integer(i)
378 } else if let Some(f) = n.as_f64() {
379 Value::Float(f)
380 } else {
381 Value::String(n.to_string())
383 }
384 }
385 serde_json::Value::String(s) => Value::String(s.clone()),
386 serde_json::Value::Array(arr) => Value::Array(arr.iter().map(json_to_value).collect()),
387 serde_json::Value::Object(map) => Value::Table(
388 map.iter()
389 .map(|(k, v)| (k.clone(), json_to_value(v)))
390 .collect(),
391 ),
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::env::MockEnv;
399 use serde::{Deserialize, Serialize};
400
401 #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
402 struct SimpleConfig {
403 host: String,
404 port: u16,
405 }
406
407 #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
408 struct NestedConfig {
409 server: ServerConfig,
410 database: DatabaseConfig,
411 }
412
413 #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
414 struct ServerConfig {
415 host: String,
416 port: u16,
417 }
418
419 #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
420 struct DatabaseConfig {
421 host: String,
422 pool_size: u32,
423 }
424
425 #[derive(Debug, Clone, Default, Serialize, Deserialize)]
426 struct ConfigWithArrays {
427 hosts: Vec<String>,
428 ports: Vec<u16>,
429 }
430
431 #[test]
432 fn test_defaults_from_value() {
433 let env = MockEnv::new();
434 let config = SimpleConfig {
435 host: "localhost".to_string(),
436 port: 8080,
437 };
438
439 let source = Defaults::from(config);
440 let values = source.load(&env).expect("should load successfully");
441
442 assert_eq!(
443 values.get("host").map(|v| v.value.as_str()),
444 Some(Some("localhost"))
445 );
446 assert_eq!(
447 values.get("port").map(|v| v.value.as_integer()),
448 Some(Some(8080))
449 );
450 }
451
452 #[test]
453 fn test_defaults_from_default_trait() {
454 let env = MockEnv::new();
455 let source = Defaults::from(SimpleConfig::default());
456 let values = source.load(&env).expect("should load successfully");
457
458 assert_eq!(values.get("host").map(|v| v.value.as_str()), Some(Some("")));
460 assert_eq!(
461 values.get("port").map(|v| v.value.as_integer()),
462 Some(Some(0))
463 );
464 }
465
466 #[test]
467 fn test_defaults_from_closure() {
468 let env = MockEnv::new();
469 let source = Defaults::from_fn(|| SimpleConfig {
470 host: "computed".to_string(),
471 port: 3000,
472 });
473
474 let values = source.load(&env).expect("should load successfully");
475
476 assert_eq!(
477 values.get("host").map(|v| v.value.as_str()),
478 Some(Some("computed"))
479 );
480 assert_eq!(
481 values.get("port").map(|v| v.value.as_integer()),
482 Some(Some(3000))
483 );
484 }
485
486 #[test]
487 fn test_defaults_closure_called_each_time() {
488 use std::sync::atomic::{AtomicU32, Ordering};
489
490 let env = MockEnv::new();
491 let counter = Arc::new(AtomicU32::new(0));
492 let counter_clone = Arc::clone(&counter);
493
494 let source = Defaults::from_fn(move || {
495 counter_clone.fetch_add(1, Ordering::SeqCst);
496 SimpleConfig {
497 host: "localhost".to_string(),
498 port: 8080,
499 }
500 });
501
502 source.load(&env).expect("should load");
503 source.load(&env).expect("should load");
504 source.load(&env).expect("should load");
505
506 assert_eq!(counter.load(Ordering::SeqCst), 3);
507 }
508
509 #[test]
510 fn test_defaults_nested_structs() {
511 let env = MockEnv::new();
512 let config = NestedConfig {
513 server: ServerConfig {
514 host: "localhost".to_string(),
515 port: 8080,
516 },
517 database: DatabaseConfig {
518 host: "db.example.com".to_string(),
519 pool_size: 10,
520 },
521 };
522
523 let source = Defaults::from(config);
524 let values = source.load(&env).expect("should load successfully");
525
526 assert_eq!(
527 values.get("server.host").map(|v| v.value.as_str()),
528 Some(Some("localhost"))
529 );
530 assert_eq!(
531 values.get("server.port").map(|v| v.value.as_integer()),
532 Some(Some(8080))
533 );
534 assert_eq!(
535 values.get("database.host").map(|v| v.value.as_str()),
536 Some(Some("db.example.com"))
537 );
538 assert_eq!(
539 values
540 .get("database.pool_size")
541 .map(|v| v.value.as_integer()),
542 Some(Some(10))
543 );
544 }
545
546 #[test]
547 fn test_defaults_with_arrays() {
548 let env = MockEnv::new();
549 let config = ConfigWithArrays {
550 hosts: vec!["host1".to_string(), "host2".to_string()],
551 ports: vec![8080, 8081, 8082],
552 };
553
554 let source = Defaults::from(config);
555 let values = source.load(&env).expect("should load successfully");
556
557 assert_eq!(
558 values.get("hosts[0]").map(|v| v.value.as_str()),
559 Some(Some("host1"))
560 );
561 assert_eq!(
562 values.get("hosts[1]").map(|v| v.value.as_str()),
563 Some(Some("host2"))
564 );
565 assert_eq!(
566 values.get("hosts.__len").map(|v| v.value.as_integer()),
567 Some(Some(2))
568 );
569 assert_eq!(
570 values.get("ports[0]").map(|v| v.value.as_integer()),
571 Some(Some(8080))
572 );
573 assert_eq!(
574 values.get("ports[1]").map(|v| v.value.as_integer()),
575 Some(Some(8081))
576 );
577 assert_eq!(
578 values.get("ports[2]").map(|v| v.value.as_integer()),
579 Some(Some(8082))
580 );
581 assert_eq!(
582 values.get("ports.__len").map(|v| v.value.as_integer()),
583 Some(Some(3))
584 );
585 }
586
587 #[test]
588 fn test_partial_defaults_basic() {
589 let env = MockEnv::new();
590 let source = Defaults::partial()
591 .set("server.port", 8080i64)
592 .set("database.pool_size", 10i64);
593
594 let values = source.load(&env).expect("should load successfully");
595
596 assert_eq!(
597 values.get("server.port").map(|v| v.value.as_integer()),
598 Some(Some(8080))
599 );
600 assert_eq!(
601 values
602 .get("database.pool_size")
603 .map(|v| v.value.as_integer()),
604 Some(Some(10))
605 );
606 }
607
608 #[test]
609 fn test_partial_defaults_various_types() {
610 let env = MockEnv::new();
611 let source = Defaults::partial()
612 .set("string_val", "hello")
613 .set("int_val", 42i64)
614 .set("float_val", 2.72f64)
615 .set("bool_val", true);
616
617 let values = source.load(&env).expect("should load successfully");
618
619 assert_eq!(
620 values.get("string_val").map(|v| v.value.as_str()),
621 Some(Some("hello"))
622 );
623 assert_eq!(
624 values.get("int_val").map(|v| v.value.as_integer()),
625 Some(Some(42))
626 );
627 assert_eq!(
628 values.get("float_val").map(|v| v.value.as_float()),
629 Some(Some(2.72))
630 );
631 assert_eq!(
632 values.get("bool_val").map(|v| v.value.as_bool()),
633 Some(Some(true))
634 );
635 }
636
637 #[test]
638 fn test_partial_defaults_set_many() {
639 let env = MockEnv::new();
640 let defaults = vec![
641 ("server.port", Value::Integer(8080)),
642 ("server.host", Value::String("localhost".to_string())),
643 ("debug", Value::Bool(true)),
644 ];
645
646 let source = Defaults::partial().set_many(defaults);
647 let values = source.load(&env).expect("should load successfully");
648
649 assert_eq!(
650 values.get("server.port").map(|v| v.value.as_integer()),
651 Some(Some(8080))
652 );
653 assert_eq!(
654 values.get("server.host").map(|v| v.value.as_str()),
655 Some(Some("localhost"))
656 );
657 assert_eq!(
658 values.get("debug").map(|v| v.value.as_bool()),
659 Some(Some(true))
660 );
661 }
662
663 #[test]
664 fn test_partial_defaults_source_location() {
665 let env = MockEnv::new();
666 let source = Defaults::partial().set("server.port", 8080i64);
667
668 let values = source.load(&env).expect("should load successfully");
669
670 let port_value = values.get("server.port").expect("should exist");
671 assert_eq!(port_value.source.source, "defaults:server.port");
672 }
673
674 #[test]
675 fn test_defaults_source_location() {
676 let env = MockEnv::new();
677 let source = Defaults::from(SimpleConfig {
678 host: "localhost".to_string(),
679 port: 8080,
680 });
681
682 let values = source.load(&env).expect("should load successfully");
683
684 let host_value = values.get("host").expect("should exist");
685 assert_eq!(host_value.source.source, "defaults");
686 }
687
688 #[test]
689 fn test_defaults_name() {
690 let source = Defaults::from(SimpleConfig::default());
691 assert_eq!(source.name(), "defaults");
692
693 let partial = Defaults::partial().set("key", "value");
694 assert_eq!(partial.name(), "defaults");
695 }
696
697 #[test]
698 fn test_defaults_is_send_sync() {
699 fn assert_send_sync<T: Send + Sync>() {}
700 assert_send_sync::<Defaults<SimpleConfig>>();
701 assert_send_sync::<PartialDefaults>();
702 }
703
704 #[test]
705 fn test_defaults_clone() {
706 let source = Defaults::from(SimpleConfig {
707 host: "localhost".to_string(),
708 port: 8080,
709 });
710 let cloned = source.clone();
711
712 let env = MockEnv::new();
713 let values1 = source.load(&env).expect("should load");
714 let values2 = cloned.load(&env).expect("should load");
715
716 assert_eq!(
717 values1.get("host").map(|v| v.value.as_str()),
718 values2.get("host").map(|v| v.value.as_str())
719 );
720 }
721
722 #[test]
723 fn test_partial_defaults_empty() {
724 let env = MockEnv::new();
725 let source = PartialDefaults::new();
726 let values = source.load(&env).expect("should load successfully");
727
728 assert!(values.is_empty());
729 }
730
731 #[test]
732 fn test_json_to_value_null() {
733 let json = serde_json::Value::Null;
734 let value = json_to_value(&json);
735 assert!(value.is_null());
736 }
737
738 #[test]
739 fn test_json_to_value_bool() {
740 let json = serde_json::Value::Bool(true);
741 let value = json_to_value(&json);
742 assert_eq!(value.as_bool(), Some(true));
743 }
744
745 #[test]
746 fn test_json_to_value_integer() {
747 let json = serde_json::json!(42);
748 let value = json_to_value(&json);
749 assert_eq!(value.as_integer(), Some(42));
750 }
751
752 #[test]
753 fn test_json_to_value_float() {
754 let json = serde_json::json!(1.5);
755 let value = json_to_value(&json);
756 assert_eq!(value.as_float(), Some(1.5));
757 }
758
759 #[test]
760 fn test_json_to_value_string() {
761 let json = serde_json::json!("hello");
762 let value = json_to_value(&json);
763 assert_eq!(value.as_str(), Some("hello"));
764 }
765
766 #[test]
767 fn test_json_to_value_array() {
768 let json = serde_json::json!([1, 2, 3]);
769 let value = json_to_value(&json);
770 let arr = value.as_array().expect("should be array");
771 assert_eq!(arr.len(), 3);
772 }
773
774 #[test]
775 fn test_json_to_value_object() {
776 let json = serde_json::json!({"key": "value"});
777 let value = json_to_value(&json);
778 let table = value.as_table().expect("should be table");
779 assert_eq!(table.get("key").and_then(|v| v.as_str()), Some("value"));
780 }
781}