1use crate::builder::BuilderEntry;
4use crate::error::{FirewallObjectError, Result};
5use crate::ip::network::{NetworkObj, NetworkObjGroup};
6use crate::service::{ApplicationObj, ServiceObj, ServiceObjGroup};
7use std::collections::BTreeMap;
8
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12#[cfg(feature = "serde")]
13use serde_json;
14
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17#[cfg_attr(feature = "serde", serde(tag = "kind", rename_all = "snake_case"))]
18#[derive(Debug, Clone)]
19pub enum ObjectRecord {
20 Network(NetworkObj),
21 NetworkGroup(NetworkObjGroup),
22 Service(ServiceObj),
23 ServiceGroup(ServiceObjGroup),
24 Application(ApplicationObj),
25}
26
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub enum ObjectKind {
31 Network,
32 NetworkGroup,
33 Service,
34 ServiceGroup,
35 Application,
36}
37
38#[derive(Debug, Default)]
40pub struct ObjectStore {
41 networks: BTreeMap<String, NetworkObj>,
42 network_groups: BTreeMap<String, NetworkObjGroup>,
43 services: BTreeMap<String, ServiceObj>,
44 service_groups: BTreeMap<String, ServiceObjGroup>,
45 applications: BTreeMap<String, ApplicationObj>,
46}
47
48impl ObjectStore {
49 pub fn new() -> Self {
50 Self::default()
51 }
52
53 #[cfg(feature = "serde")]
54 pub fn create_from_json(&mut self, payload: &str) -> Result<()> {
69 let obj: ObjectRecord = serde_json::from_str(payload)?;
70 self.create(obj)
71 }
72
73 #[cfg(feature = "serde")]
74 pub fn update_from_json(&mut self, payload: &str) -> Result<()> {
75 let obj: ObjectRecord = serde_json::from_str(payload)?;
76 self.update(obj)
77 }
78
79 pub fn create(&mut self, obj: ObjectRecord) -> Result<()> {
80 match obj {
81 ObjectRecord::Network(net) => {
82 if self.networks.contains_key(&net.name) {
83 return Err(FirewallObjectError::Message(format!(
84 "network '{}' already exists",
85 net.name
86 )));
87 }
88 self.networks.insert(net.name.clone(), net);
89 }
90 ObjectRecord::NetworkGroup(group) => {
91 if self.network_groups.contains_key(&group.name) {
92 return Err(FirewallObjectError::Message(format!(
93 "network group '{}' already exists",
94 group.name
95 )));
96 }
97 self.network_groups.insert(group.name.clone(), group);
98 }
99 ObjectRecord::Service(service) => {
100 if self.services.contains_key(&service.name) {
101 return Err(FirewallObjectError::Message(format!(
102 "service '{}' already exists",
103 service.name
104 )));
105 }
106 self.services.insert(service.name.clone(), service);
107 }
108 ObjectRecord::ServiceGroup(group) => {
109 if self.service_groups.contains_key(&group.name) {
110 return Err(FirewallObjectError::Message(format!(
111 "service group '{}' already exists",
112 group.name
113 )));
114 }
115 self.service_groups.insert(group.name.clone(), group);
116 }
117 ObjectRecord::Application(app) => {
118 if self.applications.contains_key(&app.name) {
119 return Err(FirewallObjectError::Message(format!(
120 "application '{}' already exists",
121 app.name
122 )));
123 }
124 self.applications.insert(app.name.clone(), app);
125 }
126 }
127 Ok(())
128 }
129
130 pub fn add<T>(&mut self, entry: T) -> Result<()>
146 where
147 T: Into<BuilderEntry>,
148 {
149 match entry.into() {
150 BuilderEntry::Network(net) => self.insert_network(net),
151 BuilderEntry::NetworkGroup(group) => self.insert_network_group(group),
152 BuilderEntry::Service(service) => self.insert_service(service),
153 BuilderEntry::ServiceGroup(group) => self.insert_service_group(group),
154 BuilderEntry::Application(app) => self.insert_application(app),
155 }
156 }
157
158 pub fn update(&mut self, obj: ObjectRecord) -> Result<()> {
159 match obj {
160 ObjectRecord::Network(net) => match self.networks.get_mut(&net.name) {
161 Some(existing) => {
162 *existing = net;
163 }
164 None => return Err(not_found("network", &net.name)),
165 },
166 ObjectRecord::NetworkGroup(group) => match self.network_groups.get_mut(&group.name) {
167 Some(existing) => {
168 *existing = group;
169 }
170 None => return Err(not_found("network group", &group.name)),
171 },
172 ObjectRecord::Service(service) => match self.services.get_mut(&service.name) {
173 Some(existing) => {
174 *existing = service;
175 }
176 None => return Err(not_found("service", &service.name)),
177 },
178 ObjectRecord::ServiceGroup(group) => match self.service_groups.get_mut(&group.name) {
179 Some(existing) => {
180 *existing = group;
181 }
182 None => return Err(not_found("service group", &group.name)),
183 },
184 ObjectRecord::Application(app) => match self.applications.get_mut(&app.name) {
185 Some(existing) => {
186 *existing = app;
187 }
188 None => return Err(not_found("application", &app.name)),
189 },
190 }
191 Ok(())
192 }
193
194 pub fn delete(&mut self, kind: ObjectKind, name: &str) -> Result<()> {
195 let removed = match kind {
196 ObjectKind::Network => self.networks.remove(name).is_some(),
197 ObjectKind::NetworkGroup => self.network_groups.remove(name).is_some(),
198 ObjectKind::Service => self.services.remove(name).is_some(),
199 ObjectKind::ServiceGroup => self.service_groups.remove(name).is_some(),
200 ObjectKind::Application => self.applications.remove(name).is_some(),
201 };
202
203 if removed {
204 Ok(())
205 } else {
206 Err(not_found(kind_name(kind), name))
207 }
208 }
209
210 pub fn get(&self, kind: ObjectKind, name: &str) -> Result<ObjectRecord> {
211 let cloned = match kind {
212 ObjectKind::Network => self
213 .networks
214 .get(name)
215 .cloned()
216 .map(ObjectRecord::Network)
217 .ok_or_else(|| not_found("network", name))?,
218 ObjectKind::NetworkGroup => self
219 .network_groups
220 .get(name)
221 .cloned()
222 .map(ObjectRecord::NetworkGroup)
223 .ok_or_else(|| not_found("network group", name))?,
224 ObjectKind::Service => self
225 .services
226 .get(name)
227 .cloned()
228 .map(ObjectRecord::Service)
229 .ok_or_else(|| not_found("service", name))?,
230 ObjectKind::ServiceGroup => self
231 .service_groups
232 .get(name)
233 .cloned()
234 .map(ObjectRecord::ServiceGroup)
235 .ok_or_else(|| not_found("service group", name))?,
236 ObjectKind::Application => self
237 .applications
238 .get(name)
239 .cloned()
240 .map(ObjectRecord::Application)
241 .ok_or_else(|| not_found("application", name))?,
242 };
243 Ok(cloned)
244 }
245
246 #[cfg(feature = "serde")]
247 pub fn to_json(&self, kind: ObjectKind, name: &str) -> Result<String> {
263 let obj = self.get(kind, name)?;
264 serde_json::to_string_pretty(&obj).map_err(FirewallObjectError::from)
265 }
266
267 pub fn insert_network(&mut self, obj: NetworkObj) -> Result<()> {
269 self.create(ObjectRecord::Network(obj))
270 }
271
272 pub fn insert_network_group(&mut self, group: NetworkObjGroup) -> Result<()> {
274 self.create(ObjectRecord::NetworkGroup(group))
275 }
276
277 pub fn insert_service(&mut self, obj: ServiceObj) -> Result<()> {
279 self.create(ObjectRecord::Service(obj))
280 }
281
282 pub fn insert_service_group(&mut self, group: ServiceObjGroup) -> Result<()> {
284 self.create(ObjectRecord::ServiceGroup(group))
285 }
286
287 pub fn insert_application(&mut self, app: ApplicationObj) -> Result<()> {
289 self.create(ObjectRecord::Application(app))
290 }
291
292 pub fn network(&self, name: &str) -> Result<NetworkObj> {
294 match self.get(ObjectKind::Network, name)? {
295 ObjectRecord::Network(obj) => Ok(obj),
296 _ => unreachable!(),
297 }
298 }
299
300 pub fn network_group(&self, name: &str) -> Result<NetworkObjGroup> {
302 match self.get(ObjectKind::NetworkGroup, name)? {
303 ObjectRecord::NetworkGroup(obj) => Ok(obj),
304 _ => unreachable!(),
305 }
306 }
307
308 pub fn service(&self, name: &str) -> Result<ServiceObj> {
310 match self.get(ObjectKind::Service, name)? {
311 ObjectRecord::Service(obj) => Ok(obj),
312 _ => unreachable!(),
313 }
314 }
315
316 pub fn service_group(&self, name: &str) -> Result<ServiceObjGroup> {
318 match self.get(ObjectKind::ServiceGroup, name)? {
319 ObjectRecord::ServiceGroup(obj) => Ok(obj),
320 _ => unreachable!(),
321 }
322 }
323
324 pub fn application(&self, name: &str) -> Result<ApplicationObj> {
326 match self.get(ObjectKind::Application, name)? {
327 ObjectRecord::Application(obj) => Ok(obj),
328 _ => unreachable!(),
329 }
330 }
331}
332
333fn not_found(entity: &str, name: &str) -> FirewallObjectError {
334 FirewallObjectError::Message(format!("{entity} '{name}' not found"))
335}
336
337fn kind_name(kind: ObjectKind) -> &'static str {
338 match kind {
339 ObjectKind::Network => "network",
340 ObjectKind::NetworkGroup => "network group",
341 ObjectKind::Service => "service",
342 ObjectKind::ServiceGroup => "service group",
343 ObjectKind::Application => "application",
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use crate::service::{ApplicationMatchInput, ApplicationObj, TransportService};
351 use std::collections::BTreeSet;
352
353 #[cfg(feature = "serde")]
354 use serde_json;
355
356 #[test]
357 fn create_and_get_objects() {
358 let mut store = ObjectStore::new();
359 store
360 .create(ObjectRecord::Network(
361 NetworkObj::try_from(("db1", "192.0.2.10")).unwrap(),
362 ))
363 .unwrap();
364 store
365 .create(ObjectRecord::Service(ServiceObj::new(
366 "web".into(),
367 TransportService::tcp(443),
368 )))
369 .unwrap();
370
371 assert!(store.get(ObjectKind::Network, "db1").is_ok());
372 assert!(store.get(ObjectKind::Service, "web").is_ok());
373 }
374
375 #[test]
376 fn update_requires_existing() {
377 let mut store = ObjectStore::new();
378 let err = store
379 .update(ObjectRecord::Service(ServiceObj::new(
380 "dns".into(),
381 TransportService::udp(53),
382 )))
383 .unwrap_err();
384 assert!(format!("{err}").contains("not found"));
385 }
386
387 #[test]
388 fn delete_objects() {
389 let mut store = ObjectStore::new();
390 store
391 .create(ObjectRecord::Network(
392 NetworkObj::try_from(("api", "192.0.2.5")).unwrap(),
393 ))
394 .unwrap();
395 store.delete(ObjectKind::Network, "api").unwrap();
396 assert!(store.get(ObjectKind::Network, "api").is_err());
397
398 let mut members = BTreeSet::new();
399 members.insert(NetworkObj::try_from(("app", "192.0.2.6")).unwrap());
400 let group = NetworkObjGroup::new("tier1", members).unwrap();
401 store.create(ObjectRecord::NetworkGroup(group)).unwrap();
402 store.delete(ObjectKind::NetworkGroup, "tier1").unwrap();
403 }
404
405 #[test]
406 fn applications_supported() {
407 let mut store = ObjectStore::new();
408 let app = ApplicationObj {
409 name: "metrics-ui".into(),
410 category: "internal".into(),
411 transports: vec![TransportService::tcp(443)],
412 dns_suffixes: vec![".corp.local".into()],
413 tls_sni_suffixes: vec![],
414 http_hosts: vec!["metrics.corp.local".into()],
415 };
416
417 store
418 .create(ObjectRecord::Application(app.clone()))
419 .unwrap();
420
421 let fetched = store.get(ObjectKind::Application, "metrics-ui").unwrap();
422 if let ObjectRecord::Application(obj) = fetched {
423 assert!(obj.matches(&ApplicationMatchInput {
424 http_host: Some("metrics.corp.local"),
425 ..Default::default()
426 }));
427 } else {
428 panic!("expected application");
429 }
430 }
431
432 #[cfg(feature = "serde")]
433 #[test]
434 fn json_round_trip() {
435 let mut store = ObjectStore::new();
436 let payload = serde_json::to_string(&ObjectRecord::Application(ApplicationObj {
437 name: "docs".into(),
438 category: "external".into(),
439 transports: vec![TransportService::tcp(443)],
440 dns_suffixes: vec![".example.com".into()],
441 tls_sni_suffixes: vec![".example.com".into()],
442 http_hosts: vec!["docs.example.com".into()],
443 }))
444 .unwrap();
445
446 store.create_from_json(&payload).unwrap();
447 let json = store.to_json(ObjectKind::Application, "docs").unwrap();
448 assert!(json.contains("docs"));
449 }
450
451 #[test]
452 fn helper_methods_insert_and_get() {
453 let mut store = ObjectStore::new();
454 let network = NetworkObj::try_from(("app", "192.0.2.20")).unwrap();
455 store.insert_network(network.clone()).unwrap();
456 assert_eq!(store.network("app").unwrap(), network);
457
458 let service = ServiceObj::new("web".into(), TransportService::tcp(443));
459 store.insert_service(service.clone()).unwrap();
460 assert_eq!(store.service("web").unwrap(), service);
461
462 let app = ApplicationObj {
463 name: "helper-app".into(),
464 category: "misc".into(),
465 transports: vec![TransportService::udp(6000)],
466 dns_suffixes: vec![".helpers.local".into()],
467 tls_sni_suffixes: vec![],
468 http_hosts: vec![],
469 };
470 store.insert_application(app.clone()).unwrap();
471 assert!(
472 store
473 .application("helper-app")
474 .unwrap()
475 .matches(&ApplicationMatchInput {
476 dns_query: Some("svc.helpers.local"),
477 ..Default::default()
478 })
479 );
480 }
481}