use std::{
borrow::Cow,
cmp::Ordering,
fmt::{self, Display, Formatter, Write},
hash::{Hash, Hasher},
net::{AddrParseError, IpAddr},
num::ParseIntError,
ops::{Add, AddAssign, RangeInclusive, Sub, SubAssign},
str::FromStr,
};
use compose_spec_macros::{AsShort, DeserializeTryFromString, FromShort, SerializeDisplay};
use indexmap::IndexSet;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;
use crate::{impl_from_str, serde::FromStrOrU16Visitor, Extensions, ShortOrLong};
use super::strip_brackets;
pub type Ports = IndexSet<ShortOrLong<ShortPort, Port>>;
pub fn into_short_iter(ports: Ports) -> impl Iterator<Item = Result<ShortPort, Port>> {
ports.into_iter().map(|port| match port {
ShortOrLong::Short(port) => Ok(port),
ShortOrLong::Long(port) => port.into_short(),
})
}
pub fn into_long_iter(ports: Ports) -> impl Iterator<Item = Port> {
ports.into_iter().flat_map(|port| match port {
ShortOrLong::Short(port) => ShortOrLong::Short(port.into_long_iter()),
ShortOrLong::Long(port) => ShortOrLong::Long(std::iter::once(port)),
})
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq)]
pub struct Port {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
pub target: u16,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub published: Option<Range>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub host_ip: Option<IpAddr>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub protocol: Option<Protocol>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub app_protocol: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mode: Option<Mode>,
#[serde(flatten)]
pub extensions: Extensions,
}
impl PartialEq for Port {
fn eq(&self, other: &Self) -> bool {
let Self {
name,
target,
published,
host_ip,
protocol,
app_protocol,
mode,
extensions,
} = self;
*name == other.name
&& *target == other.target
&& *published == other.published
&& *host_ip == other.host_ip
&& *protocol == other.protocol
&& *app_protocol == other.app_protocol
&& *mode == other.mode
&& extensions.as_slice() == other.extensions.as_slice()
}
}
impl Hash for Port {
fn hash<H: Hasher>(&self, state: &mut H) {
let Self {
name,
target,
published,
host_ip,
protocol,
app_protocol,
mode,
extensions,
} = self;
name.hash(state);
target.hash(state);
published.hash(state);
host_ip.hash(state);
protocol.hash(state);
app_protocol.hash(state);
mode.hash(state);
extensions.as_slice().hash(state);
}
}
impl Port {
#[must_use]
pub fn new(target: u16) -> Self {
Self {
name: None,
target,
published: None,
host_ip: None,
protocol: None,
app_protocol: None,
mode: None,
extensions: Extensions::default(),
}
}
pub fn into_short(self) -> Result<ShortPort, Self> {
if self.name.is_none()
&& self.app_protocol.is_none()
&& self.mode.is_none()
&& self.extensions.is_empty()
&& self.published.map_or(true, |range| range.end.is_none())
{
Ok(ShortPort {
host_ip: self.host_ip,
ranges: ShortRanges {
host: self.published,
container: self.target.into(),
},
protocol: self.protocol,
})
} else {
Err(self)
}
}
}
impl From<u16> for Port {
fn from(target: u16) -> Self {
Self::new(target)
}
}
#[derive(Serialize, Deserialize, Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum Mode {
Host,
#[default]
Ingress,
}
impl Mode {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Host => "host",
Self::Ingress => "ingress",
}
}
}
impl AsRef<str> for Mode {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Display for Mode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ShortPort {
pub host_ip: Option<IpAddr>,
pub ranges: ShortRanges,
pub protocol: Option<Protocol>,
}
impl ShortPort {
#[must_use]
pub const fn new(ranges: ShortRanges) -> Self {
Self {
host_ip: None,
ranges,
protocol: None,
}
}
pub fn into_long_iter(self) -> impl Iterator<Item = Port> {
let Self {
host_ip,
ranges,
protocol,
} = self;
ranges.into_iter().map(move |(host, container)| Port {
published: host.map(Into::into),
host_ip,
protocol: protocol.clone(),
..container.into()
})
}
}
impl From<ShortRanges> for ShortPort {
fn from(ranges: ShortRanges) -> Self {
Self::new(ranges)
}
}
impl From<Range> for ShortPort {
fn from(container: Range) -> Self {
ShortRanges::from(container).into()
}
}
impl From<u16> for ShortPort {
fn from(container: u16) -> Self {
Range::from(container).into()
}
}
impl FromStr for ShortPort {
type Err = ParseShortPortError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (mut s, protocol) = s
.split_once('/')
.map_or((s, None), |(s, protocol)| (s, Some(protocol.into())));
let mut colon_seen = false;
let host_ip = s
.rsplit_once(|char| {
if char == ':' {
if colon_seen {
return true;
}
colon_seen = true;
}
false
})
.map(|(host_ip, rest)| {
s = rest;
strip_brackets(host_ip)
.parse()
.map_err(|source| ParseShortPortError::IpAddr {
source,
value: host_ip.to_owned(),
})
})
.transpose()?;
Ok(Self {
host_ip,
ranges: s.parse()?,
protocol,
})
}
}
impl TryFrom<&str> for ShortPort {
type Error = ParseShortPortError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
value.parse()
}
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum ParseShortPortError {
#[error("error parsing host ip address")]
IpAddr {
source: AddrParseError,
value: String,
},
#[error("error parsing port ranges")]
ShortRanges(#[from] ParseShortRangesError),
}
impl Display for ShortPort {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let Self {
host_ip,
ranges,
protocol,
} = self;
if let Some(host_ip) = host_ip {
write!(f, "{host_ip}:")?;
if ranges.host.is_none() {
f.write_char(':')?;
}
}
Display::fmt(ranges, f)?;
if let Some(protocol) = protocol {
write!(f, "/{protocol}")?;
}
Ok(())
}
}
impl Serialize for ShortPort {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
if self.host_ip.is_some() || self.protocol.is_some() {
serializer.collect_str(self)
} else {
self.ranges.serialize(serializer)
}
}
}
impl<'de> Deserialize<'de> for ShortPort {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
FromStrOrU16Visitor::new(
"an integer or string in the format \
\"[[{host_ip}:][{host}]:]{container}[/{protocol}]\"",
)
.deserialize(deserializer)
}
}
#[derive(AsShort, FromShort, Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ShortRanges {
host: Option<Range>,
#[as_short(short)]
container: Range,
}
impl ShortRanges {
pub fn new(host: Option<Range>, container: Range) -> Result<Self, ShortRangesError> {
range_size_eq(host, container)?;
Ok(Self { host, container })
}
#[must_use]
pub const fn host(&self) -> Option<Range> {
self.host
}
pub fn replace_host(&mut self, host: Range) -> Result<Option<Range>, ShortRangesError> {
range_size_eq(Some(host), self.container)?;
Ok(self.host.replace(host))
}
pub fn take_host(&mut self) -> Option<Range> {
self.host.take()
}
#[must_use]
pub const fn container(&self) -> Range {
self.container
}
pub fn replace_container(&mut self, container: Range) -> Result<Range, ShortRangesError> {
range_size_eq(self.host, container)?;
Ok(std::mem::replace(&mut self.container, container))
}
}
fn range_size_eq(host: Option<Range>, container: Range) -> Result<(), ShortRangesError> {
if let Some(host) = host {
let host_size = host.size();
let container_size = container.size();
if host_size != container_size {
return Err(ShortRangesError {
host_size,
container_size,
});
}
}
Ok(())
}
#[derive(Error, Debug, Clone, Copy, PartialEq, Eq)]
#[error(
"host port range size `{host_size}` must be equal to \
container port range size `{container_size}`"
)]
pub struct ShortRangesError {
host_size: u16,
container_size: u16,
}
impl From<u16> for ShortRanges {
fn from(container: u16) -> Self {
Range::from(container).into()
}
}
impl FromStr for ShortRanges {
type Err = ParseShortRangesError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Some((host, container)) = s.split_once(':') {
let host = if host.is_empty() {
None
} else {
Some(host.parse()?)
};
Ok(Self::new(host, container.parse()?)?)
} else {
Ok(Range::from_str(s)?.into())
}
}
}
impl TryFrom<&str> for ShortRanges {
type Error = ParseShortRangesError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
value.parse()
}
}
impl Display for ShortRanges {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let Self { host, container } = self;
if let Some(host) = host {
write!(f, "{host}:")?;
}
Display::fmt(container, f)
}
}
impl Serialize for ShortRanges {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
if self.host.is_some() {
serializer.collect_str(self)
} else {
self.container.serialize(serializer)
}
}
}
impl<'de> Deserialize<'de> for ShortRanges {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
FromStrOrU16Visitor::new("an integer or string in the format \"[[{host}]:]{container}\"")
.deserialize(deserializer)
}
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum ParseShortRangesError {
#[error("error creating the port ranges")]
ShortRanges(#[from] ShortRangesError),
#[error("error parsing port range")]
Range(#[from] ParseRangeError),
}
impl IntoIterator for ShortRanges {
type Item = (Option<u16>, u16);
type IntoIter = ShortRangesIter;
fn into_iter(self) -> Self::IntoIter {
let Self { host, container } = self;
ShortRangesIter {
host: host.map(Range::into_iter),
container: container.into_iter(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShortRangesIter {
host: Option<RangeInclusive<u16>>,
container: RangeInclusive<u16>,
}
impl Iterator for ShortRangesIter {
type Item = (Option<u16>, u16);
fn next(&mut self) -> Option<Self::Item> {
self.container
.next()
.map(|container| (self.host.as_mut().and_then(Iterator::next), container))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Range {
start: u16,
end: Option<u16>,
}
impl Range {
pub fn new(start: u16, end: Option<u16>) -> Result<Self, RangeError> {
end.map_or(Ok(Self { start, end: None }), |end| match start.cmp(&end) {
Ordering::Less => Ok(Self {
start,
end: Some(end),
}),
Ordering::Equal => Ok(Self { start, end: None }),
Ordering::Greater => Err(RangeError { start, end }),
})
}
#[must_use]
pub const fn start(&self) -> u16 {
self.start
}
#[must_use]
pub const fn end(&self) -> Option<u16> {
self.end
}
#[must_use]
pub fn size(&self) -> u16 {
self.end.map_or(1, |end| end - self.start + 1)
}
}
#[derive(Error, Debug, Clone, Copy, PartialEq, Eq)]
#[error("the start `{start}` of the port range must be less than or equal to the end `{end}`")]
pub struct RangeError {
start: u16,
end: u16,
}
impl PartialEq<u16> for Range {
fn eq(&self, other: &u16) -> bool {
self.end.is_none() && self.start == *other
}
}
impl PartialEq<RangeInclusive<u16>> for Range {
fn eq(&self, other: &RangeInclusive<u16>) -> bool {
self.start == *other.start() && self.end.unwrap_or(self.start) == *other.end()
}
}
impl From<u16> for Range {
fn from(start: u16) -> Self {
Self { start, end: None }
}
}
impl TryFrom<(u16, Option<u16>)> for Range {
type Error = RangeError;
fn try_from((start, end): (u16, Option<u16>)) -> Result<Self, Self::Error> {
Self::new(start, end)
}
}
impl TryFrom<(u16, u16)> for Range {
type Error = RangeError;
fn try_from((start, end): (u16, u16)) -> Result<Self, Self::Error> {
Self::new(start, Some(end))
}
}
impl TryFrom<RangeInclusive<u16>> for Range {
type Error = RangeError;
fn try_from(value: RangeInclusive<u16>) -> Result<Self, Self::Error> {
value.into_inner().try_into()
}
}
impl Add<u16> for Range {
type Output = Self;
fn add(self, rhs: u16) -> Self::Output {
let Self { start, end } = self;
Self {
start: start + rhs,
end: end.map(|end| end + rhs),
}
}
}
impl AddAssign<u16> for Range {
fn add_assign(&mut self, rhs: u16) {
*self = *self + rhs;
}
}
impl Sub<u16> for Range {
type Output = Self;
fn sub(self, rhs: u16) -> Self::Output {
let Self { start, end } = self;
Self {
start: start - rhs,
end: end.map(|end| end - rhs),
}
}
}
impl SubAssign<u16> for Range {
fn sub_assign(&mut self, rhs: u16) {
*self = *self - rhs;
}
}
impl FromStr for Range {
type Err = ParseRangeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (start, end) = s
.split_once('-')
.map_or((s, None), |(start, end)| (start, Some(end)));
Ok(Self {
start: parse_range_int(start)?,
end: end.map(parse_range_int).transpose()?,
})
}
}
fn parse_range_int(value: &str) -> Result<u16, ParseRangeError> {
value.parse().map_err(|source| ParseRangeError::Int {
source,
value: value.to_owned(),
})
}
impl TryFrom<&str> for Range {
type Error = ParseRangeError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
value.parse()
}
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum ParseRangeError {
#[error("error creating the port range")]
Range(#[from] RangeError),
#[error("error parsing `{value}` as an integer")]
Int {
source: ParseIntError,
value: String,
},
}
impl Display for Range {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let Self { start, end } = self;
Display::fmt(start, f)?;
if let Some(end) = end {
write!(f, "-{end}")?;
}
Ok(())
}
}
impl Serialize for Range {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
if self.end.is_some() {
serializer.collect_str(self)
} else {
self.start.serialize(serializer)
}
}
}
impl<'de> Deserialize<'de> for Range {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
FromStrOrU16Visitor::new(
"an integer or string in the format \"{start}[-{end}]\" \
where start and end are integers",
)
.deserialize(deserializer)
}
}
impl IntoIterator for Range {
type Item = u16;
type IntoIter = RangeInclusive<u16>;
fn into_iter(self) -> Self::IntoIter {
let Self { start, end } = self;
end.map_or(start..=start, |end| start..=end)
}
}
#[derive(
SerializeDisplay, DeserializeTryFromString, Debug, Default, Clone, PartialEq, Eq, Hash,
)]
pub enum Protocol {
#[default]
Tcp,
Udp,
Other(String),
}
impl Protocol {
const TCP: &'static str = "tcp";
const UDP: &'static str = "udp";
pub fn parse<T>(protocol: T) -> Self
where
T: AsRef<str> + Into<String>,
{
match protocol.as_ref() {
Self::TCP => Self::Tcp,
Self::UDP => Self::Udp,
_ => Self::Other(protocol.into()),
}
}
#[must_use]
pub const fn is_tcp(&self) -> bool {
matches!(self, Self::Tcp)
}
#[must_use]
pub const fn is_udp(&self) -> bool {
matches!(self, Self::Udp)
}
#[must_use]
pub fn as_str(&self) -> &str {
match self {
Self::Tcp => Self::TCP,
Self::Udp => Self::UDP,
Self::Other(other) => other,
}
}
}
impl_from_str!(Protocol);
impl AsRef<str> for Protocol {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Display for Protocol {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl From<Protocol> for String {
fn from(value: Protocol) -> Self {
match value {
Protocol::Tcp | Protocol::Udp => value.as_str().to_owned(),
Protocol::Other(other) => other,
}
}
}
impl From<Protocol> for Cow<'static, str> {
fn from(value: Protocol) -> Self {
match value {
Protocol::Tcp => Self::Borrowed(Protocol::TCP),
Protocol::Udp => Self::Borrowed(Protocol::UDP),
Protocol::Other(other) => Self::Owned(other),
}
}
}
#[cfg(test)]
pub(super) mod tests {
use proptest::{
arbitrary::any,
option, prop_assert_eq, prop_compose, prop_oneof, proptest,
strategy::{Just, Strategy},
};
use super::*;
mod short_port {
use std::net::Ipv6Addr;
use super::*;
proptest! {
#[test]
fn parse_no_panic(string: String) {
let _ = string.parse::<ShortPort>();
}
#[test]
fn round_trip(port in short_port()) {
prop_assert_eq!(&port, &port.to_string().parse()?);
}
}
#[test]
fn host_ip_brackets() -> Result<(), ParseShortPortError> {
let port = ShortPort {
host_ip: Some(IpAddr::V6(Ipv6Addr::LOCALHOST)),
ranges: ShortRanges {
host: Some(80.into()),
container: 80.into(),
},
protocol: None,
};
assert_eq!("[::1]:80:80".parse::<ShortPort>()?, port);
assert_eq!("::1:80:80".parse::<ShortPort>()?, port);
Ok(())
}
}
mod range {
use super::*;
proptest! {
#[test]
fn parse_no_panic(string: String) {
let _ = string.parse::<Range>();
}
#[test]
fn round_trip(range in range()) {
prop_assert_eq!(range, range.to_string().parse::<Range>()?);
}
}
}
proptest! {
#[test]
fn short_ranges_iter(ranges in short_ranges()) {
let iter: Vec<_> = ranges.host.map_or_else(
|| std::iter::repeat(None).zip(ranges.container).collect(),
|host| host.into_iter().map(Some).zip(ranges.container).collect(),
);
let ranges: Vec<_> = ranges.into_iter().collect();
prop_assert_eq!(ranges, iter);
}
}
prop_compose! {
fn short_port()(
host_ip: Option<IpAddr>,
ranges in short_ranges(),
protocol in option::of(protocol())
) -> ShortPort {
ShortPort {
host_ip,
ranges,
protocol
}
}
}
fn short_ranges() -> impl Strategy<Value = ShortRanges> {
range()
.prop_flat_map(|range| {
let range_end = range.end.unwrap_or(range.start);
let offset = if range_end == u16::MAX {
Just(0).boxed()
} else {
(..u16::MAX - range_end).boxed()
};
(Just(range), offset)
})
.prop_map(|(range, offset)| ShortRanges {
host: (offset != 0).then(|| range + offset),
container: range,
})
}
pub(in super::super) fn range() -> impl Strategy<Value = Range> {
any::<u16>()
.prop_flat_map(|start| (Just(start), option::of(start..)))
.prop_map(|(start, end)| Range {
start,
end: end.filter(|end| *end != start),
})
}
pub(in super::super) fn protocol() -> impl Strategy<Value = Protocol> {
prop_oneof![
Just(Protocol::Tcp),
Just(Protocol::Udp),
any::<String>().prop_map_into(),
]
}
}