use crate::smtp::authentication::Authentication;
use crate::smtp::commands::*;
use crate::smtp::extension::{ClientId, ServerInfo};
use crate::smtp::response::parse_response;
use crate::smtp::response::Response;
use async_std::io::prelude::{ReadExt, WriteExt};
use bytes::{Buf, BufMut, BytesMut};
use samotop_core::common::*;
use std::fmt::Display;
use std::pin::Pin;
use std::time::Duration;
use crate::smtp::error::{Error, SmtpResult};
use std::result::Result;
#[derive(Debug)]
pub struct SmtpProto<'s, S> {
stream: Pin<&'s mut S>,
buffer: BytesMut,
line_limit: usize,
}
impl<'s, S> SmtpProto<'s, S> {
pub fn new(stream: Pin<&'s mut S>) -> Self {
SmtpProto {
stream,
buffer: BytesMut::new(),
line_limit: 8000,
}
}
pub fn buffer(&self) -> &[u8] {
self.buffer.chunk()
}
pub fn stream_mut(&mut self) -> Pin<&mut S> {
self.stream.as_mut()
}
pub fn stream(&self) -> Pin<&S> {
self.stream.as_ref()
}
}
impl<'s, S: io::Read + io::Write> SmtpProto<'s, S> {
pub async fn read_banner(&mut self, timeout: Duration) -> SmtpResult {
let banner_response = self.read_response(timeout).await?;
banner_response.is([220].as_ref())
}
pub async fn read_data_sent_response(&mut self, timeout: Duration) -> SmtpResult {
let data_response = self.read_response(timeout).await?;
data_response.is([250].as_ref())
}
pub async fn execute_ehlo_or_helo(
&mut self,
me: ClientId,
timeout: Duration,
) -> Result<(Response, ServerInfo), Error> {
match self.execute_ehlo(me.clone(), timeout).await {
Err(Error::Permanent(_resp)) => self.execute_helo(me, timeout).await,
otherwise => otherwise,
}
}
pub async fn execute_ehlo(
&mut self,
me: ClientId,
timeout: Duration,
) -> Result<(Response, ServerInfo), Error> {
let ehlo_response = self
.execute_command(HeloCommand::ehlo(me), [250], timeout)
.await?;
let server_info = ServerInfo::from_response(&ehlo_response)?;
debug!("ehlo server info {}", server_info);
Ok((ehlo_response, server_info))
}
pub async fn execute_lhlo(
&mut self,
me: ClientId,
timeout: Duration,
) -> Result<(Response, ServerInfo), Error> {
let ehlo_response = self
.execute_command(HeloCommand::lhlo(me), [250], timeout)
.await?;
let server_info = ServerInfo::from_response(&ehlo_response)?;
debug!("lhlo server info {}", server_info);
Ok((ehlo_response, server_info))
}
pub async fn execute_helo(
&mut self,
me: ClientId,
timeout: Duration,
) -> Result<(Response, ServerInfo), Error> {
let ehlo_response = self
.execute_command(HeloCommand::helo(me), [250], timeout)
.await?;
let server_info = ServerInfo::from_response(&ehlo_response)?;
debug!("helo server info {}", server_info);
Ok((ehlo_response, server_info))
}
pub async fn execute_starttls(&mut self, timeout: Duration) -> SmtpResult {
let response = self.execute_command(StarttlsCommand, [220], timeout).await;
response
}
pub async fn execute_rset(&mut self, timeout: Duration) -> SmtpResult {
let response = self.execute_command(RsetCommand, [250], timeout).await;
response
}
pub async fn execute_quit(&mut self, timeout: Duration) -> SmtpResult {
let response = self.execute_command(QuitCommand, [221], timeout).await;
response
}
pub async fn authenticate<A: Authentication>(
&mut self,
mut authentication: A,
timeout: Duration,
) -> SmtpResult {
let mut challenges = 10u8;
let mut response = self
.execute_command(AuthCommand::new(&mut authentication)?, [334, 2], timeout)
.await?;
while challenges > 0 && response.has_code(334) {
challenges -= 1;
response = self
.execute_command(
AuthResponse::new(&mut authentication, &response)?,
[334, 2],
timeout,
)
.await?;
}
if challenges == 0 {
Err(Error::ResponseParsing("Unexpected number of challenges"))
} else {
Ok(response)
}
}
pub async fn execute_command<C: Display, E: AsRef<[u16]>>(
&mut self,
command: C,
expected: E,
timeout: Duration,
) -> SmtpResult {
let command = command.to_string();
debug!("C: {}", escape_crlf(command.as_str()));
let buff = command.as_bytes();
let written = self.write_bytes(buff, timeout).await?;
debug_assert!(written == buff.len(), "Make sure we write all the data");
self.stream.flush().await?;
let response = self.read_response(timeout).await?;
response.is(expected)
}
async fn write_bytes(&mut self, buf: &[u8], timeout: Duration) -> Result<usize, Error> {
with_timeout(timeout, self.stream.write(buf)).await
}
async fn read_response(&mut self, timeout: Duration) -> SmtpResult {
with_timeout(timeout, async move {
let mut enough = self.buffer.remaining() != 0;
loop {
self.buffer.reserve(1024);
let buf = self.buffer.chunk_mut();
if !enough {
#[allow(unsafe_code)]
#[allow(clippy::transmute_ptr_to_ptr)]
let buf = unsafe { std::mem::transmute(buf) };
let read = self.stream.read(buf).await?;
if read == 0 {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("incomplete after {} bytes", self.buffer().len()),
)
.into());
}
#[allow(unsafe_code)]
unsafe {
self.buffer.advance_mut(read)
};
}
let response = std::str::from_utf8(self.buffer.chunk())?;
debug!("S: {}", escape_crlf(response));
break match parse_response(response) {
Ok((remaining, response)) => {
let consumed = self.buffer.remaining() - remaining.len();
self.buffer.advance(consumed);
response.is([2, 3].as_ref())
}
Err(nom::Err::Incomplete(_)) => {
if self.buffer.remaining() >= self.line_limit {
Err(Error::ResponseParsing("Line limit reached"))
} else {
enough = false;
continue;
}
}
Err(nom::Err::Failure(e)) => Err(Error::Parsing(e.code)),
Err(nom::Err::Error(e)) => Err(Error::Parsing(e.code)),
};
}
})
.await
}
}
async fn with_timeout<T, F, E, EOut>(timeout: Duration, f: F) -> std::result::Result<T, EOut>
where
F: Future<Output = std::result::Result<T, E>>,
EOut: From<async_std::future::TimeoutError>,
EOut: From<E>,
{
let res = async_std::future::timeout(timeout, f).await??;
Ok(res)
}
fn escape_crlf(string: &str) -> String {
string.replace("\r\n", "<CRLF>")
}