use crate::{
macros::MacroStage,
message::{Byte, Message, TryFromByteError, Version},
proto_util::{Actions, ProtoOpts, SocketInfo},
session::State,
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::{
error::Error,
ffi::CString,
fmt::{self, Display, Formatter},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
str::FromStr,
};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum CommandKind {
Abort,
BodyChunk,
ConnInfo,
DefMacros,
BodyEnd,
Helo,
QuitNc,
Header,
Mail,
Eoh,
OptNeg,
Quit,
Rcpt,
Data,
Unknown,
}
impl CommandKind {
pub(crate) fn as_state(&self) -> Option<State> {
match self {
Self::Abort => Some(State::Abort),
Self::BodyChunk => Some(State::Body),
Self::ConnInfo => Some(State::Conn),
Self::DefMacros => None,
Self::BodyEnd => Some(State::Eom),
Self::Helo => Some(State::Helo),
Self::QuitNc => Some(State::QuitNc),
Self::Header => Some(State::Header),
Self::Mail => Some(State::Mail),
Self::Eoh => Some(State::Eoh),
Self::OptNeg => Some(State::Opts),
Self::Quit => Some(State::Quit),
Self::Rcpt => Some(State::Rcpt),
Self::Data => Some(State::Data),
Self::Unknown => Some(State::Unknown),
}
}
}
impl From<CommandKind> for u8 {
fn from(kind: CommandKind) -> Self {
match kind {
CommandKind::Abort => b'A',
CommandKind::BodyChunk => b'B',
CommandKind::ConnInfo => b'C',
CommandKind::DefMacros => b'D',
CommandKind::BodyEnd => b'E',
CommandKind::Helo => b'H',
CommandKind::QuitNc => b'K',
CommandKind::Header => b'L',
CommandKind::Mail => b'M',
CommandKind::Eoh => b'N',
CommandKind::OptNeg => b'O',
CommandKind::Quit => b'Q',
CommandKind::Rcpt => b'R',
CommandKind::Data => b'T',
CommandKind::Unknown => b'U',
}
}
}
impl TryFrom<u8> for CommandKind {
type Error = TryFromByteError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
b'A' => Ok(Self::Abort),
b'B' => Ok(Self::BodyChunk),
b'C' => Ok(Self::ConnInfo),
b'D' => Ok(Self::DefMacros),
b'E' => Ok(Self::BodyEnd),
b'H' => Ok(Self::Helo),
b'K' => Ok(Self::QuitNc),
b'L' => Ok(Self::Header),
b'M' => Ok(Self::Mail),
b'N' => Ok(Self::Eoh),
b'O' => Ok(Self::OptNeg),
b'Q' => Ok(Self::Quit),
b'R' => Ok(Self::Rcpt),
b'T' => Ok(Self::Data),
b'U' => Ok(Self::Unknown),
value => Err(TryFromByteError(value)),
}
}
}
pub(crate) struct CommandMessage {
pub kind: CommandKind,
pub buffer: Bytes,
}
impl TryFrom<Message> for CommandMessage {
type Error = TryFromByteError;
fn try_from(msg: Message) -> Result<Self, Self::Error> {
let kind = msg.kind.try_into()?;
Ok(Self {
kind,
buffer: msg.buffer,
})
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum ParseCommandError {
UnknownCommand(u8),
UnknownFamily(u8),
InvalidSocketAddr,
UnknownStage(u8),
NoOptNegPayload,
EmptyCString,
NotNulTerminated,
NoU8Found,
NoU16Found,
NoCStringFound,
}
impl Display for ParseCommandError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match *self {
Self::UnknownCommand(byte) => write!(f, "unknown command: {:?}", Byte(byte)),
Self::UnknownFamily(byte) => write!(f, "unknown protocol family: {:?}", Byte(byte)),
Self::InvalidSocketAddr => write!(f, "invalid socket address"),
Self::UnknownStage(byte) => write!(f, "unknown macro stage: {:?}", Byte(byte)),
Self::NoOptNegPayload => write!(f, "no option negotiation payload found"),
Self::EmptyCString => write!(f, "empty string"),
Self::NotNulTerminated => write!(f, "not nul terminated"),
Self::NoU8Found => write!(f, "no u8 found"),
Self::NoU16Found => write!(f, "no u16 found"),
Self::NoCStringFound => write!(f, "no C string found"),
}
}
}
impl Error for ParseCommandError {}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum Command {
Abort,
BodyChunk(Bytes),
ConnInfo(ConnInfoPayload),
DefMacros(MacroPayload),
BodyEnd(Bytes),
Helo(HeloPayload),
QuitNc,
Header(HeaderPayload),
Mail(EnvAddrPayload),
Eoh,
OptNeg(OptNegPayload),
Quit,
Rcpt(EnvAddrPayload),
Data,
Unknown(UnknownPayload),
}
impl Command {
pub fn parse_command(msg: Message) -> Result<Self, ParseCommandError> {
let msg = CommandMessage::try_from(msg)
.map_err(|e| ParseCommandError::UnknownCommand(e.byte()))?;
Ok(match msg.kind {
CommandKind::Abort => Self::Abort,
CommandKind::BodyChunk => Self::BodyChunk(msg.buffer),
CommandKind::ConnInfo => Self::ConnInfo(ConnInfoPayload::parse_buffer(msg.buffer)?),
CommandKind::DefMacros => Self::DefMacros(MacroPayload::parse_buffer(msg.buffer)?),
CommandKind::BodyEnd => Self::BodyEnd(msg.buffer),
CommandKind::Helo => Self::Helo(HeloPayload::parse_buffer(msg.buffer)?),
CommandKind::QuitNc => Self::QuitNc,
CommandKind::Header => Self::Header(HeaderPayload::parse_buffer(msg.buffer)?),
CommandKind::Mail => Self::Mail(EnvAddrPayload::parse_buffer(msg.buffer)?),
CommandKind::Eoh => Self::Eoh,
CommandKind::OptNeg => Self::OptNeg(OptNegPayload::parse_buffer(msg.buffer)?),
CommandKind::Quit => Self::Quit,
CommandKind::Rcpt => Self::Rcpt(EnvAddrPayload::parse_buffer(msg.buffer)?),
CommandKind::Data => Self::Data,
CommandKind::Unknown => Self::Unknown(UnknownPayload::parse_buffer(msg.buffer)?),
})
}
pub fn into_message(self) -> Message {
match self {
Self::Abort => Message::new(CommandKind::Abort, Bytes::new()),
Self::BodyChunk(chunk) => Message::new(CommandKind::BodyChunk, chunk),
Self::ConnInfo(ConnInfoPayload { hostname, socket_info }) => {
let mut buf = BytesMut::with_capacity(64);
buf.put(hostname.to_bytes_with_nul());
match socket_info {
SocketInfo::Unknown => buf.put_u8(b'U'),
SocketInfo::Inet(addr) => {
buf.put_u8(match addr {
SocketAddr::V4(_) => b'4',
SocketAddr::V6(_) => b'6',
});
buf.put_u16(addr.port());
let ip = CString::new(addr.ip().to_string()).unwrap();
buf.put(ip.to_bytes_with_nul());
}
SocketInfo::Unix(path) => {
buf.put_u8(b'L');
buf.put_u16(0);
buf.put(path.to_bytes_with_nul());
}
}
Message::new(CommandKind::ConnInfo, buf)
}
Self::DefMacros(MacroPayload { stage, macros }) => {
let mut buf = BytesMut::new();
buf.put_u8(stage.into());
for m in macros {
buf.put(m.to_bytes_with_nul());
}
Message::new(CommandKind::DefMacros, buf)
}
Self::BodyEnd(chunk) => Message::new(CommandKind::BodyEnd, chunk),
Self::Helo(HeloPayload { hostname }) => {
let hostname = hostname.to_bytes_with_nul();
Message::new(CommandKind::Helo, Bytes::copy_from_slice(hostname))
}
Self::QuitNc => Message::new(CommandKind::QuitNc, Bytes::new()),
Self::Header(HeaderPayload { name, value }) => {
let name = name.to_bytes_with_nul();
let value = value.to_bytes_with_nul();
let mut buf = BytesMut::with_capacity(name.len() + value.len());
buf.put(name);
buf.put(value);
Message::new(CommandKind::Header, buf)
}
Self::Mail(EnvAddrPayload { args }) => {
let mut buf = BytesMut::new();
for arg in args {
buf.put(arg.to_bytes_with_nul());
}
Message::new(CommandKind::Mail, buf)
}
Self::Eoh => Message::new(CommandKind::Eoh, Bytes::new()),
Self::OptNeg(OptNegPayload { version, actions, opts }) => {
let mut buf = BytesMut::with_capacity(12);
buf.put_u32(version);
buf.put_u32(actions.bits());
buf.put_u32(opts.bits());
Message::new(CommandKind::OptNeg, buf)
}
Self::Quit => Message::new(CommandKind::Quit, Bytes::new()),
Self::Rcpt(EnvAddrPayload { args }) => {
let mut buf = BytesMut::new();
for arg in args {
buf.put(arg.to_bytes_with_nul());
}
Message::new(CommandKind::Rcpt, buf)
}
Self::Data => Message::new(CommandKind::Data, Bytes::new()),
Self::Unknown(UnknownPayload { arg }) => {
let arg = arg.to_bytes_with_nul();
Message::new(CommandKind::Unknown, Bytes::copy_from_slice(arg))
}
}
}
}
enum Family {
Unknown,
Ipv4,
Ipv6,
Unix,
}
impl TryFrom<u8> for Family {
type Error = TryFromByteError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
b'U' => Ok(Self::Unknown),
b'4' => Ok(Self::Ipv4),
b'6' => Ok(Self::Ipv6),
b'L' => Ok(Self::Unix),
value => Err(TryFromByteError(value)),
}
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct ConnInfoPayload {
pub hostname: CString,
pub socket_info: SocketInfo,
}
impl ConnInfoPayload {
pub fn parse_buffer(mut buf: Bytes) -> Result<Self, ParseCommandError> {
let hostname = get_c_string(&mut buf)?;
let family = get_u8(&mut buf)?;
let family = Family::try_from(family)
.map_err(|e| ParseCommandError::UnknownFamily(e.byte()))?;
let socket_info = match family {
Family::Unknown => SocketInfo::Unknown,
Family::Ipv4 => {
let addr = parse_socket_addr::<Ipv4Addr>(buf)?;
SocketInfo::Inet(addr)
}
Family::Ipv6 => {
let addr = parse_socket_addr::<Ipv6Addr>(buf)?;
SocketInfo::Inet(addr)
}
Family::Unix => {
let _unused = get_u16(&mut buf)?;
ensure_nul_terminated(&buf)?;
let path = get_c_string(&mut buf)?;
SocketInfo::Unix(path)
}
};
Ok(Self {
hostname,
socket_info,
})
}
}
fn parse_socket_addr<T>(mut buf: Bytes) -> Result<SocketAddr, ParseCommandError>
where
T: FromStr + Into<IpAddr>,
{
let port = get_u16(&mut buf)?;
ensure_nul_terminated(&buf)?;
let addr = get_c_string(&mut buf)?;
let addr = addr
.into_string()
.map_err(|_| ParseCommandError::InvalidSocketAddr)?
.parse::<T>()
.map_err(|_| ParseCommandError::InvalidSocketAddr)?;
Ok(SocketAddr::from((addr, port)))
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct MacroPayload {
pub stage: MacroStage,
pub macros: Vec<CString>, }
impl MacroPayload {
pub fn parse_buffer(mut buf: Bytes) -> Result<Self, ParseCommandError> {
let stage = get_u8(&mut buf)?;
let stage = MacroStage::try_from(stage)
.map_err(|e| ParseCommandError::UnknownStage(e.byte()))?;
let mut macros = vec![get_c_string(&mut buf)?];
while let Ok(s) = get_c_string(&mut buf) {
macros.push(s);
}
Ok(Self { stage, macros })
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct HeloPayload {
pub hostname: CString,
}
impl HeloPayload {
pub fn parse_buffer(mut buf: Bytes) -> Result<Self, ParseCommandError> {
ensure_nul_terminated(&buf)?;
let hostname = get_c_string(&mut buf)?;
Ok(Self { hostname })
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct HeaderPayload {
pub name: CString, pub value: CString,
}
impl HeaderPayload {
pub fn parse_buffer(mut buf: Bytes) -> Result<Self, ParseCommandError> {
ensure_nul_terminated(&buf)?;
let name = get_c_string(&mut buf)?;
if name.as_bytes().is_empty() {
return Err(ParseCommandError::EmptyCString);
}
let value = get_c_string(&mut buf)?;
Ok(Self { name, value })
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct EnvAddrPayload {
pub args: Vec<CString>, }
impl EnvAddrPayload {
pub fn parse_buffer(mut buf: Bytes) -> Result<Self, ParseCommandError> {
let mut args = vec![get_c_string(&mut buf)?];
while let Ok(s) = get_c_string(&mut buf) {
args.push(s);
}
Ok(Self { args })
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct OptNegPayload {
pub version: Version,
pub actions: Actions,
pub opts: ProtoOpts,
}
impl OptNegPayload {
pub fn parse_buffer(mut buf: Bytes) -> Result<Self, ParseCommandError> {
if buf.remaining() < 12 {
return Err(ParseCommandError::NoOptNegPayload);
}
let version = buf.get_u32();
let actions = Actions::from_bits_truncate(buf.get_u32());
let opts = ProtoOpts::from_bits_truncate(buf.get_u32());
Ok(Self {
version,
actions,
opts,
})
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct UnknownPayload {
pub arg: CString,
}
impl UnknownPayload {
pub fn parse_buffer(mut buf: Bytes) -> Result<Self, ParseCommandError> {
let arg = get_c_string(&mut buf)?;
Ok(Self { arg })
}
}
fn ensure_nul_terminated(bytes: &[u8]) -> Result<(), ParseCommandError> {
if !bytes.ends_with(&[0]) {
return Err(ParseCommandError::NotNulTerminated);
}
Ok(())
}
fn get_u8(buf: &mut Bytes) -> Result<u8, ParseCommandError> {
if !buf.has_remaining() {
return Err(ParseCommandError::NoU8Found);
}
Ok(buf.get_u8())
}
fn get_u16(buf: &mut Bytes) -> Result<u16, ParseCommandError> {
if buf.remaining() < 2 {
return Err(ParseCommandError::NoU16Found);
}
Ok(buf.get_u16())
}
fn get_c_string(buf: &mut Bytes) -> Result<CString, ParseCommandError> {
super::get_c_string(buf).map_err(|_| ParseCommandError::NoCStringFound)
}
#[cfg(test)]
mod tests {
use super::*;
use byte_strings::c_str;
#[test]
fn parse_command_ok() {
let msg = Message::new(b'L', Bytes::from_static(b"name\0value\0"));
assert_eq!(
Command::parse_command(msg),
Ok(Command::Header(HeaderPayload {
name: c_str!("name").into(),
value: c_str!("value").into(),
}))
);
}
#[test]
fn header_payload() {
assert_eq!(
HeaderPayload::parse_buffer(Bytes::from_static(b"name\0value\0")),
Ok(HeaderPayload {
name: c_str!("name").into(),
value: c_str!("value").into(),
})
);
assert!(HeaderPayload::parse_buffer(Bytes::new()).is_err());
assert!(HeaderPayload::parse_buffer(Bytes::from_static(b"name")).is_err());
}
#[test]
fn helo_payload() {
assert_eq!(
HeloPayload::parse_buffer(Bytes::from_static(b"hello\0")),
Ok(HeloPayload {
hostname: c_str!("hello").into()
})
);
assert!(HeloPayload::parse_buffer(Bytes::new()).is_err());
assert!(HeloPayload::parse_buffer(Bytes::from_static(b"hello")).is_err());
assert!(HeloPayload::parse_buffer(Bytes::from_static(b"hello\0excess")).is_err());
assert_eq!(
HeloPayload::parse_buffer(Bytes::from_static(b"hello\0excess\0")),
Ok(HeloPayload {
hostname: c_str!("hello").into()
})
);
}
}