use std::fmt;
use std::io;
use std::result;
use std::str::FromStr;
use async_stream_packed::{TlsClientUpgrader, UpgradableAsyncStream};
use futures_util::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use lettre::transport::smtp::{
authentication::{Credentials, Mechanism},
commands::*,
error::Error,
extension::{ClientId, Extension, MailBodyParameter, MailParameter, ServerInfo},
response::Response,
};
use lettre::Envelope;
#[cfg(feature = "async_native_tls")]
pub use async_stream_tls_upgrader::AsyncNativeTlsClientTlsUpgrader;
#[cfg(feature = "async_tls")]
pub use async_stream_tls_upgrader::AsyncTlsClientTlsUpgrader;
use self::codec::ClientCodec;
pub type AsyncStream<S, STU> = UpgradableAsyncStream<S, STU>;
pub struct AsyncConnection<S, STU>
where
STU: TlsClientUpgrader<S>,
{
stream: AsyncStream<S, STU>,
panic: bool,
server_info_: ServerInfo,
}
impl<S, STU> AsyncConnection<S, STU>
where
STU: TlsClientUpgrader<S>,
{
pub fn server_info(&self) -> &ServerInfo {
&self.server_info_
}
fn from_parts(stream: AsyncStream<S, STU>) -> Self {
Self {
stream,
panic: false,
server_info_: Default::default(),
}
}
pub fn new(stream: S, upgrader: STU) -> Self {
Self::from_parts(AsyncStream::new(stream, upgrader))
}
}
impl<S> AsyncConnection<S, ()>
where
S: Send + 'static,
{
pub fn with_tls_stream(stream: S) -> Self {
Self::from_parts(AsyncStream::with_upgraded_stream(stream))
}
}
#[cfg(feature = "async_native_tls")]
impl<S> AsyncConnection<S, AsyncNativeTlsClientTlsUpgrader>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
pub fn with_async_native_tls_upgrader(
stream: S,
upgrader: AsyncNativeTlsClientTlsUpgrader,
) -> Self {
Self::new(stream, upgrader)
}
}
#[cfg(feature = "async_tls")]
impl<S> AsyncConnection<S, AsyncTlsClientTlsUpgrader>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
pub fn with_async_tls_upgrader(stream: S, upgrader: AsyncTlsClientTlsUpgrader) -> Self {
Self::new(stream, upgrader)
}
}
impl<S, STU> AsyncConnection<S, STU>
where
STU: TlsClientUpgrader<S>,
{
pub async fn stream_tls_upgrade(&mut self) -> result::Result<(), Error> {
self.stream.upgrade().await.map_err(|err| err.into())
}
}
#[macro_export]
macro_rules! try_smtp (
($err: expr, $client: ident) => ({
match $err {
Ok(val) => val,
Err(err) => {
$client.abort();
return Err(From::from(err))
},
}
})
);
impl<S, STU> AsyncConnection<S, STU>
where
STU: TlsClientUpgrader<S> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
STU::Output: AsyncRead + AsyncWrite + Unpin,
{
pub async fn handshake(
&mut self,
is_smtps: bool,
hello_name: ClientId,
) -> result::Result<(), Error> {
if is_smtps && !self.stream.is_upgraded() {
self.stream_tls_upgrade().await?;
}
let _ = self.read_response().await?;
self.ehlo(&hello_name).await?;
if self.can_starttls() {
self.starttls().await?;
self.stream_tls_upgrade().await?;
self.ehlo(&hello_name).await?;
}
Ok(())
}
pub async fn send(
&mut self,
envelope: &Envelope,
email: &[u8],
) -> result::Result<Response, Error> {
let mut mail_options = vec![];
if self.server_info().supports_feature(Extension::EightBitMime) {
mail_options.push(MailParameter::Body(MailBodyParameter::EightBitMime));
}
try_smtp!(
self.command(Mail::new(envelope.from().cloned(), mail_options,))
.await,
self
);
for to_address in envelope.to() {
try_smtp!(
self.command(Rcpt::new(to_address.clone(), vec![])).await,
self
);
}
try_smtp!(self.command(Data).await, self);
let result = try_smtp!(self.message(email).await, self);
Ok(result)
}
pub fn has_broken(&self) -> bool {
self.panic
}
pub fn can_starttls(&self) -> bool {
!self.is_encrypted() && self.server_info().supports_feature(Extension::StartTls)
}
pub async fn starttls(&mut self) -> result::Result<(), Error> {
if self.server_info().supports_feature(Extension::StartTls) {
try_smtp!(self.command(Starttls).await, self);
Ok(())
} else {
Err(Error::Client("STARTTLS is not supported on this server"))
}
}
pub async fn ehlo(&mut self, hello_name: &ClientId) -> result::Result<(), Error> {
let ehlo_response = try_smtp!(
self.command(Ehlo::new(ClientId::new(hello_name.to_string())))
.await,
self
);
self.server_info_ = try_smtp!(ServerInfo::from_response(&ehlo_response), self);
Ok(())
}
pub async fn quit(&mut self) -> result::Result<Response, Error> {
Ok(try_smtp!(self.command(Quit).await, self))
}
pub fn abort(&mut self) {
if !self.panic {
self.panic = true;
let _ = self.command(Quit);
}
}
pub fn is_encrypted(&self) -> bool {
self.stream.is_upgraded()
}
pub async fn test_connected(&mut self) -> bool {
self.command(Noop).await.is_ok()
}
pub async fn auth(
&mut self,
mechanisms: &[Mechanism],
credentials: &Credentials,
) -> result::Result<Response, Error> {
let mechanism = match self.server_info_.get_auth_mechanism(mechanisms) {
Some(m) => m,
None => {
return Err(Error::Client(
"No compatible authentication mechanism was found",
))
}
};
let mut challenges = 10;
let mut response = self
.command(Auth::new(mechanism, credentials.clone(), None)?)
.await?;
while challenges > 0 && response.has_code(334) {
challenges -= 1;
response = try_smtp!(
self.command(Auth::new_from_response(
mechanism,
credentials.clone(),
&response,
)?)
.await,
self
);
}
if challenges == 0 {
Err(Error::ResponseParsing("Unexpected number of challenges"))
} else {
Ok(response)
}
}
pub async fn message(&mut self, message: &[u8]) -> result::Result<Response, Error> {
let mut out_buf: Vec<u8> = vec![];
let mut codec = ClientCodec::new();
codec.encode(message, &mut out_buf)?;
self.write(out_buf.as_slice()).await?;
self.write(b"\r\n.\r\n").await?;
self.read_response().await
}
pub async fn command<C: fmt::Display>(
&mut self,
command: C,
) -> result::Result<Response, Error> {
self.write(command.to_string().as_bytes()).await?;
self.read_response().await
}
async fn write(&mut self, string: &[u8]) -> result::Result<(), Error> {
self.stream.write_all(string).await?;
self.stream.flush().await?;
Ok(())
}
pub async fn read_response(&mut self) -> result::Result<Response, Error> {
let mut buffer = String::with_capacity(100);
let mut buf_reader = BufReader::new(&mut self.stream);
while buf_reader.read_line(&mut buffer).await? > 0 {
match Response::from_str(&buffer) {
Ok(response) => {
if response.is_positive() {
return Ok(response);
}
return Err(response.into());
}
Err(Error::Parsing(nom::error::ErrorKind::Complete)) => { }
Err(err) => return Err(err),
}
}
Err(io::Error::new(io::ErrorKind::Other, "incomplete").into())
}
}
mod codec {
use std::io::{self, Write};
#[derive(Default, Clone, Copy, Debug)]
pub struct ClientCodec {
escape_count: u8,
}
impl ClientCodec {
pub fn new() -> Self {
ClientCodec::default()
}
pub fn encode(&mut self, frame: &[u8], buf: &mut Vec<u8>) -> io::Result<()> {
match frame.len() {
0 => {
match self.escape_count {
0 => buf.write_all(b"\r\n.\r\n")?,
1 => buf.write_all(b"\n.\r\n")?,
2 => buf.write_all(b".\r\n")?,
_ => unreachable!(),
}
self.escape_count = 0;
Ok(())
}
_ => {
let mut start = 0;
for (idx, byte) in frame.iter().enumerate() {
match self.escape_count {
0 => self.escape_count = if *byte == b'\r' { 1 } else { 0 },
1 => self.escape_count = if *byte == b'\n' { 2 } else { 0 },
2 => self.escape_count = if *byte == b'.' { 3 } else { 0 },
_ => unreachable!(),
}
if self.escape_count == 3 {
self.escape_count = 0;
buf.write_all(&frame[start..idx])?;
buf.write_all(b".")?;
start = idx;
}
}
buf.write_all(&frame[start..])?;
Ok(())
}
}
}
}
}