1use std::collections::BTreeMap;
5use std::time::Duration;
6
7use crate::error::ConfigError;
8use crate::path::ConfigPath;
9use crate::reference::reference_config;
10
11#[derive(Debug, Clone, PartialEq)]
12pub enum ConfigValue {
13 Null,
14 Bool(bool),
15 Int(i64),
16 Float(f64),
17 String(String),
18 Array(Vec<ConfigValue>),
19 Object(BTreeMap<String, ConfigValue>),
20}
21
22impl ConfigValue {
23 pub fn type_name(&self) -> &'static str {
24 match self {
25 Self::Null => "null",
26 Self::Bool(_) => "bool",
27 Self::Int(_) => "int",
28 Self::Float(_) => "float",
29 Self::String(_) => "string",
30 Self::Array(_) => "array",
31 Self::Object(_) => "object",
32 }
33 }
34
35 fn from_toml(v: toml::Value) -> Self {
36 match v {
37 toml::Value::String(s) => Self::String(s),
38 toml::Value::Integer(i) => Self::Int(i),
39 toml::Value::Float(f) => Self::Float(f),
40 toml::Value::Boolean(b) => Self::Bool(b),
41 toml::Value::Datetime(d) => Self::String(d.to_string()),
42 toml::Value::Array(a) => Self::Array(a.into_iter().map(Self::from_toml).collect()),
43 toml::Value::Table(t) => {
44 Self::Object(t.into_iter().map(|(k, v)| (k, Self::from_toml(v))).collect())
45 }
46 }
47 }
48}
49
50#[derive(Debug, Clone, Default, PartialEq)]
52pub struct Config {
53 root: BTreeMap<String, ConfigValue>,
54}
55
56impl Config {
57 pub fn empty() -> Self {
58 Self::default()
59 }
60
61 pub fn reference() -> Self {
63 Self::from_toml_str(reference_config()).expect("built-in reference.conf.toml is valid")
64 }
65
66 pub fn from_toml_str(s: &str) -> Result<Self, ConfigError> {
67 let v: toml::Value = toml::from_str(s)?;
68 let table = match v {
69 toml::Value::Table(t) => t,
70 _ => return Err(ConfigError::WrongType { path: "".into(), expected: "object" }),
71 };
72 Ok(Self { root: table.into_iter().map(|(k, v)| (k, ConfigValue::from_toml(v))).collect() })
73 }
74
75 pub fn from_hocon_str(s: &str) -> Result<Self, ConfigError> {
78 let v = crate::hocon::parse(s, std::path::Path::new("."))?;
79 match v {
80 ConfigValue::Object(o) => Ok(Self { root: o }),
81 _ => Err(ConfigError::WrongType { path: "".into(), expected: "object" }),
82 }
83 }
84
85 pub fn from_hocon_file(path: impl AsRef<std::path::Path>) -> Result<Self, ConfigError> {
88 let v = crate::hocon::parse_file(path.as_ref())?;
89 match v {
90 ConfigValue::Object(o) => Ok(Self { root: o }),
91 _ => Err(ConfigError::WrongType { path: "".into(), expected: "object" }),
92 }
93 }
94
95 pub fn with_fallback(mut self, fallback: Self) -> Self {
98 merge_object(&mut self.root, fallback.root, false);
99 self
100 }
101
102 pub fn merged_with(mut self, other: Self) -> Self {
104 merge_object(&mut self.root, other.root, true);
105 self
106 }
107
108 pub fn get(&self, path: &str) -> Option<&ConfigValue> {
109 let p = ConfigPath::parse(path);
110 lookup(&self.root, p.segments())
111 }
112
113 pub fn get_string(&self, path: &str) -> Result<String, ConfigError> {
114 match self.get(path) {
115 Some(ConfigValue::String(s)) => Ok(s.clone()),
116 Some(v) => Err(ConfigError::WrongType { path: path.into(), expected: v.type_name() }),
117 None => Err(ConfigError::NotFound(path.into())),
118 }
119 }
120
121 pub fn get_int(&self, path: &str) -> Result<i64, ConfigError> {
122 match self.get(path) {
123 Some(ConfigValue::Int(i)) => Ok(*i),
124 Some(ConfigValue::Float(f)) => Ok(*f as i64),
125 Some(v) => Err(ConfigError::WrongType { path: path.into(), expected: v.type_name() }),
126 None => Err(ConfigError::NotFound(path.into())),
127 }
128 }
129
130 pub fn get_bool(&self, path: &str) -> Result<bool, ConfigError> {
131 match self.get(path) {
132 Some(ConfigValue::Bool(b)) => Ok(*b),
133 Some(v) => Err(ConfigError::WrongType { path: path.into(), expected: v.type_name() }),
134 None => Err(ConfigError::NotFound(path.into())),
135 }
136 }
137
138 pub fn get_duration(&self, path: &str) -> Result<Duration, ConfigError> {
140 match self.get(path) {
141 Some(ConfigValue::String(s)) => parse_duration(s)
142 .ok_or_else(|| ConfigError::WrongType { path: path.into(), expected: "duration" }),
143 Some(ConfigValue::Int(i)) => Ok(Duration::from_millis(*i as u64)),
144 Some(v) => Err(ConfigError::WrongType { path: path.into(), expected: v.type_name() }),
145 None => Err(ConfigError::NotFound(path.into())),
146 }
147 }
148
149 pub fn get_sub(&self, path: &str) -> Option<Config> {
150 match self.get(path)? {
151 ConfigValue::Object(o) => Some(Self { root: o.clone() }),
152 _ => None,
153 }
154 }
155}
156
157fn lookup<'a>(root: &'a BTreeMap<String, ConfigValue>, segs: &[String]) -> Option<&'a ConfigValue> {
158 let (head, tail) = segs.split_first()?;
159 let v = root.get(head)?;
160 if tail.is_empty() {
161 return Some(v);
162 }
163 match v {
164 ConfigValue::Object(o) => lookup(o, tail),
165 _ => None,
166 }
167}
168
169fn merge_object(
170 dst: &mut BTreeMap<String, ConfigValue>,
171 src: BTreeMap<String, ConfigValue>,
172 override_rhs: bool,
173) {
174 for (k, v) in src {
175 match dst.get_mut(&k) {
176 Some(ConfigValue::Object(inner)) => {
177 if let ConfigValue::Object(src_inner) = v {
178 merge_object(inner, src_inner, override_rhs);
179 } else if override_rhs {
180 dst.insert(k, v);
181 }
182 }
183 Some(_) if override_rhs => {
184 dst.insert(k, v);
185 }
186 Some(_) => {} None => {
188 dst.insert(k, v);
189 }
190 }
191 }
192}
193
194fn parse_duration(s: &str) -> Option<Duration> {
195 let s = s.trim();
196 let (num, unit) = split_number_unit(s)?;
197 let n: f64 = num.parse().ok()?;
198 let ms = match unit {
199 "ms" | "millis" | "milliseconds" => n,
200 "s" | "sec" | "seconds" | "" => n * 1000.0,
201 "m" | "min" | "minutes" => n * 60_000.0,
202 "h" | "hr" | "hours" => n * 3_600_000.0,
203 "d" | "days" => n * 86_400_000.0,
204 _ => return None,
205 };
206 Some(Duration::from_micros((ms * 1000.0) as u64))
207}
208
209fn split_number_unit(s: &str) -> Option<(&str, &str)> {
210 let idx = s.find(|c: char| !(c.is_ascii_digit() || c == '.' || c == '-')).unwrap_or(s.len());
211 let (n, u) = s.split_at(idx);
212 Some((n.trim(), u.trim()))
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn reference_loads() {
221 let c = Config::reference();
222 assert!(c.get_string("akka.actor.provider").is_ok());
223 }
224
225 #[test]
226 fn fallback_keeps_existing() {
227 let a = Config::from_toml_str("[akka]\nfoo = \"a\"\n").unwrap();
228 let b = Config::from_toml_str("[akka]\nfoo = \"b\"\nbar = \"B\"\n").unwrap();
229 let c = a.with_fallback(b);
230 assert_eq!(c.get_string("akka.foo").unwrap(), "a");
231 assert_eq!(c.get_string("akka.bar").unwrap(), "B");
232 }
233
234 #[test]
235 fn override_merge() {
236 let a = Config::from_toml_str("[akka]\nfoo = \"a\"\n").unwrap();
237 let b = Config::from_toml_str("[akka]\nfoo = \"b\"\n").unwrap();
238 let c = a.merged_with(b);
239 assert_eq!(c.get_string("akka.foo").unwrap(), "b");
240 }
241
242 #[test]
243 fn duration_parses_units() {
244 let c = Config::from_toml_str("[x]\nt = \"500ms\"\n").unwrap();
245 assert_eq!(c.get_duration("x.t").unwrap(), Duration::from_millis(500));
246 let c = Config::from_toml_str("[x]\nt = \"2s\"\n").unwrap();
247 assert_eq!(c.get_duration("x.t").unwrap(), Duration::from_secs(2));
248 }
249
250 #[test]
251 fn get_sub_returns_sub_tree() {
252 let c = Config::reference();
253 let actor = c.get_sub("akka.actor").unwrap();
254 assert!(actor.get_string("provider").is_ok());
255 }
256}