use serde::{Deserialize, Serialize, Serializer};
use std::{fmt::Display, ops::RangeInclusive, str::FromStr};
pub use ipnet::Ipv4Net;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TemplateId(pub u32);
impl From<u32> for TemplateId {
fn from(value: u32) -> Self {
TemplateId(value)
}
}
impl From<TemplateId> for u32 {
fn from(value: TemplateId) -> Self {
value.0
}
}
impl Display for TemplateId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl PartialEq<u32> for TemplateId {
fn eq(&self, other: &u32) -> bool {
self.0.eq(other)
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum State {
#[serde(rename = "active")]
Active,
#[serde(rename = "in process")]
InProcess,
#[serde(rename = "disabled")]
Disabled,
}
impl Display for State {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
State::Active => "active",
State::InProcess => "in process",
State::Disabled => "disabled",
}
)
}
}
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SwitchPort {
#[default]
Main,
Kvm,
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum Protocol {
Tcp {
flags: Option<String>,
},
Udp,
Gre,
Icmp,
Ipip,
Ah,
Esp,
}
impl Protocol {
pub fn tcp_with_flags(flags: &str) -> Self {
Protocol::Tcp {
flags: Some(flags.to_string()),
}
}
pub(crate) fn flags(&self) -> Option<String> {
match self {
Protocol::Tcp { flags } => flags.clone(),
_ => None,
}
}
}
#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Action {
#[default]
Accept,
Discard,
}
impl AsRef<str> for Action {
fn as_ref(&self) -> &str {
match self {
Action::Accept => "accept",
Action::Discard => "discard",
}
}
}
impl Display for Action {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_ref())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FirewallTemplateReference {
pub id: TemplateId,
pub name: String,
pub filter_ipv6: bool,
#[serde(rename = "whitelist_hos")]
pub whitelist_hetzner_services: bool,
pub is_default: bool,
}
#[derive(Debug, Clone)]
pub struct FirewallTemplate {
pub id: TemplateId,
pub name: String,
pub filter_ipv6: bool,
pub whitelist_hetzner_services: bool,
pub is_default: bool,
pub rules: Rules,
}
#[derive(Debug, Clone)]
pub struct FirewallTemplateConfig {
pub name: String,
pub filter_ipv6: bool,
pub whitelist_hetzner_services: bool,
pub is_default: bool,
pub rules: Rules,
}
#[derive(Debug, Clone)]
pub struct Firewall {
pub status: State,
pub filter_ipv6: bool,
pub whitelist_hetzner_services: bool,
pub port: SwitchPort,
pub rules: Rules,
}
impl Firewall {
pub fn config(&self) -> FirewallConfig {
self.into()
}
}
#[derive(Debug)]
pub struct FirewallConfig {
pub status: State,
pub filter_ipv6: bool,
pub whitelist_hetzner_services: bool,
pub rules: Rules,
}
impl FirewallConfig {
#[must_use = "This doesn't create the template, only produces a config which you can then upload with AsyncRobot::create_firewall_template"]
pub fn to_template_config(&self, name: &str) -> FirewallTemplateConfig {
FirewallTemplateConfig {
name: name.to_string(),
filter_ipv6: self.filter_ipv6,
whitelist_hetzner_services: self.whitelist_hetzner_services,
is_default: false,
rules: self.rules.clone(),
}
}
}
impl From<&Firewall> for FirewallConfig {
fn from(value: &Firewall) -> Self {
FirewallConfig {
status: value.status,
filter_ipv6: value.filter_ipv6,
whitelist_hetzner_services: value.whitelist_hetzner_services,
rules: value.rules.clone(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Rules {
pub ingress: Vec<Rule>,
pub egress: Vec<Rule>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PortRange(RangeInclusive<u16>);
impl PortRange {
pub fn port(port: u16) -> Self {
PortRange(RangeInclusive::new(port, port))
}
pub fn range(start: u16, end: u16) -> Self {
PortRange(RangeInclusive::new(start, end))
}
pub fn start(&self) -> u16 {
*self.0.start()
}
pub fn end(&self) -> u16 {
*self.0.end()
}
}
impl Display for PortRange {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.start())?;
if self.end() != self.start() {
write!(f, "-{}", self.end())?;
}
Ok(())
}
}
impl From<u16> for PortRange {
fn from(value: u16) -> Self {
PortRange::port(value)
}
}
impl From<RangeInclusive<u16>> for PortRange {
fn from(value: RangeInclusive<u16>) -> Self {
PortRange(value)
}
}
impl From<&RangeInclusive<u16>> for PortRange {
fn from(value: &RangeInclusive<u16>) -> Self {
PortRange(value.clone())
}
}
impl From<PortRange> for RangeInclusive<u16> {
fn from(value: PortRange) -> Self {
value.0
}
}
impl From<&PortRange> for RangeInclusive<u16> {
fn from(value: &PortRange) -> Self {
value.0.clone()
}
}
impl From<&PortRange> for Vec<PortRange> {
fn from(value: &PortRange) -> Self {
vec![value.clone()]
}
}
impl IntoIterator for PortRange {
type Item = u16;
type IntoIter = <RangeInclusive<u16> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.0
}
}
#[derive(Debug, thiserror::Error)]
#[error("invalid port '{0}': {1}")]
pub struct InvalidPort(String, <u16 as FromStr>::Err);
impl FromStr for PortRange {
type Err = InvalidPort;
fn from_str(value: &str) -> Result<Self, Self::Err> {
if let Some((start, end)) = value.split_once('-') {
let start = start
.parse::<u16>()
.map_err(|err| InvalidPort(start.to_string(), err))?;
let end = end
.parse::<u16>()
.map_err(|err| InvalidPort(end.to_string(), err))?;
Ok(PortRange(RangeInclusive::new(start, end)))
} else {
let port = value
.parse::<u16>()
.map_err(|err| InvalidPort(value.to_string(), err))?;
Ok(PortRange(RangeInclusive::new(port, port)))
}
}
}
impl<'de> Deserialize<'de> for PortRange {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let value: &str = Deserialize::deserialize(deserializer)?;
PortRange::from_str(value).map_err(D::Error::custom)
}
}
impl Serialize for PortRange {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
if self.0.start() == self.0.end() {
serializer.serialize_str(&format!("{}", self.start()))
} else {
serializer.serialize_str(&format!("{}-{}", self.start(), self.end()))
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Filter {
Any(AnyFilter),
Ipv4(Ipv4Filter),
Ipv6(Ipv6Filter),
}
impl Default for Filter {
fn default() -> Self {
Filter::Any(AnyFilter::default())
}
}
impl From<Ipv4Filter> for Filter {
fn from(value: Ipv4Filter) -> Self {
Filter::Ipv4(value)
}
}
impl From<Ipv6Filter> for Filter {
fn from(value: Ipv6Filter) -> Self {
Filter::Ipv6(value)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct AnyFilter {
pub dst_port: Vec<PortRange>,
pub src_port: Vec<PortRange>,
}
impl AnyFilter {
pub fn from_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
self.src_port = vec![range.into()];
self
}
pub fn to_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
self.dst_port = vec![range.into()];
self
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Ipv6Filter {
pub protocol: Option<Protocol>,
pub dst_port: Vec<PortRange>,
pub src_port: Vec<PortRange>,
}
impl Ipv6Filter {
pub fn any() -> Self {
Ipv6Filter {
protocol: None,
dst_port: Vec::new(),
src_port: Vec::new(),
}
}
pub fn ah() -> Self {
Ipv6Filter {
protocol: Some(Protocol::Ah),
dst_port: Vec::new(),
src_port: Vec::new(),
}
}
pub fn esp() -> Self {
Ipv6Filter {
protocol: Some(Protocol::Esp),
dst_port: Vec::new(),
src_port: Vec::new(),
}
}
pub fn ipip() -> Self {
Ipv6Filter {
protocol: Some(Protocol::Ipip),
dst_port: Vec::new(),
src_port: Vec::new(),
}
}
pub fn gre() -> Self {
Ipv6Filter {
protocol: Some(Protocol::Gre),
dst_port: Vec::new(),
src_port: Vec::new(),
}
}
pub fn udp() -> Self {
Ipv6Filter {
protocol: Some(Protocol::Udp),
dst_port: Vec::new(),
src_port: Vec::new(),
}
}
pub fn tcp(flags: Option<String>) -> Self {
Ipv6Filter {
protocol: Some(Protocol::Tcp { flags }),
dst_port: Vec::new(),
src_port: Vec::new(),
}
}
pub fn from_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
self.src_port.push(range.into());
self
}
pub fn to_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
self.dst_port.push(range.into());
self
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Ipv4Filter {
pub dst_ip: Option<Ipv4Net>,
pub src_ip: Option<Ipv4Net>,
pub dst_port: Vec<PortRange>,
pub src_port: Vec<PortRange>,
pub protocol: Option<Protocol>,
}
impl Ipv4Filter {
pub fn any() -> Self {
Ipv4Filter {
protocol: None,
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
}
pub fn ah() -> Self {
Ipv4Filter {
protocol: Some(Protocol::Ah),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
}
pub fn esp() -> Self {
Ipv4Filter {
protocol: Some(Protocol::Esp),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
}
pub fn ipip() -> Self {
Ipv4Filter {
protocol: Some(Protocol::Ipip),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
}
pub fn gre() -> Self {
Ipv4Filter {
protocol: Some(Protocol::Gre),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
}
pub fn udp() -> Self {
Ipv4Filter {
protocol: Some(Protocol::Udp),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
}
pub fn tcp(flags: Option<String>) -> Self {
Ipv4Filter {
protocol: Some(Protocol::Tcp { flags }),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
}
pub fn from_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
self.src_port.push(range.into());
self
}
pub fn to_port<IntoPortRange: Into<PortRange>>(mut self, range: IntoPortRange) -> Self {
self.dst_port.push(range.into());
self
}
pub fn from_ip<IntoIpNet: Into<Ipv4Net>>(mut self, ip: IntoIpNet) -> Self {
self.src_ip = Some(ip.into());
self
}
pub fn to_ip<IntoIpNet: Into<Ipv4Net>>(mut self, ip: IntoIpNet) -> Self {
self.dst_ip = Some(ip.into());
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Rule {
pub name: String,
pub filter: Filter,
pub action: Action,
}
impl Rule {
pub fn accept(name: &str) -> Self {
Rule {
name: name.to_string(),
filter: Filter::default(),
action: Action::Accept,
}
}
pub fn discard(name: &str) -> Self {
Rule {
name: name.to_string(),
filter: Filter::default(),
action: Action::Discard,
}
}
pub fn matching<F: Into<Filter>>(self, filter: F) -> Self {
Rule {
name: self.name,
action: self.action,
filter: filter.into(),
}
}
}
#[cfg(test)]
mod tests {
use std::{net::Ipv4Addr, ops::RangeInclusive};
use ipnet::Ipv4Net;
use crate::api::firewall::{
Filter, Ipv4Filter, Ipv6Filter, PortRange, Protocol, State, TemplateId,
};
use super::AnyFilter;
#[test]
fn template_conversions() {
assert_eq!(u32::from(TemplateId::from(1337u32)), 1337)
}
#[test]
fn template_id_equality() {
assert_eq!(TemplateId(1337), 1337u32);
}
#[test]
fn state_display() {
assert_eq!(State::Active.to_string(), "active");
assert_eq!(State::InProcess.to_string(), "in process");
assert_eq!(State::Disabled.to_string(), "disabled");
}
#[test]
fn protocol_construction() {
assert_eq!(
Protocol::tcp_with_flags("ack"),
Protocol::Tcp {
flags: Some("ack".to_string())
}
);
assert!(Protocol::tcp_with_flags("ack").flags().is_some());
assert!(Protocol::Tcp { flags: None }.flags().is_none());
}
#[test]
fn range_conversion() {
assert_eq!(
PortRange::from(1000..=1005),
PortRange::from(&(1000..=1005))
);
assert_eq!(PortRange::from(1000..=1000), PortRange::from(1000),);
assert_eq!(
RangeInclusive::from(PortRange::from(1000..=1005)),
1000..=1005
);
assert_eq!(
RangeInclusive::from(&(PortRange::from(1000..=1005))),
1000..=1005
);
}
#[test]
fn range_iteration() {
assert_eq!(
PortRange::from(100..=105).into_iter().collect::<Vec<_>>(),
vec![100, 101, 102, 103, 104, 105]
);
}
#[test]
fn ip_construction() {
assert_eq!(
Filter::from(Ipv6Filter::any()),
Filter::Ipv6(Ipv6Filter::any())
);
assert_eq!(
Filter::from(Ipv4Filter::any()),
Filter::Ipv4(Ipv4Filter::any())
);
}
#[test]
fn anyfilter_construction() {
assert_eq!(
AnyFilter::default().from_port(100).to_port(200),
AnyFilter {
src_port: vec![PortRange::from(100)],
dst_port: vec![PortRange::from(200)],
}
);
}
#[test]
fn ipv6filter_construction() {
assert_eq!(
Ipv6Filter::any(),
Ipv6Filter {
protocol: None,
dst_port: Vec::new(),
src_port: Vec::new(),
}
);
assert_eq!(
Ipv6Filter::ah(),
Ipv6Filter {
protocol: Some(Protocol::Ah),
dst_port: Vec::new(),
src_port: Vec::new(),
}
);
assert_eq!(
Ipv6Filter::esp(),
Ipv6Filter {
protocol: Some(Protocol::Esp),
dst_port: Vec::new(),
src_port: Vec::new(),
}
);
assert_eq!(
Ipv6Filter::ipip(),
Ipv6Filter {
protocol: Some(Protocol::Ipip),
dst_port: Vec::new(),
src_port: Vec::new(),
}
);
assert_eq!(
Ipv6Filter::gre(),
Ipv6Filter {
protocol: Some(Protocol::Gre),
dst_port: Vec::new(),
src_port: Vec::new(),
}
);
assert_eq!(
Ipv6Filter::udp(),
Ipv6Filter {
protocol: Some(Protocol::Udp),
dst_port: Vec::new(),
src_port: Vec::new(),
}
);
assert_eq!(
Ipv6Filter::tcp(None),
Ipv6Filter {
protocol: Some(Protocol::Tcp { flags: None }),
dst_port: Vec::new(),
src_port: Vec::new(),
}
);
assert_eq!(
Ipv6Filter::any().from_port(100).to_port(200),
Ipv6Filter {
protocol: None,
dst_port: vec![PortRange::from(200)],
src_port: vec![PortRange::from(100)]
}
)
}
#[test]
fn ipv4filter_construction() {
assert_eq!(
Ipv4Filter::any(),
Ipv4Filter {
protocol: None,
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
);
assert_eq!(
Ipv4Filter::ah(),
Ipv4Filter {
protocol: Some(Protocol::Ah),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
);
assert_eq!(
Ipv4Filter::esp(),
Ipv4Filter {
protocol: Some(Protocol::Esp),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
);
assert_eq!(
Ipv4Filter::ipip(),
Ipv4Filter {
protocol: Some(Protocol::Ipip),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
);
assert_eq!(
Ipv4Filter::gre(),
Ipv4Filter {
protocol: Some(Protocol::Gre),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
);
assert_eq!(
Ipv4Filter::udp(),
Ipv4Filter {
protocol: Some(Protocol::Udp),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
);
assert_eq!(
Ipv4Filter::tcp(None),
Ipv4Filter {
protocol: Some(Protocol::Tcp { flags: None }),
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: None,
dst_ip: None,
}
);
assert_eq!(
Ipv4Filter::any().from_port(100).to_port(200),
Ipv4Filter {
protocol: None,
dst_port: vec![PortRange::from(200)],
src_port: vec![PortRange::from(100)],
src_ip: None,
dst_ip: None,
}
);
assert_eq!(
Ipv4Filter::any()
.from_ip(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 0), 8).unwrap())
.to_ip(Ipv4Net::new(Ipv4Addr::new(192, 168, 0, 0), 16).unwrap()),
Ipv4Filter {
protocol: None,
dst_port: Vec::new(),
src_port: Vec::new(),
src_ip: Some(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 0), 8).unwrap()),
dst_ip: Some(Ipv4Net::new(Ipv4Addr::new(192, 168, 0, 0), 16).unwrap()),
}
)
}
}