use crate::assertion::MailAssertion;
use crate::client_state::{AuthState, ClientState};
use crate::mail::Mail;
use crate::mailbox::Mailbox;
use crate::response;
use crate::stream::Stream;
use base64::{engine::general_purpose::STANDARD, Engine};
use lazy_static::lazy_static;
use native_tls::{Identity, TlsAcceptor};
use rcgen::{generate_simple_self_signed, CertifiedKey};
use regex::bytes::Regex;
use std::io;
use std::net::{IpAddr, SocketAddr, TcpListener};
use std::sync::{Arc, Mutex};
use wg::WaitGroup;
lazy_static! {
static ref RE_SEP: Regex = Regex::new(r"( |\r\n)").unwrap();
static ref RE_MAIL_USER: Regex = Regex::new(r".*<(.*?)>").unwrap();
static ref RE_BODY: Regex = Regex::new(r"(?s)(.*?)(\r\n\.\r\n|\z)").unwrap();
}
enum AuthError {
InvalidCredentials,
UnsupportedMethod,
}
#[derive(Clone)]
pub struct MockServer {
listener: Arc<TcpListener>,
host: IpAddr,
port: u16,
domain: String,
tls_acceptor: TlsAcceptor,
cert_pem: String,
require_tls: bool,
mailboxes: Vec<Mailbox>,
mails: Arc<Mutex<Vec<Mail>>>,
no_verify_credentials: bool,
remaining_emails: Option<u8>,
wg: WaitGroup,
}
#[derive(Default)]
pub struct MockServerBuilder {
domain: String,
require_tls: bool,
mailboxes: Vec<Mailbox>,
no_verify_credentials: bool,
remaining_emails: Option<u8>,
}
impl MockServerBuilder {
fn new() -> Self {
Self {
domain: env!("CARGO_CRATE_NAME").to_string(),
..Default::default()
}
}
pub fn build(self) -> MockServer {
let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind tcp listener");
let addr = listener.local_addr().unwrap();
let host = addr.ip();
let CertifiedKey { cert, key_pair } = generate_simple_self_signed([host.to_string()])
.expect("Failed to generate certificate");
let cert_pem = cert.pem();
let identity =
Identity::from_pkcs8(cert_pem.as_bytes(), key_pair.serialize_pem().as_bytes())
.expect("Failed to create ssl identity");
MockServer {
listener: Arc::new(listener),
host,
port: addr.port(),
domain: self.domain,
tls_acceptor: TlsAcceptor::new(identity).expect("Failed to create tls acceptor"),
cert_pem,
require_tls: self.require_tls,
mailboxes: self.mailboxes,
mails: Arc::new(Mutex::new(Vec::new())),
no_verify_credentials: self.no_verify_credentials,
remaining_emails: self.remaining_emails,
wg: WaitGroup::new(),
}
}
pub fn domain(mut self, domain: &str) -> Self {
self.domain = String::from(domain);
self
}
pub fn require_tls(mut self) -> Self {
self.require_tls = true;
self
}
pub fn no_verify_credentials(mut self) -> Self {
self.no_verify_credentials = true;
self
}
pub fn assert_after_n_emails(mut self, n: u8) -> Self {
if n == 0 {
panic!("close_after_n: n must be greater than 0");
}
self.remaining_emails = Some(n);
self
}
pub fn add_mailbox(mut self, user: &str, password: &str) -> Self {
self.mailboxes.push(Mailbox::new(user, password));
self
}
}
impl MockServer {
pub fn builder() -> MockServerBuilder {
MockServerBuilder::new()
}
pub fn host(&self) -> IpAddr {
self.host
}
pub fn port(&self) -> u16 {
self.port
}
pub fn cert_pem(&self) -> &[u8] {
self.cert_pem.as_bytes()
}
pub fn assert(&self, assertion: MailAssertion) -> bool {
self.wg.wait();
assertion.assert(&self.mails.lock().unwrap())
}
pub fn start(&self) {
let server = Arc::new(Mutex::new(self.clone()));
let server = Arc::clone(&server);
std::thread::spawn(move || {
let mut server = server.lock().unwrap();
while let Ok((stream, socket)) = server.listener.accept() {
server.wg.add(1);
server.handle_client(socket, Stream::Plain(stream));
if server.remaining_emails.is_none() {
server.wg.done();
}
}
});
}
fn ready_message(&self) -> Vec<u8> {
format!("220 {} Service ready\r\n", self.domain)
.as_bytes()
.to_vec()
}
fn handle_client(&mut self, socket: SocketAddr, mut stream: Stream) {
if let Err(e) = stream.write_all(&self.ready_message()) {
println!("Failed to send response to {}: {}", socket, e);
}
let mut buffer = vec![0; 128];
let mut client_state = ClientState::new();
while let Ok(size) = stream.read(&mut buffer) {
let data = &buffer[..size];
if let Ok(data_str) = std::str::from_utf8(data) {
println!("{} sent: {}", socket, data_str);
} else {
println!("{} sent invalid UTF-8 data", socket);
}
match self.handle_stream(stream, data, &mut client_state) {
Ok(new_stream) => stream = new_stream,
Err(e) => {
println!("Failed to send response to {}: {}", socket, e);
break;
}
}
}
println!("Connection closed");
}
fn handle_stream(
&mut self,
mut stream: Stream,
data: &[u8],
client_state: &mut ClientState,
) -> Result<Stream, io::Error> {
if data.is_empty() {
stream.write_all(&self.ready_message())?;
}
if client_state.mail_transaction.is_receiving_data {
let captures = RE_BODY.captures(data).unwrap();
let body = captures.get(1).unwrap().as_bytes();
let mut i = 0;
while i < body.len() {
if i + 3 < body.len() && body[i..i + 4] == *b"\r\n.." {
client_state.mail_transaction.body.extend(b"\r\n.");
i += 4;
} else {
client_state.mail_transaction.body.push(body[i]);
i += 1;
}
}
if captures.get(2).unwrap().as_bytes() == b"\r\n.\r\n" {
let mail = Mail::new(
client_state.mail_transaction.sender.to_owned().unwrap(),
client_state.mail_transaction.recipients.to_owned(),
client_state.mail_transaction.body.to_owned(),
);
self.mails.lock().unwrap().push(mail);
client_state.mail_transaction.reset();
if let Some(remaining) = self.remaining_emails.as_mut() {
*remaining -= 1;
if *remaining == 0 {
self.wg.done();
}
}
}
stream.write_all(response::OK)?;
return Ok(stream);
}
if let Ok(user) = match client_state.auth_state {
AuthState::Plain => self.authenticate_plain(data),
_ => {
let (command, arg) = parse_command(data);
return self.handle_command(stream, command, arg, client_state);
}
} {
client_state.authed_user = Some(user);
client_state.auth_state = AuthState::Completed;
stream.write_all(response::AUTH_SUCCESS)?;
} else {
stream.write_all(response::AUTH_INVALID_CREDENTIALS)?;
}
Ok(stream)
}
fn handle_command(
&mut self,
mut stream: Stream,
command: Vec<u8>,
arg: Vec<u8>,
client_state: &mut ClientState,
) -> Result<Stream, io::Error> {
if command == b"ehlo" {
client_state.reset();
client_state.has_ehloed = true;
let auth_methods = "AUTH PLAIN";
let response = if client_state.is_tls_established {
format!(
"250-{}\r\n\
250 {}\r\n",
self.domain, auth_methods
)
} else {
format!(
"250-{}\r\n\
250-{}\r\n\
250 STARTTLS\r\n",
self.domain, auth_methods
)
};
stream.write_all(response.as_bytes())?;
} else if command == b"helo" {
client_state.reset();
client_state.has_ehloed = true;
stream.write_all(format!("250 {}\r\n", &self.domain).as_bytes())?;
} else if command == b"quit" {
stream.write_all(b"221 OK\r\n")?;
stream.shutdown()?;
} else if command == b"noop" {
stream.write_all(response::OK)?;
} else if command == b"starttls" {
if arg.is_empty() {
stream.write_all(b"220 Ready to start TLS\r\n")?;
let stream = self.upgrade_to_tls(stream)?;
client_state.is_tls_established = true;
client_state.has_ehloed = false;
return Ok(stream);
} else {
stream.write_all(b"501 Syntax error (no parameters allowed)\r\n")?;
}
} else if self.require_tls && !client_state.is_tls_established {
stream.write_all(b"530 5.7.0 Must issue a STARTTLS command first\r\n")?;
} else if command == b"auth" {
if client_state.authed_user.is_some() {
stream.write_all(b"503 Already authenticated\r\n")?;
} else if client_state.mail_transaction.sender.is_some() {
stream.write_all(b"503 Cannot authenticate during a mail transaction\r\n")?;
} else {
let mut parts = arg.splitn(2, |&c| c == b' ');
let method = parts.next().unwrap_or_default().to_ascii_lowercase();
let answer = parts.next().unwrap_or_default();
if answer.is_empty() {
if (match method.as_slice() {
b"plain" => {
client_state.auth_state = AuthState::Plain;
Ok(())
}
_ => Err(stream.write_all(response::AUTH_UNSUPPORTED_METHOD)?),
})
.is_ok()
{
stream.write_all(b"334 \r\n")?; }
} else if answer == b"*" {
stream.write_all(b"501 Authentication cancelled by client\r\n")?;
} else {
if self.no_verify_credentials {
client_state.auth_state = AuthState::Completed;
} else {
match match method.as_slice() {
b"plain" => self.authenticate_plain(answer),
_ => Err(AuthError::UnsupportedMethod),
} {
Ok(user) => {
client_state.authed_user = Some(user);
client_state.auth_state = AuthState::Completed;
}
Err(AuthError::UnsupportedMethod) => {
stream.write_all(response::AUTH_UNSUPPORTED_METHOD)?
}
Err(AuthError::InvalidCredentials) => {
stream.write_all(response::AUTH_INVALID_CREDENTIALS)?
}
}
}
if client_state.auth_state == AuthState::Completed {
stream.write_all(response::AUTH_SUCCESS)?;
}
}
}
} else if command == b"mail" {
if !client_state.has_ehloed {
stream.write_all(b"503 Session has not been opened with EHLO/HELO\r\n")?;
} else if client_state.mail_transaction.sender.is_some() {
stream.write_all(b"503 A mail transaction is already in progress\r\n")?;
} else if client_state.auth_state != AuthState::Completed {
stream.write_all(b"530 5.7.0 Authentication required\r\n")?;
} else {
match RE_MAIL_USER
.captures(&arg[5..])
.map(|c| c.get(1))
.map(|c| c.unwrap().as_bytes())
{
Some(address) => {
if self.no_verify_credentials
|| address == client_state.authed_user.as_ref().unwrap()
{
client_state.mail_transaction.sender = Some(address.to_vec());
stream.write_all(response::OK)?;
} else {
stream.write_all(response::BAD_MAILBOX_SYNTAX)?;
}
}
None => stream.write_all(response::BAD_MAILBOX_SYNTAX)?,
}
}
} else if command == b"data" {
if client_state.mail_transaction.sender.is_none() {
stream.write_all(response::NO_MAIL_TRANSACTION)?;
} else if client_state.mail_transaction.recipients.is_empty() {
stream.write_all(b"503 No recipients\r\n")?;
} else {
client_state.mail_transaction.is_receiving_data = true;
stream.write_all(b"354 Start mail input; end with <CRLF>.<CRLF>\r\n")?;
}
} else if command == b"rset" {
client_state.reset();
stream.write_all(response::OK)?;
} else if command == b"rcpt" {
if client_state.mail_transaction.sender.is_none() {
stream.write_all(response::NO_MAIL_TRANSACTION)?;
} else {
match RE_MAIL_USER
.captures(&arg[3..])
.map(|c| c.get(1))
.map(|c| c.unwrap().as_bytes())
{
Some(recipient) => {
client_state
.mail_transaction
.recipients
.insert(recipient.to_vec());
stream.write_all(response::OK)?;
}
None => stream.write_all(response::BAD_MAILBOX_SYNTAX)?,
}
}
} else if command == b"vrfy" {
stream.write_all(b"252\r\n")?;
} else {
stream.write_all(b"500 Unrecognized command\r\n")?;
}
Ok(stream)
}
fn upgrade_to_tls(&self, stream: Stream) -> Result<Stream, io::Error> {
match stream {
Stream::Plain(plain_stream) => Ok(Stream::Tls(
self.tls_acceptor
.accept(plain_stream)
.expect("Failed to accept stream"),
)),
Stream::Tls(_) => Err(io::Error::new(
io::ErrorKind::AlreadyExists,
"Connection is already tls",
)),
}
}
fn authenticate_plain(&self, token: &[u8]) -> Result<Vec<u8>, AuthError> {
let creds = STANDARD
.decode(token)
.map_err(|_| AuthError::InvalidCredentials)?;
let mut parts = creds.split(|&b| b == 0).skip(1);
let user = parts.next().unwrap_or_default();
let password = parts.next().unwrap_or_default();
for mailbox in &self.mailboxes {
if mailbox.authenticate(user, password) {
return Ok(user.to_vec());
}
}
Err(AuthError::InvalidCredentials)
}
}
fn parse_command(command: &[u8]) -> (Vec<u8>, Vec<u8>) {
let mut parts = RE_SEP.splitn(&command[..command.len() - 2], 2); let verb = parts.next().unwrap().to_ascii_lowercase();
let arg = parts.next().unwrap_or_default().to_vec();
(verb, arg)
}