use serde::{
Deserialize, Serialize,
de::{Error as DeError, Unexpected},
};
use std::{fmt::Display, str::FromStr};
use crate::util::serialization::{SerializeAsString, ToStringForFigment};
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct PortRange {
pub begin: u16,
pub end: u16,
}
impl SerializeAsString for PortRange {}
impl ToStringForFigment for PortRange {}
impl Display for PortRange {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.begin == self.end {
f.write_fmt(format_args!("{}", self.begin))
} else {
f.write_fmt(format_args!("{}-{}", self.begin, self.end))
}
}
}
static PR_EXPECTED: &str = "a single port number [0..65535] or a range `a-b`";
impl FromStr for PortRange {
type Err = figment::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
use figment::error::Error as FigmentError;
if let Ok(n) = s.parse::<u16>() {
return Ok(Self { begin: n, end: n });
}
if let Ok(n) = s.parse::<u64>() {
return Err(FigmentError::invalid_value(
Unexpected::Unsigned(n),
&PR_EXPECTED,
));
}
if let Some((a, b)) = s.split_once('-') {
let aa = a.parse();
let bb = b.parse();
if aa.is_ok() && bb.is_ok() {
let aa = aa.unwrap_or_default();
let bb = bb.unwrap_or_default();
if aa > bb {
return Err(FigmentError::custom(format!(
"invalid port range `{s}` (must be increasing)"
)));
} else if aa == 0 && bb != 0 {
return Err(FigmentError::custom(format!(
"invalid port range `{s}` (port 0 means \"any\" so cannot be part of a range)"
)));
}
return Ok(Self { begin: aa, end: bb });
}
}
Err(FigmentError::invalid_value(
Unexpected::Str(s),
&PR_EXPECTED,
))
}
}
impl PortRange {
pub(crate) fn is_default(self) -> bool {
self.begin == 0 && self.begin == self.end
}
pub(crate) fn combine(self, theirs: PortRange) -> anyhow::Result<PortRange> {
Ok(if self.is_default() {
theirs
} else if theirs.is_default() {
self
} else {
let begin = std::cmp::max(self.begin, theirs.begin);
let end = std::cmp::min(self.end, theirs.end);
anyhow::ensure!(
begin <= end,
"requested port range {theirs} could not be satisfied (our config: {self})"
);
PortRange { begin, end }
})
}
}
impl From<PortRange> for figment::value::Value {
fn from(value: PortRange) -> Self {
value.to_string().into()
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::PortRange as ConfigPortRange;
use pretty_assertions::assert_eq;
use std::str::FromStr;
type Uut = super::PortRange;
#[test]
fn output_single() {
let uut = Uut {
begin: 123,
end: 123,
};
assert_eq!(format!("{uut}"), "123");
}
#[test]
fn output_range() {
let uut = Uut {
begin: 123,
end: 456,
};
assert_eq!(format!("{uut}"), "123-456");
}
#[test]
fn parse_single() {
let uut = Uut::from_str("1234").unwrap();
assert_eq!(uut.begin, 1234);
assert_eq!(uut.end, 1234);
}
#[test]
fn parse_range() {
let uut = Uut::from_str("1234-2345").unwrap();
assert_eq!(uut.begin, 1234);
assert_eq!(uut.end, 2345);
let v = figment::value::Value::from(uut);
assert_eq!(v.as_str(), Some("1234-2345"));
}
#[test]
fn invalid_range() {
let _ = Uut::from_str("1000-999").expect_err("should have failed");
}
#[test]
fn invalid_negative() {
let _ = Uut::from_str("-500").expect_err("should have failed");
}
#[test]
fn invalid_out_of_range() {
let _ = Uut::from_str("65537").expect_err("should have failed");
}
#[test]
fn invalid_unparsable() {
let _ = Uut::from_str("fdsfdsfds").expect_err("should have failed");
}
#[test]
fn port_range_not_zero() {
let _ = Uut::from_str("0-1000").expect_err("should have failed");
}
#[test]
fn port_range_combine() {
fn pr(begin: u16, end: u16) -> Uut {
ConfigPortRange { begin, end }
}
let config = pr(42, 88);
assert_eq!(Uut::default().combine(config).unwrap(), config);
assert_eq!(config.combine(Uut::default()).unwrap(), config);
assert_eq!(config.combine(pr(77, 99)).unwrap(), pr(77, 88));
assert_eq!(config.combine(pr(5, 49)).unwrap(), pr(42, 49));
assert_eq!(config.combine(pr(5, 123)).unwrap(), pr(42, 88));
assert_eq!(config.combine(pr(51, 62)).unwrap(), pr(51, 62));
let _ = config.combine(pr(123, 456)).expect_err("failure expected");
}
}