1use crate::error::{FirewallObjectError, Result};
4use crate::ip::network::{NetworkObj, NetworkObjGroup};
5use crate::service::{ApplicationObj, ServiceObj, ServiceObjGroup};
6use std::collections::BTreeMap;
7
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11#[cfg(feature = "serde")]
12use serde_json;
13
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16#[cfg_attr(feature = "serde", serde(tag = "kind", rename_all = "snake_case"))]
17#[derive(Debug, Clone)]
18pub enum ObjectRecord {
19 Network(NetworkObj),
20 NetworkGroup(NetworkObjGroup),
21 Service(ServiceObj),
22 ServiceGroup(ServiceObjGroup),
23 Application(ApplicationObj),
24}
25
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum ObjectKind {
30 Network,
31 NetworkGroup,
32 Service,
33 ServiceGroup,
34 Application,
35}
36
37#[derive(Debug, Default)]
39pub struct ObjectStore {
40 networks: BTreeMap<String, NetworkObj>,
41 network_groups: BTreeMap<String, NetworkObjGroup>,
42 services: BTreeMap<String, ServiceObj>,
43 service_groups: BTreeMap<String, ServiceObjGroup>,
44 applications: BTreeMap<String, ApplicationObj>,
45}
46
47impl ObjectStore {
48 pub fn new() -> Self {
49 Self::default()
50 }
51
52 #[cfg(feature = "serde")]
53 pub fn create_from_json(&mut self, payload: &str) -> Result<()> {
68 let obj: ObjectRecord = serde_json::from_str(payload)?;
69 self.create(obj)
70 }
71
72 #[cfg(feature = "serde")]
73 pub fn update_from_json(&mut self, payload: &str) -> Result<()> {
74 let obj: ObjectRecord = serde_json::from_str(payload)?;
75 self.update(obj)
76 }
77
78 pub fn create(&mut self, obj: ObjectRecord) -> Result<()> {
79 match obj {
80 ObjectRecord::Network(net) => {
81 if self.networks.contains_key(&net.name) {
82 return Err(FirewallObjectError::Message(format!(
83 "network '{}' already exists",
84 net.name
85 )));
86 }
87 self.networks.insert(net.name.clone(), net);
88 }
89 ObjectRecord::NetworkGroup(group) => {
90 if self.network_groups.contains_key(&group.name) {
91 return Err(FirewallObjectError::Message(format!(
92 "network group '{}' already exists",
93 group.name
94 )));
95 }
96 self.network_groups.insert(group.name.clone(), group);
97 }
98 ObjectRecord::Service(service) => {
99 if self.services.contains_key(&service.name) {
100 return Err(FirewallObjectError::Message(format!(
101 "service '{}' already exists",
102 service.name
103 )));
104 }
105 self.services.insert(service.name.clone(), service);
106 }
107 ObjectRecord::ServiceGroup(group) => {
108 if self.service_groups.contains_key(&group.name) {
109 return Err(FirewallObjectError::Message(format!(
110 "service group '{}' already exists",
111 group.name
112 )));
113 }
114 self.service_groups.insert(group.name.clone(), group);
115 }
116 ObjectRecord::Application(app) => {
117 if self.applications.contains_key(&app.name) {
118 return Err(FirewallObjectError::Message(format!(
119 "application '{}' already exists",
120 app.name
121 )));
122 }
123 self.applications.insert(app.name.clone(), app);
124 }
125 }
126 Ok(())
127 }
128
129 pub fn update(&mut self, obj: ObjectRecord) -> Result<()> {
130 match obj {
131 ObjectRecord::Network(net) => match self.networks.get_mut(&net.name) {
132 Some(existing) => {
133 *existing = net;
134 }
135 None => return Err(not_found("network", &net.name)),
136 },
137 ObjectRecord::NetworkGroup(group) => match self.network_groups.get_mut(&group.name) {
138 Some(existing) => {
139 *existing = group;
140 }
141 None => return Err(not_found("network group", &group.name)),
142 },
143 ObjectRecord::Service(service) => match self.services.get_mut(&service.name) {
144 Some(existing) => {
145 *existing = service;
146 }
147 None => return Err(not_found("service", &service.name)),
148 },
149 ObjectRecord::ServiceGroup(group) => match self.service_groups.get_mut(&group.name) {
150 Some(existing) => {
151 *existing = group;
152 }
153 None => return Err(not_found("service group", &group.name)),
154 },
155 ObjectRecord::Application(app) => match self.applications.get_mut(&app.name) {
156 Some(existing) => {
157 *existing = app;
158 }
159 None => return Err(not_found("application", &app.name)),
160 },
161 }
162 Ok(())
163 }
164
165 pub fn delete(&mut self, kind: ObjectKind, name: &str) -> Result<()> {
166 let removed = match kind {
167 ObjectKind::Network => self.networks.remove(name).is_some(),
168 ObjectKind::NetworkGroup => self.network_groups.remove(name).is_some(),
169 ObjectKind::Service => self.services.remove(name).is_some(),
170 ObjectKind::ServiceGroup => self.service_groups.remove(name).is_some(),
171 ObjectKind::Application => self.applications.remove(name).is_some(),
172 };
173
174 if removed {
175 Ok(())
176 } else {
177 Err(not_found(kind_name(kind), name))
178 }
179 }
180
181 pub fn get(&self, kind: ObjectKind, name: &str) -> Result<ObjectRecord> {
182 let cloned = match kind {
183 ObjectKind::Network => self
184 .networks
185 .get(name)
186 .cloned()
187 .map(ObjectRecord::Network)
188 .ok_or_else(|| not_found("network", name))?,
189 ObjectKind::NetworkGroup => self
190 .network_groups
191 .get(name)
192 .cloned()
193 .map(ObjectRecord::NetworkGroup)
194 .ok_or_else(|| not_found("network group", name))?,
195 ObjectKind::Service => self
196 .services
197 .get(name)
198 .cloned()
199 .map(ObjectRecord::Service)
200 .ok_or_else(|| not_found("service", name))?,
201 ObjectKind::ServiceGroup => self
202 .service_groups
203 .get(name)
204 .cloned()
205 .map(ObjectRecord::ServiceGroup)
206 .ok_or_else(|| not_found("service group", name))?,
207 ObjectKind::Application => self
208 .applications
209 .get(name)
210 .cloned()
211 .map(ObjectRecord::Application)
212 .ok_or_else(|| not_found("application", name))?,
213 };
214 Ok(cloned)
215 }
216
217 #[cfg(feature = "serde")]
218 pub fn to_json(&self, kind: ObjectKind, name: &str) -> Result<String> {
234 let obj = self.get(kind, name)?;
235 serde_json::to_string_pretty(&obj).map_err(FirewallObjectError::from)
236 }
237
238 pub fn insert_network(&mut self, obj: NetworkObj) -> Result<()> {
240 self.create(ObjectRecord::Network(obj))
241 }
242
243 pub fn insert_network_group(&mut self, group: NetworkObjGroup) -> Result<()> {
245 self.create(ObjectRecord::NetworkGroup(group))
246 }
247
248 pub fn insert_service(&mut self, obj: ServiceObj) -> Result<()> {
250 self.create(ObjectRecord::Service(obj))
251 }
252
253 pub fn insert_service_group(&mut self, group: ServiceObjGroup) -> Result<()> {
255 self.create(ObjectRecord::ServiceGroup(group))
256 }
257
258 pub fn insert_application(&mut self, app: ApplicationObj) -> Result<()> {
260 self.create(ObjectRecord::Application(app))
261 }
262
263 pub fn network(&self, name: &str) -> Result<NetworkObj> {
265 match self.get(ObjectKind::Network, name)? {
266 ObjectRecord::Network(obj) => Ok(obj),
267 _ => unreachable!(),
268 }
269 }
270
271 pub fn network_group(&self, name: &str) -> Result<NetworkObjGroup> {
273 match self.get(ObjectKind::NetworkGroup, name)? {
274 ObjectRecord::NetworkGroup(obj) => Ok(obj),
275 _ => unreachable!(),
276 }
277 }
278
279 pub fn service(&self, name: &str) -> Result<ServiceObj> {
281 match self.get(ObjectKind::Service, name)? {
282 ObjectRecord::Service(obj) => Ok(obj),
283 _ => unreachable!(),
284 }
285 }
286
287 pub fn service_group(&self, name: &str) -> Result<ServiceObjGroup> {
289 match self.get(ObjectKind::ServiceGroup, name)? {
290 ObjectRecord::ServiceGroup(obj) => Ok(obj),
291 _ => unreachable!(),
292 }
293 }
294
295 pub fn application(&self, name: &str) -> Result<ApplicationObj> {
297 match self.get(ObjectKind::Application, name)? {
298 ObjectRecord::Application(obj) => Ok(obj),
299 _ => unreachable!(),
300 }
301 }
302}
303
304fn not_found(entity: &str, name: &str) -> FirewallObjectError {
305 FirewallObjectError::Message(format!("{entity} '{name}' not found"))
306}
307
308fn kind_name(kind: ObjectKind) -> &'static str {
309 match kind {
310 ObjectKind::Network => "network",
311 ObjectKind::NetworkGroup => "network group",
312 ObjectKind::Service => "service",
313 ObjectKind::ServiceGroup => "service group",
314 ObjectKind::Application => "application",
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::service::{ApplicationMatchInput, ApplicationObj, TransportService};
322 use std::collections::BTreeSet;
323
324 #[cfg(feature = "serde")]
325 use serde_json;
326
327 #[test]
328 fn create_and_get_objects() {
329 let mut store = ObjectStore::new();
330 store
331 .create(ObjectRecord::Network(
332 NetworkObj::try_from(("db1", "192.0.2.10")).unwrap(),
333 ))
334 .unwrap();
335 store
336 .create(ObjectRecord::Service(ServiceObj::new(
337 "web".into(),
338 TransportService::tcp(443),
339 )))
340 .unwrap();
341
342 assert!(store.get(ObjectKind::Network, "db1").is_ok());
343 assert!(store.get(ObjectKind::Service, "web").is_ok());
344 }
345
346 #[test]
347 fn update_requires_existing() {
348 let mut store = ObjectStore::new();
349 let err = store
350 .update(ObjectRecord::Service(ServiceObj::new(
351 "dns".into(),
352 TransportService::udp(53),
353 )))
354 .unwrap_err();
355 assert!(format!("{err}").contains("not found"));
356 }
357
358 #[test]
359 fn delete_objects() {
360 let mut store = ObjectStore::new();
361 store
362 .create(ObjectRecord::Network(
363 NetworkObj::try_from(("api", "192.0.2.5")).unwrap(),
364 ))
365 .unwrap();
366 store.delete(ObjectKind::Network, "api").unwrap();
367 assert!(store.get(ObjectKind::Network, "api").is_err());
368
369 let mut members = BTreeSet::new();
370 members.insert(NetworkObj::try_from(("app", "192.0.2.6")).unwrap());
371 let group = NetworkObjGroup::new("tier1", members).unwrap();
372 store.create(ObjectRecord::NetworkGroup(group)).unwrap();
373 store.delete(ObjectKind::NetworkGroup, "tier1").unwrap();
374 }
375
376 #[test]
377 fn applications_supported() {
378 let mut store = ObjectStore::new();
379 let app = ApplicationObj {
380 name: "metrics-ui".into(),
381 category: "internal".into(),
382 transports: vec![TransportService::tcp(443)],
383 dns_suffixes: vec![".corp.local".into()],
384 tls_sni_suffixes: vec![],
385 http_hosts: vec!["metrics.corp.local".into()],
386 };
387
388 store
389 .create(ObjectRecord::Application(app.clone()))
390 .unwrap();
391
392 let fetched = store.get(ObjectKind::Application, "metrics-ui").unwrap();
393 if let ObjectRecord::Application(obj) = fetched {
394 assert!(obj.matches(&ApplicationMatchInput {
395 http_host: Some("metrics.corp.local"),
396 ..Default::default()
397 }));
398 } else {
399 panic!("expected application");
400 }
401 }
402
403 #[cfg(feature = "serde")]
404 #[test]
405 fn json_round_trip() {
406 let mut store = ObjectStore::new();
407 let payload = serde_json::to_string(&ObjectRecord::Application(ApplicationObj {
408 name: "docs".into(),
409 category: "external".into(),
410 transports: vec![TransportService::tcp(443)],
411 dns_suffixes: vec![".example.com".into()],
412 tls_sni_suffixes: vec![".example.com".into()],
413 http_hosts: vec!["docs.example.com".into()],
414 }))
415 .unwrap();
416
417 store.create_from_json(&payload).unwrap();
418 let json = store.to_json(ObjectKind::Application, "docs").unwrap();
419 assert!(json.contains("docs"));
420 }
421
422 #[test]
423 fn helper_methods_insert_and_get() {
424 let mut store = ObjectStore::new();
425 let network = NetworkObj::try_from(("app", "192.0.2.20")).unwrap();
426 store.insert_network(network.clone()).unwrap();
427 assert_eq!(store.network("app").unwrap(), network);
428
429 let service = ServiceObj::new("web".into(), TransportService::tcp(443));
430 store.insert_service(service.clone()).unwrap();
431 assert_eq!(store.service("web").unwrap(), service);
432
433 let app = ApplicationObj {
434 name: "helper-app".into(),
435 category: "misc".into(),
436 transports: vec![TransportService::udp(6000)],
437 dns_suffixes: vec![".helpers.local".into()],
438 tls_sni_suffixes: vec![],
439 http_hosts: vec![],
440 };
441 store.insert_application(app.clone()).unwrap();
442 assert!(
443 store
444 .application("helper-app")
445 .unwrap()
446 .matches(&ApplicationMatchInput {
447 dns_query: Some("svc.helpers.local"),
448 ..Default::default()
449 })
450 );
451 }
452}