use crate::{
connection::Connection,
ffi_util::IntoCString,
macros::{MacroStage, Macros},
message::reply::Reply,
proto_util::{Actions, ProtoOpts},
};
use async_trait::async_trait;
use bytes::Bytes;
use std::{
collections::HashMap,
error::Error,
ffi::CString,
fmt::{self, Display, Formatter},
io::{self, Write},
str::FromStr,
};
pub trait SetErrorReply {
fn set_error_reply<I, T>(
&mut self,
rcode: &str,
xcode: Option<&str>,
message: I,
) -> Result<(), SmtpReplyError>
where
I: IntoIterator<Item = T>,
T: IntoCString;
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum SmtpReplyError {
InvalidReplyCode,
InvalidEnhancedStatusCode,
InvalidReplyText,
}
impl Display for SmtpReplyError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidReplyCode => write!(f, "invalid SMTP reply code"),
Self::InvalidEnhancedStatusCode => write!(f, "invalid enhanced status code"),
Self::InvalidReplyText => write!(f, "invalid SMTP reply text"),
}
}
}
impl Error for SmtpReplyError {}
#[async_trait]
pub trait ContextActions {
async fn add_header<'cx, 'k, 'v>(
&'cx self,
name: impl IntoCString + Send + 'k,
value: impl IntoCString + Send + 'v,
) -> Result<(), ActionError>;
async fn insert_header<'cx, 'k, 'v>(
&'cx self,
index: i32,
name: impl IntoCString + Send + 'k,
value: impl IntoCString + Send + 'v,
) -> Result<(), ActionError>;
async fn change_header<'cx, 'k, 'v>(
&'cx self,
name: impl IntoCString + Send + 'k,
index: i32,
value: Option<impl IntoCString + Send + 'v>,
) -> Result<(), ActionError>;
async fn change_sender<'cx, 'a, 'b>(
&'cx self,
mail: impl IntoCString + Send + 'a,
args: Option<impl IntoCString + Send + 'b>,
) -> Result<(), ActionError>;
async fn add_recipient<'cx, 'a>(
&'cx self,
rcpt: impl IntoCString + Send + 'a,
) -> Result<(), ActionError>;
async fn add_recipient_ext<'cx, 'a, 'b>(
&'cx self,
rcpt: impl IntoCString + Send + 'a,
args: Option<impl IntoCString + Send + 'b>,
) -> Result<(), ActionError>;
async fn delete_recipient<'cx, 'a>(
&'cx self,
rcpt: impl IntoCString + Send + 'a,
) -> Result<(), ActionError>;
async fn replace_body<'cx, 'a>(&'cx self, chunk: &'a [u8]) -> Result<(), ActionError>;
async fn progress<'cx>(&'cx self) -> Result<(), ActionError>;
async fn quarantine<'cx, 'a>(
&'cx self,
reason: impl IntoCString + Send + 'a,
) -> Result<(), ActionError>;
}
#[derive(Debug)]
pub enum ActionError {
NotAvailable,
InvalidParam,
Io(io::Error),
}
impl Display for ActionError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::NotAvailable => write!(f, "action has not been enabled"),
Self::InvalidParam => write!(f, "invalid context action parameter"),
Self::Io(error) => write!(f, "I/O error: {error}"),
}
}
}
impl Error for ActionError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Io(error) => Some(error),
_ => None,
}
}
}
impl From<io::Error> for ActionError {
fn from(error: io::Error) -> Self {
Self::Io(error)
}
}
pub struct NegotiateContext<T: Send> {
pub data: Option<T>,
pub reply: SmtpReply,
pub requested_actions: Actions,
pub requested_opts: ProtoOpts,
pub requested_macros: HashMap<MacroStage, CString>,
}
impl<T: Send> NegotiateContext<T> {
pub(crate) fn new(
data: Option<T>,
reply: SmtpReply,
requested_actions: Actions,
requested_opts: ProtoOpts,
) -> Self {
Self {
data,
reply,
requested_actions,
requested_opts,
requested_macros: HashMap::new(),
}
}
}
pub struct Context<T: Send> {
pub data: Option<T>,
pub macros: Macros,
pub reply: SmtpReply,
}
impl<T: Send> Context<T> {
pub(crate) fn new() -> Self {
Self {
data: None,
macros: Macros::new(),
reply: SmtpReply::new(),
}
}
pub(crate) fn clear_macros(&mut self) {
self.macros.clear();
}
pub(crate) fn clear_macros_after(&mut self, stage: MacroStage) {
self.macros.clear_after(stage);
}
pub(crate) fn insert_macros(&mut self, stage: MacroStage, entries: HashMap<CString, CString>) {
self.macros.insert(stage, entries);
}
pub(crate) fn restore(&mut self, data: Option<T>, reply: SmtpReply) {
self.data = data;
self.reply = reply;
}
pub(crate) fn take_reply_if<F>(&mut self, predicate: F) -> Option<Reply>
where
F: FnMut(&ReplyCode) -> bool,
{
self.reply
.take_error_reply_if(predicate)
.map(|reply| Reply::ReplyCode { reply })
}
}
pub struct EomContext<T: Send> {
pub data: Option<T>,
pub macros: Macros,
pub reply: SmtpReply,
pub actions: EomActions,
}
impl<T: Send> EomContext<T> {
pub(crate) fn new(
conn: Connection,
data: Option<T>,
macros: Macros,
reply: SmtpReply,
available_actions: Actions,
) -> Self {
Self {
data,
macros,
reply,
actions: EomActions {
conn,
available_actions,
},
}
}
}
#[derive(Debug, Eq, Hash, PartialEq)]
pub struct SmtpReply {
error_reply: Option<ErrorReply>,
}
impl SmtpReply {
pub(crate) fn new() -> Self {
Self { error_reply: None }
}
pub(crate) fn clone_internal(&self) -> Self {
Self {
error_reply: self.error_reply.clone(),
}
}
pub(crate) fn take_error_reply_if<F>(&mut self, mut predicate: F) -> Option<CString>
where
F: FnMut(&ReplyCode) -> bool,
{
if let Some(reply) = self.error_reply.as_ref() {
if predicate(&reply.rcode) {
return self.error_reply.take().map(|r| r.make_error_reply());
}
}
None
}
}
impl SetErrorReply for SmtpReply {
fn set_error_reply<I, T>(
&mut self,
rcode: &str,
xcode: Option<&str>,
message: I,
) -> Result<(), SmtpReplyError>
where
I: IntoIterator<Item = T>,
T: IntoCString,
{
let rcode = rcode
.parse()
.map_err(|_| SmtpReplyError::InvalidReplyCode)?;
let xcode = xcode
.map(str::parse)
.transpose()
.map_err(|_| SmtpReplyError::InvalidEnhancedStatusCode)?;
let mut msg_lines = Vec::new();
for (i, line) in message.into_iter().enumerate() {
if i >= 32 {
return Err(SmtpReplyError::InvalidReplyText);
}
let line = line.into_c_string();
if line.as_bytes().len() > 980 {
return Err(SmtpReplyError::InvalidReplyText);
}
if line.as_bytes().iter().any(|&b| matches!(b, b'\r' | b'\n')) {
return Err(SmtpReplyError::InvalidReplyText);
}
msg_lines.push(line);
}
self.error_reply = Some(ErrorReply {
rcode,
xcode,
message: msg_lines,
});
Ok(())
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
struct ErrorReply {
rcode: ReplyCode,
xcode: Option<EnhancedStatusCode>,
message: Vec<CString>,
}
impl ErrorReply {
pub(crate) fn make_error_reply(&self) -> CString {
fn fmt_codes(msg: &mut Vec<u8>, rcode: &ReplyCode, xcode: Option<&EnhancedStatusCode>) {
write!(msg, "{} ", rcode.as_ref()).unwrap();
if let Some(xcode) = xcode {
write!(msg, "{}", xcode.as_ref()).unwrap();
}
}
let msg = match &self.message[..] {
[] => {
let mut msg = Vec::new();
fmt_codes(&mut msg, &self.rcode, self.xcode.as_ref());
msg
}
[text] => {
let mut msg = Vec::new();
fmt_codes(&mut msg, &self.rcode, self.xcode.as_ref());
if self.xcode.is_some() {
write!(msg, " ").unwrap();
}
msg.write_all(text.as_bytes()).unwrap();
msg
}
[lines @ .., last_line] => {
let rcode = &self.rcode;
let xcode = self.xcode.as_ref().map_or(
match rcode {
ReplyCode::Transient(_) => "4.0.0",
ReplyCode::Permanent(_) => "5.0.0",
},
|c| c.as_ref(),
);
let mut msg = Vec::<u8>::new();
for line in lines {
write!(msg, "{}-{} ", rcode.as_ref(), xcode).unwrap();
msg.write_all(line.as_bytes()).unwrap();
msg.write_all(b"\r\n").unwrap();
}
write!(msg, "{} {} ", rcode.as_ref(), xcode).unwrap();
msg.write_all(last_line.as_bytes()).unwrap();
msg
}
};
CString::new(msg).expect("invalid error reply text")
}
}
pub struct EomActions {
conn: Connection,
available_actions: Actions,
}
#[async_trait]
impl ContextActions for EomActions {
async fn add_header<'cx, 'k, 'v>(
&'cx self,
name: impl IntoCString + Send + 'k,
value: impl IntoCString + Send + 'v,
) -> Result<(), ActionError> {
if !self.available_actions.contains(Actions::ADD_HEADER) {
return Err(ActionError::NotAvailable);
}
let name = name.into_c_string();
if name.as_bytes().is_empty() {
return Err(ActionError::InvalidParam);
}
let value = value.into_c_string();
self.conn.write_reply(Reply::AddHeader { name, value }).await?;
Ok(())
}
async fn insert_header<'cx, 'k, 'v>(
&'cx self,
index: i32,
name: impl IntoCString + Send + 'k,
value: impl IntoCString + Send + 'v,
) -> Result<(), ActionError> {
if !self.available_actions.contains(Actions::ADD_HEADER) {
return Err(ActionError::NotAvailable);
}
if index < 0 {
return Err(ActionError::InvalidParam);
}
let name = name.into_c_string();
if name.as_bytes().is_empty() {
return Err(ActionError::InvalidParam);
}
let value = value.into_c_string();
self.conn.write_reply(Reply::InsertHeader { index, name, value }).await?;
Ok(())
}
async fn change_header<'cx, 'k, 'v>(
&'cx self,
name: impl IntoCString + Send + 'k,
index: i32,
value: Option<impl IntoCString + Send + 'v>,
) -> Result<(), ActionError> {
if !self.available_actions.contains(Actions::CHANGE_HEADER) {
return Err(ActionError::NotAvailable);
}
if index < 0 {
return Err(ActionError::InvalidParam);
}
let name = name.into_c_string();
if name.as_bytes().is_empty() {
return Err(ActionError::InvalidParam);
}
let value = value.map_or_else(Default::default, |v| v.into_c_string());
self.conn.write_reply(Reply::ChangeHeader { name, index, value }).await?;
Ok(())
}
async fn change_sender<'cx, 'a, 'b>(
&'cx self,
mail: impl IntoCString + Send + 'a,
args: Option<impl IntoCString + Send + 'b>,
) -> Result<(), ActionError> {
if !self.available_actions.contains(Actions::CHANGE_SENDER) {
return Err(ActionError::NotAvailable);
}
let mail = mail.into_c_string();
if mail.as_bytes().is_empty() {
return Err(ActionError::InvalidParam);
}
let args = args.map(|a| a.into_c_string());
self.conn.write_reply(Reply::ChangeSender { mail, args }).await?;
Ok(())
}
async fn add_recipient<'cx, 'a>(
&'cx self,
rcpt: impl IntoCString + Send + 'a,
) -> Result<(), ActionError> {
if !self.available_actions.contains(Actions::ADD_RCPT) {
return Err(ActionError::NotAvailable);
}
let rcpt = rcpt.into_c_string();
if rcpt.as_bytes().is_empty() {
return Err(ActionError::InvalidParam);
}
self.conn.write_reply(Reply::AddRcpt { rcpt }).await?;
Ok(())
}
async fn add_recipient_ext<'cx, 'a, 'b>(
&'cx self,
rcpt: impl IntoCString + Send + 'a,
args: Option<impl IntoCString + Send + 'b>,
) -> Result<(), ActionError> {
if !self.available_actions.contains(Actions::ADD_RCPT_EXT) {
return Err(ActionError::NotAvailable);
}
let rcpt = rcpt.into_c_string();
if rcpt.as_bytes().is_empty() {
return Err(ActionError::InvalidParam);
}
let args = args.map(|a| a.into_c_string());
self.conn.write_reply(Reply::AddRcptExt { rcpt, args }).await?;
Ok(())
}
async fn delete_recipient<'cx, 'a>(
&'cx self,
rcpt: impl IntoCString + Send + 'a,
) -> Result<(), ActionError> {
if !self.available_actions.contains(Actions::DELETE_RCPT) {
return Err(ActionError::NotAvailable);
}
let rcpt = rcpt.into_c_string();
if rcpt.as_bytes().is_empty() {
return Err(ActionError::InvalidParam);
}
self.conn.write_reply(Reply::DeleteRcpt { rcpt }).await?;
Ok(())
}
async fn replace_body<'cx, 'a>(&'cx self, chunk: &'a [u8]) -> Result<(), ActionError> {
if !self.available_actions.contains(Actions::REPLACE_BODY) {
return Err(ActionError::NotAvailable);
}
const CHUNK_SIZE: usize = 65535;
if chunk.is_empty() {
let chunk = Bytes::new();
self.conn.write_reply(Reply::ReplaceBody { chunk }).await?;
} else {
for chunk in chunk.chunks(CHUNK_SIZE) {
let chunk = Bytes::copy_from_slice(chunk);
self.conn.write_reply(Reply::ReplaceBody { chunk }).await?;
}
}
Ok(())
}
async fn progress<'cx>(&'cx self) -> Result<(), ActionError> {
self.conn.write_reply(Reply::Progress).await?;
Ok(())
}
async fn quarantine<'cx, 'a>(
&'cx self,
reason: impl IntoCString + Send + 'a,
) -> Result<(), ActionError> {
if !self.available_actions.contains(Actions::QUARANTINE) {
return Err(ActionError::NotAvailable);
}
let reason = reason.into_c_string();
if reason.as_bytes().is_empty() {
return Err(ActionError::InvalidParam);
}
self.conn.write_reply(Reply::Quarantine { reason }).await?;
Ok(())
}
}
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct ParseStatusCodeError;
impl Error for ParseStatusCodeError {}
impl Display for ParseStatusCodeError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "failed to parse status code")
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum ReplyCode {
Transient(String),
Permanent(String),
}
impl AsRef<str> for ReplyCode {
fn as_ref(&self) -> &str {
match self {
Self::Transient(s) | Self::Permanent(s) => s,
}
}
}
impl FromStr for ReplyCode {
type Err = ParseStatusCodeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.as_bytes() {
[x, y, z]
if matches!(x, b'4'..=b'5')
&& matches!(y, b'0'..=b'5')
&& matches!(z, b'0'..=b'9') =>
{
Ok(match x {
b'4' => Self::Transient(s.into()),
b'5' => Self::Permanent(s.into()),
_ => unreachable!(),
})
}
_ => Err(ParseStatusCodeError),
}
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum EnhancedStatusCode {
Transient(String),
Permanent(String),
}
impl AsRef<str> for EnhancedStatusCode {
fn as_ref(&self) -> &str {
match self {
Self::Transient(s) | Self::Permanent(s) => s,
}
}
}
impl FromStr for EnhancedStatusCode {
type Err = ParseStatusCodeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
fn is_three_digits(s: &str) -> bool {
s == "0"
|| matches!(s.len(), 1..=3)
&& s.chars().all(|c| c.is_ascii_digit())
&& !s.starts_with('0')
}
let mut iter = s.splitn(3, '.');
match (iter.next(), iter.next(), iter.next()) {
(Some(class), Some(subject), Some(detail))
if matches!(class, "4" | "5")
&& is_three_digits(subject)
&& is_three_digits(detail) =>
{
Ok(match class {
"4" => Self::Transient(s.into()),
"5" => Self::Permanent(s.into()),
_ => unreachable!(),
})
}
_ => Err(ParseStatusCodeError),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use byte_strings::c_str;
#[test]
fn error_reply_ok() {
let reply = ErrorReply {
rcode: "550".parse().unwrap(),
xcode: None,
message: vec![],
};
assert_eq!(reply.make_error_reply().as_ref(), c_str!("550 "));
let reply = ErrorReply {
rcode: "550".parse().unwrap(),
xcode: Some("5.0.0".parse().unwrap()),
message: vec![],
};
assert_eq!(reply.make_error_reply().as_ref(), c_str!("550 5.0.0"));
let reply = ErrorReply {
rcode: "550".parse().unwrap(),
xcode: None,
message: vec![c_str!("failure").into()],
};
assert_eq!(reply.make_error_reply().as_ref(), c_str!("550 failure"));
let reply = ErrorReply {
rcode: "550".parse().unwrap(),
xcode: Some("5.0.0".parse().unwrap()),
message: vec![c_str!("failure").into()],
};
assert_eq!(reply.make_error_reply().as_ref(), c_str!("550 5.0.0 failure"));
}
#[test]
fn error_reply_multi_ok() {
let reply = ErrorReply {
rcode: "400".parse().unwrap(),
xcode: None,
message: vec![c_str!("complete").into(), c_str!("failure").into()],
};
assert_eq!(
reply.make_error_reply().as_ref(),
c_str!("400-4.0.0 complete\r\n400 4.0.0 failure")
);
let reply = ErrorReply {
rcode: "411".parse().unwrap(),
xcode: Some("4.1.1".parse().unwrap()),
message: vec![c_str!("complete").into(), c_str!("failure").into()],
};
assert_eq!(
reply.make_error_reply().as_ref(),
c_str!("411-4.1.1 complete\r\n411 4.1.1 failure")
);
}
}