use core::{any::Any, fmt};
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
use alloc::string::ToString;
use alloc::{borrow::Cow, boxed::Box, string::String, vec::Vec};
use std::io::{self, Read, Write};
use secrecy::SecretString;
use thiserror::Error;
#[cfg(feature = "scram")]
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
use pimalaya_stream::sasl::SaslScramSha256;
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
use pimalaya_stream::{
sasl::{Sasl, SaslAnonymous, SaslLogin, SaslOauthbearer, SaslPlain, SaslXoauth2},
std::stream::StreamStd,
tls::Tls,
};
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
use url::Url;
#[cfg(feature = "scram")]
use crate::rfc7677::auth_scram_sha_256::*;
use crate::{
coroutine::*,
message::*,
rfc3207::starttls::*,
rfc5321::{
data::*,
ehlo::*,
greeting::*,
helo::*,
mail::*,
noop::*,
quit::*,
rcpt::*,
rset::*,
types::{
domain::Domain, ehlo_domain::EhloDomain, forward_path::ForwardPath, greeting::Greeting,
parameter::Parameter, reverse_path::ReversePath,
},
},
rfc7628::auth_oauthbearer::*,
sasl::{auth_anonymous::*, auth_login::*, auth_plain::*, auth_xoauth2::*},
};
#[derive(Debug, Error)]
pub enum SmtpClientStdError {
#[error(transparent)]
Greeting(#[from] SmtpGreetingGetError),
#[error(transparent)]
Ehlo(#[from] SmtpEhloError),
#[error(transparent)]
Helo(#[from] SmtpHeloError),
#[error(transparent)]
StartTls(#[from] SmtpStartTlsError),
#[error(transparent)]
AuthAnonymous(#[from] SmtpAuthAnonymousError),
#[error(transparent)]
AuthLogin(#[from] SmtpAuthLoginError),
#[error(transparent)]
AuthPlain(#[from] SmtpAuthPlainError),
#[error(transparent)]
AuthOAuthBearer(#[from] SmtpAuthOauthbearerError),
#[error(transparent)]
AuthXOAuth2(#[from] SmtpAuthXoauth2Error),
#[cfg(feature = "scram")]
#[error(transparent)]
AuthScramSha256(#[from] SmtpAuthScramSha256Error),
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
#[cfg(not(feature = "scram"))]
#[error("SCRAM-SHA-256 SASL mechanism requires the `scram` cargo feature")]
ScramSha256NotEnabled,
#[error(transparent)]
Mail(#[from] SmtpMailError),
#[error(transparent)]
Rcpt(#[from] SmtpRcptError),
#[error(transparent)]
Data(#[from] SmtpDataError),
#[error(transparent)]
Noop(#[from] SmtpNoopError),
#[error(transparent)]
Rset(#[from] SmtpRsetError),
#[error(transparent)]
Quit(#[from] SmtpQuitError),
#[error(transparent)]
MessageSend(#[from] SmtpMessageSendError),
#[error(transparent)]
Io(#[from] io::Error),
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
#[error(transparent)]
Tls(#[from] anyhow::Error),
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
#[error("SMTP URL `{0}` has no host")]
UrlMissingHost(String),
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
#[error("SMTP URL `{0}` has unsupported scheme `{1}` (expected `smtp` or `smtps`)")]
UrlUnsupportedScheme(String, String),
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
#[error("STARTTLS requested on an `smtps://` URL: TLS is already active")]
StartTlsOverTls,
}
const READ_BUFFER_SIZE: usize = 16 * 1024;
pub fn default_alpn() -> Vec<String> {
vec![String::from("smtp")]
}
pub struct SmtpClientStd {
pub stream: Box<dyn SmtpStream>,
}
impl SmtpClientStd {
pub fn new<S: Read + Write + Send + 'static>(stream: S) -> Self {
Self {
stream: Box::new(stream),
}
}
pub fn set_stream<S: Read + Write + Send + 'static>(&mut self, stream: S) {
self.stream = Box::new(stream);
}
pub fn run<C, T, E>(&mut self, mut coroutine: C) -> Result<T, SmtpClientStdError>
where
C: SmtpCoroutine<Yield = SmtpYield, Return = Result<T, E>>,
SmtpClientStdError: From<E>,
{
let mut buf = [0u8; READ_BUFFER_SIZE];
let mut arg: Option<&[u8]> = None;
loop {
match coroutine.resume(arg.take()) {
SmtpCoroutineState::Complete(Ok(out)) => return Ok(out),
SmtpCoroutineState::Complete(Err(err)) => return Err(err.into()),
SmtpCoroutineState::Yielded(SmtpYield::WantsRead) => {
let n = self.stream.read(&mut buf)?;
arg = Some(&buf[..n]);
}
SmtpCoroutineState::Yielded(SmtpYield::WantsWrite(bytes)) => {
self.stream.write_all(&bytes)?;
arg = None;
}
}
}
}
pub fn greeting(&mut self) -> Result<Greeting<'static>, SmtpClientStdError> {
self.run(SmtpGreetingGet::new())
}
pub fn ehlo(
&mut self,
domain: EhloDomain<'_>,
) -> Result<Vec<Cow<'static, str>>, SmtpClientStdError> {
self.run(SmtpEhlo::new(domain))
}
pub fn helo(&mut self, domain: Domain<'_>) -> Result<(), SmtpClientStdError> {
self.run(SmtpHelo::new(domain))
}
pub fn starttls(&mut self) -> Result<Vec<u8>, SmtpClientStdError> {
self.run(SmtpStartTls::new())
}
pub fn quit(&mut self) -> Result<(), SmtpClientStdError> {
self.run(SmtpQuit::new())
}
pub fn auth_anonymous(
&mut self,
trace: Option<&str>,
domain: EhloDomain<'_>,
) -> Result<(), SmtpClientStdError> {
self.run(SmtpAuthAnonymous::new(
trace,
domain,
SmtpAuthAnonymousOptions::default(),
))
}
pub fn auth_login(
&mut self,
login: &str,
password: &SecretString,
domain: EhloDomain<'_>,
) -> Result<(), SmtpClientStdError> {
self.run(SmtpAuthLogin::new(
login,
password,
domain,
SmtpAuthLoginOptions::default(),
))
}
pub fn auth_plain(
&mut self,
login: &str,
password: &SecretString,
domain: EhloDomain<'_>,
) -> Result<(), SmtpClientStdError> {
self.run(SmtpAuthPlain::new(
login,
password,
domain,
SmtpAuthPlainOptions::default(),
))
}
pub fn auth_oauthbearer(
&mut self,
token: &SecretString,
username: Option<&str>,
domain: EhloDomain<'_>,
) -> Result<(), SmtpClientStdError> {
self.run(SmtpAuthOauthbearer::new(
token,
username,
domain,
SmtpAuthOauthbearerOptions::default(),
))
}
pub fn auth_xoauth2(
&mut self,
username: &str,
token: &SecretString,
domain: EhloDomain<'_>,
) -> Result<(), SmtpClientStdError> {
self.run(SmtpAuthXoauth2::new(
username,
token,
domain,
SmtpAuthXoauth2Options::default(),
))
}
#[cfg(feature = "scram")]
pub fn auth_scram_sha256(
&mut self,
username: &str,
password: &SecretString,
nonce: &[u8],
domain: EhloDomain<'_>,
) -> Result<(), SmtpClientStdError> {
self.run(SmtpAuthScramSha256::new(
username,
password,
nonce,
domain,
SmtpAuthScramSha256Options::default(),
))
}
pub fn mail(
&mut self,
reverse_path: ReversePath<'_>,
parameters: Vec<Parameter<'_>>,
) -> Result<(), SmtpClientStdError> {
self.run(SmtpMail::new(reverse_path, parameters))
}
pub fn rcpt(
&mut self,
forward_path: ForwardPath<'_>,
parameters: Vec<Parameter<'_>>,
) -> Result<(), SmtpClientStdError> {
self.run(SmtpRcpt::new(forward_path, parameters))
}
pub fn data(&mut self, message: Vec<u8>) -> Result<(), SmtpClientStdError> {
self.run(SmtpData::new(message))
}
pub fn rset(&mut self) -> Result<(), SmtpClientStdError> {
self.run(SmtpRset::new())
}
pub fn noop(&mut self) -> Result<(), SmtpClientStdError> {
self.run(SmtpNoop::new())
}
pub fn send<'a>(
&mut self,
reverse_path: ReversePath<'_>,
forward_paths: impl IntoIterator<Item = ForwardPath<'a>>,
message: Vec<u8>,
) -> Result<(), SmtpClientStdError> {
self.run(SmtpMessageSend::new(reverse_path, forward_paths, message))
}
}
impl fmt::Debug for SmtpClientStd {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SmtpClientStd").finish_non_exhaustive()
}
}
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
impl SmtpClientStd {
pub fn connect(
url: &Url,
tls: &Tls,
starttls: bool,
domain: EhloDomain<'_>,
sasl: Option<impl Into<Sasl>>,
) -> Result<Self, SmtpClientStdError> {
use bounded_static::IntoBoundedStatic;
let Some(host) = url.host_str() else {
return Err(SmtpClientStdError::UrlMissingHost(url.to_string()));
};
let (stream, is_tls) = match url.scheme() {
scheme if scheme.eq_ignore_ascii_case("smtp") => (
StreamStd::connect_tcp(host, url.port().unwrap_or(25))?,
false,
),
scheme if scheme.eq_ignore_ascii_case("smtps") => (
StreamStd::connect_tls(host, url.port().unwrap_or(465), tls)?,
true,
),
scheme => {
let url = url.to_string();
let scheme = scheme.to_string();
return Err(SmtpClientStdError::UrlUnsupportedScheme(url, scheme));
}
};
if starttls && is_tls {
return Err(SmtpClientStdError::StartTlsOverTls);
}
let domain = domain.into_static();
let stream = {
let mut stream = stream;
run_smtp_inline(&mut stream, SmtpGreetingGet::new())?;
run_smtp_inline(&mut stream, SmtpEhlo::new(domain.clone()))?;
if starttls {
run_smtp_inline(&mut stream, SmtpStartTls::new())?;
stream.upgrade_tls(tls)?
} else {
stream
}
};
let mut client = Self::new(stream);
if starttls {
client.ehlo(domain.clone())?;
}
if let Some(sasl) = sasl.map(Into::into) {
match sasl {
Sasl::Anonymous(SaslAnonymous { message }) => {
client.auth_anonymous(message.as_deref(), domain.clone())?;
}
Sasl::Login(SaslLogin { username, password }) => {
client.auth_login(&username, &password, domain.clone())?;
}
Sasl::Plain(SaslPlain {
authzid: _,
authcid,
passwd,
}) => {
client.auth_plain(&authcid, &passwd, domain.clone())?;
}
Sasl::Oauthbearer(SaslOauthbearer {
username,
host: _,
port: _,
token,
}) => {
client.auth_oauthbearer(&token, Some(&username), domain.clone())?;
}
Sasl::Xoauth2(SaslXoauth2 { username, token }) => {
client.auth_xoauth2(&username, &token, domain.clone())?;
}
#[cfg(feature = "scram")]
Sasl::ScramSha256(SaslScramSha256 { username, password }) => {
use rand::{Rng, distributions::Alphanumeric, thread_rng};
let nonce = thread_rng()
.sample_iter(&Alphanumeric)
.take(24)
.collect::<Vec<u8>>();
client.auth_scram_sha256(&username, &password, &nonce, domain.clone())?;
}
#[cfg(not(feature = "scram"))]
Sasl::ScramSha256(_) => {
return Err(SmtpClientStdError::ScramSha256NotEnabled);
}
}
}
Ok(client)
}
}
#[cfg(any(
feature = "rustls-aws",
feature = "rustls-ring",
feature = "native-tls"
))]
fn run_smtp_inline<C, T, E>(
stream: &mut StreamStd,
mut coroutine: C,
) -> Result<T, SmtpClientStdError>
where
C: SmtpCoroutine<Yield = SmtpYield, Return = Result<T, E>>,
SmtpClientStdError: From<E>,
{
let mut buf = [0u8; READ_BUFFER_SIZE];
let mut arg: Option<&[u8]> = None;
loop {
match coroutine.resume(arg.take()) {
SmtpCoroutineState::Complete(Ok(out)) => return Ok(out),
SmtpCoroutineState::Complete(Err(err)) => return Err(err.into()),
SmtpCoroutineState::Yielded(SmtpYield::WantsRead) => {
let n = stream.read(&mut buf)?;
arg = Some(&buf[..n]);
}
SmtpCoroutineState::Yielded(SmtpYield::WantsWrite(bytes)) => {
stream.write_all(&bytes)?;
}
}
}
}
pub trait SmtpStream: Read + Write + Send + Any {
fn as_any_mut(&mut self) -> &mut dyn Any;
}
impl<T: Read + Write + Send + Any> SmtpStream for T {
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}