use crate::{
callbacks::{Callbacks, Status},
config::Config,
connection::Connection,
context::{Context, EomContext, NegotiateContext, ReplyCode},
macros::MacroStage,
message::{
command::{
CommandKind, CommandMessage, ConnInfoPayload, EnvAddrPayload, HeaderPayload,
HeloPayload, MacroPayload, OptNegPayload, ParseCommandError, UnknownPayload,
},
reply::Reply,
Byte, Version, PROTOCOL_VERSION,
},
proto_util::{Actions, ProtoOpts},
};
use bytes::Bytes;
use std::{
cmp,
collections::HashMap,
error::Error,
ffi::CString,
fmt::{self, Display, Formatter},
io,
sync::Arc,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
select,
sync::{watch, OwnedSemaphorePermit},
};
use tracing::{trace, warn};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum State {
Init,
Opts,
Conn,
Helo,
Mail,
Rcpt,
Data,
Header,
Eoh,
Body,
Eom,
Quit,
Abort,
Unknown,
QuitNc,
}
impl State {
fn all() -> impl DoubleEndedIterator<Item = Self> {
use State::*;
[
Init, Opts, Conn, Helo, Mail, Rcpt, Data, Header, Eoh, Body, Eom, Quit, Abort, Unknown,
QuitNc,
]
.into_iter()
}
fn is_mail_transaction(&self) -> bool {
use State::*;
matches!(self, Mail | Rcpt | Data | Header | Eoh | Body)
}
fn can_reach(&self, target: Self, opts: ProtoOpts) -> bool {
let can_be_skipped = |s: &Self| match s {
Self::Conn => opts.contains(ProtoOpts::NO_CONNECT),
Self::Helo => opts.contains(ProtoOpts::NO_HELO),
Self::Mail => opts.contains(ProtoOpts::NO_MAIL),
Self::Rcpt => opts.contains(ProtoOpts::NO_RCPT),
Self::Data => opts.contains(ProtoOpts::NO_DATA),
Self::Header => opts.contains(ProtoOpts::NO_HEADER),
Self::Eoh => opts.contains(ProtoOpts::NO_EOH),
Self::Body => opts.contains(ProtoOpts::NO_BODY),
Self::Unknown => opts.contains(ProtoOpts::NO_UNKNOWN),
_ => false,
};
if self.has_transition_to(target) {
return true;
}
self.remaining()
.skip(1)
.take_while(can_be_skipped)
.any(|s| s.has_transition_to(target))
}
fn has_transition_to(&self, next: Self) -> bool {
use State::*;
match self {
Init => matches!(next, Opts),
Opts | QuitNc => matches!(next, Conn | Unknown),
Conn | Helo => matches!(next, Helo | Mail | Unknown),
Mail => matches!(next, Rcpt | Abort | Unknown),
Rcpt => matches!(next, Header | Eoh | Data | Body | Eom | Rcpt | Abort | Unknown),
Data | Header => matches!(next, Eoh | Header | Abort),
Eoh | Body => matches!(next, Body | Eom | Abort),
Eom => matches!(next, Quit | Mail | Unknown | QuitNc),
Quit | Abort => false,
Unknown => matches!(
next,
Helo | Mail | Rcpt | Data | Body | Unknown | Abort | Quit | QuitNc
),
}
}
fn remaining(&self) -> impl Iterator<Item = Self> + '_ {
Self::all().skip_while(move |s| s != self)
}
}
fn opts_from_callbacks<T: Send>(callbacks: &Callbacks<T>) -> ProtoOpts {
let mut opts = ProtoOpts::empty();
opts.set(ProtoOpts::NO_CONNECT, callbacks.connect.is_none());
opts.set(ProtoOpts::NO_HELO, callbacks.helo.is_none());
opts.set(ProtoOpts::NO_MAIL, callbacks.mail.is_none());
opts.set(ProtoOpts::NO_RCPT, callbacks.rcpt.is_none());
opts.set(ProtoOpts::NO_DATA, callbacks.data.is_none());
opts.set(ProtoOpts::NO_HEADER, callbacks.header.is_none());
opts.set(ProtoOpts::NO_EOH, callbacks.eoh.is_none());
opts.set(ProtoOpts::NO_BODY, callbacks.body.is_none());
opts.set(ProtoOpts::NO_UNKNOWN, callbacks.unknown.is_none());
opts
}
type SessionResult = Result<(), SessionError>;
#[derive(Debug)]
enum SessionError {
MilterShutDown,
Io(io::Error),
UnknownCommand(u8),
BufferEmpty,
ParseCommand(ParseCommandError),
ProtocolVersionNotSupported(Version),
InvalidNegotiateStatus(Status),
ActionsNotSupported(Actions),
ProtoOptsNotSupported(ProtoOpts),
}
impl Display for SessionError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::MilterShutDown => write!(f, "milter shut down"),
Self::Io(error) => write!(f, "I/O error: {error}"),
Self::UnknownCommand(byte) => write!(f, "unknown command: {:?}", Byte(*byte)),
Self::BufferEmpty => write!(f, "command with empty payload buffer"),
Self::ParseCommand(error) => write!(f, "command could not be parsed: {error}"),
Self::ProtocolVersionNotSupported(version) => {
write!(f, "requested milter protocol version {version} not supported")
}
Self::InvalidNegotiateStatus(status) => {
write!(f, "invalid status in negotiation: {status:?}")
}
Self::ActionsNotSupported(actions) => {
write!(f, "requested actions not supported: {actions:?}")
}
Self::ProtoOptsNotSupported(opts) => {
write!(f, "requested milter protocol options not supported: {opts:?}")
}
}
}
}
impl Error for SessionError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Io(error) => Some(error),
Self::ParseCommand(error) => Some(error),
_ => None,
}
}
}
impl From<io::Error> for SessionError {
fn from(error: io::Error) -> Self {
Self::Io(error)
}
}
impl From<ParseCommandError> for SessionError {
fn from(error: ParseCommandError) -> Self {
Self::ParseCommand(error)
}
}
pub fn spawn<S, T>(
stream: S,
shutdown_sender: &watch::Sender<bool>,
callbacks: &Arc<Callbacks<T>>,
config: &Config,
permit: OwnedSemaphorePermit,
) where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
T: Send + 'static,
{
let callbacks = callbacks.clone();
let shutdown = shutdown_sender.subscribe();
let session = Session::new(stream, shutdown, callbacks, config);
tokio::spawn(async move {
trace!("session beginning processing commands");
match session.process_commands().await {
Ok(()) => {
trace!("session done processing commands");
}
Err(e) => {
trace!("error in session while processing commands: {e}");
}
}
drop(permit);
});
}
struct Session<T: Send> {
conn: Connection,
state: State,
shutdown: watch::Receiver<bool>,
callbacks: Arc<Callbacks<T>>,
context: Context<T>,
actions: Actions,
opts: ProtoOpts,
}
impl<T: Send> Session<T> {
fn new<S>(
stream: S,
shutdown: watch::Receiver<bool>,
callbacks: Arc<Callbacks<T>>,
config: &Config,
) -> Self
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let conn = Connection::new(stream, config.connection_timeout);
let actions = config.actions;
let opts = opts_from_callbacks(&callbacks);
Self {
conn,
state: State::Init,
shutdown,
callbacks,
context: Context::new(),
actions,
opts,
}
}
async fn process_commands(mut self) -> SessionResult {
let result = self.process_until_done().await;
if result.is_err() && self.state.is_mail_transaction() {
if let Some(abort) = &self.callbacks.abort {
let _ = abort(&mut self.context).await;
}
}
if self.state != State::Quit {
if let Some(close) = &self.callbacks.close {
let _ = close(&mut self.context).await;
}
}
result
}
async fn process_until_done(&mut self) -> SessionResult {
while self.state != State::Quit {
if *self.shutdown.borrow() {
return Err(SessionError::MilterShutDown);
}
let msg = select! {
msg = self.conn.read_message() => msg?,
_ = self.shutdown.changed() => continue,
};
let msg = CommandMessage::try_from(msg)
.map_err(|e| SessionError::UnknownCommand(e.byte()))?;
let cmd = msg.kind;
trace!(?cmd, "got next command");
if let Some(next_state) = cmd.as_state() {
if !self.state.can_reach(next_state, self.opts) {
if self.state.is_mail_transaction() {
if let Some(abort) = &self.callbacks.abort {
let _ = abort(&mut self.context).await;
}
}
self.state = State::Helo;
if !self.state.can_reach(next_state, self.opts) {
if next_state == State::Quit {
break;
} else {
trace!("ignoring unexpected command");
continue;
}
}
}
}
ensure_buffer_present(cmd, &msg.buffer)?;
if let Some(next_state) = cmd.as_state() {
trace!(state = ?next_state, "transitioned to next state");
self.state = next_state;
}
self.handle_command(msg).await?;
}
Ok(())
}
async fn handle_command(&mut self, msg: CommandMessage) -> SessionResult {
let status = match msg.kind {
CommandKind::OptNeg => {
self.handle_opt_neg_command(msg.buffer).await?;
return Ok(());
}
CommandKind::DefMacros => {
self.handle_def_macros_command(msg.buffer);
return Ok(());
}
CommandKind::ConnInfo => {
self.context.clear_macros_after(MacroStage::Connect);
if let Some(connect) = &self.callbacks.connect {
let ConnInfoPayload {
hostname,
socket_info,
} = ConnInfoPayload::parse_buffer(msg.buffer)?;
connect(&mut self.context, hostname, socket_info).await
} else {
Status::Continue
}
}
CommandKind::Helo => {
self.context.clear_macros_after(MacroStage::Helo);
if let Some(helo) = &self.callbacks.helo {
let HeloPayload { hostname } = HeloPayload::parse_buffer(msg.buffer)?;
helo(&mut self.context, hostname).await
} else {
Status::Continue
}
}
CommandKind::Mail => {
self.context.clear_macros_after(MacroStage::Mail);
if let Some(mail) = &self.callbacks.mail {
let EnvAddrPayload { args } = EnvAddrPayload::parse_buffer(msg.buffer)?;
mail(&mut self.context, args).await
} else {
Status::Continue
}
}
CommandKind::Rcpt => {
self.context.clear_macros_after(MacroStage::Rcpt);
if let Some(rcpt) = &self.callbacks.rcpt {
let EnvAddrPayload { args } = EnvAddrPayload::parse_buffer(msg.buffer)?;
rcpt(&mut self.context, args).await
} else {
Status::Continue
}
}
CommandKind::Data => {
if let Some(data) = &self.callbacks.data {
data(&mut self.context).await
} else {
Status::Continue
}
}
CommandKind::Header => {
if let Some(header) = &self.callbacks.header {
let HeaderPayload { name, value } = HeaderPayload::parse_buffer(msg.buffer)?;
header(&mut self.context, name, value).await
} else {
Status::Continue
}
}
CommandKind::Eoh => {
if let Some(eoh) = &self.callbacks.eoh {
eoh(&mut self.context).await
} else {
Status::Continue
}
}
CommandKind::BodyChunk => {
if let Some(body) = &self.callbacks.body {
body(&mut self.context, msg.buffer).await
} else {
Status::Continue
}
}
CommandKind::BodyEnd => {
let mut status = Status::Continue;
if let Some(body) = &self.callbacks.body {
if !msg.buffer.is_empty() {
status = body(&mut self.context, msg.buffer).await;
if status != Status::Continue {
self.write_reply(status).await?;
}
}
}
if status == Status::Continue {
if let Some(eom) = &self.callbacks.eom {
let mut cx = EomContext::new(
self.conn.clone(),
self.context.data.take(),
self.context.macros.clone_internal(),
self.context.reply.clone_internal(),
self.actions,
);
status = eom(&mut cx).await;
self.context.restore(cx.data, cx.reply);
}
}
status
}
CommandKind::Abort => {
if let Some(abort) = &self.callbacks.abort {
let _ = abort(&mut self.context).await;
}
return Ok(());
}
CommandKind::Quit | CommandKind::QuitNc => {
if let Some(close) = &self.callbacks.close {
let _ = close(&mut self.context).await;
}
self.context.clear_macros();
return Ok(());
}
CommandKind::Unknown => {
if let Some(unknown) = &self.callbacks.unknown {
let UnknownPayload { arg } = UnknownPayload::parse_buffer(msg.buffer)?;
unknown(&mut self.context, arg).await
} else {
Status::Continue
}
}
};
self.write_reply(status).await?;
if status == Status::Accept
|| matches!(status, Status::Reject | Status::Discard | Status::Tempfail)
&& !matches!(self.state, State::Rcpt | State::Unknown)
{
self.state = State::Helo;
}
Ok(())
}
async fn handle_opt_neg_command(&mut self, buffer: Bytes) -> SessionResult {
self.context.clear_macros();
let OptNegPayload {
version: mta_version,
actions: mut mta_actions,
opts: mut mta_opts,
} = OptNegPayload::parse_buffer(buffer)?;
if mta_version < 2 {
return Err(SessionError::ProtocolVersionNotSupported(mta_version));
}
let target_version = cmp::min(mta_version, PROTOCOL_VERSION);
if mta_actions.is_empty() {
mta_actions = Actions::min_flags();
}
if mta_opts.is_empty() {
mta_opts = ProtoOpts::min_flags();
}
let target_opts;
let requested_macros;
if let Some(negotiate) = &self.callbacks.negotiate {
let default_actions = mta_actions;
let default_opts = self.opts | (mta_opts & ProtoOpts::SKIP);
let mut cx = NegotiateContext::new(
self.context.data.take(),
self.context.reply.clone_internal(),
default_actions,
default_opts,
);
let status = negotiate(&mut cx, mta_actions, mta_opts).await;
self.context.restore(cx.data, cx.reply);
requested_macros = cx.requested_macros;
match status {
Status::AllOpts => {
self.actions = mta_actions;
target_opts = default_opts;
}
Status::Continue => {
self.actions = cx.requested_actions;
self.opts = cx.requested_opts;
target_opts = cx.requested_opts;
}
status => {
return Err(SessionError::InvalidNegotiateStatus(status));
}
}
} else {
target_opts = self.opts;
requested_macros = Default::default();
}
if !mta_actions.contains(self.actions) {
return Err(SessionError::ActionsNotSupported(self.actions));
}
if !mta_opts.contains(target_opts) {
return Err(SessionError::ProtoOptsNotSupported(target_opts));
}
self.write_opt_neg_reply(target_version, self.actions, target_opts, requested_macros)
.await?;
Ok(())
}
fn handle_def_macros_command(&mut self, buffer: Bytes) {
let MacroPayload { stage, macros } = match MacroPayload::parse_buffer(buffer) {
Ok(payload) => payload,
Err(e) => {
trace!("skipping unrecognized macro command: {e}");
return;
}
};
let mut entries = HashMap::new();
let mut macros = macros.into_iter();
while let (Some(k), Some(v)) = (macros.next(), macros.next()) {
entries.entry(k).or_insert(v);
}
trace!(?stage, ?entries, "registered new macro definitions");
self.context.insert_macros(stage, entries);
}
async fn write_opt_neg_reply(
&mut self,
version: Version,
requested_actions: Actions,
requested_opts: ProtoOpts,
requested_macros: HashMap<MacroStage, CString>,
) -> io::Result<()> {
let reply = Reply::OptNeg {
version,
actions: requested_actions,
opts: requested_opts,
macros: requested_macros,
};
self.conn.write_reply(reply).await
}
async fn write_reply(&mut self, mut status: Status) -> io::Result<()> {
fn needs_noreply(opts: ProtoOpts, state: State) -> bool {
match state {
State::Conn => opts.contains(ProtoOpts::NOREPLY_CONNECT),
State::Helo => opts.contains(ProtoOpts::NOREPLY_HELO),
State::Mail => opts.contains(ProtoOpts::NOREPLY_MAIL),
State::Rcpt => opts.contains(ProtoOpts::NOREPLY_RCPT),
State::Data => opts.contains(ProtoOpts::NOREPLY_DATA),
State::Header => opts.contains(ProtoOpts::NOREPLY_HEADER),
State::Eoh => opts.contains(ProtoOpts::NOREPLY_EOH),
State::Body => opts.contains(ProtoOpts::NOREPLY_BODY),
State::Unknown => opts.contains(ProtoOpts::NOREPLY_UNKNOWN),
_ => false,
}
}
if needs_noreply(self.opts, self.state) && status != Status::Noreply {
warn!("status response Noreply requested but not used");
status = Status::Noreply;
}
let reply = match status {
Status::Accept => Reply::Accept,
Status::Continue => Reply::Continue,
Status::Reject => {
self.context
.take_reply_if(|r| matches!(r, ReplyCode::Permanent(_)))
.unwrap_or(Reply::Reject)
}
Status::Tempfail => {
self.context
.take_reply_if(|r| matches!(r, ReplyCode::Transient(_)))
.unwrap_or(Reply::Tempfail)
}
Status::Discard => Reply::Discard,
Status::Skip => Reply::Skip,
Status::Noreply => {
return Ok(());
}
_ => return Ok(()),
};
self.conn.write_reply(reply).await
}
}
fn ensure_buffer_present(cmd: CommandKind, buf: &[u8]) -> SessionResult {
use CommandKind::*;
if matches!(
cmd,
DefMacros | BodyChunk | ConnInfo | Helo | Header | Mail | OptNeg | Rcpt | Unknown
) && buf.is_empty()
{
return Err(SessionError::BufferEmpty);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn can_reach_ok() {
use State::*;
let opts = ProtoOpts::NO_HELO
| ProtoOpts::NO_MAIL
| ProtoOpts::NO_DATA
| ProtoOpts::NO_HEADER
| ProtoOpts::NO_EOH
| ProtoOpts::NO_BODY
| ProtoOpts::NO_UNKNOWN;
assert!(Init.can_reach(Opts, opts));
assert!(Opts.can_reach(Conn, opts));
assert!(!Conn.can_reach(Conn, opts));
assert!(!Opts.can_reach(Helo, opts));
assert!(Conn.can_reach(Helo, opts));
assert!(Conn.can_reach(Mail, opts));
assert!(Conn.can_reach(Rcpt, opts));
assert!(!Conn.can_reach(Data, opts));
}
}