use std::{
io,
sync::{Arc, Mutex},
};
use base64::{Engine as _, engine::general_purpose};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{TcpListener, TcpStream},
sync::broadcast,
};
use crate::auth::SharedUserStore;
use crate::delivery::{RcptContext, RcptDecision, SharedDeliveryPolicy, build_dsn_message};
use crate::store::{Store, current_internal_date};
use crate::tls::{SmtpStream, build_tls_acceptor};
use crate::{MailboxEvent, MailboxNotifier};
type SmtpReader = BufReader<SmtpStream>;
async fn write_raw(writer: &mut SmtpReader, data: &[u8]) -> io::Result<()> {
writer.get_mut().write_all(data).await
}
pub(crate) async fn run_smtp(
listener: TcpListener,
store: Arc<Mutex<Store>>,
auth: SharedUserStore,
delivery_policy: SharedDeliveryPolicy,
mut shutdown_rx: broadcast::Receiver<()>,
mailbox_notifier: MailboxNotifier,
) {
loop {
tokio::select! {
accept = listener.accept() => {
let Ok((stream, _)) = accept else { break };
let store = store.clone();
let auth = auth.clone();
let notifier = mailbox_notifier.clone();
let delivery_policy = delivery_policy.clone();
tokio::spawn(async move {
let _ = handle_smtp(stream, store, auth, delivery_policy, notifier).await;
});
}
_ = shutdown_rx.recv() => {
break;
}
}
}
}
async fn handle_smtp(
stream: TcpStream,
store: Arc<Mutex<Store>>,
auth: SharedUserStore,
delivery_policy: SharedDeliveryPolicy,
mailbox_notifier: MailboxNotifier,
) -> io::Result<()> {
let mut reader: SmtpReader = BufReader::new(SmtpStream::Plain(stream));
let mut tls_active = false;
write_raw(&mut reader, b"220 elektromail SMTP ready\r\n").await?;
let mut line = String::new();
let mut current_sender: Option<AddressInfo> = None;
let mut recipients: Vec<RcptContext> = Vec::new();
let mut authenticated = false;
loop {
line.clear();
let bytes = reader.read_line(&mut line).await?;
if bytes == 0 {
break;
}
let trimmed = line.trim_end_matches(&['\r', '\n'][..]);
let upper = trimmed.to_ascii_uppercase();
if upper.starts_with("EHLO") || upper.starts_with("HELO") {
let starttls_line = if tls_active { "" } else { "250-STARTTLS\r\n" };
write_raw(
&mut reader,
format!(
"250-localhost\r\n250-AUTH PLAIN\r\n{}250 SIZE 10485760\r\n",
starttls_line
)
.as_bytes(),
)
.await?;
} else if upper == "STARTTLS" {
if tls_active {
write_raw(&mut reader, b"503 TLS already active\r\n").await?;
continue;
}
write_raw(&mut reader, b"220 Ready to start TLS\r\n").await?;
let acceptor = build_tls_acceptor()?;
let inner = reader.into_inner();
let SmtpStream::Plain(stream) = inner else {
return Err(io::Error::other("STARTTLS requires plaintext stream"));
};
let tls_stream = acceptor.accept(stream).await.map_err(io::Error::other)?;
reader = BufReader::new(SmtpStream::Tls(tls_stream));
tls_active = true;
current_sender = None;
recipients.clear();
authenticated = false;
} else if upper.starts_with("AUTH ") || upper == "AUTH" {
if authenticated {
write_raw(&mut reader, b"503 Already authenticated\r\n").await?;
continue;
}
let mut parts = trimmed.split_whitespace();
let _ = parts.next();
let mechanism = parts.next().unwrap_or("").to_ascii_uppercase();
if mechanism != "PLAIN" {
write_raw(&mut reader, b"504 Unrecognized authentication type\r\n").await?;
continue;
}
let mut response = parts.next().map(str::to_string);
if response.is_none() {
write_raw(&mut reader, b"334 \r\n").await?;
line.clear();
let bytes = reader.read_line(&mut line).await?;
if bytes == 0 {
break;
}
let resp_trim = line.trim_end_matches(&['\r', '\n'][..]);
if resp_trim == "*" {
write_raw(&mut reader, b"501 Authentication canceled\r\n").await?;
continue;
}
response = Some(resp_trim.to_string());
}
let response = response.unwrap_or_default();
let Ok(decoded) = general_purpose::STANDARD.decode(response.as_bytes()) else {
write_raw(&mut reader, b"501 Invalid base64 data\r\n").await?;
continue;
};
let decoded = String::from_utf8_lossy(&decoded);
let mut cred_parts = decoded.split('\0');
let _authzid = cred_parts.next().unwrap_or("");
let authcid = cred_parts.next().unwrap_or("");
let passwd = cred_parts.next().unwrap_or("");
if auth.authenticate(authcid, passwd) {
authenticated = true;
write_raw(&mut reader, b"235 Authentication successful\r\n").await?;
} else {
write_raw(&mut reader, b"535 Authentication credentials invalid\r\n").await?;
}
} else if upper.starts_with("MAIL FROM:") {
current_sender = parse_smtp_address(trimmed);
recipients.clear();
write_raw(&mut reader, b"250 OK\r\n").await?;
} else if upper.starts_with("RCPT TO:") {
let Some(address) = parse_smtp_address(trimmed) else {
write_raw(&mut reader, b"501 Invalid address\r\n").await?;
continue;
};
let decision = delivery_policy.on_rcpt(
&address.address,
current_sender.as_ref().map(|s| s.address.as_str()),
);
if let RcptDecision::Reject { code, message } = &decision {
write_raw(&mut reader, format!("{code} {message}\r\n").as_bytes()).await?;
} else {
recipients.push(RcptContext {
address: address.address,
user: address.user,
decision,
});
write_raw(&mut reader, b"250 OK\r\n").await?;
}
} else if upper == "DATA" {
if recipients.is_empty() {
write_raw(&mut reader, b"503 Bad sequence of commands\r\n").await?;
continue;
}
write_raw(&mut reader, b"354 End data with <CR><LF>.<CR><LF>\r\n").await?;
let data = read_smtp_data(&mut reader).await?;
let accepted_users: Vec<String> = recipients
.iter()
.filter_map(|recipient| match recipient.decision {
RcptDecision::Accept => Some(recipient.user.clone()),
_ => None,
})
.collect();
let delivery_counts = {
let mut guard = store.lock().expect("store lock poisoned");
let mut counts = Vec::new();
for user in &accepted_users {
guard.append(user, "INBOX", data.clone(), current_internal_date());
let count = guard.list(user, "INBOX").len();
counts.push((user.clone(), count));
}
drop(guard);
counts
};
let dsn_recipients = delivery_policy.on_data(
current_sender.as_ref().map(|s| s.address.as_str()),
&recipients,
&data,
);
if !dsn_recipients.is_empty() {
if let Some(sender) = current_sender.as_ref().filter(|s| !s.user.is_empty()) {
let dsn = build_dsn_message(&sender.address, &dsn_recipients);
let mut guard = store.lock().expect("store lock poisoned");
guard.append(&sender.user, "INBOX", dsn, current_internal_date());
}
}
for (user, new_count) in delivery_counts {
let _ = mailbox_notifier.send(MailboxEvent {
user,
mailbox: "INBOX".to_string(),
new_count,
});
}
write_raw(&mut reader, b"250 OK\r\n").await?;
} else if upper == "RSET" {
current_sender = None;
recipients.clear();
write_raw(&mut reader, b"250 OK\r\n").await?;
} else if upper == "NOOP" {
write_raw(&mut reader, b"250 OK\r\n").await?;
} else if upper == "QUIT" {
write_raw(&mut reader, b"221 Bye\r\n").await?;
break;
} else {
write_raw(&mut reader, b"502 Command not implemented\r\n").await?;
}
}
Ok(())
}
struct AddressInfo {
address: String,
user: String,
}
#[cfg(test)]
pub(crate) fn parse_rcpt_user(line: &str) -> String {
parse_smtp_address(line)
.map(|info| info.user)
.unwrap_or_default()
}
fn parse_smtp_address(line: &str) -> Option<AddressInfo> {
let address = parse_address_value(line)?;
if address.trim().is_empty() {
return None;
}
let user = parse_mailbox_user(&address);
Some(AddressInfo { address, user })
}
fn parse_address_value(line: &str) -> Option<String> {
let raw = line.split_once(':').map_or(line, |(_, value)| value).trim();
if raw.is_empty() {
return None;
}
let addr = if let Some(start) = raw.find('<') {
let end = raw[start + 1..]
.find('>')
.map_or(raw.len(), |idx| start + 1 + idx);
raw[start + 1..end].trim()
} else {
raw.split_whitespace().next().unwrap_or("")
};
if addr.is_empty() {
return None;
}
Some(addr.to_string())
}
fn parse_mailbox_user(address: &str) -> String {
let local = address.split('@').next().unwrap_or(address);
local.split('+').next().unwrap_or(local).to_string()
}
async fn read_smtp_data<R: AsyncBufReadExt + Unpin>(reader: &mut R) -> io::Result<Vec<u8>> {
let mut data = Vec::new();
let mut line = String::new();
loop {
line.clear();
let bytes = reader.read_line(&mut line).await?;
if bytes == 0 {
break;
}
if line == ".\r\n" || line == ".\n" {
break;
}
if line.starts_with("..") {
line.remove(0);
}
data.extend_from_slice(line.as_bytes());
}
Ok(data)
}
#[cfg(test)]
mod tests {
use super::parse_rcpt_user;
#[test]
fn parse_rcpt_user_with_brackets() {
let user = parse_rcpt_user("RCPT TO:<user@example.com>");
assert_eq!(user, "user");
}
#[test]
fn parse_rcpt_user_without_brackets() {
let user = parse_rcpt_user("RCPT TO:user@example.com");
assert_eq!(user, "user");
}
#[test]
fn parse_rcpt_user_strips_plus_tag() {
let user = parse_rcpt_user("RCPT TO:<user+news@example.com>");
assert_eq!(user, "user");
}
}