use super::{AssertReply, tls::build_tls_connector};
use crate::{Credentials, SmtpClient, SmtpClientBuilder};
use smtp_proto::{EXT_START_TLS, EhloResponse};
use std::hash::Hash;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::time::Duration;
use tokio::net::TcpSocket;
use tokio::{
io,
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};
use tokio_rustls::client::TlsStream;
impl<T: AsRef<str> + PartialEq + Eq + Hash> SmtpClientBuilder<T> {
pub fn new(hostname: T, port: u16) -> Result<Self, String> {
Ok(SmtpClientBuilder {
addr: format!("{}:{}", hostname.as_ref(), port),
timeout: Duration::from_secs(60 * 60),
tls_connector: build_tls_connector(false)?,
tls_hostname: hostname,
tls_implicit: true,
is_lmtp: false,
local_host: gethostname::gethostname()
.to_str()
.unwrap_or("[127.0.0.1]")
.to_string(),
credentials: None,
say_ehlo: true,
local_ip: None,
})
}
pub fn allow_invalid_certs(mut self) -> Self {
self.tls_connector = build_tls_connector(true).unwrap();
self
}
pub fn implicit_tls(mut self, tls_implicit: bool) -> Self {
self.tls_implicit = tls_implicit;
self
}
pub fn lmtp(mut self, is_lmtp: bool) -> Self {
self.is_lmtp = is_lmtp;
self
}
pub fn say_ehlo(mut self, say_ehlo: bool) -> Self {
self.say_ehlo = say_ehlo;
self
}
pub fn helo_host(mut self, host: impl Into<String>) -> Self {
self.local_host = host.into();
self
}
pub fn credentials(mut self, credentials: impl Into<Credentials<T>>) -> Self {
self.credentials = Some(credentials.into());
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn local_ip(mut self, local_ip: IpAddr) -> Self {
self.local_ip = Some(local_ip);
self
}
async fn tcp_stream(&self) -> io::Result<TcpStream> {
if let Some(local_addr) = self.local_ip {
let remote_addrs = self.addr.to_socket_addrs()?;
let mut last_err = None;
for addr in remote_addrs {
let local_addr = SocketAddr::new(local_addr, 0);
let socket = match local_addr.ip() {
IpAddr::V4(_) => TcpSocket::new_v4()?,
IpAddr::V6(_) => TcpSocket::new_v6()?,
};
socket.bind(local_addr)?;
match socket.connect(addr).await {
Ok(stream) => return Ok(stream),
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
)
}))
} else {
TcpStream::connect(&self.addr).await
}
}
pub async fn connect(&self) -> crate::Result<SmtpClient<TlsStream<TcpStream>>> {
tokio::time::timeout(self.timeout, async {
let mut client = SmtpClient {
stream: self.tcp_stream().await?,
timeout: self.timeout,
};
let mut client = if self.tls_implicit {
let mut client = client
.into_tls(&self.tls_connector, self.tls_hostname.as_ref())
.await?;
client.read().await?.assert_positive_completion()?;
client
} else {
client.read().await?.assert_positive_completion()?;
let response = if !self.is_lmtp {
client.ehlo(&self.local_host).await?
} else {
client.lhlo(&self.local_host).await?
};
if response.has_capability(EXT_START_TLS) {
client
.start_tls(&self.tls_connector, self.tls_hostname.as_ref())
.await?
} else {
return Err(crate::Error::MissingStartTls);
}
};
if self.say_ehlo {
let capabilities = client.capabilities(&self.local_host, self.is_lmtp).await?;
if let Some(credentials) = &self.credentials {
client.authenticate(&credentials, &capabilities).await?;
}
}
Ok(client)
})
.await
.map_err(|_| crate::Error::Timeout)?
}
pub async fn connect_plain(&self) -> crate::Result<SmtpClient<TcpStream>> {
let mut client = SmtpClient {
stream: tokio::time::timeout(self.timeout, async { self.tcp_stream().await })
.await
.map_err(|_| crate::Error::Timeout)??,
timeout: self.timeout,
};
client.read().await?.assert_positive_completion()?;
if self.say_ehlo {
let capabilities = client.capabilities(&self.local_host, self.is_lmtp).await?;
if let Some(credentials) = &self.credentials {
client.authenticate(&credentials, &capabilities).await?;
}
}
Ok(client)
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> SmtpClient<T> {
pub async fn capabilities(
&mut self,
local_host: &str,
is_lmtp: bool,
) -> crate::Result<EhloResponse<String>> {
if !is_lmtp {
self.ehlo(local_host).await
} else {
self.lhlo(local_host).await
}
}
}