#![forbid(unsafe_code)]
use std::sync::Arc;
use std::time::Duration;
use bytes::BytesMut;
use rustls_pki_types::ServerName;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use imap_client::{Capabilities, ClientError, RawClient, Session, Tls, Unauthenticated};
use imap_core::error::ParseError;
use imap_core::parser::parse_response;
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_PRE_TLS_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn connect_tls(
domain: &str,
port: u16,
) -> Result<Session<Unauthenticated, Tls>, ClientError> {
connect_tls_with_timeouts(
domain,
port,
DEFAULT_CONNECT_TIMEOUT,
DEFAULT_HANDSHAKE_TIMEOUT,
)
.await
}
pub async fn connect_tls_with_timeouts(
domain: &str,
port: u16,
connect_timeout: Duration,
handshake_timeout: Duration,
) -> Result<Session<Unauthenticated, Tls>, ClientError> {
let tcp = connect_tcp(domain, port, connect_timeout).await?;
let connector = default_tls_connector()?;
handshake_with_connector(connector, domain, tcp, handshake_timeout).await
}
pub async fn connect_starttls(
domain: &str,
port: u16,
) -> Result<Session<Unauthenticated, Tls>, ClientError> {
connect_starttls_with_timeouts(
domain,
port,
DEFAULT_CONNECT_TIMEOUT,
DEFAULT_HANDSHAKE_TIMEOUT,
DEFAULT_PRE_TLS_TIMEOUT,
)
.await
}
pub async fn connect_starttls_with_timeouts(
domain: &str,
port: u16,
connect_timeout: Duration,
handshake_timeout: Duration,
pre_tls_timeout: Duration,
) -> Result<Session<Unauthenticated, Tls>, ClientError> {
let tcp = connect_tcp(domain, port, connect_timeout).await?;
let connector = default_tls_connector()?;
starttls_with_connector(connector, domain, tcp, handshake_timeout, pre_tls_timeout).await
}
pub fn default_tls_connector() -> Result<TlsConnector, ClientError> {
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = ClientConfig::builder_with_provider(Arc::new(
tokio_rustls::rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.map_err(|e| ClientError::CommandFailed(format!("TLS config failed: {e}")))?
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(TlsConnector::from(Arc::new(config)))
}
pub async fn handshake_with_connector<S>(
connector: TlsConnector,
domain: &str,
stream: S,
handshake_timeout: Duration,
) -> Result<Session<Unauthenticated, Tls>, ClientError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let server_name = ServerName::try_from(domain)
.map_err(|_| ClientError::CommandFailed("Invalid domain for TLS".into()))?
.to_owned();
let tls_stream =
tokio::time::timeout(handshake_timeout, connector.connect(server_name, stream))
.await
.map_err(|_| ClientError::Timeout)??;
finalize_session_after_handshake(tls_stream).await
}
pub async fn starttls_with_connector<S>(
connector: TlsConnector,
domain: &str,
mut stream: S,
handshake_timeout: Duration,
pre_tls_timeout: Duration,
) -> Result<Session<Unauthenticated, Tls>, ClientError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
drive_starttls_exchange(&mut stream, pre_tls_timeout).await?;
handshake_with_connector(connector, domain, stream, handshake_timeout).await
}
async fn connect_tcp(domain: &str, port: u16, timeout: Duration) -> Result<TcpStream, ClientError> {
let addr = format!("{}:{}", domain, port);
tokio::time::timeout(timeout, TcpStream::connect(&addr))
.await
.map_err(|_| ClientError::Timeout)?
.map_err(ClientError::Io)
}
async fn finalize_session_after_handshake<S>(
tls_stream: S,
) -> Result<Session<Unauthenticated, Tls>, ClientError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let mut raw = RawClient::new(tls_stream);
let mut events = raw.events();
let mut capabilities = Capabilities::default();
let mut got_greeting_caps = false;
let _ = tokio::time::timeout(Duration::from_millis(250), events.recv())
.await
.ok()
.and_then(|r| r.ok())
.map(|frame| {
if let Ok((_, response)) = parse_response(&frame) {
got_greeting_caps = capabilities.try_update_from(&response);
}
});
if !got_greeting_caps {
let cap_resp = raw.execute_command("CAPABILITY").await?;
if let Ok((_, response)) = parse_response(&cap_resp) {
capabilities.try_update_from(&response);
}
while let Ok(event) = events.try_recv() {
if let Ok((_, response)) = parse_response(&event)
&& capabilities.try_update_from(&response)
{
break;
}
}
}
Ok(Session::new(raw, capabilities))
}
async fn drive_starttls_exchange<S>(
stream: &mut S,
pre_tls_timeout: Duration,
) -> Result<(), ClientError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(4096);
let _greeting = read_one_frame(stream, &mut buf, pre_tls_timeout).await?;
write_all_with_timeout(stream, b"T0001 CAPABILITY\r\n", pre_tls_timeout).await?;
let mut caps = Capabilities::default();
loop {
let frame = read_one_frame(stream, &mut buf, pre_tls_timeout).await?;
if let Ok((_, response)) = parse_response(&frame) {
match &response {
imap_core::ast::Response::Status(s) if s.tag == Some("T0001") => {
if !matches!(s.status, imap_core::ast::Status::Ok) {
return Err(ClientError::CommandFailed(format!(
"CAPABILITY failed: {}",
s.text
)));
}
break;
}
_ => {
caps.try_update_from(&response);
}
}
}
}
if !caps.starttls {
return Err(ClientError::CommandFailed(
"server does not advertise STARTTLS".into(),
));
}
write_all_with_timeout(stream, b"T0002 STARTTLS\r\n", pre_tls_timeout).await?;
loop {
let frame = read_one_frame(stream, &mut buf, pre_tls_timeout).await?;
if let Ok((_, response)) = parse_response(&frame)
&& let imap_core::ast::Response::Status(s) = &response
&& s.tag == Some("T0002")
{
if !matches!(s.status, imap_core::ast::Status::Ok) {
return Err(ClientError::CommandFailed(format!(
"STARTTLS rejected: {}",
s.text
)));
}
return Ok(());
}
}
}
async fn write_all_with_timeout<S: AsyncWrite + Unpin>(
stream: &mut S,
bytes: &[u8],
timeout: Duration,
) -> Result<(), ClientError> {
tokio::time::timeout(timeout, stream.write_all(bytes))
.await
.map_err(|_| ClientError::Timeout)?
.map_err(ClientError::Io)
}
async fn read_one_frame<S: AsyncRead + Unpin>(
stream: &mut S,
buf: &mut BytesMut,
timeout: Duration,
) -> Result<Vec<u8>, ClientError> {
let deadline = tokio::time::Instant::now() + timeout;
loop {
if !buf.is_empty() {
match parse_response(buf) {
Ok((remaining, _)) => {
let consumed = buf.len() - remaining.len();
return Ok(buf.split_to(consumed).to_vec());
}
Err(ParseError::Incomplete) => {}
Err(_) => {
return Err(ClientError::CommandFailed(
"malformed pre-TLS response".into(),
));
}
}
}
let now = tokio::time::Instant::now();
if now >= deadline {
return Err(ClientError::Timeout);
}
let n = tokio::time::timeout(deadline - now, stream.read_buf(buf))
.await
.map_err(|_| ClientError::Timeout)?
.map_err(ClientError::Io)?;
if n == 0 {
return Err(ClientError::ConnectionClosed);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[tokio::test]
async fn test_connect_tls_invalid_domain() {
let (client_io, _server_io) = duplex(1024);
let connector = default_tls_connector().unwrap();
let r = handshake_with_connector(
connector,
"invalid domain",
client_io,
DEFAULT_HANDSHAKE_TIMEOUT,
)
.await;
assert!(r.is_err());
}
#[tokio::test]
async fn test_connect_tls_handshake_failure() {
let (client_io, server_io) = duplex(1024);
drop(server_io);
let connector = default_tls_connector().unwrap();
let r = handshake_with_connector(connector, "localhost", client_io, Duration::from_secs(2))
.await;
assert!(r.is_err());
}
#[tokio::test]
async fn test_connect_tcp_timeout() {
let r = connect_tcp("192.0.2.1", 993, Duration::from_millis(100)).await;
assert!(matches!(
r,
Err(ClientError::Timeout) | Err(ClientError::Io(_))
));
}
#[tokio::test]
async fn test_starttls_drive_happy_path() {
let (mut client_io, mut server_io) = duplex(4096);
let server_task = tokio::spawn(async move {
server_io
.write_all(b"* OK IMAP service ready\r\n")
.await
.unwrap();
let mut buf = [0u8; 1024];
let n = server_io.read(&mut buf).await.unwrap();
assert!(String::from_utf8_lossy(&buf[..n]).contains("CAPABILITY"));
server_io
.write_all(b"* CAPABILITY IMAP4rev2 STARTTLS\r\nT0001 OK done\r\n")
.await
.unwrap();
let n = server_io.read(&mut buf).await.unwrap();
assert!(String::from_utf8_lossy(&buf[..n]).contains("STARTTLS"));
server_io
.write_all(b"T0002 OK begin TLS\r\n")
.await
.unwrap();
});
drive_starttls_exchange(&mut client_io, Duration::from_secs(5))
.await
.unwrap();
server_task.await.unwrap();
}
#[tokio::test]
async fn test_starttls_drive_rejects_no_starttls() {
let (mut client_io, mut server_io) = duplex(4096);
let server_task = tokio::spawn(async move {
server_io.write_all(b"* OK ready\r\n").await.unwrap();
let mut buf = [0u8; 1024];
let _ = server_io.read(&mut buf).await.unwrap();
server_io
.write_all(b"* CAPABILITY IMAP4rev2\r\nT0001 OK done\r\n")
.await
.unwrap();
});
let r = drive_starttls_exchange(&mut client_io, Duration::from_secs(5)).await;
assert!(matches!(r, Err(ClientError::CommandFailed(_))));
server_task.await.unwrap();
}
#[tokio::test]
async fn test_starttls_drive_rejects_starttls_no() {
let (mut client_io, mut server_io) = duplex(4096);
let server_task = tokio::spawn(async move {
server_io.write_all(b"* OK ready\r\n").await.unwrap();
let mut buf = [0u8; 1024];
let _ = server_io.read(&mut buf).await.unwrap();
server_io
.write_all(b"* CAPABILITY IMAP4rev2 STARTTLS\r\nT0001 OK done\r\n")
.await
.unwrap();
let _ = server_io.read(&mut buf).await.unwrap();
server_io.write_all(b"T0002 NO not now\r\n").await.unwrap();
});
let r = drive_starttls_exchange(&mut client_io, Duration::from_secs(5)).await;
assert!(matches!(r, Err(ClientError::CommandFailed(_))));
server_task.await.unwrap();
}
}