pub use indymilter::{message::Version, Actions, IntoCString, MacroStage, ProtoOpts, SocketInfo};
use bytes::{Bytes, BytesMut};
use indymilter::{
message::{
self,
command::{
Command, ConnInfoPayload, EnvAddrPayload, HeaderPayload, HeloPayload, MacroPayload,
OptNegPayload, UnknownPayload,
},
reply::Reply,
PROTOCOL_VERSION,
},
EitherStream,
};
use std::{
cmp,
collections::HashMap,
error::Error,
ffi::{CStr, CString},
fmt::{self, Display, Formatter},
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
time::Duration,
};
use tokio::{
io::{AsyncWriteExt, BufStream},
net::{TcpStream, ToSocketAddrs},
task, time,
};
pub type TestResult<T> = Result<T, TestError>;
#[derive(Debug)]
pub enum TestError {
MilterUsage,
InvalidReply,
Io(io::Error),
}
impl Display for TestError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::MilterUsage => write!(f, "misuse of the MTA/milter interface"),
Self::InvalidReply => write!(f, "milter sent invalid or unexpected reply"),
Self::Io(e) => write!(f, "I/O error: {e}"),
}
}
}
impl Error for TestError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::MilterUsage | Self::InvalidReply => None,
Self::Io(e) => Some(e),
}
}
}
impl From<io::Error> for TestError {
fn from(error: io::Error) -> Self {
Self::Io(error)
}
}
#[cfg(unix)]
type Stream = EitherStream<BufStream<TcpStream>, BufStream<tokio::net::UnixStream>>;
#[cfg(not(unix))]
type Stream = EitherStream<BufStream<TcpStream>, TcpStream>;
#[derive(Clone, Debug)]
pub struct TestConnectionBuilder {
read_timeout: Duration,
write_timeout: Duration,
version: Version,
actions: Actions,
opts: ProtoOpts,
exact_leading_space: bool,
}
impl TestConnectionBuilder {
pub fn read_timeout(mut self, value: Duration) -> Self {
self.read_timeout = value;
self
}
pub fn write_timeout(mut self, value: Duration) -> Self {
self.write_timeout = value;
self
}
pub fn protocol_version(mut self, value: Version) -> Self {
self.version = value;
self
}
pub fn available_actions(mut self, value: Actions) -> Self {
self.actions = value;
self
}
pub fn available_opts(mut self, value: ProtoOpts) -> Self {
self.opts = value;
self
}
pub fn exact_leading_space(mut self, value: bool) -> Self {
self.exact_leading_space = value;
self
}
pub async fn open_tcp(self, addr: impl ToSocketAddrs) -> TestResult<TestConnection> {
let stream = time::timeout(self.write_timeout, TcpStream::connect(addr))
.await
.map_err(io::Error::from)??;
let stream = BufStream::new(stream);
self.open(Stream::Tcp(stream)).await
}
#[cfg(unix)]
pub async fn open_unix(self, addr: impl AsRef<std::path::Path>) -> TestResult<TestConnection> {
let stream = time::timeout(self.write_timeout, tokio::net::UnixStream::connect(addr))
.await
.map_err(io::Error::from)??;
let stream = BufStream::new(stream);
self.open(Stream::Unix(stream)).await
}
async fn open(self, stream: Stream) -> TestResult<TestConnection> {
let mut conn = TestConnection {
stream,
read_timeout: self.read_timeout,
write_timeout: self.write_timeout,
actions: Default::default(),
opts: Default::default(),
macros: Default::default(),
exact_leading_space: self.exact_leading_space,
skip_seen: false,
};
conn.write_command(Command::OptNeg(OptNegPayload {
version: self.version,
actions: self.actions,
opts: self.opts,
}))
.await?;
let reply = conn.read_reply().await?;
if let Reply::OptNeg { actions, opts, macros, .. } = reply {
conn.actions = actions;
conn.opts = opts;
conn.macros = macros;
} else {
return Err(TestError::InvalidReply);
}
Ok(conn)
}
}
#[derive(Debug)]
pub struct TestConnection {
stream: Stream,
read_timeout: Duration,
write_timeout: Duration,
actions: Actions,
opts: ProtoOpts,
macros: HashMap<MacroStage, CString>,
exact_leading_space: bool,
skip_seen: bool,
}
impl TestConnection {
pub fn configure() -> TestConnectionBuilder {
TestConnectionBuilder {
read_timeout: Duration::from_secs(30),
write_timeout: Duration::from_secs(30),
version: PROTOCOL_VERSION,
actions: Actions::all(),
opts: ProtoOpts::all(),
exact_leading_space: false,
}
}
pub async fn open(addr: impl ToSocketAddrs) -> TestResult<Self> {
Self::configure().open_tcp(addr).await
}
pub fn negotiated_actions(&self) -> Actions {
self.actions
}
pub fn negotiated_opts(&self) -> ProtoOpts {
self.opts
}
pub fn requested_macros(&self) -> &HashMap<MacroStage, CString> {
&self.macros
}
pub async fn connect(
&mut self,
hostname: impl IntoCString,
socket_info: impl IntoSocketInfo,
) -> TestResult<Status> {
self.skip_seen = false;
if self.opts.contains(ProtoOpts::NO_CONNECT) {
return Err(TestError::MilterUsage);
}
let hostname = hostname.into_c_string();
let socket_info = socket_info.into_socket_info();
self.write_command(Command::ConnInfo(ConnInfoPayload {
hostname,
socket_info,
}))
.await?;
let status = if self.opts.contains(ProtoOpts::NOREPLY_CONNECT) {
Status::Noreply
} else {
self.read_status().await?
};
Ok(status)
}
pub async fn helo(&mut self, hostname: impl IntoCString) -> TestResult<Status> {
self.skip_seen = false;
if self.opts.contains(ProtoOpts::NO_HELO) {
return Err(TestError::MilterUsage);
}
let hostname = hostname.into_c_string();
self.write_command(Command::Helo(HeloPayload { hostname }))
.await?;
let status = if self.opts.contains(ProtoOpts::NOREPLY_HELO) {
Status::Noreply
} else {
self.read_status().await?
};
Ok(status)
}
pub async fn mail<I, T>(&mut self, args: I) -> TestResult<Status>
where
I: IntoIterator<Item = T>,
T: IntoCString,
{
self.skip_seen = false;
if self.opts.contains(ProtoOpts::NO_MAIL) {
return Err(TestError::MilterUsage);
}
let args: Vec<_> = args.into_iter().map(|x| x.into_c_string()).collect();
if args.is_empty() {
return Err(TestError::MilterUsage);
}
self.write_command(Command::Mail(EnvAddrPayload { args }))
.await?;
let status = if self.opts.contains(ProtoOpts::NOREPLY_MAIL) {
Status::Noreply
} else {
self.read_status().await?
};
Ok(status)
}
pub async fn rcpt<I, T>(&mut self, args: I) -> TestResult<Status>
where
I: IntoIterator<Item = T>,
T: IntoCString,
{
self.skip_seen = false;
if self.opts.contains(ProtoOpts::NO_RCPT) {
return Err(TestError::MilterUsage);
}
let args: Vec<_> = args.into_iter().map(|x| x.into_c_string()).collect();
if args.is_empty() {
return Err(TestError::MilterUsage);
}
self.write_command(Command::Rcpt(EnvAddrPayload { args }))
.await?;
let status = if self.opts.contains(ProtoOpts::NOREPLY_RCPT) {
Status::Noreply
} else {
self.read_status().await?
};
Ok(status)
}
pub async fn data(&mut self) -> TestResult<Status> {
self.skip_seen = false;
if self.opts.contains(ProtoOpts::NO_DATA) {
return Err(TestError::MilterUsage);
}
self.write_command(Command::Data).await?;
let status = if self.opts.contains(ProtoOpts::NOREPLY_DATA) {
Status::Noreply
} else {
self.read_status().await?
};
Ok(status)
}
pub async fn header(
&mut self,
name: impl IntoCString,
value: impl IntoCString,
) -> TestResult<Status> {
self.skip_seen = false;
if self.opts.contains(ProtoOpts::NO_HEADER) {
return Err(TestError::MilterUsage);
}
let name = name.into_c_string();
if name.as_bytes().is_empty() {
return Err(TestError::MilterUsage);
}
let mut value = value.into_c_string();
if !self.exact_leading_space && self.opts.contains(ProtoOpts::LEADING_SPACE) {
let mut bytes = value.into_bytes_with_nul();
bytes.insert(0, b' ');
value = CString::from_vec_with_nul(bytes).unwrap();
}
self.write_command(Command::Header(HeaderPayload { name, value }))
.await?;
let status = if self.opts.contains(ProtoOpts::NOREPLY_HEADER) {
Status::Noreply
} else {
self.read_status().await?
};
Ok(status)
}
pub async fn eoh(&mut self) -> TestResult<Status> {
self.skip_seen = false;
if self.opts.contains(ProtoOpts::NO_EOH) {
return Err(TestError::MilterUsage);
}
self.write_command(Command::Eoh).await?;
let status = if self.opts.contains(ProtoOpts::NOREPLY_EOH) {
Status::Noreply
} else {
self.read_status().await?
};
Ok(status)
}
pub async fn body(&mut self, bytes: impl Into<Bytes>) -> TestResult<Status> {
if self.skip_seen {
self.skip_seen = false;
return Err(TestError::MilterUsage);
}
if self.opts.contains(ProtoOpts::NO_BODY) {
return Err(TestError::MilterUsage);
}
let mut bytes = bytes.into();
if bytes.is_empty() {
self.write_command(Command::BodyChunk(Default::default()))
.await?;
let status = if self.opts.contains(ProtoOpts::NOREPLY_BODY) {
Status::Noreply
} else {
let status = self.read_status().await?;
if status == Status::Skip {
self.skip_seen = true;
}
status
};
return Ok(status);
}
const CHUNK_SIZE: usize = 65535;
let mut status = Status::Continue;
loop {
let next_chunk_len = cmp::min(CHUNK_SIZE, bytes.len());
if next_chunk_len == 0 {
break;
}
let chunk = bytes.split_to(next_chunk_len);
self.write_command(Command::BodyChunk(chunk)).await?;
if self.opts.contains(ProtoOpts::NOREPLY_BODY) {
status = Status::Noreply;
} else {
status = self.read_status().await?;
if status == Status::Skip {
self.skip_seen = true;
}
if status != Status::Continue {
break;
}
}
}
Ok(status)
}
pub async fn eom(&mut self) -> TestResult<(EomActions, Status)> {
self.skip_seen = false;
self.write_command(Command::BodyEnd(Default::default()))
.await?;
let mut actions = vec![];
let final_reply = loop {
let reply = self.read_reply().await?;
match reply {
Reply::Accept
| Reply::Continue
| Reply::Discard
| Reply::Reject
| Reply::Skip
| Reply::Tempfail
| Reply::ReplyCode { .. } => {
break reply;
}
reply => {
let action = EomAction::from_reply(reply)?;
actions.push(action);
}
}
};
let status = self.make_into_status(final_reply)?;
let actions = EomActions { actions };
Ok((actions, status))
}
pub async fn abort(&mut self) -> TestResult<()> {
self.skip_seen = false;
self.write_command(Command::Abort).await?;
Ok(())
}
pub async fn unknown(&mut self, arg: impl IntoCString) -> TestResult<Status> {
self.skip_seen = false;
if self.opts.contains(ProtoOpts::NO_UNKNOWN) {
return Err(TestError::MilterUsage);
}
let arg = arg.into_c_string();
self.write_command(Command::Unknown(UnknownPayload { arg }))
.await?;
let status = if self.opts.contains(ProtoOpts::NOREPLY_UNKNOWN) {
Status::Noreply
} else {
self.read_status().await?
};
Ok(status)
}
pub async fn close(mut self) -> TestResult<()> {
self.write_command(Command::Quit).await?;
self.stream.shutdown().await?;
for _ in 0..11 {
task::yield_now().await;
}
Ok(())
}
pub async fn macros<I, K, V>(&mut self, stage: MacroStage, macros: I) -> TestResult<()>
where
I: IntoIterator<Item = (K, V)>,
K: IntoCString,
V: IntoCString,
{
self.skip_seen = false;
let macros: Vec<_> = macros
.into_iter()
.flat_map(|(k, v)| [k.into_c_string(), v.into_c_string()])
.collect();
if macros.is_empty() {
return Err(TestError::MilterUsage);
}
self.write_command(Command::DefMacros(MacroPayload { stage, macros }))
.await?;
Ok(())
}
async fn write_command(&mut self, cmd: Command) -> io::Result<()> {
let msg = cmd.into_message();
let f = message::write(&mut self.stream, msg);
time::timeout(self.write_timeout, f).await??;
Ok(())
}
async fn read_status(&mut self) -> TestResult<Status> {
let reply = self.read_reply().await?;
self.make_into_status(reply)
}
fn make_into_status(&self, reply: Reply) -> TestResult<Status> {
let status = Status::from_reply(reply)?;
if status == Status::Skip && !self.opts.contains(ProtoOpts::SKIP) {
return Err(TestError::InvalidReply);
}
Ok(status)
}
async fn read_reply(&mut self) -> TestResult<Reply> {
let f = message::read(&mut self.stream);
let msg = time::timeout(self.read_timeout, f)
.await
.map_err(io::Error::from)??;
let reply = Reply::parse_reply(msg).map_err(|_| TestError::InvalidReply)?;
Ok(reply)
}
}
pub trait IntoSocketInfo {
fn into_socket_info(self) -> SocketInfo;
}
impl IntoSocketInfo for SocketInfo {
fn into_socket_info(self) -> SocketInfo {
self
}
}
impl IntoSocketInfo for SocketAddr {
fn into_socket_info(self) -> SocketInfo {
self.into()
}
}
impl<I: Into<IpAddr>> IntoSocketInfo for (I, u16) {
fn into_socket_info(self) -> SocketInfo {
let addr = self.0.into();
SocketAddr::from((addr, self.1)).into()
}
}
impl IntoSocketInfo for [u8; 4] {
fn into_socket_info(self) -> SocketInfo {
(self, 0).into_socket_info()
}
}
impl IntoSocketInfo for [u16; 8] {
fn into_socket_info(self) -> SocketInfo {
(self, 0).into_socket_info()
}
}
impl IntoSocketInfo for [u8; 16] {
fn into_socket_info(self) -> SocketInfo {
(self, 0).into_socket_info()
}
}
impl IntoSocketInfo for Ipv4Addr {
fn into_socket_info(self) -> SocketInfo {
(self, 0).into_socket_info()
}
}
impl IntoSocketInfo for Ipv6Addr {
fn into_socket_info(self) -> SocketInfo {
(self, 0).into_socket_info()
}
}
impl IntoSocketInfo for IpAddr {
fn into_socket_info(self) -> SocketInfo {
(self, 0).into_socket_info()
}
}
impl IntoSocketInfo for CString {
fn into_socket_info(self) -> SocketInfo {
self.into()
}
}
impl IntoSocketInfo for &CStr {
fn into_socket_info(self) -> SocketInfo {
self.into_c_string().into()
}
}
impl IntoSocketInfo for &str {
fn into_socket_info(self) -> SocketInfo {
self.into_c_string().into()
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum Status {
Accept,
Continue,
Discard,
Skip,
Reject { message: Option<CString> },
Tempfail { message: Option<CString> },
Noreply,
}
impl Status {
fn from_reply(reply: Reply) -> Result<Self, TestError> {
match reply {
Reply::Accept => Ok(Self::Accept),
Reply::Continue => Ok(Self::Continue),
Reply::Discard => Ok(Self::Discard),
Reply::Reject => Ok(Self::Reject { message: None }),
Reply::Skip => Ok(Self::Skip),
Reply::Tempfail => Ok(Self::Tempfail { message: None }),
Reply::ReplyCode { reply } => {
if reply.as_bytes().starts_with(&[b'4']) {
Ok(Self::Tempfail {
message: Some(reply),
})
} else if reply.as_bytes().starts_with(&[b'5']) {
Ok(Self::Reject {
message: Some(reply),
})
} else {
Err(TestError::InvalidReply)
}
}
_ => Err(TestError::InvalidReply),
}
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum EomAction {
AddHeader {
name: CString,
value: CString,
},
InsertHeader {
index: i32,
name: CString,
value: CString,
},
ChangeHeader {
name: CString,
index: i32,
value: CString, },
DeleteHeader {
name: CString,
index: i32,
},
ChangeSender {
mail: CString,
args: Option<CString>,
},
AddRecipient {
rcpt: CString,
args: Option<CString>,
},
DeleteRecipient {
rcpt: CString,
},
ReplaceBody {
chunk: Bytes,
},
Progress,
Quarantine {
reason: CString,
},
}
impl EomAction {
fn from_reply(reply: Reply) -> Result<Self, TestError> {
match reply {
Reply::AddRcpt { rcpt } => Ok(Self::AddRecipient { rcpt, args: None }),
Reply::DeleteRcpt { rcpt } => Ok(Self::DeleteRecipient { rcpt }),
Reply::AddRcptExt { rcpt, args } => Ok(Self::AddRecipient { rcpt, args }),
Reply::ReplaceBody { chunk } => Ok(Self::ReplaceBody { chunk }),
Reply::ChangeSender { mail, args } => Ok(Self::ChangeSender { mail, args }),
Reply::AddHeader { name, value } => Ok(Self::AddHeader { name, value }),
Reply::InsertHeader { index, name, value } => {
Ok(Self::InsertHeader { index, name, value })
}
Reply::ChangeHeader { name, index, value } => {
if value.as_bytes().is_empty() {
Ok(Self::DeleteHeader { name, index })
} else {
Ok(Self::ChangeHeader { name, index, value })
}
}
Reply::Progress => Ok(Self::Progress),
Reply::Quarantine { reason } => Ok(Self::Quarantine { reason }),
_ => Err(TestError::InvalidReply),
}
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct EomActions {
pub actions: Vec<EomAction>,
}
impl EomActions {
pub fn has_add_header<M1, M2>(&self, mname: M1, mvalue: M2) -> bool
where
M1: for<'a> Matcher<&'a CStr>,
M2: for<'a> Matcher<&'a CStr>,
{
self.actions.iter().any(|a| {
matches!(a, EomAction::AddHeader { name, value }
if mname.matches(name) && mvalue.matches(value))
})
}
pub fn has_insert_header<M1, M2, M3>(&self, mindex: M1, mname: M2, mvalue: M3) -> bool
where
M1: Matcher<i32>,
M2: for<'a> Matcher<&'a CStr>,
M3: for<'a> Matcher<&'a CStr>,
{
self.actions.iter().any(|a| {
matches!(a, EomAction::InsertHeader { index, name, value }
if mindex.matches(*index) && mname.matches(name) && mvalue.matches(value))
})
}
pub fn has_change_header<M1, M2, M3>(&self, mname: M1, mindex: M2, mvalue: M3) -> bool
where
M1: for<'a> Matcher<&'a CStr>,
M2: Matcher<i32>,
M3: for<'a> Matcher<&'a CStr>,
{
self.actions.iter().any(|a| {
matches!(a, EomAction::ChangeHeader { name, index, value }
if mname.matches(name) && mindex.matches(*index) && mvalue.matches(value))
})
}
pub fn has_delete_header<M1, M2>(&self, mname: M1, mindex: M2) -> bool
where
M1: for<'a> Matcher<&'a CStr>,
M2: Matcher<i32>,
{
self.actions.iter().any(|a| {
matches!(a, EomAction::DeleteHeader { name, index }
if mname.matches(name) && mindex.matches(*index))
})
}
pub fn has_change_sender<M1, M2>(&self, mmail: M1, margs: M2) -> bool
where
M1: for<'a> Matcher<&'a CStr>,
M2: for<'a> Matcher<Option<&'a CStr>>,
{
self.actions.iter().any(|a| {
matches!(a, EomAction::ChangeSender { mail, args }
if mmail.matches(mail) && margs.matches(args.as_deref()))
})
}
pub fn has_add_recipient<M1, M2>(&self, mrcpt: M1, margs: M2) -> bool
where
M1: for<'a> Matcher<&'a CStr>,
M2: for<'a> Matcher<Option<&'a CStr>>,
{
self.actions.iter().any(|a| {
matches!(a, EomAction::AddRecipient { rcpt, args }
if mrcpt.matches(rcpt) && margs.matches(args.as_deref()))
})
}
pub fn has_delete_recipient<M>(&self, mrcpt: M) -> bool
where
M: for<'a> Matcher<&'a CStr>,
{
self.actions.iter().any(|a| {
matches!(a, EomAction::DeleteRecipient { rcpt } if mrcpt.matches(rcpt))
})
}
pub fn has_replaced_body<M>(&self, body: M) -> bool
where
M: for<'a> Matcher<&'a [u8]>,
{
if !self.actions.iter().any(|a| matches!(a, EomAction::ReplaceBody { .. })) {
return false;
}
let mut replaced_body = BytesMut::new();
for a in &self.actions {
if let EomAction::ReplaceBody { chunk } = a {
replaced_body.extend(chunk);
}
}
body.matches(&replaced_body)
}
pub fn has_quarantine<M>(&self, mreason: M) -> bool
where
M: for<'a> Matcher<&'a CStr>,
{
self.actions.iter().any(|a| {
matches!(a, EomAction::Quarantine { reason } if mreason.matches(reason))
})
}
}
pub trait Matcher<T> {
fn matches(&self, value: T) -> bool;
}
impl Matcher<i32> for i32 {
fn matches(&self, value: i32) -> bool {
*self == value
}
}
impl<F> Matcher<i32> for F
where
F: Fn(i32) -> bool,
{
fn matches(&self, value: i32) -> bool {
(self)(value)
}
}
impl<'a, 'b> Matcher<&'b CStr> for &'a str {
fn matches(&self, value: &'b CStr) -> bool {
self.as_bytes() == value.to_bytes()
}
}
impl<'a, 'b> Matcher<Option<&'b CStr>> for &'a str {
fn matches(&self, value: Option<&'b CStr>) -> bool {
match value {
None => false,
Some(value) => self.matches(value),
}
}
}
impl<'a, 'b> Matcher<Option<&'b CStr>> for Option<&'a str> {
fn matches(&self, value: Option<&'b CStr>) -> bool {
match (self, value) {
(None, None) => true,
(Some(s), Some(value)) => s.matches(value),
_ => false,
}
}
}
impl<'a, 'b> Matcher<&'b [u8]> for &'a [u8] {
fn matches(&self, value: &'b [u8]) -> bool {
*self == value
}
}
impl<'a, 'b, const N: usize> Matcher<&'b [u8]> for &'a [u8; N] {
fn matches(&self, value: &'b [u8]) -> bool {
self == &value
}
}
impl<'a, 'b> Matcher<&'b [u8]> for &'a str {
fn matches(&self, value: &'b [u8]) -> bool {
self.as_bytes() == value
}
}
#[cfg(feature = "regex")]
impl<'a, 'b> Matcher<&'b CStr> for &'a regex::Regex {
fn matches(&self, value: &'b CStr) -> bool {
match value.to_str() {
Ok(s) => self.is_match(s),
Err(_) => false,
}
}
}
#[cfg(feature = "regex")]
impl<'a, 'b> Matcher<Option<&'b CStr>> for &'a regex::Regex {
fn matches(&self, value: Option<&'b CStr>) -> bool {
match value {
None => false,
Some(value) => self.matches(value),
}
}
}
#[cfg(feature = "regex")]
impl<'a, 'b> Matcher<&'b CStr> for &'a regex::bytes::Regex {
fn matches(&self, value: &'b CStr) -> bool {
self.is_match(value.to_bytes())
}
}
#[cfg(feature = "regex")]
impl<'a, 'b> Matcher<Option<&'b CStr>> for &'a regex::bytes::Regex {
fn matches(&self, value: Option<&'b CStr>) -> bool {
match value {
None => false,
Some(value) => self.matches(value),
}
}
}
#[cfg(feature = "regex")]
impl<'a, 'b> Matcher<&'b [u8]> for &'a regex::bytes::Regex {
fn matches(&self, value: &'b [u8]) -> bool {
self.is_match(value)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct AnyMatcher;
impl<T> Matcher<T> for AnyMatcher {
fn matches(&self, _: T) -> bool {
true
}
}
pub fn any() -> AnyMatcher {
AnyMatcher
}