1use figment::providers::Serialized;
2use figment::Figment;
3use serde_json::Value;
4use thiserror::Error;
5
6mod private {
7 pub trait Sealed {}
8
9 impl Sealed for figment::Figment {}
10}
11
12#[derive(Debug, Error)]
13#[non_exhaustive]
14pub enum RemoveExistingKeyError<'a> {
15 #[error("key {0} not found")]
16 NotFound(&'a str),
17}
18
19pub trait FigmentExt: Sized + private::Sealed {
21 fn remove_existing_keys<'a, T: AsRef<str>>(
25 &self,
26 keys: &'a [T],
27 ) -> Result<Self, RemoveExistingKeyError<'a>>;
28
29 fn has_key(&self, key: &str) -> bool;
33}
34
35impl FigmentExt for Figment {
36 fn remove_existing_keys<'a, T: AsRef<str>>(
37 &self,
38 keys: &'a [T],
39 ) -> Result<Self, RemoveExistingKeyError<'a>> {
40 let mut value = self.extract::<Value>().expect("json serializable value");
41 let mut pointer = String::new();
42 let mut parts = vec![];
43 for key in keys {
44 let key = key.as_ref();
45 if !self.has_key(key) {
46 return Err(RemoveExistingKeyError::NotFound(key));
47 }
48 pointer.clear();
49 parts.clear();
50 parts.extend(key.split('.'));
51 match parts.as_slice() {
53 [] => {
54 unreachable!("empty parts");
56 }
57 [field] => {
58 value.as_object_mut().expect("object").remove(*field);
59 }
60 [components @ .., field] => {
61 for c in components {
62 pointer.push('/');
63 pointer.push_str(c);
64 }
65 value
66 .pointer_mut(&pointer)
67 .and_then(Value::as_object_mut)
68 .expect("object")
69 .remove(*field);
70 }
71 }
72 }
73 Ok(Figment::from(Serialized::defaults(value)))
74 }
75
76 fn has_key(&self, key: &str) -> bool {
77 self.find_metadata(key).is_some() && !key.is_empty()
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 fn get_test_figment() -> Figment {
86 Figment::from(Serialized::defaults(serde_json::json!({
87 "foo": {
88 "bar": {
89 "baz": {
90 "name": "Baz",
91 }
92 },
93 "s": "Foo string",
94 },
95 "vec": ["foo", "bar", "baz"],
96 })))
97 }
98
99 #[test]
100 fn remove_existing_keys() {
101 let figment = get_test_figment();
102
103 #[rustfmt::skip]
104 let keys = [
105 "foo.bar.baz.name",
106 "foo.s",
107 ];
108 for key in keys {
109 assert!(figment.has_key(key));
110 }
111 let actual = figment.remove_existing_keys(&keys);
112 let f = actual.expect("keys removed");
113 for key in keys {
114 assert!(!f.has_key(key));
115 }
116 }
117
118 #[test]
119 fn remove_missing_key() {
120 let figment = get_test_figment();
121 let key = "foo.not_exist";
122 assert!(!figment.has_key(key));
123 let keys = [key];
124 let actual = figment.remove_existing_keys(&keys);
125 let err = actual.expect_err("key doesn't exist");
126 assert!(matches!(err, RemoveExistingKeyError::NotFound(_)));
127 }
128
129 #[test]
130 fn has_key() {
131 let figment = get_test_figment();
132 assert!(!figment.has_key(""));
133 }
134}