Skip to main content

atomr_config/
value.rs

1//! `Config` and `ConfigValue` — the HOCON-equivalent value tree.
2
3use std::collections::BTreeMap;
4use std::time::Duration;
5
6use crate::error::ConfigError;
7use crate::path::ConfigPath;
8use crate::reference::reference_config;
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum ConfigValue {
12    Null,
13    Bool(bool),
14    Int(i64),
15    Float(f64),
16    String(String),
17    Array(Vec<ConfigValue>),
18    Object(BTreeMap<String, ConfigValue>),
19}
20
21impl ConfigValue {
22    pub fn type_name(&self) -> &'static str {
23        match self {
24            Self::Null => "null",
25            Self::Bool(_) => "bool",
26            Self::Int(_) => "int",
27            Self::Float(_) => "float",
28            Self::String(_) => "string",
29            Self::Array(_) => "array",
30            Self::Object(_) => "object",
31        }
32    }
33
34    fn from_toml(v: toml::Value) -> Self {
35        match v {
36            toml::Value::String(s) => Self::String(s),
37            toml::Value::Integer(i) => Self::Int(i),
38            toml::Value::Float(f) => Self::Float(f),
39            toml::Value::Boolean(b) => Self::Bool(b),
40            toml::Value::Datetime(d) => Self::String(d.to_string()),
41            toml::Value::Array(a) => Self::Array(a.into_iter().map(Self::from_toml).collect()),
42            toml::Value::Table(t) => {
43                Self::Object(t.into_iter().map(|(k, v)| (k, Self::from_toml(v))).collect())
44            }
45        }
46    }
47}
48
49/// Akka `Config` root — a merged, layered value tree.
50#[derive(Debug, Clone, Default, PartialEq)]
51pub struct Config {
52    root: BTreeMap<String, ConfigValue>,
53}
54
55impl Config {
56    pub fn empty() -> Self {
57        Self::default()
58    }
59
60    /// Load the atomr reference configuration.
61    pub fn reference() -> Self {
62        Self::from_toml_str(reference_config()).expect("built-in reference.conf.toml is valid")
63    }
64
65    pub fn from_toml_str(s: &str) -> Result<Self, ConfigError> {
66        let v: toml::Value = toml::from_str(s)?;
67        let table = match v {
68            toml::Value::Table(t) => t,
69            _ => return Err(ConfigError::WrongType { path: "".into(), expected: "object" }),
70        };
71        Ok(Self { root: table.into_iter().map(|(k, v)| (k, ConfigValue::from_toml(v))).collect() })
72    }
73
74    /// Parse a HOCON document (Pekko `reference.conf`
75    /// syntax). See [`crate::hocon`] for the supported subset.
76    pub fn from_hocon_str(s: &str) -> Result<Self, ConfigError> {
77        let v = crate::hocon::parse(s, std::path::Path::new("."))?;
78        match v {
79            ConfigValue::Object(o) => Ok(Self { root: o }),
80            _ => Err(ConfigError::WrongType { path: "".into(), expected: "object" }),
81        }
82    }
83
84    /// Parse a HOCON file from disk; `include` directives resolve
85    /// relative to the file's parent directory.
86    pub fn from_hocon_file(path: impl AsRef<std::path::Path>) -> Result<Self, ConfigError> {
87        let v = crate::hocon::parse_file(path.as_ref())?;
88        match v {
89            ConfigValue::Object(o) => Ok(Self { root: o }),
90            _ => Err(ConfigError::WrongType { path: "".into(), expected: "object" }),
91        }
92    }
93
94    /// Merge `other` on top of `self`; keys from `other` win for scalars,
95    /// objects merge recursively — matches HOCON fallback/merge semantics.
96    pub fn with_fallback(mut self, fallback: Self) -> Self {
97        merge_object(&mut self.root, fallback.root, /*override_rhs=*/ false);
98        self
99    }
100
101    /// Merge `other` on top of `self`, where `other` wins.
102    pub fn merged_with(mut self, other: Self) -> Self {
103        merge_object(&mut self.root, other.root, true);
104        self
105    }
106
107    pub fn get(&self, path: &str) -> Option<&ConfigValue> {
108        let p = ConfigPath::parse(path);
109        lookup(&self.root, p.segments())
110    }
111
112    pub fn get_string(&self, path: &str) -> Result<String, ConfigError> {
113        match self.get(path) {
114            Some(ConfigValue::String(s)) => Ok(s.clone()),
115            Some(v) => Err(ConfigError::WrongType { path: path.into(), expected: v.type_name() }),
116            None => Err(ConfigError::NotFound(path.into())),
117        }
118    }
119
120    pub fn get_int(&self, path: &str) -> Result<i64, ConfigError> {
121        match self.get(path) {
122            Some(ConfigValue::Int(i)) => Ok(*i),
123            Some(ConfigValue::Float(f)) => Ok(*f as i64),
124            Some(v) => Err(ConfigError::WrongType { path: path.into(), expected: v.type_name() }),
125            None => Err(ConfigError::NotFound(path.into())),
126        }
127    }
128
129    pub fn get_bool(&self, path: &str) -> Result<bool, ConfigError> {
130        match self.get(path) {
131            Some(ConfigValue::Bool(b)) => Ok(*b),
132            Some(v) => Err(ConfigError::WrongType { path: path.into(), expected: v.type_name() }),
133            None => Err(ConfigError::NotFound(path.into())),
134        }
135    }
136
137    /// Accepts "10ms", "5s", "2m", "1h", or integer milliseconds.
138    pub fn get_duration(&self, path: &str) -> Result<Duration, ConfigError> {
139        match self.get(path) {
140            Some(ConfigValue::String(s)) => parse_duration(s)
141                .ok_or_else(|| ConfigError::WrongType { path: path.into(), expected: "duration" }),
142            Some(ConfigValue::Int(i)) => Ok(Duration::from_millis(*i as u64)),
143            Some(v) => Err(ConfigError::WrongType { path: path.into(), expected: v.type_name() }),
144            None => Err(ConfigError::NotFound(path.into())),
145        }
146    }
147
148    pub fn get_sub(&self, path: &str) -> Option<Config> {
149        match self.get(path)? {
150            ConfigValue::Object(o) => Some(Self { root: o.clone() }),
151            _ => None,
152        }
153    }
154
155    /// Deserialize a sub-tree at `path` into a strongly-typed value `T`.
156    /// Bridge through `serde_json::Value` so any `serde::Deserialize`
157    /// type composes. -equivalent of typed `Config.As<T>()`
158    /// extension.
159    ///
160    /// Returns [`ConfigError::NotFound`] if `path` is absent.
161    pub fn extract<T: serde::de::DeserializeOwned>(&self, path: &str) -> Result<T, ConfigError> {
162        let v = self.get(path).ok_or_else(|| ConfigError::NotFound(path.into()))?;
163        let json = config_value_to_json(v);
164        serde_json::from_value(json)
165            .map_err(|e| ConfigError::WrongType { path: path.into(), expected: leak(e.to_string()) })
166    }
167
168    /// Deserialize the entire root config into `T`.
169    pub fn extract_root<T: serde::de::DeserializeOwned>(&self) -> Result<T, ConfigError> {
170        let json = config_value_to_json(&ConfigValue::Object(self.root.clone()));
171        serde_json::from_value(json)
172            .map_err(|e| ConfigError::WrongType { path: "".into(), expected: leak(e.to_string()) })
173    }
174}
175
176fn leak(s: String) -> &'static str {
177    Box::leak(s.into_boxed_str())
178}
179
180fn config_value_to_json(v: &ConfigValue) -> serde_json::Value {
181    match v {
182        ConfigValue::Null => serde_json::Value::Null,
183        ConfigValue::Bool(b) => serde_json::Value::Bool(*b),
184        ConfigValue::Int(i) => serde_json::Value::Number((*i).into()),
185        ConfigValue::Float(f) => {
186            serde_json::Number::from_f64(*f).map(serde_json::Value::Number).unwrap_or(serde_json::Value::Null)
187        }
188        ConfigValue::String(s) => serde_json::Value::String(s.clone()),
189        ConfigValue::Array(items) => {
190            serde_json::Value::Array(items.iter().map(config_value_to_json).collect())
191        }
192        ConfigValue::Object(o) => {
193            let map: serde_json::Map<String, serde_json::Value> =
194                o.iter().map(|(k, v)| (k.clone(), config_value_to_json(v))).collect();
195            serde_json::Value::Object(map)
196        }
197    }
198}
199
200fn lookup<'a>(root: &'a BTreeMap<String, ConfigValue>, segs: &[String]) -> Option<&'a ConfigValue> {
201    let (head, tail) = segs.split_first()?;
202    let v = root.get(head)?;
203    if tail.is_empty() {
204        return Some(v);
205    }
206    match v {
207        ConfigValue::Object(o) => lookup(o, tail),
208        _ => None,
209    }
210}
211
212fn merge_object(
213    dst: &mut BTreeMap<String, ConfigValue>,
214    src: BTreeMap<String, ConfigValue>,
215    override_rhs: bool,
216) {
217    for (k, v) in src {
218        match dst.get_mut(&k) {
219            Some(ConfigValue::Object(inner)) => {
220                if let ConfigValue::Object(src_inner) = v {
221                    merge_object(inner, src_inner, override_rhs);
222                } else if override_rhs {
223                    dst.insert(k, v);
224                }
225            }
226            Some(_) if override_rhs => {
227                dst.insert(k, v);
228            }
229            Some(_) => {} // keep existing
230            None => {
231                dst.insert(k, v);
232            }
233        }
234    }
235}
236
237fn parse_duration(s: &str) -> Option<Duration> {
238    let s = s.trim();
239    let (num, unit) = split_number_unit(s)?;
240    let n: f64 = num.parse().ok()?;
241    let ms = match unit {
242        "ms" | "millis" | "milliseconds" => n,
243        "s" | "sec" | "seconds" | "" => n * 1000.0,
244        "m" | "min" | "minutes" => n * 60_000.0,
245        "h" | "hr" | "hours" => n * 3_600_000.0,
246        "d" | "days" => n * 86_400_000.0,
247        _ => return None,
248    };
249    Some(Duration::from_micros((ms * 1000.0) as u64))
250}
251
252fn split_number_unit(s: &str) -> Option<(&str, &str)> {
253    let idx = s.find(|c: char| !(c.is_ascii_digit() || c == '.' || c == '-')).unwrap_or(s.len());
254    let (n, u) = s.split_at(idx);
255    Some((n.trim(), u.trim()))
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn reference_loads() {
264        let c = Config::reference();
265        assert!(c.get_string("akka.actor.provider").is_ok());
266    }
267
268    #[test]
269    fn fallback_keeps_existing() {
270        let a = Config::from_toml_str("[akka]\nfoo = \"a\"\n").unwrap();
271        let b = Config::from_toml_str("[akka]\nfoo = \"b\"\nbar = \"B\"\n").unwrap();
272        let c = a.with_fallback(b);
273        assert_eq!(c.get_string("akka.foo").unwrap(), "a");
274        assert_eq!(c.get_string("akka.bar").unwrap(), "B");
275    }
276
277    #[test]
278    fn override_merge() {
279        let a = Config::from_toml_str("[akka]\nfoo = \"a\"\n").unwrap();
280        let b = Config::from_toml_str("[akka]\nfoo = \"b\"\n").unwrap();
281        let c = a.merged_with(b);
282        assert_eq!(c.get_string("akka.foo").unwrap(), "b");
283    }
284
285    #[test]
286    fn duration_parses_units() {
287        let c = Config::from_toml_str("[x]\nt = \"500ms\"\n").unwrap();
288        assert_eq!(c.get_duration("x.t").unwrap(), Duration::from_millis(500));
289        let c = Config::from_toml_str("[x]\nt = \"2s\"\n").unwrap();
290        assert_eq!(c.get_duration("x.t").unwrap(), Duration::from_secs(2));
291    }
292
293    #[test]
294    fn get_sub_returns_sub_tree() {
295        let c = Config::reference();
296        let actor = c.get_sub("akka.actor").unwrap();
297        assert!(actor.get_string("provider").is_ok());
298    }
299
300    #[test]
301    fn extract_typed_value() {
302        #[derive(serde::Deserialize, PartialEq, Debug)]
303        struct Cluster {
304            seed_nodes: Vec<String>,
305            min_members: u32,
306        }
307        let toml = "[akka.cluster]\nseed_nodes = [\"a\", \"b\"]\nmin_members = 3\n";
308        let c = Config::from_toml_str(toml).unwrap();
309        let cl: Cluster = c.extract("akka.cluster").unwrap();
310        assert_eq!(cl, Cluster { seed_nodes: vec!["a".into(), "b".into()], min_members: 3 });
311    }
312
313    #[test]
314    fn extract_returns_not_found_for_missing_path() {
315        let c = Config::empty();
316        let r: Result<u32, _> = c.extract("missing.key");
317        assert!(matches!(r, Err(ConfigError::NotFound(_))));
318    }
319
320    #[test]
321    fn extract_returns_wrong_type_for_mismatch() {
322        let c = Config::from_toml_str("[x]\ny = \"not a number\"\n").unwrap();
323        let r: Result<u32, _> = c.extract("x.y");
324        assert!(matches!(r, Err(ConfigError::WrongType { .. })));
325    }
326}