rootasrole_core/database/
mod.rs1use std::path::Path;
2use std::{cell::RefCell, error::Error, rc::Rc};
3
4use crate::save_settings;
5use crate::util::{toggle_lock_config, ImmutableLock};
6use crate::version::PACKAGE_VERSION;
7
8use actor::SUserType;
9use bon::{builder, Builder};
10use chrono::Duration;
11use linked_hash_set::LinkedHashSet;
12use log::debug;
13use options::EnvBehavior;
14use serde::{de, Deserialize, Serialize};
15
16use self::{migration::Migration, options::EnvKey, structs::SConfig, versionning::Versioning};
17
18use crate::util::warn_if_mutable;
19use crate::SettingsFile;
20use crate::{open_with_privileges, write_json_config};
21use crate::{util::immutable_effective, RemoteStorageSettings, ROOTASROLE};
22
23pub mod actor;
24#[cfg(feature = "finder")]
25pub mod finder;
26pub mod migration;
27pub mod options;
28pub mod structs;
29pub mod versionning;
30
31#[derive(Debug, Default, Builder)]
32#[builder(on(_, overwritable))]
33pub struct FilterMatcher {
34 pub role: Option<String>,
35 pub task: Option<String>,
36 pub env_behavior: Option<EnvBehavior>,
37 #[builder(into)]
38 pub user: Option<SUserType>,
39}
40
41pub fn make_weak_config(config: &Rc<RefCell<SConfig>>) {
42 for role in &config.as_ref().borrow().roles {
43 role.as_ref().borrow_mut()._config = Some(Rc::downgrade(config));
44 for task in &role.as_ref().borrow().tasks {
45 task.as_ref().borrow_mut()._role = Some(Rc::downgrade(role));
46 }
47 }
48}
49
50pub fn read_json_config<P: AsRef<Path>>(
51 settings: Rc<RefCell<SettingsFile>>,
52 settings_path: P,
53) -> Result<Rc<RefCell<SConfig>>, Box<dyn Error>> {
54 let default_remote: RemoteStorageSettings = RemoteStorageSettings::default();
55 let binding = settings.as_ref().borrow();
56 let path = binding
57 .storage
58 .settings
59 .as_ref()
60 .unwrap_or(&default_remote)
61 .path
62 .as_ref();
63 if path.is_none() || path.is_some_and(|p| p == settings_path.as_ref()) {
64 make_weak_config(&settings.as_ref().borrow().config);
65 return Ok(settings.as_ref().borrow().config.clone());
66 } else {
67 let file = open_with_privileges(path.unwrap())?;
68 warn_if_mutable(
69 &file,
70 settings
71 .as_ref()
72 .borrow()
73 .storage
74 .settings
75 .as_ref()
76 .unwrap_or(&default_remote)
77 .immutable
78 .unwrap_or(true),
79 )?;
80 let versionned_config: Versioning<Rc<RefCell<SConfig>>> = serde_json::from_reader(file)?;
81 let config = versionned_config.data;
82 if let Ok(true) = Migration::migrate(
83 &versionned_config.version,
84 &mut *config.as_ref().borrow_mut(),
85 versionning::JSON_MIGRATIONS,
86 ) {
87 save_json(settings.clone(), config.clone())?;
88 } else {
89 debug!("No migrations needed");
90 }
91 make_weak_config(&config);
92 Ok(config)
93 }
94}
95
96pub fn save_json(
97 settings: Rc<RefCell<SettingsFile>>,
98 config: Rc<RefCell<SConfig>>,
99) -> Result<(), Box<dyn Error>> {
100 let default_remote: RemoteStorageSettings = RemoteStorageSettings::default();
101 let into = ROOTASROLE.into();
102 let binding = settings.as_ref().borrow();
103 let path = binding
104 .storage
105 .settings
106 .as_ref()
107 .unwrap_or(&default_remote)
108 .path
109 .as_ref()
110 .unwrap_or(&into);
111 if path == &into {
112 return save_settings(settings.clone());
114 }
115
116 debug!("Writing config file");
117 let versionned: Versioning<Rc<RefCell<SConfig>>> = Versioning {
118 version: PACKAGE_VERSION.to_owned().parse()?,
119 data: config,
120 };
121 if let Some(settings) = &settings.as_ref().borrow().storage.settings {
122 if settings.immutable.unwrap_or(true) {
123 debug!("Toggling immutable on for config file");
124 toggle_lock_config(path, ImmutableLock::Unset)?;
125 }
126 }
127 write_sconfig(&settings.as_ref().borrow(), versionned)?;
128 if let Some(settings) = &settings.as_ref().borrow().storage.settings {
129 if settings.immutable.unwrap_or(true) {
130 debug!("Toggling immutable off for config file");
131 toggle_lock_config(path, ImmutableLock::Set)?;
132 }
133 }
134 debug!("Resetting immutable privilege");
135 immutable_effective(false)?;
136 Ok(())
137}
138
139fn write_sconfig(
140 settings: &SettingsFile,
141 config: Versioning<Rc<RefCell<SConfig>>>,
142) -> Result<(), Box<dyn Error>> {
143 let default_remote = RemoteStorageSettings::default();
144 let binding = ROOTASROLE.into();
145 let path = settings
146 .storage
147 .settings
148 .as_ref()
149 .unwrap_or(&default_remote)
150 .path
151 .as_ref()
152 .unwrap_or(&binding);
153 write_json_config(&config, path)?;
154 Ok(())
155}
156
157fn lhs_deserialize_envkey<'de, D>(
159 deserializer: D,
160) -> Result<Option<LinkedHashSet<EnvKey>>, D::Error>
161where
162 D: de::Deserializer<'de>,
163{
164 if let Ok(v) = Vec::<EnvKey>::deserialize(deserializer) {
165 Ok(Some(v.into_iter().collect()))
166 } else {
167 Ok(None)
168 }
169}
170
171fn lhs_serialize_envkey<S>(
173 value: &Option<LinkedHashSet<EnvKey>>,
174 serializer: S,
175) -> Result<S::Ok, S::Error>
176where
177 S: serde::Serializer,
178{
179 if let Some(v) = value {
180 let v: Vec<EnvKey> = v.iter().cloned().collect();
181 v.serialize(serializer)
182 } else {
183 serializer.serialize_none()
184 }
185}
186
187fn lhs_deserialize<'de, D>(deserializer: D) -> Result<Option<LinkedHashSet<String>>, D::Error>
189where
190 D: de::Deserializer<'de>,
191{
192 if let Ok(v) = Vec::<String>::deserialize(deserializer) {
193 Ok(Some(v.into_iter().collect()))
194 } else {
195 Ok(None)
196 }
197}
198
199fn lhs_serialize<S>(value: &Option<LinkedHashSet<String>>, serializer: S) -> Result<S::Ok, S::Error>
201where
202 S: serde::Serializer,
203{
204 if let Some(v) = value {
205 let v: Vec<String> = v.iter().cloned().collect();
206 v.serialize(serializer)
207 } else {
208 serializer.serialize_none()
209 }
210}
211
212pub fn is_default<T: PartialEq + Default>(t: &T) -> bool {
213 t == &T::default()
214}
215
216fn serialize_duration<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
217where
218 S: serde::Serializer,
219{
220 match value {
222 Some(value) => serializer.serialize_str(&format!(
223 "{:#02}:{:#02}:{:#02}",
224 value.num_hours(),
225 value.num_minutes() % 60,
226 value.num_seconds() % 60
227 )),
228 None => serializer.serialize_none(),
229 }
230}
231
232fn deserialize_duration<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
233where
234 D: de::Deserializer<'de>,
235{
236 let s = String::deserialize(deserializer)?;
237 let mut parts = s.split(':');
238 if let (Some(hours), Some(minutes), Some(seconds)) = (parts.next(), parts.next(), parts.next())
240 {
241 let hours: i64 = hours.parse().map_err(de::Error::custom)?;
242 let minutes: i64 = minutes.parse().map_err(de::Error::custom)?;
243 let seconds: i64 = seconds.parse().map_err(de::Error::custom)?;
244 return Ok(Some(
245 Duration::hours(hours) + Duration::minutes(minutes) + Duration::seconds(seconds),
246 ));
247 }
248 Err(de::Error::custom("Invalid duration format"))
249}
250
251fn serialize_capset<S>(value: &capctl::CapSet, serializer: S) -> Result<S::Ok, S::Error>
252where
253 S: serde::Serializer,
254{
255 let v: Vec<String> = value.iter().map(|cap| cap.to_string()).collect();
256 v.serialize(serializer)
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 struct LinkedHashSetTester<T>(LinkedHashSet<T>);
264
265 impl<'de> Deserialize<'de> for LinkedHashSetTester<EnvKey> {
266 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267 where
268 D: serde::Deserializer<'de>,
269 {
270 Ok(Self(
271 lhs_deserialize_envkey(deserializer).map(|v| v.unwrap())?,
272 ))
273 }
274 }
275
276 impl Serialize for LinkedHashSetTester<EnvKey> {
277 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
278 where
279 S: serde::Serializer,
280 {
281 lhs_serialize_envkey(&Some(self.0.clone()), serializer)
282 }
283 }
284
285 impl<'de> Deserialize<'de> for LinkedHashSetTester<String> {
286 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
287 where
288 D: serde::Deserializer<'de>,
289 {
290 Ok(Self(lhs_deserialize(deserializer).map(|v| v.unwrap())?))
291 }
292 }
293
294 impl Serialize for LinkedHashSetTester<String> {
295 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
296 where
297 S: serde::Serializer,
298 {
299 lhs_serialize(&Some(self.0.clone()), serializer)
300 }
301 }
302
303 struct DurationTester(Duration);
304
305 impl<'de> Deserialize<'de> for DurationTester {
306 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
307 where
308 D: serde::Deserializer<'de>,
309 {
310 Ok(Self(
311 deserialize_duration(deserializer).map(|v| v.unwrap())?,
312 ))
313 }
314 }
315
316 impl Serialize for DurationTester {
317 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
318 where
319 S: serde::Serializer,
320 {
321 serialize_duration(&Some(self.0.clone()), serializer)
322 }
323 }
324
325 #[test]
326 fn test_lhs_deserialize_envkey() {
327 let json = r#"["key1", "key2", "key3"]"#;
328 let deserialized: Option<LinkedHashSetTester<EnvKey>> = serde_json::from_str(json).unwrap();
329 assert!(deserialized.is_some());
330 let set = deserialized.unwrap();
331 assert_eq!(set.0.len(), 3);
332 assert!(set.0.contains(&EnvKey::from("key1")));
333 assert!(set.0.contains(&EnvKey::from("key2")));
334 assert!(set.0.contains(&EnvKey::from("key3")));
335 }
336
337 #[test]
338 fn test_lhs_serialize_envkey() {
339 let mut set = LinkedHashSetTester(LinkedHashSet::new());
340 set.0.insert(EnvKey::from("key1"));
341 set.0.insert(EnvKey::from("key2"));
342 set.0.insert(EnvKey::from("key3"));
343 let serialized = serde_json::to_string(&Some(set)).unwrap();
344 assert_eq!(serialized, r#"["key1","key2","key3"]"#);
345 }
346
347 #[test]
348 fn test_lhs_deserialize() {
349 let json = r#"["value1", "value2", "value3"]"#;
350 let deserialized: Option<LinkedHashSetTester<String>> = serde_json::from_str(json).unwrap();
351 assert!(deserialized.is_some());
352 let set = deserialized.unwrap();
353 assert_eq!(set.0.len(), 3);
354 assert!(set.0.contains("value1"));
355 assert!(set.0.contains("value2"));
356 assert!(set.0.contains("value3"));
357 }
358
359 #[test]
360 fn test_lhs_serialize() {
361 let mut set = LinkedHashSetTester(LinkedHashSet::new());
362 set.0.insert("value1".to_string());
363 set.0.insert("value2".to_string());
364 set.0.insert("value3".to_string());
365 let serialized = serde_json::to_string(&Some(set)).unwrap();
366 assert_eq!(serialized, r#"["value1","value2","value3"]"#);
367 }
368
369 #[test]
370 fn test_serialize_duration() {
371 let duration = Some(DurationTester(Duration::seconds(3661)));
372 let serialized = serde_json::to_string(&duration).unwrap();
373 assert_eq!(serialized, r#""01:01:01""#);
374 }
375
376 #[test]
377 fn test_deserialize_duration() {
378 let json = r#""01:01:01""#;
379 let deserialized: Option<DurationTester> = serde_json::from_str(json).unwrap();
380 assert!(deserialized.is_some());
381 let duration = deserialized.unwrap();
382 assert_eq!(duration.0.num_seconds(), 3661);
383 }
384
385 #[test]
386 fn test_is_default() {
387 assert!(is_default(&0));
388 assert!(is_default(&String::new()));
389 assert!(!is_default(&1));
390 assert!(!is_default(&"non-default".to_string()));
391 }
392}