use ipnet::IpNet;
use log::LevelFilter;
use std::{
collections::HashSet,
error::Error,
fmt::{self, Display, Formatter},
net::IpAddr,
str::FromStr,
};
use syslog::Facility;
use viaspf::{record::ExplainString, DomainName, ParseParamError, SpfResult};
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct ParseSocketError;
impl Error for ParseSocketError {}
impl Display for ParseSocketError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "failed to parse socket")
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum Socket {
Inet(String),
Unix(String),
}
impl Display for Socket {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Inet(s) => write!(f, "inet:{s}"),
Self::Unix(s) => write!(f, "unix:{s}"),
}
}
}
impl FromStr for Socket {
type Err = ParseSocketError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Some(s) = s.strip_prefix("inet:") {
Ok(Self::Inet(s.into()))
} else if let Some(s) = s.strip_prefix("unix:") {
Ok(Self::Unix(s.into()))
} else {
Err(ParseSocketError)
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct ParseStatusCodeError;
impl Error for ParseStatusCodeError {}
impl Display for ParseStatusCodeError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "failed to parse status code")
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum ReplyCode {
Transient(String),
Permanent(String),
}
impl AsRef<str> for ReplyCode {
fn as_ref(&self) -> &str {
match self {
Self::Transient(s) | Self::Permanent(s) => s,
}
}
}
impl Display for ReplyCode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.as_ref().fmt(f)
}
}
impl FromStr for ReplyCode {
type Err = ParseStatusCodeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.as_bytes() {
[x, y, z]
if matches!(x, b'4'..=b'5')
&& matches!(y, b'0'..=b'5')
&& matches!(z, b'0'..=b'9') =>
{
Ok(match x {
b'4' => Self::Transient(s.into()),
b'5' => Self::Permanent(s.into()),
_ => unreachable!(),
})
}
_ => Err(ParseStatusCodeError),
}
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum EnhancedStatusCode {
Transient(String),
Permanent(String),
}
impl EnhancedStatusCode {
pub fn is_compatible_with(&self, reply_code: &ReplyCode) -> bool {
matches!(
(self, reply_code),
(Self::Transient(_), ReplyCode::Transient(_))
| (Self::Permanent(_), ReplyCode::Permanent(_))
)
}
}
impl AsRef<str> for EnhancedStatusCode {
fn as_ref(&self) -> &str {
match self {
Self::Transient(s) | Self::Permanent(s) => s,
}
}
}
impl Display for EnhancedStatusCode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.as_ref().fmt(f)
}
}
impl FromStr for EnhancedStatusCode {
type Err = ParseStatusCodeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
fn is_three_digits(s: &str) -> bool {
s == "0"
|| matches!(s.len(), 1..=3)
&& s.chars().all(|c| c.is_ascii_digit())
&& !s.starts_with('0')
}
let mut iter = s.splitn(3, '.');
match (iter.next(), iter.next(), iter.next()) {
(Some(class), Some(subject), Some(detail))
if matches!(class, "4" | "5")
&& is_three_digits(subject)
&& is_three_digits(detail) =>
{
Ok(match class {
"4" => Self::Transient(s.into()),
"5" => Self::Permanent(s.into()),
_ => unreachable!(),
})
}
_ => Err(ParseStatusCodeError),
}
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Header(Vec<HeaderType>);
impl Header {
pub fn new(header_types: Vec<HeaderType>) -> Result<Self, Vec<HeaderType>> {
if header_types.len() == header_types.iter().collect::<HashSet<_>>().len() {
Ok(Self(header_types))
} else {
Err(header_types)
}
}
pub fn iter(&self) -> impl DoubleEndedIterator<Item = &HeaderType> {
self.0.iter()
}
}
impl Default for Header {
fn default() -> Self {
HeaderType::ReceivedSpf.into()
}
}
impl From<Vec<HeaderType>> for Header {
fn from(header_types: Vec<HeaderType>) -> Self {
Self(header_types)
}
}
impl From<HeaderType> for Header {
fn from(header_type: HeaderType) -> Self {
vec![header_type].into()
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct ParseHeaderTypeError;
impl Error for ParseHeaderTypeError {}
impl Display for ParseHeaderTypeError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "failed to parse header type")
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum HeaderType {
ReceivedSpf,
AuthenticationResults,
}
impl Display for HeaderType {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::ReceivedSpf => write!(f, "Received-SPF"),
Self::AuthenticationResults => write!(f, "Authentication-Results"),
}
}
}
impl FromStr for HeaderType {
type Err = ParseHeaderTypeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("Received-SPF") {
Ok(HeaderType::ReceivedSpf)
} else if s.eq_ignore_ascii_case("Authentication-Results") {
Ok(HeaderType::AuthenticationResults)
} else {
Err(ParseHeaderTypeError)
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct ParseResultKindError;
impl Error for ParseResultKindError {}
impl Display for ParseResultKindError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "failed to parse SPF result kind")
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum RejectResultKind {
Fail,
Softfail,
Temperror,
Permerror,
}
impl RejectResultKind {
fn from_spf_result(spf_result: &SpfResult) -> Option<Self> {
use SpfResult::*;
match spf_result {
None | Neutral | Pass => Option::None,
Fail(_) => Some(Self::Fail),
Softfail => Some(Self::Softfail),
Temperror => Some(Self::Temperror),
Permerror => Some(Self::Permerror),
}
}
}
impl FromStr for RejectResultKind {
type Err = ParseResultKindError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"fail" => Ok(Self::Fail),
"softfail" => Ok(Self::Softfail),
"temperror" => Ok(Self::Temperror),
"permerror" => Ok(Self::Permerror),
_ => Err(ParseResultKindError),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RejectResults(HashSet<RejectResultKind>);
impl RejectResults {
pub fn includes(&self, result: &SpfResult) -> bool {
matches!(RejectResultKind::from_spf_result(result), Some(k) if self.0.contains(&k))
}
}
impl Default for RejectResults {
fn default() -> Self {
HashSet::from([
RejectResultKind::Fail,
RejectResultKind::Temperror,
RejectResultKind::Permerror,
])
.into()
}
}
impl From<HashSet<RejectResultKind>> for RejectResults {
fn from(results: HashSet<RejectResultKind>) -> Self {
Self(results)
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum DefinitiveHeloResultKind {
Pass,
Fail,
Softfail,
Temperror,
Permerror,
}
impl DefinitiveHeloResultKind {
fn from_spf_result(spf_result: &SpfResult) -> Option<Self> {
use SpfResult::*;
match spf_result {
None | Neutral => Option::None,
Pass => Some(Self::Pass),
Fail(_) => Some(Self::Fail),
Softfail => Some(Self::Softfail),
Temperror => Some(Self::Temperror),
Permerror => Some(Self::Permerror),
}
}
}
impl FromStr for DefinitiveHeloResultKind {
type Err = ParseResultKindError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"pass" => Ok(Self::Pass),
"fail" => Ok(Self::Fail),
"softfail" => Ok(Self::Softfail),
"temperror" => Ok(Self::Temperror),
"permerror" => Ok(Self::Permerror),
_ => Err(ParseResultKindError),
}
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct DefinitiveHeloResults(HashSet<DefinitiveHeloResultKind>);
impl DefinitiveHeloResults {
pub fn includes(&self, result: &SpfResult) -> bool {
matches!(DefinitiveHeloResultKind::from_spf_result(result), Some(k) if self.0.contains(&k))
}
}
impl From<HashSet<DefinitiveHeloResultKind>> for DefinitiveHeloResults {
fn from(results: HashSet<DefinitiveHeloResultKind>) -> Self {
Self(results)
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum ExplainStringMod {
Substitute(ExplainString),
Decorate {
prefix: ExplainString,
suffix: ExplainString,
},
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct ExpExplainString(pub ExplainStringMod);
impl AsRef<ExplainStringMod> for ExpExplainString {
fn as_ref(&self) -> &ExplainStringMod {
&self.0
}
}
impl From<ExplainStringMod> for ExpExplainString {
fn from(m: ExplainStringMod) -> Self {
Self(m)
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct ReasonExplainString(pub ExplainStringMod);
impl AsRef<ExplainStringMod> for ReasonExplainString {
fn as_ref(&self) -> &ExplainStringMod {
&self.0
}
}
impl From<ExplainStringMod> for ReasonExplainString {
fn from(m: ExplainStringMod) -> Self {
Self(m)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TrustedNetworks {
pub trust_loopback: bool,
pub networks: HashSet<IpNet>,
}
impl TrustedNetworks {
pub fn contains(&self, addr: IpAddr) -> bool {
self.trust_loopback && addr.is_loopback() || self.networks.iter().any(|n| n.contains(&addr))
}
}
impl Default for TrustedNetworks {
fn default() -> Self {
Self {
trust_loopback: true,
networks: Default::default(),
}
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct SkipEntry {
pub local_part: Option<String>,
pub domain: DomainName,
pub match_subdomains: bool,
}
impl SkipEntry {
fn matches(&self, sender: &str) -> bool {
fn matches_domain(entry: &SkipEntry, domain: &DomainName) -> bool {
if entry.match_subdomains {
domain.as_ref().is_subdomain_of(entry.domain.as_ref())
} else {
domain == &entry.domain
}
}
fn matches_local_part(entry: &SkipEntry, local_part: Option<&str>) -> bool {
match (&local_part, &entry.local_part) {
(Some(lp1), Some(lp2)) => {
lp1.eq_ignore_ascii_case(lp2)
}
(None, Some(_)) => false,
(_, None) => true,
}
}
let (local_part, domain) = match sender.rsplit_once('@') {
Some((local_part, domain)) => (Some(local_part), domain),
None => (None, sender),
};
match DomainName::new(domain) {
Ok(domain) => matches_domain(self, &domain) && matches_local_part(self, local_part),
Err(_) => false,
}
}
}
impl FromStr for SkipEntry {
type Err = ParseParamError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut local_part = None;
let mut match_subdomains = false;
let domain = match s.rsplit_once('@') {
Some((l, d)) => {
local_part = Some(l.into());
d
}
None => match s.strip_prefix('.') {
Some(s) => {
match_subdomains = true;
s
}
None => s,
},
};
let domain = DomainName::new(domain)?;
Ok(Self {
local_part,
domain,
match_subdomains,
})
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct SkipSenders(HashSet<SkipEntry>);
impl SkipSenders {
pub fn extended_with(mut self, other: Self) -> Self {
self.0.extend(other.0);
self
}
pub fn includes(&self, sender: &str) -> bool {
self.0.iter().any(|e| e.matches(sender))
}
}
impl From<HashSet<SkipEntry>> for SkipSenders {
fn from(results: HashSet<SkipEntry>) -> Self {
Self(results)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct ParseLogDestinationError;
impl Error for ParseLogDestinationError {}
impl Display for ParseLogDestinationError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "failed to parse log destination")
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub enum LogDestination {
#[default]
Syslog,
Stderr,
}
impl FromStr for LogDestination {
type Err = ParseLogDestinationError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"syslog" => Ok(Self::Syslog),
"stderr" => Ok(Self::Stderr),
_ => Err(ParseLogDestinationError),
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct ParseLogLevelError;
impl Error for ParseLogLevelError {}
impl Display for ParseLogLevelError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "failed to parse log level")
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub enum LogLevel {
Error,
Warn,
#[default]
Info,
Debug,
}
impl FromStr for LogLevel {
type Err = ParseLogLevelError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"error" => Ok(Self::Error),
"warn" => Ok(Self::Warn),
"info" => Ok(Self::Info),
"debug" => Ok(Self::Debug),
_ => Err(ParseLogLevelError),
}
}
}
impl From<LogLevel> for LevelFilter {
fn from(log_level: LogLevel) -> Self {
match log_level {
LogLevel::Error => Self::Error,
LogLevel::Warn => Self::Warn,
LogLevel::Info => Self::Info,
LogLevel::Debug => Self::Debug,
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct ParseSyslogFacilityError;
impl Error for ParseSyslogFacilityError {}
impl Display for ParseSyslogFacilityError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "failed to parse syslog facility")
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub enum SyslogFacility {
Auth,
Authpriv,
Cron,
Daemon,
Ftp,
Kern,
Local0,
Local1,
Local2,
Local3,
Local4,
Local5,
Local6,
Local7,
Lpr,
#[default]
Mail,
News,
Syslog,
User,
Uucp,
}
impl FromStr for SyslogFacility {
type Err = ParseSyslogFacilityError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"auth" => Ok(Self::Auth),
"authpriv" => Ok(Self::Authpriv),
"cron" => Ok(Self::Cron),
"daemon" => Ok(Self::Daemon),
"ftp" => Ok(Self::Ftp),
"kern" => Ok(Self::Kern),
"local0" => Ok(Self::Local0),
"local1" => Ok(Self::Local1),
"local2" => Ok(Self::Local2),
"local3" => Ok(Self::Local3),
"local4" => Ok(Self::Local4),
"local5" => Ok(Self::Local5),
"local6" => Ok(Self::Local6),
"local7" => Ok(Self::Local7),
"lpr" => Ok(Self::Lpr),
"mail" => Ok(Self::Mail),
"news" => Ok(Self::News),
"syslog" => Ok(Self::Syslog),
"user" => Ok(Self::User),
"uucp" => Ok(Self::Uucp),
_ => Err(ParseSyslogFacilityError),
}
}
}
impl From<SyslogFacility> for Facility {
fn from(syslog_facility: SyslogFacility) -> Self {
match syslog_facility {
SyslogFacility::Auth => Self::LOG_AUTH,
SyslogFacility::Authpriv => Self::LOG_AUTHPRIV,
SyslogFacility::Cron => Self::LOG_CRON,
SyslogFacility::Daemon => Self::LOG_DAEMON,
SyslogFacility::Ftp => Self::LOG_FTP,
SyslogFacility::Kern => Self::LOG_KERN,
SyslogFacility::Local0 => Self::LOG_LOCAL0,
SyslogFacility::Local1 => Self::LOG_LOCAL1,
SyslogFacility::Local2 => Self::LOG_LOCAL2,
SyslogFacility::Local3 => Self::LOG_LOCAL3,
SyslogFacility::Local4 => Self::LOG_LOCAL4,
SyslogFacility::Local5 => Self::LOG_LOCAL5,
SyslogFacility::Local6 => Self::LOG_LOCAL6,
SyslogFacility::Local7 => Self::LOG_LOCAL7,
SyslogFacility::Lpr => Self::LOG_LPR,
SyslogFacility::Mail => Self::LOG_MAIL,
SyslogFacility::News => Self::LOG_NEWS,
SyslogFacility::Syslog => Self::LOG_SYSLOG,
SyslogFacility::User => Self::LOG_USER,
SyslogFacility::Uucp => Self::LOG_UUCP,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ipnet::Ipv4Net;
use std::net::Ipv6Addr;
#[test]
fn reply_code_parse_ok() {
let code = "441".parse();
assert_eq!(code, Ok(ReplyCode::Transient("441".into())));
let code = "499".parse::<ReplyCode>();
assert_eq!(code, Err(ParseStatusCodeError));
}
#[test]
fn enhanced_status_code_parse_ok() {
let code = "4.1.23".parse();
assert_eq!(code, Ok(EnhancedStatusCode::Transient("4.1.23".into())));
let code = "4.0.23".parse();
assert_eq!(code, Ok(EnhancedStatusCode::Transient("4.0.23".into())));
let code = "4.01.23".parse::<EnhancedStatusCode>();
assert_eq!(code, Err(ParseStatusCodeError));
}
#[test]
fn trusted_networks_loopback_ok() {
let trusted_networks = TrustedNetworks::default();
assert!(trusted_networks.contains(IpAddr::from([127, 0, 1, 0])));
assert!(trusted_networks.contains(Ipv6Addr::LOCALHOST.into()));
}
#[test]
fn trusted_networks_subnet_ok() {
let net = Ipv4Net::new([43, 5, 0, 0].into(), 16).unwrap();
let trusted_networks = TrustedNetworks {
networks: HashSet::from([net.into()]),
..Default::default()
};
assert!(trusted_networks.contains(IpAddr::from([43, 5, 117, 8])));
}
#[test]
fn skip_senders_ok() {
let skip_senders = SkipSenders::from(HashSet::from([
SkipEntry {
local_part: None,
domain: DomainName::new("example.com").unwrap(),
match_subdomains: false,
},
SkipEntry {
local_part: None,
domain: DomainName::new("super.example.com").unwrap(),
match_subdomains: true,
},
SkipEntry {
local_part: Some("from".into()),
domain: DomainName::new("example.org").unwrap(),
match_subdomains: false,
},
]));
assert!(skip_senders.includes("Example.Com"));
assert!(skip_senders.includes("Me@Example.Com"));
assert!(!skip_senders.includes("super.example.com"));
assert!(skip_senders.includes("mx1.super.example.com"));
assert!(!skip_senders.includes("me@super.example.com"));
assert!(skip_senders.includes("me@mx1.super.example.com"));
assert!(!skip_senders.includes("example.org"));
assert!(!skip_senders.includes("mail.example.org"));
assert!(!skip_senders.includes("to@example.org"));
assert!(skip_senders.includes("FROM@example.org"));
}
}