use dashmap::DashMap;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Default)]
pub struct SmtpExtensions {
pub max_size: Option<usize>,
pub pipelining: bool,
pub eight_bit_mime: bool,
pub starttls: bool,
}
impl SmtpExtensions {
pub fn from_ehlo(ehlo_text: &str) -> Self {
let mut ext = SmtpExtensions::default();
for line in ehlo_text.lines() {
let keyword = line
.trim_start_matches(|c: char| c.is_ascii_digit())
.trim_start_matches(['-', ' '])
.to_ascii_uppercase();
if keyword.starts_with("SIZE") {
let parts: Vec<&str> = keyword.splitn(2, ' ').collect();
if parts.len() == 2 {
ext.max_size = parts[1].trim().parse().ok();
}
} else if keyword == "PIPELINING" {
ext.pipelining = true;
} else if keyword == "8BITMIME" {
ext.eight_bit_mime = true;
} else if keyword == "STARTTLS" {
ext.starttls = true;
}
}
ext
}
}
pub struct PooledConn {
pub reader: BufReader<TcpStream>,
pub last_used: SystemTime,
pub extensions: SmtpExtensions,
pub remote_key: String,
}
impl PooledConn {
pub fn stream_mut(&mut self) -> &mut TcpStream {
self.reader.get_mut()
}
}
#[derive(Debug, Clone)]
pub struct OutboundPoolConfig {
pub per_remote_cap: usize,
pub global_cap: usize,
pub idle_timeout: Duration,
}
impl Default for OutboundPoolConfig {
fn default() -> Self {
Self {
per_remote_cap: 8,
global_cap: 256,
idle_timeout: Duration::from_secs(30),
}
}
}
pub struct OutboundPool {
conns: DashMap<String, Mutex<VecDeque<PooledConn>>>,
config: OutboundPoolConfig,
total_idle: Arc<AtomicUsize>,
}
impl OutboundPool {
pub fn new(
config: OutboundPoolConfig,
mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
) -> Arc<Self> {
let pool = Arc::new(Self {
conns: DashMap::new(),
config: config.clone(),
total_idle: Arc::new(AtomicUsize::new(0)),
});
let reaper_pool = pool.clone();
let reap_interval = config.idle_timeout / 2;
tokio::spawn(async move {
loop {
tokio::select! {
_ = tokio::time::sleep(reap_interval) => {}
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
break;
}
}
}
reaper_pool.reap_idle().await;
}
});
pool
}
pub async fn get_or_connect(&self, remote_key: &str) -> anyhow::Result<PooledConn> {
if let Some(bucket) = self.conns.get(remote_key) {
let mut deque = bucket.lock().await;
if let Some(conn) = deque.pop_front() {
self.total_idle.fetch_sub(1, Ordering::Relaxed);
return Ok(conn);
}
}
self.open_fresh(remote_key).await
}
pub async fn return_conn(&self, mut conn: PooledConn) {
if let Err(e) = rset_connection(&mut conn).await {
tracing::debug!(
remote = conn.remote_key.as_str(),
"dropping connection after failed RSET: {}",
e
);
return;
}
if self.total_idle.load(Ordering::Relaxed) >= self.config.global_cap {
tracing::debug!(
remote = conn.remote_key.as_str(),
"global pool cap reached, dropping connection"
);
return;
}
let remote_key = conn.remote_key.clone();
let bucket = self
.conns
.entry(remote_key.clone())
.or_insert_with(|| Mutex::new(VecDeque::new()));
let mut deque = bucket.lock().await;
if deque.len() >= self.config.per_remote_cap {
tracing::debug!(
remote = remote_key.as_str(),
"per-remote cap reached, dropping connection"
);
return;
}
conn.last_used = SystemTime::now();
deque.push_back(conn);
self.total_idle.fetch_add(1, Ordering::Relaxed);
}
pub fn idle_count(&self) -> usize {
self.total_idle.load(Ordering::Relaxed)
}
async fn open_fresh(&self, remote_key: &str) -> anyhow::Result<PooledConn> {
let stream = TcpStream::connect(remote_key)
.await
.map_err(|e| anyhow::anyhow!("SMTP outbound connect to {}: {}", remote_key, e))?;
let mut reader = BufReader::new(stream);
let greeting = smtp_read_response_raw(&mut reader).await?;
if !greeting.starts_with("220") {
anyhow::bail!(
"unexpected SMTP greeting from {}: {}",
remote_key,
greeting.trim()
);
}
smtp_write(&mut reader, "EHLO localhost\r\n").await?;
let ehlo_text = smtp_read_response_raw(&mut reader).await?;
if !ehlo_text.starts_with("250") {
anyhow::bail!("EHLO rejected by {}: {}", remote_key, ehlo_text.trim());
}
let extensions = SmtpExtensions::from_ehlo(&ehlo_text);
Ok(PooledConn {
reader,
last_used: SystemTime::now(),
extensions,
remote_key: remote_key.to_string(),
})
}
async fn reap_idle(&self) {
let now = SystemTime::now();
let mut total_reaped = 0usize;
for bucket_ref in self.conns.iter() {
let mut deque = bucket_ref.value().lock().await;
let before = deque.len();
deque.retain(|conn| {
match conn.last_used.elapsed() {
Ok(elapsed) => elapsed <= self.config.idle_timeout,
Err(_) => true,
}
});
let reaped = before - deque.len();
total_reaped += reaped;
}
if total_reaped > 0 {
self.total_idle.fetch_sub(total_reaped, Ordering::Relaxed);
tracing::debug!(
"outbound pool idle reaper: closed {} connections",
total_reaped
);
}
let _ = now; }
}
pub(crate) async fn smtp_write(
reader: &mut BufReader<TcpStream>,
cmd: &str,
) -> std::io::Result<()> {
let stream = reader.get_mut();
stream.write_all(cmd.as_bytes()).await?;
stream.flush().await
}
pub(crate) async fn smtp_read_response_raw(
reader: &mut BufReader<TcpStream>,
) -> std::io::Result<String> {
let mut full = String::new();
loop {
let mut line = String::new();
reader.read_line(&mut line).await?;
let is_last = line.len() >= 4 && line.as_bytes().get(3) == Some(&b' ');
full.push_str(&line);
if is_last {
break;
}
}
Ok(full)
}
async fn rset_connection(conn: &mut PooledConn) -> anyhow::Result<()> {
smtp_write(&mut conn.reader, "RSET\r\n").await?;
let rset_resp = smtp_read_response_raw(&mut conn.reader).await?;
if !rset_resp.starts_with("250") {
anyhow::bail!("RSET rejected: {}", rset_resp.trim());
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncReadExt;
use tokio::net::TcpListener;
#[derive(Debug, Clone)]
struct FakeServerBehaviour {
accept_count: usize,
ehlo_response: String,
accept_rset: bool,
accept_mail: bool,
accept_rcpt: bool,
accept_data: bool,
}
impl Default for FakeServerBehaviour {
fn default() -> Self {
Self {
accept_count: 1,
ehlo_response: "250-localhost\r\n250 PIPELINING\r\n".to_string(),
accept_rset: true,
accept_mail: true,
accept_rcpt: true,
accept_data: true,
}
}
}
async fn spawn_fake_smtp(
behaviour: FakeServerBehaviour,
) -> (u16, tokio::sync::watch::Receiver<usize>) {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind fake smtp");
let port = listener.local_addr().expect("local addr").port();
let (tx, rx) = tokio::sync::watch::channel(0usize);
tokio::spawn(async move {
let mut count = 0usize;
while count < behaviour.accept_count {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
count += 1;
let _ = tx.send(count);
let beh = behaviour.clone();
tokio::spawn(async move {
socket.write_all(b"220 localhost ESMTP\r\n").await.ok();
let mut buf = [0u8; 4096];
loop {
let n = match socket.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let raw = String::from_utf8_lossy(&buf[..n]);
let cmd = raw.trim().to_ascii_uppercase();
if cmd.starts_with("EHLO") || cmd.starts_with("HELO") {
socket.write_all(beh.ehlo_response.as_bytes()).await.ok();
} else if cmd.starts_with("RSET") {
if beh.accept_rset {
socket.write_all(b"250 OK\r\n").await.ok();
} else {
socket
.write_all(b"500 Command not recognized\r\n")
.await
.ok();
}
} else if cmd.starts_with("MAIL") {
if beh.accept_mail {
socket.write_all(b"250 OK\r\n").await.ok();
} else {
socket.write_all(b"550 Rejected\r\n").await.ok();
}
} else if cmd.starts_with("RCPT") {
if beh.accept_rcpt {
socket.write_all(b"250 OK\r\n").await.ok();
} else {
socket.write_all(b"550 Rejected\r\n").await.ok();
}
} else if cmd.starts_with("DATA") {
if beh.accept_data {
socket.write_all(b"354 Go ahead\r\n").await.ok();
let mut data_buf = [0u8; 4096];
loop {
let dn = socket.read(&mut data_buf).await.unwrap_or(0);
if dn == 0 {
break;
}
let chunk = String::from_utf8_lossy(&data_buf[..dn]);
if chunk.contains("\r\n.\r\n") || chunk.trim() == "." {
socket.write_all(b"250 Queued\r\n").await.ok();
break;
}
}
} else {
socket.write_all(b"550 Rejected\r\n").await.ok();
}
} else if cmd.starts_with("QUIT") {
socket.write_all(b"221 Bye\r\n").await.ok();
break;
}
}
});
}
});
(port, rx)
}
fn make_pool(
config: OutboundPoolConfig,
) -> (Arc<OutboundPool>, tokio::sync::watch::Sender<bool>) {
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let pool = OutboundPool::new(config, shutdown_rx);
(pool, shutdown_tx)
}
#[tokio::test]
async fn test_outbound_pool_basic_reuse() {
let beh = FakeServerBehaviour {
accept_count: 2, ..Default::default()
};
let (port, connect_rx) = spawn_fake_smtp(beh).await;
let remote = format!("127.0.0.1:{}", port);
let config = OutboundPoolConfig {
per_remote_cap: 4,
global_cap: 16,
idle_timeout: Duration::from_secs(30),
};
let (pool, _tx) = make_pool(config);
let conn1 = pool
.get_or_connect(&remote)
.await
.expect("first connect should succeed");
assert_eq!(
*connect_rx.borrow(),
1,
"one TCP connection after first get"
);
pool.return_conn(conn1).await;
assert_eq!(pool.idle_count(), 1, "one idle conn after return");
let conn2 = pool
.get_or_connect(&remote)
.await
.expect("second get should succeed");
assert_eq!(
*connect_rx.borrow(),
1,
"connection count must stay at 1 (pooled reuse)"
);
pool.return_conn(conn2).await;
assert_eq!(pool.idle_count(), 1);
}
#[tokio::test]
async fn test_outbound_pool_idle_reaper() {
let beh = FakeServerBehaviour {
accept_count: 1,
..Default::default()
};
let (port, _connect_rx) = spawn_fake_smtp(beh).await;
let remote = format!("127.0.0.1:{}", port);
let idle_timeout = Duration::from_millis(80);
let config = OutboundPoolConfig {
per_remote_cap: 4,
global_cap: 16,
idle_timeout,
};
let (pool, _tx) = make_pool(config);
let conn = pool
.get_or_connect(&remote)
.await
.expect("connect must succeed");
pool.return_conn(conn).await;
assert_eq!(pool.idle_count(), 1, "one idle conn before timeout");
tokio::time::sleep(idle_timeout * 3).await;
assert_eq!(
pool.idle_count(),
0,
"idle conn must be reaped after timeout"
);
}
#[tokio::test]
async fn test_outbound_pool_rset_on_return() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let port = listener.local_addr().expect("local_addr").port();
let remote = format!("127.0.0.1:{}", port);
let (seen_tx, mut seen_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
tokio::spawn(async move {
let Ok((mut socket, _)) = listener.accept().await else {
return;
};
socket.write_all(b"220 localhost ESMTP\r\n").await.ok();
let mut buf = [0u8; 4096];
loop {
let n = match socket.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let raw = String::from_utf8_lossy(&buf[..n]).to_string();
let cmd = raw.trim().to_ascii_uppercase();
if cmd.starts_with("EHLO") || cmd.starts_with("HELO") {
socket.write_all(b"250 localhost\r\n").await.ok();
} else if cmd.starts_with("RSET") {
let _ = seen_tx.send("RSET".to_string());
socket.write_all(b"250 OK\r\n").await.ok();
} else if cmd.starts_with("QUIT") {
socket.write_all(b"221 Bye\r\n").await.ok();
break;
}
}
});
let config = OutboundPoolConfig::default();
let (pool, _tx) = make_pool(config);
let conn = pool
.get_or_connect(&remote)
.await
.expect("connect must succeed");
pool.return_conn(conn).await;
let cmd = tokio::time::timeout(Duration::from_secs(2), seen_rx.recv())
.await
.expect("timed out waiting for RSET")
.expect("channel closed");
assert_eq!(cmd, "RSET");
}
#[test]
fn test_smtp_extensions_parsing() {
let ehlo = "250-localhost\r\n250-SIZE 10240000\r\n250-PIPELINING\r\n250-8BITMIME\r\n250 STARTTLS\r\n";
let ext = SmtpExtensions::from_ehlo(ehlo);
assert_eq!(ext.max_size, Some(10_240_000));
assert!(ext.pipelining);
assert!(ext.eight_bit_mime);
assert!(ext.starttls);
}
}