Skip to main content

firewall_objects/objects/
mod.rs

1//! Object storage helpers, including optional JSON serialization (`serde` feature).
2
3use crate::builder::BuilderEntry;
4use crate::error::{FirewallObjectError, Result};
5use crate::ip::network::{NetworkObj, NetworkObjGroup};
6use crate::service::{ApplicationObj, ServiceObj, ServiceObjGroup};
7use std::collections::BTreeMap;
8
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12#[cfg(feature = "serde")]
13use serde_json;
14
15/// Unified representation of objects managed via CRUD helpers.
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17#[cfg_attr(feature = "serde", serde(tag = "kind", rename_all = "snake_case"))]
18#[derive(Debug, Clone)]
19pub enum ObjectRecord {
20    Network(NetworkObj),
21    NetworkGroup(NetworkObjGroup),
22    Service(ServiceObj),
23    ServiceGroup(ServiceObjGroup),
24    Application(ApplicationObj),
25}
26
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub enum ObjectKind {
31    Network,
32    NetworkGroup,
33    Service,
34    ServiceGroup,
35    Application,
36}
37
38/// In-memory store for firewall objects.
39#[derive(Debug, Default)]
40pub struct ObjectStore {
41    networks: BTreeMap<String, NetworkObj>,
42    network_groups: BTreeMap<String, NetworkObjGroup>,
43    services: BTreeMap<String, ServiceObj>,
44    service_groups: BTreeMap<String, ServiceObjGroup>,
45    applications: BTreeMap<String, ApplicationObj>,
46}
47
48impl ObjectStore {
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    #[cfg(feature = "serde")]
54    /// Create an object from a JSON payload.
55    ///
56    /// ```
57    /// # use firewall_objects::objects::{ObjectStore, ObjectKind};
58    /// # use firewall_objects::service::TransportService;
59    /// let payload = r#"{
60    ///     "kind": "service",
61    ///     "name": "https",
62    ///     "value": { "Tcp": 443 }
63    /// }"#;
64    /// let mut store = ObjectStore::new();
65    /// store.create_from_json(payload).unwrap();
66    /// assert!(store.get(ObjectKind::Service, "https").is_ok());
67    /// ```
68    pub fn create_from_json(&mut self, payload: &str) -> Result<()> {
69        let obj: ObjectRecord = serde_json::from_str(payload)?;
70        self.create(obj)
71    }
72
73    #[cfg(feature = "serde")]
74    pub fn update_from_json(&mut self, payload: &str) -> Result<()> {
75        let obj: ObjectRecord = serde_json::from_str(payload)?;
76        self.update(obj)
77    }
78
79    pub fn create(&mut self, obj: ObjectRecord) -> Result<()> {
80        match obj {
81            ObjectRecord::Network(net) => {
82                if self.networks.contains_key(&net.name) {
83                    return Err(FirewallObjectError::Message(format!(
84                        "network '{}' already exists",
85                        net.name
86                    )));
87                }
88                self.networks.insert(net.name.clone(), net);
89            }
90            ObjectRecord::NetworkGroup(group) => {
91                if self.network_groups.contains_key(&group.name) {
92                    return Err(FirewallObjectError::Message(format!(
93                        "network group '{}' already exists",
94                        group.name
95                    )));
96                }
97                self.network_groups.insert(group.name.clone(), group);
98            }
99            ObjectRecord::Service(service) => {
100                if self.services.contains_key(&service.name) {
101                    return Err(FirewallObjectError::Message(format!(
102                        "service '{}' already exists",
103                        service.name
104                    )));
105                }
106                self.services.insert(service.name.clone(), service);
107            }
108            ObjectRecord::ServiceGroup(group) => {
109                if self.service_groups.contains_key(&group.name) {
110                    return Err(FirewallObjectError::Message(format!(
111                        "service group '{}' already exists",
112                        group.name
113                    )));
114                }
115                self.service_groups.insert(group.name.clone(), group);
116            }
117            ObjectRecord::Application(app) => {
118                if self.applications.contains_key(&app.name) {
119                    return Err(FirewallObjectError::Message(format!(
120                        "application '{}' already exists",
121                        app.name
122                    )));
123                }
124                self.applications.insert(app.name.clone(), app);
125            }
126        }
127        Ok(())
128    }
129
130    /// Insert any builder entry or object emitted by the [`builder`](crate::builder) helpers.
131    ///
132    /// ```
133    /// use firewall_objects::builder::{address, service, service_group};
134    /// use firewall_objects::objects::ObjectStore;
135    ///
136    /// let mut store = ObjectStore::new();
137    /// store.add(address("server1", "192.0.2.10").unwrap()).unwrap();
138    ///
139    /// let services = service_group("web").unwrap()
140    ///     .with_service(service::tcp(443)).unwrap()
141    ///     .with_service(service::udp(8443)).unwrap()
142    ///     .build().unwrap();
143    /// store.add(services).unwrap();
144    /// ```
145    pub fn add<T>(&mut self, entry: T) -> Result<()>
146    where
147        T: Into<BuilderEntry>,
148    {
149        match entry.into() {
150            BuilderEntry::Network(net) => self.insert_network(net),
151            BuilderEntry::NetworkGroup(group) => self.insert_network_group(group),
152            BuilderEntry::Service(service) => self.insert_service(service),
153            BuilderEntry::ServiceGroup(group) => self.insert_service_group(group),
154            BuilderEntry::Application(app) => self.insert_application(app),
155        }
156    }
157
158    pub fn update(&mut self, obj: ObjectRecord) -> Result<()> {
159        match obj {
160            ObjectRecord::Network(net) => match self.networks.get_mut(&net.name) {
161                Some(existing) => {
162                    *existing = net;
163                }
164                None => return Err(not_found("network", &net.name)),
165            },
166            ObjectRecord::NetworkGroup(group) => match self.network_groups.get_mut(&group.name) {
167                Some(existing) => {
168                    *existing = group;
169                }
170                None => return Err(not_found("network group", &group.name)),
171            },
172            ObjectRecord::Service(service) => match self.services.get_mut(&service.name) {
173                Some(existing) => {
174                    *existing = service;
175                }
176                None => return Err(not_found("service", &service.name)),
177            },
178            ObjectRecord::ServiceGroup(group) => match self.service_groups.get_mut(&group.name) {
179                Some(existing) => {
180                    *existing = group;
181                }
182                None => return Err(not_found("service group", &group.name)),
183            },
184            ObjectRecord::Application(app) => match self.applications.get_mut(&app.name) {
185                Some(existing) => {
186                    *existing = app;
187                }
188                None => return Err(not_found("application", &app.name)),
189            },
190        }
191        Ok(())
192    }
193
194    pub fn delete(&mut self, kind: ObjectKind, name: &str) -> Result<()> {
195        let removed = match kind {
196            ObjectKind::Network => self.networks.remove(name).is_some(),
197            ObjectKind::NetworkGroup => self.network_groups.remove(name).is_some(),
198            ObjectKind::Service => self.services.remove(name).is_some(),
199            ObjectKind::ServiceGroup => self.service_groups.remove(name).is_some(),
200            ObjectKind::Application => self.applications.remove(name).is_some(),
201        };
202
203        if removed {
204            Ok(())
205        } else {
206            Err(not_found(kind_name(kind), name))
207        }
208    }
209
210    pub fn get(&self, kind: ObjectKind, name: &str) -> Result<ObjectRecord> {
211        let cloned = match kind {
212            ObjectKind::Network => self
213                .networks
214                .get(name)
215                .cloned()
216                .map(ObjectRecord::Network)
217                .ok_or_else(|| not_found("network", name))?,
218            ObjectKind::NetworkGroup => self
219                .network_groups
220                .get(name)
221                .cloned()
222                .map(ObjectRecord::NetworkGroup)
223                .ok_or_else(|| not_found("network group", name))?,
224            ObjectKind::Service => self
225                .services
226                .get(name)
227                .cloned()
228                .map(ObjectRecord::Service)
229                .ok_or_else(|| not_found("service", name))?,
230            ObjectKind::ServiceGroup => self
231                .service_groups
232                .get(name)
233                .cloned()
234                .map(ObjectRecord::ServiceGroup)
235                .ok_or_else(|| not_found("service group", name))?,
236            ObjectKind::Application => self
237                .applications
238                .get(name)
239                .cloned()
240                .map(ObjectRecord::Application)
241                .ok_or_else(|| not_found("application", name))?,
242        };
243        Ok(cloned)
244    }
245
246    #[cfg(feature = "serde")]
247    /// Serialize an object to JSON.
248    ///
249    /// ```
250    /// # use firewall_objects::objects::{ObjectStore, ObjectRecord, ObjectKind};
251    /// # use firewall_objects::service::TransportService;
252    /// let mut store = ObjectStore::new();
253    /// store.create(ObjectRecord::Service(
254    ///     firewall_objects::service::ServiceObj::new(
255    ///         "dns".into(),
256    ///         TransportService::udp(53),
257    ///     )
258    /// )).unwrap();
259    /// let json = store.to_json(ObjectKind::Service, "dns").unwrap();
260    /// assert!(json.contains("\"dns\""));
261    /// ```
262    pub fn to_json(&self, kind: ObjectKind, name: &str) -> Result<String> {
263        let obj = self.get(kind, name)?;
264        serde_json::to_string_pretty(&obj).map_err(FirewallObjectError::from)
265    }
266
267    /// Convenience: insert a network object.
268    pub fn insert_network(&mut self, obj: NetworkObj) -> Result<()> {
269        self.create(ObjectRecord::Network(obj))
270    }
271
272    /// Convenience: insert a network group.
273    pub fn insert_network_group(&mut self, group: NetworkObjGroup) -> Result<()> {
274        self.create(ObjectRecord::NetworkGroup(group))
275    }
276
277    /// Convenience: insert a service object.
278    pub fn insert_service(&mut self, obj: ServiceObj) -> Result<()> {
279        self.create(ObjectRecord::Service(obj))
280    }
281
282    /// Convenience: insert a service group.
283    pub fn insert_service_group(&mut self, group: ServiceObjGroup) -> Result<()> {
284        self.create(ObjectRecord::ServiceGroup(group))
285    }
286
287    /// Convenience: insert an application.
288    pub fn insert_application(&mut self, app: ApplicationObj) -> Result<()> {
289        self.create(ObjectRecord::Application(app))
290    }
291
292    /// Convenience accessor for networks.
293    pub fn network(&self, name: &str) -> Result<NetworkObj> {
294        match self.get(ObjectKind::Network, name)? {
295            ObjectRecord::Network(obj) => Ok(obj),
296            _ => unreachable!(),
297        }
298    }
299
300    /// Convenience accessor for network groups.
301    pub fn network_group(&self, name: &str) -> Result<NetworkObjGroup> {
302        match self.get(ObjectKind::NetworkGroup, name)? {
303            ObjectRecord::NetworkGroup(obj) => Ok(obj),
304            _ => unreachable!(),
305        }
306    }
307
308    /// Convenience accessor for services.
309    pub fn service(&self, name: &str) -> Result<ServiceObj> {
310        match self.get(ObjectKind::Service, name)? {
311            ObjectRecord::Service(obj) => Ok(obj),
312            _ => unreachable!(),
313        }
314    }
315
316    /// Convenience accessor for service groups.
317    pub fn service_group(&self, name: &str) -> Result<ServiceObjGroup> {
318        match self.get(ObjectKind::ServiceGroup, name)? {
319            ObjectRecord::ServiceGroup(obj) => Ok(obj),
320            _ => unreachable!(),
321        }
322    }
323
324    /// Convenience accessor for applications.
325    pub fn application(&self, name: &str) -> Result<ApplicationObj> {
326        match self.get(ObjectKind::Application, name)? {
327            ObjectRecord::Application(obj) => Ok(obj),
328            _ => unreachable!(),
329        }
330    }
331}
332
333fn not_found(entity: &str, name: &str) -> FirewallObjectError {
334    FirewallObjectError::Message(format!("{entity} '{name}' not found"))
335}
336
337fn kind_name(kind: ObjectKind) -> &'static str {
338    match kind {
339        ObjectKind::Network => "network",
340        ObjectKind::NetworkGroup => "network group",
341        ObjectKind::Service => "service",
342        ObjectKind::ServiceGroup => "service group",
343        ObjectKind::Application => "application",
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use crate::service::{ApplicationMatchInput, ApplicationObj, TransportService};
351    use std::collections::BTreeSet;
352
353    #[cfg(feature = "serde")]
354    use serde_json;
355
356    #[test]
357    fn create_and_get_objects() {
358        let mut store = ObjectStore::new();
359        store
360            .create(ObjectRecord::Network(
361                NetworkObj::try_from(("db1", "192.0.2.10")).unwrap(),
362            ))
363            .unwrap();
364        store
365            .create(ObjectRecord::Service(ServiceObj::new(
366                "web".into(),
367                TransportService::tcp(443),
368            )))
369            .unwrap();
370
371        assert!(store.get(ObjectKind::Network, "db1").is_ok());
372        assert!(store.get(ObjectKind::Service, "web").is_ok());
373    }
374
375    #[test]
376    fn update_requires_existing() {
377        let mut store = ObjectStore::new();
378        let err = store
379            .update(ObjectRecord::Service(ServiceObj::new(
380                "dns".into(),
381                TransportService::udp(53),
382            )))
383            .unwrap_err();
384        assert!(format!("{err}").contains("not found"));
385    }
386
387    #[test]
388    fn delete_objects() {
389        let mut store = ObjectStore::new();
390        store
391            .create(ObjectRecord::Network(
392                NetworkObj::try_from(("api", "192.0.2.5")).unwrap(),
393            ))
394            .unwrap();
395        store.delete(ObjectKind::Network, "api").unwrap();
396        assert!(store.get(ObjectKind::Network, "api").is_err());
397
398        let mut members = BTreeSet::new();
399        members.insert(NetworkObj::try_from(("app", "192.0.2.6")).unwrap());
400        let group = NetworkObjGroup::new("tier1", members).unwrap();
401        store.create(ObjectRecord::NetworkGroup(group)).unwrap();
402        store.delete(ObjectKind::NetworkGroup, "tier1").unwrap();
403    }
404
405    #[test]
406    fn applications_supported() {
407        let mut store = ObjectStore::new();
408        let app = ApplicationObj {
409            name: "metrics-ui".into(),
410            category: "internal".into(),
411            transports: vec![TransportService::tcp(443)],
412            dns_suffixes: vec![".corp.local".into()],
413            tls_sni_suffixes: vec![],
414            http_hosts: vec!["metrics.corp.local".into()],
415        };
416
417        store
418            .create(ObjectRecord::Application(app.clone()))
419            .unwrap();
420
421        let fetched = store.get(ObjectKind::Application, "metrics-ui").unwrap();
422        if let ObjectRecord::Application(obj) = fetched {
423            assert!(obj.matches(&ApplicationMatchInput {
424                http_host: Some("metrics.corp.local"),
425                ..Default::default()
426            }));
427        } else {
428            panic!("expected application");
429        }
430    }
431
432    #[cfg(feature = "serde")]
433    #[test]
434    fn json_round_trip() {
435        let mut store = ObjectStore::new();
436        let payload = serde_json::to_string(&ObjectRecord::Application(ApplicationObj {
437            name: "docs".into(),
438            category: "external".into(),
439            transports: vec![TransportService::tcp(443)],
440            dns_suffixes: vec![".example.com".into()],
441            tls_sni_suffixes: vec![".example.com".into()],
442            http_hosts: vec!["docs.example.com".into()],
443        }))
444        .unwrap();
445
446        store.create_from_json(&payload).unwrap();
447        let json = store.to_json(ObjectKind::Application, "docs").unwrap();
448        assert!(json.contains("docs"));
449    }
450
451    #[test]
452    fn helper_methods_insert_and_get() {
453        let mut store = ObjectStore::new();
454        let network = NetworkObj::try_from(("app", "192.0.2.20")).unwrap();
455        store.insert_network(network.clone()).unwrap();
456        assert_eq!(store.network("app").unwrap(), network);
457
458        let service = ServiceObj::new("web".into(), TransportService::tcp(443));
459        store.insert_service(service.clone()).unwrap();
460        assert_eq!(store.service("web").unwrap(), service);
461
462        let app = ApplicationObj {
463            name: "helper-app".into(),
464            category: "misc".into(),
465            transports: vec![TransportService::udp(6000)],
466            dns_suffixes: vec![".helpers.local".into()],
467            tls_sni_suffixes: vec![],
468            http_hosts: vec![],
469        };
470        store.insert_application(app.clone()).unwrap();
471        assert!(
472            store
473                .application("helper-app")
474                .unwrap()
475                .matches(&ApplicationMatchInput {
476                    dns_query: Some("svc.helpers.local"),
477                    ..Default::default()
478                })
479        );
480    }
481}