1use serde::{Deserialize, Serialize, Serializer};
2use std::{fmt::Display, ops::RangeInclusive, str::FromStr};
3
4pub use ipnet::Ipv4Net;
5
6#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
16pub struct TemplateId(pub u32);
17
18impl From<u32> for TemplateId {
19 fn from(value: u32) -> Self {
20 TemplateId(value)
21 }
22}
23
24impl From<TemplateId> for u32 {
25 fn from(value: TemplateId) -> Self {
26 value.0
27 }
28}
29
30impl Display for TemplateId {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 self.0.fmt(f)
33 }
34}
35
36impl PartialEq<u32> for TemplateId {
37 fn eq(&self, other: &u32) -> bool {
38 self.0.eq(other)
39 }
40}
41
42#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
44pub enum State {
45 #[serde(rename = "active")]
47 Active,
48
49 #[serde(rename = "in process")]
51 InProcess,
52
53 #[serde(rename = "disabled")]
55 Disabled,
56}
57
58impl Display for State {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(
61 f,
62 "{}",
63 match self {
64 State::Active => "active",
65 State::InProcess => "in process",
66 State::Disabled => "disabled",
67 }
68 )
69 }
70}
71
72#[derive(Default, Clone, Debug, Serialize, Deserialize)]
74#[serde(rename_all = "lowercase")]
75pub enum SwitchPort {
76 #[default]
77 Main,
79 Kvm,
81}
82
83#[derive(Clone, PartialEq, Eq, Debug)]
85pub enum Protocol {
86 Tcp {
88 flags: Option<String>,
92 },
93
94 Udp,
96
97 Gre,
99
100 Icmp,
102
103 Ipip,
105
106 Ah,
108
109 Esp,
111}
112
113impl Protocol {
114 pub fn tcp_with_flags(flags: &str) -> Self {
116 Protocol::Tcp {
117 flags: Some(flags.to_string()),
118 }
119 }
120
121 pub(crate) fn flags(&self) -> Option<String> {
122 match self {
123 Protocol::Tcp { flags } => flags.clone(),
124 _ => None,
125 }
126 }
127}
128
129#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Serialize, Deserialize)]
131#[serde(rename_all = "lowercase")]
132pub enum Action {
133 #[default]
135 Accept,
136
137 Discard,
139}
140
141impl AsRef<str> for Action {
142 fn as_ref(&self) -> &str {
143 match self {
144 Action::Accept => "accept",
145 Action::Discard => "discard",
146 }
147 }
148}
149
150impl Display for Action {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 f.write_str(self.as_ref())
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct FirewallTemplateReference {
162 pub id: TemplateId,
165
166 pub name: String,
168
169 pub filter_ipv6: bool,
171
172 #[serde(rename = "whitelist_hos")]
175 pub whitelist_hetzner_services: bool,
176
177 pub is_default: bool,
179}
180
181#[derive(Debug, Clone)]
183pub struct FirewallTemplate {
184 pub id: TemplateId,
186
187 pub name: String,
189
190 pub filter_ipv6: bool,
192
193 pub whitelist_hetzner_services: bool,
196
197 pub is_default: bool,
200
201 pub rules: Rules,
203}
204
205#[derive(Debug, Clone)]
207pub struct FirewallTemplateConfig {
208 pub name: String,
210
211 pub filter_ipv6: bool,
213
214 pub whitelist_hetzner_services: bool,
217
218 pub is_default: bool,
221
222 pub rules: Rules,
224}
225
226#[derive(Debug, Clone)]
232pub struct Firewall {
233 pub status: State,
235
236 pub filter_ipv6: bool,
238
239 pub whitelist_hetzner_services: bool,
242
243 pub port: SwitchPort,
245
246 pub rules: Rules,
248}
249
250impl Firewall {
251 pub fn config(&self) -> FirewallConfig {
253 self.into()
254 }
255}
256
257#[derive(Debug)]
259pub struct FirewallConfig {
260 pub status: State,
262
263 pub filter_ipv6: bool,
265
266 pub whitelist_hetzner_services: bool,
269
270 pub rules: Rules,
272}
273
274impl FirewallConfig {
275 #[must_use = "This doesn't create the template, only produces a config which you can then upload with AsyncRobot::create_firewall_template"]
277 pub fn to_template_config(&self, name: &str) -> FirewallTemplateConfig {
278 FirewallTemplateConfig {
279 name: name.to_string(),
280 filter_ipv6: self.filter_ipv6,
281 whitelist_hetzner_services: self.whitelist_hetzner_services,
282 is_default: false,
283 rules: self.rules.clone(),
284 }
285 }
286}
287
288impl From<&Firewall> for FirewallConfig {
289 fn from(value: &Firewall) -> Self {
290 FirewallConfig {
291 status: value.status,
292 filter_ipv6: value.filter_ipv6,
293 whitelist_hetzner_services: value.whitelist_hetzner_services,
294 rules: value.rules.clone(),
295 }
296 }
297}
298
299#[derive(Debug, Clone, PartialEq, Eq)]
301pub struct Rules {
302 pub ingress: Vec<Rule>,
304
305 pub egress: Vec<Rule>,
307}
308
309#[derive(Debug, Clone, PartialEq, Eq)]
311pub struct PortRange(RangeInclusive<u16>);
312
313impl PortRange {
314 pub fn port(port: u16) -> Self {
325 PortRange(RangeInclusive::new(port, port))
326 }
327
328 pub fn range(start: u16, end: u16) -> Self {
337 PortRange(RangeInclusive::new(start, end))
338 }
339
340 pub fn start(&self) -> u16 {
342 *self.0.start()
343 }
344
345 pub fn end(&self) -> u16 {
347 *self.0.end()
348 }
349}
350
351impl Display for PortRange {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353 write!(f, "{}", self.start())?;
354 if self.end() != self.start() {
355 write!(f, "-{}", self.end())?;
356 }
357
358 Ok(())
359 }
360}
361
362impl From<u16> for PortRange {
363 fn from(value: u16) -> Self {
364 PortRange::port(value)
365 }
366}
367
368impl From<RangeInclusive<u16>> for PortRange {
369 fn from(value: RangeInclusive<u16>) -> Self {
370 PortRange(value)
371 }
372}
373
374impl From<&RangeInclusive<u16>> for PortRange {
375 fn from(value: &RangeInclusive<u16>) -> Self {
376 PortRange(value.clone())
377 }
378}
379
380impl From<PortRange> for RangeInclusive<u16> {
381 fn from(value: PortRange) -> Self {
382 value.0
383 }
384}
385
386impl From<&PortRange> for RangeInclusive<u16> {
387 fn from(value: &PortRange) -> Self {
388 value.0.clone()
389 }
390}
391
392impl From<&PortRange> for Vec<PortRange> {
393 fn from(value: &PortRange) -> Self {
394 vec![value.clone()]
395 }
396}
397
398impl IntoIterator for PortRange {
399 type Item = u16;
400
401 type IntoIter = <RangeInclusive<u16> as IntoIterator>::IntoIter;
402
403 fn into_iter(self) -> Self::IntoIter {
404 self.0
405 }
406}
407
408#[derive(Debug, thiserror::Error)]
410#[error("invalid port '{0}': {1}")]
411pub struct InvalidPort(String, <u16 as FromStr>::Err);
412
413impl FromStr for PortRange {
414 type Err = InvalidPort;
415
416 fn from_str(value: &str) -> Result<Self, Self::Err> {
417 if let Some((start, end)) = value.split_once('-') {
418 let start = start
419 .parse::<u16>()
420 .map_err(|err| InvalidPort(start.to_string(), err))?;
421 let end = end
422 .parse::<u16>()
423 .map_err(|err| InvalidPort(end.to_string(), err))?;
424
425 Ok(PortRange(RangeInclusive::new(start, end)))
426 } else {
427 let port = value
428 .parse::<u16>()
429 .map_err(|err| InvalidPort(value.to_string(), err))?;
430
431 Ok(PortRange(RangeInclusive::new(port, port)))
432 }
433 }
434}
435
436impl<'de> Deserialize<'de> for PortRange {
437 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
438 where
439 D: serde::Deserializer<'de>,
440 {
441 use serde::de::Error;
442
443 let value: &str = Deserialize::deserialize(deserializer)?;
444
445 PortRange::from_str(value).map_err(D::Error::custom)
446 }
447}
448
449impl Serialize for PortRange {
450 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
451 where
452 S: Serializer,
453 {
454 if self.0.start() == self.0.end() {
455 serializer.serialize_str(&format!("{}", self.start()))
456 } else {
457 serializer.serialize_str(&format!("{}-{}", self.start(), self.end()))
458 }
459 }
460}
461
462#[derive(Debug, Clone, PartialEq, Eq)]
464pub enum Filter {
465 Any(AnyFilter),
467 Ipv4(Ipv4Filter),
469 Ipv6(Ipv6Filter),
471}
472
473impl Default for Filter {
474 fn default() -> Self {
475 Filter::Any(AnyFilter::default())
476 }
477}
478
479impl From<Ipv4Filter> for Filter {
480 fn from(value: Ipv4Filter) -> Self {
481 Filter::Ipv4(value)
482 }
483}
484
485impl From<Ipv6Filter> for Filter {
486 fn from(value: Ipv6Filter) -> Self {
487 Filter::Ipv6(value)
488 }
489}
490
491#[derive(Debug, Clone, Default, PartialEq, Eq)]
493pub struct AnyFilter {
494 pub dst_port: Vec<PortRange>,
496
497 pub src_port: Vec<PortRange>,
499}
500
501impl AnyFilter {
502 pub fn from_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
504 self.src_port = vec![range.into()];
505 self
506 }
507
508 pub fn to_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
510 self.dst_port = vec![range.into()];
511 self
512 }
513}
514
515#[derive(Debug, Clone, Default, PartialEq, Eq)]
517pub struct Ipv6Filter {
518 pub protocol: Option<Protocol>,
520
521 pub dst_port: Vec<PortRange>,
523
524 pub src_port: Vec<PortRange>,
526}
527
528impl Ipv6Filter {
529 pub fn any() -> Self {
531 Ipv6Filter {
532 protocol: None,
533 dst_port: Vec::new(),
534 src_port: Vec::new(),
535 }
536 }
537
538 pub fn ah() -> Self {
540 Ipv6Filter {
541 protocol: Some(Protocol::Ah),
542 dst_port: Vec::new(),
543 src_port: Vec::new(),
544 }
545 }
546
547 pub fn esp() -> Self {
549 Ipv6Filter {
550 protocol: Some(Protocol::Esp),
551 dst_port: Vec::new(),
552 src_port: Vec::new(),
553 }
554 }
555
556 pub fn ipip() -> Self {
558 Ipv6Filter {
559 protocol: Some(Protocol::Ipip),
560 dst_port: Vec::new(),
561 src_port: Vec::new(),
562 }
563 }
564
565 pub fn gre() -> Self {
567 Ipv6Filter {
568 protocol: Some(Protocol::Gre),
569 dst_port: Vec::new(),
570 src_port: Vec::new(),
571 }
572 }
573
574 pub fn udp() -> Self {
576 Ipv6Filter {
577 protocol: Some(Protocol::Udp),
578 dst_port: Vec::new(),
579 src_port: Vec::new(),
580 }
581 }
582
583 pub fn tcp(flags: Option<String>) -> Self {
585 Ipv6Filter {
586 protocol: Some(Protocol::Tcp { flags }),
587 dst_port: Vec::new(),
588 src_port: Vec::new(),
589 }
590 }
591
592 pub fn from_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
594 self.src_port.push(range.into());
595 self
596 }
597
598 pub fn to_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
600 self.dst_port.push(range.into());
601 self
602 }
603}
604
605#[derive(Debug, Clone, Default, PartialEq, Eq)]
607pub struct Ipv4Filter {
608 pub dst_ip: Option<Ipv4Net>,
610
611 pub src_ip: Option<Ipv4Net>,
616
617 pub dst_port: Vec<PortRange>,
619
620 pub src_port: Vec<PortRange>,
622
623 pub protocol: Option<Protocol>,
625}
626
627impl Ipv4Filter {
628 pub fn any() -> Self {
630 Ipv4Filter {
631 protocol: None,
632 dst_port: Vec::new(),
633 src_port: Vec::new(),
634 src_ip: None,
635 dst_ip: None,
636 }
637 }
638
639 pub fn ah() -> Self {
641 Ipv4Filter {
642 protocol: Some(Protocol::Ah),
643 dst_port: Vec::new(),
644 src_port: Vec::new(),
645 src_ip: None,
646 dst_ip: None,
647 }
648 }
649
650 pub fn esp() -> Self {
652 Ipv4Filter {
653 protocol: Some(Protocol::Esp),
654 dst_port: Vec::new(),
655 src_port: Vec::new(),
656 src_ip: None,
657 dst_ip: None,
658 }
659 }
660
661 pub fn ipip() -> Self {
663 Ipv4Filter {
664 protocol: Some(Protocol::Ipip),
665 dst_port: Vec::new(),
666 src_port: Vec::new(),
667 src_ip: None,
668 dst_ip: None,
669 }
670 }
671
672 pub fn gre() -> Self {
674 Ipv4Filter {
675 protocol: Some(Protocol::Gre),
676 dst_port: Vec::new(),
677 src_port: Vec::new(),
678 src_ip: None,
679 dst_ip: None,
680 }
681 }
682
683 pub fn udp() -> Self {
685 Ipv4Filter {
686 protocol: Some(Protocol::Udp),
687 dst_port: Vec::new(),
688 src_port: Vec::new(),
689 src_ip: None,
690 dst_ip: None,
691 }
692 }
693
694 pub fn tcp(flags: Option<String>) -> Self {
696 Ipv4Filter {
697 protocol: Some(Protocol::Tcp { flags }),
698 dst_port: Vec::new(),
699 src_port: Vec::new(),
700 src_ip: None,
701 dst_ip: None,
702 }
703 }
704
705 pub fn from_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
707 self.src_port.push(range.into());
708 self
709 }
710
711 pub fn to_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
713 self.dst_port.push(range.into());
714 self
715 }
716
717 pub fn from_ip<IntoIpNet: Into<Ipv4Net>>(mut self, ip: IntoIpNet) -> Self {
719 self.src_ip = Some(ip.into());
720 self
721 }
722
723 pub fn to_ip<IntoIpNet: Into<Ipv4Net>>(mut self, ip: IntoIpNet) -> Self {
725 self.dst_ip = Some(ip.into());
726 self
727 }
728}
729
730#[derive(Debug, Clone, PartialEq, Eq)]
732pub struct Rule {
733 pub name: String,
735
736 pub filter: Filter,
738
739 pub action: Action,
741}
742
743impl Rule {
744 pub fn accept(name: &str) -> Self {
749 Rule {
750 name: name.to_string(),
751 filter: Filter::default(),
752 action: Action::Accept,
753 }
754 }
755
756 pub fn discard(name: &str) -> Self {
761 Rule {
762 name: name.to_string(),
763 filter: Filter::default(),
764 action: Action::Discard,
765 }
766 }
767
768 pub fn matching<F: Into<Filter>>(self, filter: F) -> Self {
770 Rule {
771 name: self.name,
772 action: self.action,
773 filter: filter.into(),
774 }
775 }
776}
777
778#[cfg(test)]
779mod tests {
780 use std::{net::Ipv4Addr, ops::RangeInclusive};
781
782 use ipnet::Ipv4Net;
783
784 use crate::api::firewall::{
785 Filter, Ipv4Filter, Ipv6Filter, PortRange, Protocol, State, TemplateId,
786 };
787
788 use super::AnyFilter;
789
790 #[test]
791 fn template_conversions() {
792 assert_eq!(u32::from(TemplateId::from(1337u32)), 1337)
793 }
794
795 #[test]
796 fn template_id_equality() {
797 assert_eq!(TemplateId(1337), 1337u32);
798 }
799
800 #[test]
801 fn state_display() {
802 assert_eq!(State::Active.to_string(), "active");
803 assert_eq!(State::InProcess.to_string(), "in process");
804 assert_eq!(State::Disabled.to_string(), "disabled");
805 }
806
807 #[test]
808 fn protocol_construction() {
809 assert_eq!(
810 Protocol::tcp_with_flags("ack"),
811 Protocol::Tcp {
812 flags: Some("ack".to_string())
813 }
814 );
815
816 assert!(Protocol::tcp_with_flags("ack").flags().is_some());
817 assert!(Protocol::Tcp { flags: None }.flags().is_none());
818 }
819
820 #[test]
821 fn range_conversion() {
822 assert_eq!(
823 PortRange::from(1000..=1005),
824 PortRange::from(&(1000..=1005))
825 );
826
827 assert_eq!(PortRange::from(1000..=1000), PortRange::from(1000),);
828
829 assert_eq!(
830 RangeInclusive::from(PortRange::from(1000..=1005)),
831 1000..=1005
832 );
833
834 assert_eq!(
835 RangeInclusive::from(&(PortRange::from(1000..=1005))),
836 1000..=1005
837 );
838 }
839
840 #[test]
841 fn range_iteration() {
842 assert_eq!(
843 PortRange::from(100..=105).into_iter().collect::<Vec<_>>(),
844 vec![100, 101, 102, 103, 104, 105]
845 );
846 }
847
848 #[test]
849 fn ip_construction() {
850 assert_eq!(
851 Filter::from(Ipv6Filter::any()),
852 Filter::Ipv6(Ipv6Filter::any())
853 );
854
855 assert_eq!(
856 Filter::from(Ipv4Filter::any()),
857 Filter::Ipv4(Ipv4Filter::any())
858 );
859 }
860
861 #[test]
862 fn anyfilter_construction() {
863 assert_eq!(
864 AnyFilter::default().from_port(100).to_port(200),
865 AnyFilter {
866 src_port: vec![PortRange::from(100)],
867 dst_port: vec![PortRange::from(200)],
868 }
869 );
870 }
871
872 #[test]
873 fn ipv6filter_construction() {
874 assert_eq!(
875 Ipv6Filter::any(),
876 Ipv6Filter {
877 protocol: None,
878 dst_port: Vec::new(),
879 src_port: Vec::new(),
880 }
881 );
882
883 assert_eq!(
884 Ipv6Filter::ah(),
885 Ipv6Filter {
886 protocol: Some(Protocol::Ah),
887 dst_port: Vec::new(),
888 src_port: Vec::new(),
889 }
890 );
891
892 assert_eq!(
893 Ipv6Filter::esp(),
894 Ipv6Filter {
895 protocol: Some(Protocol::Esp),
896 dst_port: Vec::new(),
897 src_port: Vec::new(),
898 }
899 );
900
901 assert_eq!(
902 Ipv6Filter::ipip(),
903 Ipv6Filter {
904 protocol: Some(Protocol::Ipip),
905 dst_port: Vec::new(),
906 src_port: Vec::new(),
907 }
908 );
909
910 assert_eq!(
911 Ipv6Filter::gre(),
912 Ipv6Filter {
913 protocol: Some(Protocol::Gre),
914 dst_port: Vec::new(),
915 src_port: Vec::new(),
916 }
917 );
918
919 assert_eq!(
920 Ipv6Filter::udp(),
921 Ipv6Filter {
922 protocol: Some(Protocol::Udp),
923 dst_port: Vec::new(),
924 src_port: Vec::new(),
925 }
926 );
927
928 assert_eq!(
929 Ipv6Filter::tcp(None),
930 Ipv6Filter {
931 protocol: Some(Protocol::Tcp { flags: None }),
932 dst_port: Vec::new(),
933 src_port: Vec::new(),
934 }
935 );
936
937 assert_eq!(
938 Ipv6Filter::any().from_port(100).to_port(200),
939 Ipv6Filter {
940 protocol: None,
941 dst_port: vec![PortRange::from(200)],
942 src_port: vec![PortRange::from(100)]
943 }
944 )
945 }
946
947 #[test]
948 fn ipv4filter_construction() {
949 assert_eq!(
950 Ipv4Filter::any(),
951 Ipv4Filter {
952 protocol: None,
953 dst_port: Vec::new(),
954 src_port: Vec::new(),
955 src_ip: None,
956 dst_ip: None,
957 }
958 );
959
960 assert_eq!(
961 Ipv4Filter::ah(),
962 Ipv4Filter {
963 protocol: Some(Protocol::Ah),
964 dst_port: Vec::new(),
965 src_port: Vec::new(),
966 src_ip: None,
967 dst_ip: None,
968 }
969 );
970
971 assert_eq!(
972 Ipv4Filter::esp(),
973 Ipv4Filter {
974 protocol: Some(Protocol::Esp),
975 dst_port: Vec::new(),
976 src_port: Vec::new(),
977 src_ip: None,
978 dst_ip: None,
979 }
980 );
981
982 assert_eq!(
983 Ipv4Filter::ipip(),
984 Ipv4Filter {
985 protocol: Some(Protocol::Ipip),
986 dst_port: Vec::new(),
987 src_port: Vec::new(),
988 src_ip: None,
989 dst_ip: None,
990 }
991 );
992
993 assert_eq!(
994 Ipv4Filter::gre(),
995 Ipv4Filter {
996 protocol: Some(Protocol::Gre),
997 dst_port: Vec::new(),
998 src_port: Vec::new(),
999 src_ip: None,
1000 dst_ip: None,
1001 }
1002 );
1003
1004 assert_eq!(
1005 Ipv4Filter::udp(),
1006 Ipv4Filter {
1007 protocol: Some(Protocol::Udp),
1008 dst_port: Vec::new(),
1009 src_port: Vec::new(),
1010 src_ip: None,
1011 dst_ip: None,
1012 }
1013 );
1014
1015 assert_eq!(
1016 Ipv4Filter::tcp(None),
1017 Ipv4Filter {
1018 protocol: Some(Protocol::Tcp { flags: None }),
1019 dst_port: Vec::new(),
1020 src_port: Vec::new(),
1021 src_ip: None,
1022 dst_ip: None,
1023 }
1024 );
1025
1026 assert_eq!(
1027 Ipv4Filter::any().from_port(100).to_port(200),
1028 Ipv4Filter {
1029 protocol: None,
1030 dst_port: vec![PortRange::from(200)],
1031 src_port: vec![PortRange::from(100)],
1032 src_ip: None,
1033 dst_ip: None,
1034 }
1035 );
1036
1037 assert_eq!(
1038 Ipv4Filter::any()
1039 .from_ip(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 0), 8).unwrap())
1040 .to_ip(Ipv4Net::new(Ipv4Addr::new(192, 168, 0, 0), 16).unwrap()),
1041 Ipv4Filter {
1042 protocol: None,
1043 dst_port: Vec::new(),
1044 src_port: Vec::new(),
1045 src_ip: Some(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 0), 8).unwrap()),
1046 dst_ip: Some(Ipv4Net::new(Ipv4Addr::new(192, 168, 0, 0), 16).unwrap()),
1047 }
1048 )
1049 }
1050}