1use eyre::bail;
2use serde::{Serialize, de::DeserializeOwned};
3
4mod serializer;
5
6use std::{
7 ops::{Deref, DerefMut},
8 path::PathBuf,
9 sync::Arc,
10};
11
12fn config_to_value<T: Serialize>(config: &T) -> Result<serializer::Value, eyre::Error> {
13 let config_str = serializer::to_string(config)?;
14 Ok(serializer::from_str(&config_str)?)
15}
16
17#[derive(Debug, Clone)]
18pub struct ConfigStore<T: Default + Serialize + DeserializeOwned + PartialEq> {
19 pub path: PathBuf,
20 nest: Option<String>,
21 cached: T,
22}
23
24impl<T: Default + Serialize + DeserializeOwned + PartialEq> ConfigStore<T> {
25 fn preflight(path: PathBuf, nest: Option<String>) -> Result<Option<Self>, eyre::Error> {
26 if path.is_dir() {
27 bail!(
28 "Given config path is a directory... either change the path or delete the directory."
29 );
30 }
31
32 if !path.exists() {
33 return Ok(Some(Self::new(path, nest)?));
34 }
35
36 if !path.is_file() {
37 bail!(
38 "Given config path exists and is not a file... either change the path or delete the file."
39 );
40 }
41
42 Ok(None)
43 }
44
45 pub fn read(
46 path: impl Into<PathBuf>,
47 nest: impl Into<Option<String>>,
48 ) -> Result<Self, eyre::Error> {
49 let path = path.into();
50 let nest = nest.into();
51
52 if let Some(config) = Self::preflight(path.clone(), nest.clone())? {
53 return Ok(config);
54 }
55
56 let config_str = std::fs::read_to_string(&path)?;
57 let deserialized: serializer::Value = serializer::from_str(&config_str)?;
58
59 let cached = match nest {
60 Some(ref key) => deserialized
61 .get(key)
62 .ok_or_else(|| eyre::eyre!("Nested config '{}' not found", key))?
63 .clone(),
64 None => deserialized,
65 };
66
67 Ok(Self {
68 path,
69 nest,
70 cached: T::deserialize(cached)?,
71 })
72 }
73
74 pub fn arc(self) -> Arc<Self> {
75 return Arc::new(self);
76 }
77
78 #[cfg(feature = "tokio")]
79 pub async fn async_read(
80 path: impl Into<PathBuf>,
81 nest: impl Into<Option<String>>,
82 ) -> Result<Self, eyre::Error> {
83 let path = path.into();
84 let nest = nest.into();
85
86 if let Some(config) = Self::preflight(path.clone(), nest.clone())? {
87 return Ok(config);
88 }
89
90 let config_str = tokio::fs::read_to_string(&path).await?;
91 let deserialized: serializer::Value = serializer::from_str(&config_str)?;
92
93 let cached = match nest {
94 Some(ref key) => deserialized
95 .get(key)
96 .ok_or_else(|| eyre::eyre!("Nested config '{}' not found", key))?
97 .clone(),
98 None => deserialized,
99 };
100
101 Ok(Self {
102 path,
103 nest,
104 cached: T::deserialize(cached)?,
105 })
106 }
107
108 pub fn update(&mut self) -> eyre::Result<bool> {
109 let new = Self::read(self.path.clone(), self.nest.clone())?;
110
111 Ok(match self.cached == new.cached {
112 true => false,
113 false => {
114 self.cached = new.cached;
115 true
116 }
117 })
118 }
119
120 #[cfg(feature = "tokio")]
121 pub async fn async_update(&mut self) -> eyre::Result<bool> {
122 let new = Self::async_read(self.path.clone(), self.nest.clone()).await?;
123
124 Ok(match self.cached == new.cached {
125 true => false,
126 false => {
127 self.cached = new.cached;
128 true
129 }
130 })
131 }
132
133 fn new(path: PathBuf, nest: Option<String>) -> Result<Self, eyre::Error> {
134 std::fs::create_dir_all(path.parent().unwrap())?;
135
136 let config = Self {
137 path,
138 nest,
139 cached: T::default(),
140 };
141
142 config.save()?;
143
144 Ok(config)
145 }
146
147 pub fn into_inner(self) -> T {
148 self.cached
149 }
150
151 pub fn merge(&mut self, other_config: T) -> Result<(), eyre::Error> {
152 let mut config_value = config_to_value(&self.cached)?;
153 let other_value = config_to_value(&other_config)?;
154 serializer::merge_values(&mut config_value, other_value)?;
155 self.cached = T::deserialize(config_value)?;
156 Ok(())
157 }
158
159 pub fn overwrite(&mut self, other_config: T) -> Result<(), eyre::Error> {
160 let mut config_value = config_to_value(&self.cached)?;
161 let other_value = config_to_value(&other_config)?;
162 serializer::overwrite_values(&mut config_value, other_value);
163 self.cached = T::deserialize(config_value)?;
164 Ok(())
165 }
166
167 pub fn save(&self) -> Result<(), eyre::Error> {
168 let to_write = match &self.nest {
169 Some(key) => {
170 let mut root: std::collections::HashMap<String, serializer::Value> =
172 if self.path.exists() {
173 let content = std::fs::read_to_string(&self.path)?;
174 serializer::from_str(&content)?
175 } else {
176 std::collections::HashMap::new()
177 };
178
179 let cached_str = serializer::to_string(&self.cached)?;
181 let cached_value: serializer::Value = serializer::from_str(&cached_str)?;
182
183 root.insert(key.clone(), cached_value);
184 serializer::to_string(&root)?
185 }
186 None => serializer::to_string(&self.cached)?,
187 };
188
189 std::fs::write(&self.path, to_write)?;
190 Ok(())
191 }
192
193 #[cfg(feature = "tokio")]
194 pub async fn async_save(&self) -> Result<(), eyre::Error> {
195 let to_write = match &self.nest {
196 Some(key) => {
197 let mut root: std::collections::HashMap<String, serializer::Value> =
199 if self.path.exists() {
200 let content = std::fs::read_to_string(&self.path)?;
201 serializer::from_str(&content)?
202 } else {
203 std::collections::HashMap::new()
204 };
205
206 let cached_str = serializer::to_string(&self.cached)?;
208 let cached_value: serializer::Value = serializer::from_str(&cached_str)?;
209
210 root.insert(key.clone(), cached_value);
211 serializer::to_string(&root)?
212 }
213 None => serializer::to_string(&self.cached)?,
214 };
215
216 tokio::fs::write(&self.path, to_write).await?;
217 Ok(())
218 }
219}
220
221impl<T: Default + Serialize + DeserializeOwned + PartialEq> Deref for ConfigStore<T> {
222 type Target = T;
223
224 fn deref(&self) -> &Self::Target {
225 &self.cached
226 }
227}
228
229impl<T: Default + Serialize + DeserializeOwned + PartialEq> DerefMut for ConfigStore<T> {
230 fn deref_mut(&mut self) -> &mut Self::Target {
231 &mut self.cached
232 }
233}
234
235impl<T: Default + Serialize + DeserializeOwned + PartialEq> PartialEq for ConfigStore<T> {
236 fn eq(&self, other: &Self) -> bool {
237 self.cached == other.cached
238 }
239}
240impl<T: Default + Serialize + DeserializeOwned + PartialEq> Eq for ConfigStore<T> {}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use serde::{Deserialize, Serialize};
246
247 #[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
248 struct TestConfig {
249 name: Option<String>,
250 database: DatabaseConfig,
251 features: FeatureConfig,
252 }
253
254 #[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
255 struct DatabaseConfig {
256 url: Option<String>,
257 pool_size: Option<u16>,
258 }
259
260 #[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
261 struct FeatureConfig {
262 enabled: Option<bool>,
263 }
264
265 fn store(config: TestConfig) -> ConfigStore<TestConfig> {
266 ConfigStore {
267 path: PathBuf::from("/tmp/easy-config-store-test.toml"),
268 nest: None,
269 cached: config,
270 }
271 }
272
273 #[test]
274 fn merge_adds_missing_nested_values() {
275 let mut config = store(TestConfig {
276 name: Some("app".to_owned()),
277 database: DatabaseConfig {
278 url: Some("sqlite://app.db".to_owned()),
279 pool_size: None,
280 },
281 features: FeatureConfig::default(),
282 });
283 let other = TestConfig {
284 name: None,
285 database: DatabaseConfig {
286 url: None,
287 pool_size: Some(8),
288 },
289 features: FeatureConfig {
290 enabled: Some(true),
291 },
292 };
293
294 config.merge(other).expect("merge succeeds");
295
296 assert_eq!(config.name.as_deref(), Some("app"));
297 assert_eq!(config.database.url.as_deref(), Some("sqlite://app.db"));
298 assert_eq!(config.database.pool_size, Some(8));
299 assert_eq!(config.features.enabled, Some(true));
300 }
301
302 #[test]
303 fn merge_allows_equal_values() {
304 let mut config = store(TestConfig {
305 name: Some("app".to_owned()),
306 ..Default::default()
307 });
308 let other = TestConfig {
309 name: Some("app".to_owned()),
310 ..Default::default()
311 };
312
313 config.merge(other).expect("equal values do not conflict");
314
315 assert_eq!(config.name.as_deref(), Some("app"));
316 }
317
318 #[test]
319 fn merge_errors_on_conflicting_values() {
320 let mut config = store(TestConfig {
321 name: Some("app".to_owned()),
322 ..Default::default()
323 });
324 let other = TestConfig {
325 name: Some("other".to_owned()),
326 ..Default::default()
327 };
328
329 let error = config.merge(other).expect_err("conflict should fail");
330
331 assert!(error.to_string().contains("conflicting config value"));
332 assert_eq!(config.name.as_deref(), Some("app"));
333 }
334
335 #[test]
336 fn overwrite_replaces_conflicting_values_and_keeps_missing_values() {
337 let mut config = store(TestConfig {
338 name: Some("app".to_owned()),
339 database: DatabaseConfig {
340 url: Some("sqlite://app.db".to_owned()),
341 pool_size: Some(4),
342 },
343 features: FeatureConfig::default(),
344 });
345 let other = TestConfig {
346 name: Some("other".to_owned()),
347 database: DatabaseConfig {
348 url: None,
349 pool_size: Some(8),
350 },
351 features: FeatureConfig {
352 enabled: Some(true),
353 },
354 };
355
356 config.overwrite(other).expect("overwrite succeeds");
357
358 assert_eq!(config.name.as_deref(), Some("other"));
359 assert_eq!(config.database.url.as_deref(), Some("sqlite://app.db"));
360 assert_eq!(config.database.pool_size, Some(8));
361 assert_eq!(config.features.enabled, Some(true));
362 }
363}