rootasrole_core/database/
mod.rs

1use std::path::Path;
2use std::{cell::RefCell, error::Error, rc::Rc};
3
4use crate::save_settings;
5use crate::util::{toggle_lock_config, ImmutableLock};
6use crate::version::PACKAGE_VERSION;
7
8use actor::SUserType;
9use bon::{builder, Builder};
10use chrono::Duration;
11use linked_hash_set::LinkedHashSet;
12use log::debug;
13use options::EnvBehavior;
14use serde::{de, Deserialize, Serialize};
15
16use self::{migration::Migration, options::EnvKey, structs::SConfig, versionning::Versioning};
17
18use crate::util::warn_if_mutable;
19use crate::SettingsFile;
20use crate::{open_with_privileges, write_json_config};
21use crate::{util::immutable_effective, RemoteStorageSettings, ROOTASROLE};
22
23pub mod actor;
24#[cfg(feature = "finder")]
25pub mod finder;
26pub mod migration;
27pub mod options;
28pub mod structs;
29pub mod versionning;
30
31#[derive(Debug, Default, Builder)]
32#[builder(on(_, overwritable))]
33pub struct FilterMatcher {
34    pub role: Option<String>,
35    pub task: Option<String>,
36    pub env_behavior: Option<EnvBehavior>,
37    #[builder(into)]
38    pub user: Option<SUserType>,
39}
40
41pub fn make_weak_config(config: &Rc<RefCell<SConfig>>) {
42    for role in &config.as_ref().borrow().roles {
43        role.as_ref().borrow_mut()._config = Some(Rc::downgrade(config));
44        for task in &role.as_ref().borrow().tasks {
45            task.as_ref().borrow_mut()._role = Some(Rc::downgrade(role));
46        }
47    }
48}
49
50pub fn read_json_config<P: AsRef<Path>>(
51    settings: Rc<RefCell<SettingsFile>>,
52    settings_path: P,
53) -> Result<Rc<RefCell<SConfig>>, Box<dyn Error>> {
54    let default_remote: RemoteStorageSettings = RemoteStorageSettings::default();
55    let binding = settings.as_ref().borrow();
56    let path = binding
57        .storage
58        .settings
59        .as_ref()
60        .unwrap_or(&default_remote)
61        .path
62        .as_ref();
63    if path.is_none() || path.is_some_and(|p| p == settings_path.as_ref()) {
64        make_weak_config(&settings.as_ref().borrow().config);
65        return Ok(settings.as_ref().borrow().config.clone());
66    } else {
67        let file = open_with_privileges(path.unwrap())?;
68        warn_if_mutable(
69            &file,
70            settings
71                .as_ref()
72                .borrow()
73                .storage
74                .settings
75                .as_ref()
76                .unwrap_or(&default_remote)
77                .immutable
78                .unwrap_or(true),
79        )?;
80        let versionned_config: Versioning<Rc<RefCell<SConfig>>> = serde_json::from_reader(file)?;
81        let config = versionned_config.data;
82        if let Ok(true) = Migration::migrate(
83            &versionned_config.version,
84            &mut *config.as_ref().borrow_mut(),
85            versionning::JSON_MIGRATIONS,
86        ) {
87            save_json(settings.clone(), config.clone())?;
88        } else {
89            debug!("No migrations needed");
90        }
91        make_weak_config(&config);
92        Ok(config)
93    }
94}
95
96pub fn save_json(
97    settings: Rc<RefCell<SettingsFile>>,
98    config: Rc<RefCell<SConfig>>,
99) -> Result<(), Box<dyn Error>> {
100    let default_remote: RemoteStorageSettings = RemoteStorageSettings::default();
101    let into = ROOTASROLE.into();
102    let binding = settings.as_ref().borrow();
103    let path = binding
104        .storage
105        .settings
106        .as_ref()
107        .unwrap_or(&default_remote)
108        .path
109        .as_ref()
110        .unwrap_or(&into);
111    if path == &into {
112        // if /etc/security/rootasrole.json then you need to consider the settings to save in addition to the config
113        return save_settings(settings.clone());
114    }
115
116    debug!("Writing config file");
117    let versionned: Versioning<Rc<RefCell<SConfig>>> = Versioning {
118        version: PACKAGE_VERSION.to_owned().parse()?,
119        data: config,
120    };
121    if let Some(settings) = &settings.as_ref().borrow().storage.settings {
122        if settings.immutable.unwrap_or(true) {
123            debug!("Toggling immutable on for config file");
124            toggle_lock_config(path, ImmutableLock::Unset)?;
125        }
126    }
127    write_sconfig(&settings.as_ref().borrow(), versionned)?;
128    if let Some(settings) = &settings.as_ref().borrow().storage.settings {
129        if settings.immutable.unwrap_or(true) {
130            debug!("Toggling immutable off for config file");
131            toggle_lock_config(path, ImmutableLock::Set)?;
132        }
133    }
134    debug!("Resetting immutable privilege");
135    immutable_effective(false)?;
136    Ok(())
137}
138
139fn write_sconfig(
140    settings: &SettingsFile,
141    config: Versioning<Rc<RefCell<SConfig>>>,
142) -> Result<(), Box<dyn Error>> {
143    let default_remote = RemoteStorageSettings::default();
144    let binding = ROOTASROLE.into();
145    let path = settings
146        .storage
147        .settings
148        .as_ref()
149        .unwrap_or(&default_remote)
150        .path
151        .as_ref()
152        .unwrap_or(&binding);
153    write_json_config(&config, path)?;
154    Ok(())
155}
156
157// deserialize the linked hash set
158fn lhs_deserialize_envkey<'de, D>(
159    deserializer: D,
160) -> Result<Option<LinkedHashSet<EnvKey>>, D::Error>
161where
162    D: de::Deserializer<'de>,
163{
164    if let Ok(v) = Vec::<EnvKey>::deserialize(deserializer) {
165        Ok(Some(v.into_iter().collect()))
166    } else {
167        Ok(None)
168    }
169}
170
171// serialize the linked hash set
172fn lhs_serialize_envkey<S>(
173    value: &Option<LinkedHashSet<EnvKey>>,
174    serializer: S,
175) -> Result<S::Ok, S::Error>
176where
177    S: serde::Serializer,
178{
179    if let Some(v) = value {
180        let v: Vec<EnvKey> = v.iter().cloned().collect();
181        v.serialize(serializer)
182    } else {
183        serializer.serialize_none()
184    }
185}
186
187// deserialize the linked hash set
188fn lhs_deserialize<'de, D>(deserializer: D) -> Result<Option<LinkedHashSet<String>>, D::Error>
189where
190    D: de::Deserializer<'de>,
191{
192    if let Ok(v) = Vec::<String>::deserialize(deserializer) {
193        Ok(Some(v.into_iter().collect()))
194    } else {
195        Ok(None)
196    }
197}
198
199// serialize the linked hash set
200fn lhs_serialize<S>(value: &Option<LinkedHashSet<String>>, serializer: S) -> Result<S::Ok, S::Error>
201where
202    S: serde::Serializer,
203{
204    if let Some(v) = value {
205        let v: Vec<String> = v.iter().cloned().collect();
206        v.serialize(serializer)
207    } else {
208        serializer.serialize_none()
209    }
210}
211
212pub fn is_default<T: PartialEq + Default>(t: &T) -> bool {
213    t == &T::default()
214}
215
216fn serialize_duration<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
217where
218    S: serde::Serializer,
219{
220    // hh:mm:ss format
221    match value {
222        Some(value) => serializer.serialize_str(&format!(
223            "{:#02}:{:#02}:{:#02}",
224            value.num_hours(),
225            value.num_minutes() % 60,
226            value.num_seconds() % 60
227        )),
228        None => serializer.serialize_none(),
229    }
230}
231
232fn deserialize_duration<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
233where
234    D: de::Deserializer<'de>,
235{
236    let s = String::deserialize(deserializer)?;
237    let mut parts = s.split(':');
238    //unwrap or error
239    if let (Some(hours), Some(minutes), Some(seconds)) = (parts.next(), parts.next(), parts.next())
240    {
241        let hours: i64 = hours.parse().map_err(de::Error::custom)?;
242        let minutes: i64 = minutes.parse().map_err(de::Error::custom)?;
243        let seconds: i64 = seconds.parse().map_err(de::Error::custom)?;
244        return Ok(Some(
245            Duration::hours(hours) + Duration::minutes(minutes) + Duration::seconds(seconds),
246        ));
247    }
248    Err(de::Error::custom("Invalid duration format"))
249}
250
251fn serialize_capset<S>(value: &capctl::CapSet, serializer: S) -> Result<S::Ok, S::Error>
252where
253    S: serde::Serializer,
254{
255    let v: Vec<String> = value.iter().map(|cap| cap.to_string()).collect();
256    v.serialize(serializer)
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    struct LinkedHashSetTester<T>(LinkedHashSet<T>);
264
265    impl<'de> Deserialize<'de> for LinkedHashSetTester<EnvKey> {
266        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267        where
268            D: serde::Deserializer<'de>,
269        {
270            Ok(Self(
271                lhs_deserialize_envkey(deserializer).map(|v| v.unwrap())?,
272            ))
273        }
274    }
275
276    impl Serialize for LinkedHashSetTester<EnvKey> {
277        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
278        where
279            S: serde::Serializer,
280        {
281            lhs_serialize_envkey(&Some(self.0.clone()), serializer)
282        }
283    }
284
285    impl<'de> Deserialize<'de> for LinkedHashSetTester<String> {
286        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
287        where
288            D: serde::Deserializer<'de>,
289        {
290            Ok(Self(lhs_deserialize(deserializer).map(|v| v.unwrap())?))
291        }
292    }
293
294    impl Serialize for LinkedHashSetTester<String> {
295        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
296        where
297            S: serde::Serializer,
298        {
299            lhs_serialize(&Some(self.0.clone()), serializer)
300        }
301    }
302
303    struct DurationTester(Duration);
304
305    impl<'de> Deserialize<'de> for DurationTester {
306        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
307        where
308            D: serde::Deserializer<'de>,
309        {
310            Ok(Self(
311                deserialize_duration(deserializer).map(|v| v.unwrap())?,
312            ))
313        }
314    }
315
316    impl Serialize for DurationTester {
317        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
318        where
319            S: serde::Serializer,
320        {
321            serialize_duration(&Some(self.0.clone()), serializer)
322        }
323    }
324
325    #[test]
326    fn test_lhs_deserialize_envkey() {
327        let json = r#"["key1", "key2", "key3"]"#;
328        let deserialized: Option<LinkedHashSetTester<EnvKey>> = serde_json::from_str(json).unwrap();
329        assert!(deserialized.is_some());
330        let set = deserialized.unwrap();
331        assert_eq!(set.0.len(), 3);
332        assert!(set.0.contains(&EnvKey::from("key1")));
333        assert!(set.0.contains(&EnvKey::from("key2")));
334        assert!(set.0.contains(&EnvKey::from("key3")));
335    }
336
337    #[test]
338    fn test_lhs_serialize_envkey() {
339        let mut set = LinkedHashSetTester(LinkedHashSet::new());
340        set.0.insert(EnvKey::from("key1"));
341        set.0.insert(EnvKey::from("key2"));
342        set.0.insert(EnvKey::from("key3"));
343        let serialized = serde_json::to_string(&Some(set)).unwrap();
344        assert_eq!(serialized, r#"["key1","key2","key3"]"#);
345    }
346
347    #[test]
348    fn test_lhs_deserialize() {
349        let json = r#"["value1", "value2", "value3"]"#;
350        let deserialized: Option<LinkedHashSetTester<String>> = serde_json::from_str(json).unwrap();
351        assert!(deserialized.is_some());
352        let set = deserialized.unwrap();
353        assert_eq!(set.0.len(), 3);
354        assert!(set.0.contains("value1"));
355        assert!(set.0.contains("value2"));
356        assert!(set.0.contains("value3"));
357    }
358
359    #[test]
360    fn test_lhs_serialize() {
361        let mut set = LinkedHashSetTester(LinkedHashSet::new());
362        set.0.insert("value1".to_string());
363        set.0.insert("value2".to_string());
364        set.0.insert("value3".to_string());
365        let serialized = serde_json::to_string(&Some(set)).unwrap();
366        assert_eq!(serialized, r#"["value1","value2","value3"]"#);
367    }
368
369    #[test]
370    fn test_serialize_duration() {
371        let duration = Some(DurationTester(Duration::seconds(3661)));
372        let serialized = serde_json::to_string(&duration).unwrap();
373        assert_eq!(serialized, r#""01:01:01""#);
374    }
375
376    #[test]
377    fn test_deserialize_duration() {
378        let json = r#""01:01:01""#;
379        let deserialized: Option<DurationTester> = serde_json::from_str(json).unwrap();
380        assert!(deserialized.is_some());
381        let duration = deserialized.unwrap();
382        assert_eq!(duration.0.num_seconds(), 3661);
383    }
384
385    #[test]
386    fn test_is_default() {
387        assert!(is_default(&0));
388        assert!(is_default(&String::new()));
389        assert!(!is_default(&1));
390        assert!(!is_default(&"non-default".to_string()));
391    }
392}