use crate::{SmtpConfig, SmtpSpecRegistry};
use mockforge_core::protocol_abstraction::{
MessagePattern, MiddlewareChain, Protocol, ProtocolRequest, SpecRegistry,
};
use mockforge_core::Result;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadBuf};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{rustls, server::TlsStream, TlsAcceptor};
use tracing::{debug, error, info, warn};
pub enum SmtpStream {
Plain(TcpStream),
Tls(Box<TlsStream<TcpStream>>),
}
impl AsyncRead for SmtpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
SmtpStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
SmtpStream::Tls(s) => Pin::new(s.as_mut()).poll_read(cx, buf),
}
}
}
impl AsyncWrite for SmtpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
SmtpStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
SmtpStream::Tls(s) => Pin::new(s.as_mut()).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
SmtpStream::Plain(s) => Pin::new(s).poll_flush(cx),
SmtpStream::Tls(s) => Pin::new(s.as_mut()).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
SmtpStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
SmtpStream::Tls(s) => Pin::new(s.as_mut()).poll_shutdown(cx),
}
}
}
pub struct SmtpServer {
config: SmtpConfig,
spec_registry: Arc<SmtpSpecRegistry>,
middleware_chain: Arc<MiddlewareChain>,
#[allow(dead_code)]
tls_acceptor: Option<TlsAcceptor>,
}
impl SmtpServer {
pub fn new(config: SmtpConfig, spec_registry: Arc<SmtpSpecRegistry>) -> Result<Self> {
let middleware_chain = Arc::new(MiddlewareChain::new());
let tls_acceptor = if config.enable_starttls {
Some(Self::load_tls_acceptor(&config)?)
} else {
None
};
Ok(Self {
config,
spec_registry,
middleware_chain,
tls_acceptor,
})
}
fn load_tls_acceptor(config: &SmtpConfig) -> Result<TlsAcceptor> {
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::fs::File;
use std::io::BufReader;
let cert_path = config.tls_cert_path.as_ref().ok_or_else(|| {
mockforge_core::Error::internal("TLS certificate path not configured")
})?;
let key_path = config.tls_key_path.as_ref().ok_or_else(|| {
mockforge_core::Error::internal("TLS private key path not configured")
})?;
let cert_file = File::open(cert_path)?;
let mut cert_reader = BufReader::new(cert_file);
let certs: Vec<Vec<u8>> = certs(&mut cert_reader)?;
let certs: Vec<rustls::Certificate> = certs.into_iter().map(rustls::Certificate).collect();
let key_file = File::open(key_path)?;
let mut key_reader = BufReader::new(key_file);
let mut keys: Vec<Vec<u8>> = pkcs8_private_keys(&mut key_reader)?;
if keys.is_empty() {
return Err(mockforge_core::Error::internal("No private keys found"));
}
let mut server_config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, rustls::PrivateKey(keys.remove(0)))
.map_err(|e| mockforge_core::Error::internal(format!("TLS config error: {}", e)))?;
server_config.alpn_protocols = vec![b"smtp".to_vec()];
Ok(TlsAcceptor::from(Arc::new(server_config)))
}
pub fn with_middleware(
config: SmtpConfig,
spec_registry: Arc<SmtpSpecRegistry>,
middleware_chain: Arc<MiddlewareChain>,
) -> Result<Self> {
let tls_acceptor = if config.enable_starttls {
Some(Self::load_tls_acceptor(&config)?)
} else {
None
};
Ok(Self {
config,
spec_registry,
middleware_chain,
tls_acceptor,
})
}
pub async fn start(&self) -> Result<()> {
let addr = format!("{}:{}", self.config.host, self.config.port);
let listener = TcpListener::bind(&addr).await?;
info!("SMTP server listening on {}", addr);
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
debug!("New SMTP connection from {}", peer_addr);
let registry = self.spec_registry.clone();
let middleware = self.middleware_chain.clone();
let hostname = self.config.hostname.clone();
let tls_acceptor = self.tls_acceptor.clone();
tokio::spawn(async move {
if let Err(e) = handle_smtp_session(
SmtpStream::Plain(stream),
peer_addr,
registry,
middleware,
hostname,
tls_acceptor,
)
.await
{
error!("SMTP session error from {}: {}", peer_addr, e);
}
});
}
Err(e) => {
error!("Failed to accept SMTP connection: {}", e);
}
}
}
}
}
async fn handle_smtp_session(
stream: SmtpStream,
peer_addr: SocketAddr,
registry: Arc<SmtpSpecRegistry>,
middleware: Arc<MiddlewareChain>,
hostname: String,
tls_acceptor: Option<TlsAcceptor>,
) -> Result<()> {
let mut reader = BufReader::new(stream);
let greeting = format!("220 {} ESMTP MockForge SMTP Server\r\n", hostname);
reader.get_mut().write_all(greeting.as_bytes()).await?;
let mut session_state = SessionState::new();
let mut line: Vec<u8> = Vec::new();
while reader.read_until(b'\n', &mut line).await? > 0 {
if session_state.in_data_mode {
let trimmed = strip_line_terminator(&line);
if trimmed == b"." {
session_state.in_data_mode = false;
let response =
process_email(&session_state, ®istry, &middleware, peer_addr).await?;
reader.get_mut().write_all(response.as_bytes()).await?;
session_state.reset();
} else {
session_state.data.extend_from_slice(trimmed);
session_state.data.push(b'\n');
}
line.clear();
continue;
}
let as_str = String::from_utf8_lossy(&line);
let command = as_str.trim();
debug!("SMTP command from {}: {}", peer_addr, command);
if let Some(stage) = session_state.pending_auth.clone() {
handle_auth_continuation(stage, command, &mut session_state, reader.get_mut()).await?;
line.clear();
continue;
}
if command.is_empty() {
line.clear();
continue;
}
if command.eq_ignore_ascii_case("STARTTLS") {
if !matches!(reader.get_ref(), SmtpStream::Plain(_)) {
reader.get_mut().write_all(b"503 Command not allowed\r\n").await?;
} else if let Some(acceptor) = tls_acceptor.clone() {
reader.get_mut().write_all(b"220 Ready to start TLS\r\n").await?;
reader.get_mut().flush().await?;
let inner = reader.into_inner();
let tcp = match inner {
SmtpStream::Plain(t) => t,
SmtpStream::Tls(_) => unreachable!("checked is Plain above"),
};
let tls_stream = acceptor.accept(tcp).await.map_err(|e| {
mockforge_core::Error::internal(format!("TLS accept failed: {e}"))
})?;
reader = BufReader::new(SmtpStream::Tls(Box::new(tls_stream)));
session_state = SessionState::new();
line.clear();
continue;
} else {
reader
.get_mut()
.write_all(b"454 TLS not available due to temporary reason\r\n")
.await?;
}
line.clear();
continue;
}
match handle_smtp_command(
command,
&mut session_state,
reader.get_mut(),
&hostname,
®istry,
&middleware,
peer_addr,
)
.await
{
Ok(should_continue) => {
if !should_continue {
debug!("SMTP session ended for {}", peer_addr);
break;
}
}
Err(e) => {
error!("Error handling SMTP command: {}", e);
let error_response = "500 Internal server error\r\n";
reader.get_mut().write_all(error_response.as_bytes()).await?;
}
}
line.clear();
}
Ok(())
}
fn decode_plain_auth(b64: &str) -> Option<String> {
use base64::Engine as _;
let decoded = base64::engine::general_purpose::STANDARD.decode(b64.trim()).ok()?;
let mut parts = decoded.split(|b| *b == 0);
let _authzid = parts.next()?;
let authcid = parts.next()?;
let _passwd = parts.next()?;
Some(String::from_utf8_lossy(authcid).into_owned())
}
async fn handle_auth_continuation<W: AsyncWriteExt + Unpin>(
stage: AuthStage,
line: &str,
state: &mut SessionState,
writer: &mut W,
) -> Result<()> {
use base64::Engine as _;
match stage {
AuthStage::AwaitingPlainCredentials => {
state.pending_auth = None;
match decode_plain_auth(line) {
Some(user) => {
state.authenticated_user = Some(user);
writer.write_all(b"235 2.7.0 Authentication successful\r\n").await?;
}
None => {
writer.write_all(b"535 5.7.8 Authentication credentials invalid\r\n").await?;
}
}
}
AuthStage::AwaitingLoginUsername => {
let decoded = base64::engine::general_purpose::STANDARD
.decode(line.trim())
.ok()
.and_then(|b| String::from_utf8(b).ok());
match decoded {
Some(u) => {
state.authenticated_user = Some(u);
state.pending_auth = Some(AuthStage::AwaitingLoginPassword);
writer.write_all(b"334 UGFzc3dvcmQ6\r\n").await?;
}
None => {
state.pending_auth = None;
state.authenticated_user = None;
writer.write_all(b"535 5.7.8 Authentication credentials invalid\r\n").await?;
}
}
}
AuthStage::AwaitingLoginPassword => {
state.pending_auth = None;
if base64::engine::general_purpose::STANDARD.decode(line.trim()).is_ok() {
writer.write_all(b"235 2.7.0 Authentication successful\r\n").await?;
} else {
state.authenticated_user = None;
writer.write_all(b"535 5.7.8 Authentication credentials invalid\r\n").await?;
}
}
}
Ok(())
}
fn strip_line_terminator(line: &[u8]) -> &[u8] {
let mut end = line.len();
if end > 0 && line[end - 1] == b'\n' {
end -= 1;
}
if end > 0 && line[end - 1] == b'\r' {
end -= 1;
}
&line[..end]
}
async fn handle_smtp_command<W: AsyncWriteExt + Unpin>(
command: &str,
state: &mut SessionState,
writer: &mut W,
hostname: &str,
registry: &Arc<SmtpSpecRegistry>,
middleware: &Arc<MiddlewareChain>,
peer_addr: SocketAddr,
) -> Result<bool> {
let parts: Vec<&str> = command.splitn(2, ' ').collect();
let cmd = parts[0].to_uppercase();
match cmd.as_str() {
"HELLO" | "EHLO" => {
let domain = parts.get(1).unwrap_or(&hostname);
let response = if cmd == "EHLO" {
format!(
"250-{} Hello {}\r\n\
250-SIZE 10485760\r\n\
250-8BITMIME\r\n\
250-STARTTLS\r\n\
250-AUTH PLAIN LOGIN\r\n\
250 HELP\r\n",
hostname, domain
)
} else {
format!("250 {} Hello {}\r\n", hostname, domain)
};
writer.write_all(response.as_bytes()).await?;
Ok(true)
}
"MAIL" => {
if let Some(from_part) = parts.get(1) {
let from = extract_email_address(from_part);
state.mail_from = Some(from);
writer.write_all(b"250 OK\r\n").await?;
} else {
writer.write_all(b"501 Syntax error in parameters\r\n").await?;
}
Ok(true)
}
"RCPT" => {
if let Some(to_part) = parts.get(1) {
let to = extract_email_address(to_part);
state.rcpt_to.push(to);
writer.write_all(b"250 OK\r\n").await?;
} else {
writer.write_all(b"501 Syntax error in parameters\r\n").await?;
}
Ok(true)
}
"DATA" => {
writer.write_all(b"354 Start mail input; end with <CRLF>.<CRLF>\r\n").await?;
state.in_data_mode = true;
Ok(true)
}
"RSET" => {
state.reset();
writer.write_all(b"250 OK\r\n").await?;
Ok(true)
}
"NOOP" => {
writer.write_all(b"250 OK\r\n").await?;
Ok(true)
}
"QUIT" => {
writer.write_all(b"221 Bye\r\n").await?;
Ok(false) }
"STARTTLS" => {
writer.write_all(b"220 Ready to start TLS\r\n").await?;
Ok(true)
}
"AUTH" => {
let rest = parts.get(1).copied().unwrap_or("");
let mut auth_args = rest.splitn(2, ' ');
let mechanism = auth_args.next().map(|s| s.to_ascii_uppercase()).unwrap_or_default();
let initial_response = auth_args.next().map(str::trim).filter(|s| !s.is_empty());
match mechanism.as_str() {
"PLAIN" => {
if let Some(b64) = initial_response {
match decode_plain_auth(b64) {
Some(user) => {
state.authenticated_user = Some(user);
writer
.write_all(b"235 2.7.0 Authentication successful\r\n")
.await?;
}
None => {
writer
.write_all(b"535 5.7.8 Authentication credentials invalid\r\n")
.await?;
}
}
} else {
state.pending_auth = Some(AuthStage::AwaitingPlainCredentials);
writer.write_all(b"334 \r\n").await?;
}
Ok(true)
}
"LOGIN" => {
state.pending_auth = Some(AuthStage::AwaitingLoginUsername);
writer.write_all(b"334 VXNlcm5hbWU6\r\n").await?;
Ok(true)
}
_ => {
writer
.write_all(b"504 5.5.4 Authentication mechanism not supported\r\n")
.await?;
Ok(true)
}
}
}
"HELP" => {
let help_text = "214-Commands supported:\r\n\
214- HELLO EHLO MAIL RCPT DATA\r\n\
214- RSET NOOP QUIT HELP STARTTLS\r\n\
214 End of HELP info\r\n";
writer.write_all(help_text.as_bytes()).await?;
Ok(true)
}
_ => {
if state.in_data_mode {
if command == "." {
state.in_data_mode = false;
let response = process_email(state, registry, middleware, peer_addr).await?;
writer.write_all(response.as_bytes()).await?;
state.reset();
} else {
state.data.extend_from_slice(command.as_bytes());
state.data.push(b'\n');
}
Ok(true)
} else {
warn!("Unknown SMTP command: {}", command);
writer.write_all(b"502 Command not implemented\r\n").await?;
Ok(true)
}
}
}
}
async fn process_email(
state: &SessionState,
registry: &Arc<SmtpSpecRegistry>,
middleware: &Arc<MiddlewareChain>,
peer_addr: SocketAddr,
) -> Result<String> {
let from = state
.mail_from
.as_ref()
.ok_or_else(|| mockforge_core::Error::internal("Missing MAIL FROM"))?;
let to = state.rcpt_to.join(", ");
let subject = extract_subject(&state.data);
let captured = crate::fixtures::StoredEmail {
id: uuid::Uuid::new_v4().to_string(),
from: from.clone(),
to: state.rcpt_to.clone(),
subject: subject.clone(),
body: String::from_utf8_lossy(&state.data).into_owned(),
headers: HashMap::from([
("from".to_string(), from.clone()),
("to".to_string(), to.clone()),
("subject".to_string(), subject.clone()),
]),
received_at: chrono::Utc::now(),
raw: Some(state.data.clone()),
};
if let Err(e) = registry.store_email(captured) {
warn!("Failed to store email in mailbox: {}", e);
}
let mut request = ProtocolRequest {
protocol: Protocol::Smtp,
pattern: MessagePattern::OneWay,
operation: "SEND".to_string(),
path: from.clone(),
topic: None,
routing_key: None,
partition: None,
qos: None,
metadata: HashMap::from([
("from".to_string(), from.clone()),
("to".to_string(), to.clone()),
("subject".to_string(), subject.clone()),
]),
body: Some(state.data.clone()),
client_ip: Some(peer_addr.ip().to_string()),
};
if let Some(short_circuit_response) = middleware.process_request(&mut request).await? {
return Ok(String::from_utf8_lossy(&short_circuit_response.body).to_string());
}
let response = match registry.generate_mock_response(&request) {
Ok(mut resp) => {
middleware.process_response(&request, &mut resp).await?;
String::from_utf8_lossy(&resp.body).to_string()
}
Err(_) => "250 OK\r\n".to_string(),
};
Ok(response)
}
fn extract_email_address(param: &str) -> String {
if let Some(start) = param.find('<') {
if let Some(end) = param.find('>') {
return param[start + 1..end].to_string();
}
}
param.trim().to_string()
}
fn extract_subject(data: &[u8]) -> String {
let header_text = String::from_utf8_lossy(data);
for line in header_text.lines() {
if line.is_empty() {
break;
}
if line.to_lowercase().starts_with("subject:") {
return line[8..].trim().to_string();
}
}
String::new()
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(clippy::enum_variant_names)] enum AuthStage {
AwaitingLoginUsername,
AwaitingLoginPassword,
AwaitingPlainCredentials,
}
struct SessionState {
mail_from: Option<String>,
rcpt_to: Vec<String>,
data: Vec<u8>,
in_data_mode: bool,
pending_auth: Option<AuthStage>,
authenticated_user: Option<String>,
}
impl SessionState {
fn new() -> Self {
Self {
mail_from: None,
rcpt_to: Vec::new(),
data: Vec::new(),
in_data_mode: false,
pending_auth: None,
authenticated_user: None,
}
}
fn reset(&mut self) {
self.mail_from = None;
self.rcpt_to.clear();
self.data.clear();
self.in_data_mode = false;
self.pending_auth = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_email_address() {
assert_eq!(extract_email_address("FROM:<user@example.com>"), "user@example.com");
assert_eq!(extract_email_address("TO:<admin@test.com>"), "admin@test.com");
assert_eq!(extract_email_address("user@example.com"), "user@example.com");
}
#[test]
fn test_extract_email_address_whitespace() {
assert_eq!(extract_email_address(" user@example.com "), "user@example.com");
}
#[test]
fn test_extract_email_address_no_brackets() {
assert_eq!(extract_email_address("plain@email.com"), "plain@email.com");
}
#[test]
fn test_extract_email_address_mail_from_format() {
assert_eq!(extract_email_address("FROM:<sender@domain.com>"), "sender@domain.com");
}
#[test]
fn test_extract_subject() {
let data =
"From: sender@example.com\nSubject: Test Email\nTo: recipient@example.com\n\nBody text";
assert_eq!(extract_subject(data.as_bytes()), "Test Email");
}
#[test]
fn test_extract_subject_not_found() {
let data = "From: sender@example.com\nTo: recipient@example.com\n\nBody text";
assert_eq!(extract_subject(data.as_bytes()), "");
}
#[test]
fn test_extract_subject_lowercase() {
let data = "subject: lowercase subject\nFrom: sender@example.com";
assert_eq!(extract_subject(data.as_bytes()), "lowercase subject");
}
#[test]
fn test_extract_subject_mixed_case() {
let data = "SUBJECT: UPPERCASE SUBJECT\nFrom: sender@example.com";
assert_eq!(extract_subject(data.as_bytes()), "UPPERCASE SUBJECT");
}
#[test]
fn test_session_state() {
let mut state = SessionState::new();
assert!(state.mail_from.is_none());
assert_eq!(state.rcpt_to.len(), 0);
state.mail_from = Some("sender@example.com".to_string());
state.rcpt_to.push("recipient@example.com".to_string());
state.reset();
assert!(state.mail_from.is_none());
assert_eq!(state.rcpt_to.len(), 0);
}
#[test]
fn test_session_state_new() {
let state = SessionState::new();
assert!(state.mail_from.is_none());
assert!(state.rcpt_to.is_empty());
assert!(state.data.is_empty());
assert!(!state.in_data_mode);
}
#[test]
fn test_session_state_reset() {
let mut state = SessionState::new();
state.mail_from = Some("test@example.com".to_string());
state.rcpt_to.push("recipient1@example.com".to_string());
state.rcpt_to.push("recipient2@example.com".to_string());
state.data = b"Email body content".to_vec();
state.in_data_mode = true;
state.reset();
assert!(state.mail_from.is_none());
assert!(state.rcpt_to.is_empty());
assert!(state.data.is_empty());
assert!(!state.in_data_mode);
}
#[test]
fn test_session_state_multiple_recipients() {
let mut state = SessionState::new();
state.rcpt_to.push("a@example.com".to_string());
state.rcpt_to.push("b@example.com".to_string());
state.rcpt_to.push("c@example.com".to_string());
assert_eq!(state.rcpt_to.len(), 3);
}
#[test]
fn test_session_state_data_accumulation() {
let mut state = SessionState::new();
state.data.extend_from_slice(b"Line 1\n");
state.data.extend_from_slice(b"Line 2\n");
state.data.extend_from_slice(b"Line 3\n");
assert_eq!(state.data, b"Line 1\nLine 2\nLine 3\n");
}
#[test]
fn test_strip_line_terminator() {
assert_eq!(strip_line_terminator(b"hello\r\n"), b"hello");
assert_eq!(strip_line_terminator(b"hello\n"), b"hello");
assert_eq!(strip_line_terminator(b"hello"), b"hello");
assert_eq!(strip_line_terminator(b""), b"");
assert_eq!(strip_line_terminator(b"\xff\xfe\r\n"), b"\xff\xfe");
}
#[test]
fn test_extract_subject_from_bytes_with_non_utf8_body() {
let mut data = Vec::new();
data.extend_from_slice(b"From: a@example.test\r\n");
data.extend_from_slice(b"Subject: 8BITMIME body below\r\n");
data.extend_from_slice(b"\r\n");
data.extend_from_slice(&[0xff, 0xfe, 0xfd]); assert_eq!(extract_subject(&data), "8BITMIME body below");
}
#[tokio::test]
async fn test_smtp_server_new() {
let config = SmtpConfig::default();
let registry = Arc::new(SmtpSpecRegistry::new());
let server = SmtpServer::new(config, registry);
assert!(server.is_ok());
}
#[tokio::test]
async fn test_smtp_server_with_middleware() {
let config = SmtpConfig::default();
let registry = Arc::new(SmtpSpecRegistry::new());
let middleware = Arc::new(MiddlewareChain::new());
let server = SmtpServer::with_middleware(config, registry, middleware);
assert!(server.is_ok());
}
}