use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use bytes::BytesMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;
use crate::codec::{decode, encode};
use crate::deliver_by::validate_deliver_by_value;
use crate::error::Error;
use crate::future_release::{
parse_rfc3339_to_utc_key, validate_hold_for_seconds, validate_hold_until_datetime,
};
use crate::types::{
DomainOrLiteral, EnhancedStatusCode, ForwardPath, Protocol, RecipientResult, ReversePath,
ServerCapabilities, SmtpResponse,
};
mod auth;
mod bdat;
mod helpers;
mod lifecycle;
mod lmtp;
mod sending;
mod session;
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::similar_names,
clippy::wildcard_in_or_patterns,
clippy::items_after_statements,
clippy::manual_let_else,
clippy::match_wild_err_arm
)]
#[path = "tests.rs"]
mod tests;
pub use daaki_message::TlsMode;
enum SmtpStream {
Plain(TcpStream),
Tls(Box<TlsStream<TcpStream>>),
}
impl SmtpStream {
async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
match self {
Self::Plain(s) => s.write_all(buf).await,
Self::Tls(s) => s.write_all(buf).await,
}
}
async fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Plain(s) => s.flush().await,
Self::Tls(s) => s.flush().await,
}
}
async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
Self::Plain(s) => s.read(buf).await,
Self::Tls(s) => s.read(buf).await,
}
}
fn is_tls(&self) -> bool {
matches!(self, Self::Tls(_))
}
}
struct SmtpInner {
stream: SmtpStream,
read_buf: BytesMut,
capabilities: ServerCapabilities,
ehlo_domain: DomainOrLiteral,
authenticated: bool,
server_shutting_down: bool,
helo_mode: bool,
}
impl SmtpInner {
async fn write_all(&mut self, buf: &[u8]) -> Result<(), Error> {
self.stream.write_all(buf).await?;
self.stream.flush().await?;
Ok(())
}
async fn read_response(&mut self) -> Result<SmtpResponse, Error> {
loop {
if let Some((resp, consumed)) = SmtpConnection::try_parse_response(&self.read_buf)? {
let _ = self.read_buf.split_to(consumed);
if resp.code == 421 {
self.server_shutting_down = true;
}
return Ok(resp);
}
if self.read_buf.len() > SmtpConnection::MAX_RESPONSE_BUFFER {
return Err(Error::Protocol(format!(
"response exceeds maximum buffer size ({} bytes) \
(RFC 5321 Section 4.5.3.1.5)",
SmtpConnection::MAX_RESPONSE_BUFFER
)));
}
let mut tmp = [0u8; 4096];
let n = self.stream.read(&mut tmp).await?;
if n == 0 {
return Err(Error::Closed);
}
self.read_buf.extend_from_slice(&tmp[..n]);
}
}
async fn rset_best_effort(&mut self) {
let mut buf = BytesMut::new();
encode::encode_rset(&mut buf);
if self.write_all(&buf).await.is_ok() {
let _ = self.read_response().await;
}
}
async fn quit_best_effort(&mut self) {
let mut buf = BytesMut::new();
encode::encode_quit(&mut buf);
if self.write_all(&buf).await.is_ok() {
let _ = self.read_response().await;
}
}
}
pub(super) fn default_ehlo_domain() -> Result<DomainOrLiteral, Error> {
DomainOrLiteral::new("[127.0.0.1]")
.map_err(|err| Error::Protocol(format!("invalid default EHLO domain: {err}")))
}
pub struct SmtpConnection {
inner: tokio::sync::Mutex<SmtpInner>,
protocol: Protocol,
}
const _: fn() = || {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SmtpConnection>();
};
impl std::fmt::Debug for SmtpConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("SmtpConnection");
s.field("protocol", &self.protocol);
if let Ok(inner) = self.inner.try_lock() {
s.field("ehlo_domain", &inner.ehlo_domain);
let transport = match &inner.stream {
SmtpStream::Plain(_) => "plain",
SmtpStream::Tls(_) => "tls",
};
s.field("transport", &transport);
s.field("capabilities", &inner.capabilities);
} else {
s.field("ehlo_domain", &"<locked>");
s.field("transport", &"<locked>");
s.field("capabilities", &"<locked>");
}
s.finish_non_exhaustive()
}
}