use std::num::NonZeroU16;
use nom::{
character::complete::digit1,
error::{ErrorKind, ParseError},
Err as NomErr, IResult, Parser,
};
use crate::pgpass::DELIMITER_CHAR;
use super::field::wildcard;
fn non_zero_u16(s: &str) -> IResult<&str, NonZeroU16, PortError> {
let Ok((remaining, digits)) = digit1::<_, nom::error::Error<_>>.parse(s) else {
return Err(NomErr::Error(PortError::InvalidPort));
};
let Ok(num_u32) = digits.parse::<u32>() else {
return Err(NomErr::Error(PortError::InvalidPort));
};
let Ok(num_u16) = num_u32.try_into() else {
return Err(NomErr::Error(PortError::InvalidPortNumber(num_u32)));
};
let Some(num) = NonZeroU16::new(num_u16) else {
return Err(NomErr::Error(PortError::InvalidPortNumber(num_u32)));
};
Ok((remaining, num))
}
pub fn port_number(s: &str) -> IResult<&str, Option<NonZeroU16>, PortError> {
if let Ok((remaining, _)) = wildcard.parse(s) {
Ok((remaining, None))
} else if s.is_empty() || s.starts_with(DELIMITER_CHAR) {
Err(NomErr::Error(PortError::Empty))
} else {
let (remaining, num) = non_zero_u16.parse(s)?;
Ok((remaining, Some(num)))
}
}
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum PortError {
#[error("Fields must not be empty (use * for wildcards).")]
Empty,
#[error("Could not parse the port number.")]
InvalidPort,
#[error("{0} is not a valid port number.")]
InvalidPortNumber(u32),
#[error("The delimiter character was not found")]
Undelimited,
#[error("An unknown error occurred during parsing (kind: {0:?}).")]
Unknown(ErrorKind),
}
impl ParseError<&str> for PortError {
fn from_error_kind(_input: &str, kind: ErrorKind) -> Self {
if kind == ErrorKind::Tag {
Self::Undelimited
} else {
Self::Unknown(kind)
}
}
fn append(_input: &str, _kind: ErrorKind, other: Self) -> Self {
other
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn simple() {
let s = "123:abc";
let expected = (":abc", Some(NonZeroU16::new(123).unwrap()));
assert_eq!(port_number.parse(s).unwrap(), expected);
let s = "*:abc";
let expected = (":abc", None);
assert_eq!(port_number.parse(s).unwrap(), expected);
}
#[test]
fn zero_is_invalid() {
let s = "0";
let expected = PortError::InvalidPortNumber(0);
let NomErr::Error(actual) = port_number.parse(s).err().unwrap() else {
unreachable!()
};
assert_eq!(actual, expected);
}
#[test]
fn above_u16_max_is_invalid() {
let s = "65536";
let expected = PortError::InvalidPortNumber(65536);
let NomErr::Error(actual) = port_number.parse(s).err().unwrap() else {
unreachable!()
};
assert_eq!(actual, expected);
}
#[test]
fn invalid_characters() {
let s = "abc";
let expected = PortError::InvalidPort;
let NomErr::Error(actual) = port_number.parse(s).err().unwrap() else {
unreachable!()
};
assert_eq!(actual, expected);
}
#[test]
fn empty() {
let s = "";
let expected = PortError::Empty;
let NomErr::Error(actual) = port_number.parse(s).err().unwrap() else {
unreachable!()
};
assert_eq!(actual, expected);
let s = ":";
let expected = PortError::Empty;
let NomErr::Error(actual) = port_number.parse(s).err().unwrap() else {
unreachable!()
};
assert_eq!(actual, expected);
}
}