use std::sync::Arc;
use std::time::Duration;
use bytes::BytesMut;
use flate2::{Compress, Decompress, FlushCompress, FlushDecompress, Status};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;
use tracing::{debug, warn};
use crate::codec::encode::{
encode_multi_append_header_with_literal8, encode_quoted_or_literal,
encode_quoted_or_literal_utf8, LiteralMode,
};
use crate::error::Error;
use crate::types::{
format_fetch_attrs, AclEntry, AppendMessage, Capability, Command, CopyResult, EsearchResponse,
ExpungeResult, FetchAttr, FetchResponse, Flag, ListRightsResponse, MailboxAttribute,
MailboxFilter, MailboxInfo, MailboxName, MetadataEntry, MetadataResult, MoveResult,
NamespaceResponse, NotifyEvent, NotifySetParams, QresyncParams, QuotaResource,
QuotaRootResponse, Response, ResponseCode, SelectOptions, SelectedMailbox, SequenceSet,
StatusItem, StatusResult, StoreOperation, StoreResult, TaggedResponse, ThreadNode, UidRange,
UntaggedResponse, UntaggedStatus,
};
mod append;
mod auth;
pub(super) mod dispatch;
pub(super) mod driver;
mod extensions;
mod helpers;
mod idle;
mod lifecycle;
mod mailbox;
pub(super) mod pipeline;
mod seq_ops;
mod sort_thread;
pub(super) mod state;
mod tag;
pub mod typed_event;
mod uid_ops;
pub(super) mod wire;
#[cfg(test)]
#[path = "tests.rs"]
mod tests;
pub use daaki_message::TlsMode;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TcpKeepalive {
pub time: Duration,
pub interval: Duration,
}
impl TcpKeepalive {
pub fn new(time: Duration, interval: Duration) -> Self {
Self { time, interval }
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum SessionState {
NotAuthenticated,
Authenticated,
Selected,
Logout,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum IdleEvent {
Exists(u32),
Expunge(u32),
Vanished {
earlier: bool,
uids: Vec<UidRange>,
},
Fetch(Box<crate::types::FetchResponse>),
Recent(u32),
Alert(String),
Timeout,
Cancelled,
MailboxEvent(MailboxInfo),
MailboxStatus {
mailbox: MailboxName,
items: Vec<StatusItem>,
},
MetadataChange {
mailbox: MailboxName,
entries: Vec<MetadataEntry>,
},
SearchUpdate(Box<crate::types::EsearchResponse>),
ExtensionEvent(String),
StatusUpdate {
status: UntaggedStatus,
code: ResponseCode,
text: String,
},
NotificationOverflow {
code_text: Option<String>,
resp_text: String,
},
Bye {
code: Option<ResponseCode>,
text: String,
},
ServerTerminated,
}
#[non_exhaustive]
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SearchResult {
pub ids: Vec<u32>,
pub mod_seq: Option<u64>,
pub truncated: bool,
}
#[allow(clippy::large_enum_variant)]
enum InnerStream {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),
}
impl InnerStream {
async fn read_buf(&mut self, buf: &mut BytesMut) -> std::io::Result<usize> {
match self {
Self::Plain(s) => s.read_buf(buf).await,
Self::Tls(s) => s.read_buf(buf).await,
}
}
async fn write_all(&mut self, data: &[u8]) -> std::io::Result<()> {
match self {
Self::Plain(s) => s.write_all(data).await,
Self::Tls(s) => s.write_all(data).await,
}
}
async fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Plain(s) => s.flush().await,
Self::Tls(s) => s.flush().await,
}
}
}
struct CompressedStream {
inner: InnerStream,
decompress: Decompress,
compress: Compress,
raw_read_buf: BytesMut,
inflate_buf: Vec<u8>,
}
const COMPRESSED_RAW_BUF_SIZE: usize = 8192;
const INFLATE_BUF_SIZE: usize = 16384;
const DEFLATE_BUF_SIZE: usize = 16384;
impl CompressedStream {
fn new(inner: InnerStream) -> Self {
Self {
inner,
decompress: Decompress::new(false),
compress: Compress::new(flate2::Compression::default(), false),
raw_read_buf: BytesMut::with_capacity(COMPRESSED_RAW_BUF_SIZE),
inflate_buf: vec![0u8; INFLATE_BUF_SIZE],
}
}
#[allow(clippy::cast_possible_truncation)]
async fn read_buf(&mut self, buf: &mut BytesMut) -> std::io::Result<usize> {
loop {
if !self.raw_read_buf.is_empty() {
let before_in = self.decompress.total_in();
let before_out = self.decompress.total_out();
let status = self
.decompress
.decompress(
&self.raw_read_buf,
&mut self.inflate_buf,
FlushDecompress::Sync,
)
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("deflate decompression error: {e}"),
)
})?;
let consumed = (self.decompress.total_in() - before_in) as usize;
let produced = (self.decompress.total_out() - before_out) as usize;
if consumed > 0 {
let _ = self.raw_read_buf.split_to(consumed);
}
if produced > 0 {
buf.extend_from_slice(&self.inflate_buf[..produced]);
return Ok(produced);
}
if status == Status::StreamEnd {
return Ok(0);
}
}
let n = self.inner.read_buf(&mut self.raw_read_buf).await?;
if n == 0 {
return Ok(0); }
}
}
#[allow(clippy::cast_possible_truncation)]
async fn write_all(&mut self, data: &[u8]) -> std::io::Result<()> {
let mut deflate_buf = vec![0u8; DEFLATE_BUF_SIZE];
let mut input_offset = 0;
while input_offset < data.len() {
let before_in = self.compress.total_in();
let before_out = self.compress.total_out();
self.compress
.compress(&data[input_offset..], &mut deflate_buf, FlushCompress::None)
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("deflate compression error: {e}"),
)
})?;
let consumed = (self.compress.total_in() - before_in) as usize;
let produced = (self.compress.total_out() - before_out) as usize;
input_offset += consumed;
if produced > 0 {
self.inner.write_all(&deflate_buf[..produced]).await?;
}
}
let mut flush = FlushCompress::Sync;
loop {
let before_out = self.compress.total_out();
self.compress
.compress(&[], &mut deflate_buf, flush)
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("deflate sync-flush error: {e}"),
)
})?;
let produced = (self.compress.total_out() - before_out) as usize;
if produced > 0 {
self.inner.write_all(&deflate_buf[..produced]).await?;
}
if produced == 0 {
break;
}
flush = FlushCompress::None;
}
Ok(())
}
async fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush().await
}
}
#[allow(clippy::large_enum_variant)]
enum ImapStream {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),
Compressed(CompressedStream),
Poisoned,
#[cfg(test)]
Memory(tokio::io::DuplexStream),
}
impl ImapStream {
async fn read_buf(&mut self, buf: &mut BytesMut) -> std::io::Result<usize> {
match self {
Self::Plain(s) => s.read_buf(buf).await,
Self::Tls(s) => s.read_buf(buf).await,
Self::Compressed(s) => s.read_buf(buf).await,
Self::Poisoned => Err(std::io::Error::other("stream poisoned during upgrade")),
#[cfg(test)]
Self::Memory(s) => s.read_buf(buf).await,
}
}
async fn write_all(&mut self, data: &[u8]) -> std::io::Result<()> {
match self {
Self::Plain(s) => s.write_all(data).await,
Self::Tls(s) => s.write_all(data).await,
Self::Compressed(s) => s.write_all(data).await,
Self::Poisoned => Err(std::io::Error::other("stream poisoned during upgrade")),
#[cfg(test)]
Self::Memory(s) => s.write_all(data).await,
}
}
async fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Plain(s) => s.flush().await,
Self::Tls(s) => s.flush().await,
Self::Compressed(s) => s.flush().await,
Self::Poisoned => Err(std::io::Error::other("stream poisoned during upgrade")),
#[cfg(test)]
Self::Memory(s) => s.flush().await,
}
}
fn set_keepalive(&self, ka: &TcpKeepalive) -> Result<(), Error> {
use socket2::SockRef;
let sock_ka = socket2::TcpKeepalive::new()
.with_time(ka.time)
.with_interval(ka.interval);
let result = match self {
Self::Plain(tcp) => SockRef::from(tcp).set_tcp_keepalive(&sock_ka),
Self::Tls(tls) => SockRef::from(tls.get_ref().0).set_tcp_keepalive(&sock_ka),
Self::Compressed(c) => {
let inner_result = match &c.inner {
InnerStream::Plain(tcp) => SockRef::from(tcp).set_tcp_keepalive(&sock_ka),
InnerStream::Tls(tls) => {
SockRef::from(tls.get_ref().0).set_tcp_keepalive(&sock_ka)
}
};
inner_result
}
Self::Poisoned => {
return Err(Error::Io(std::sync::Arc::new(std::io::Error::other(
"cannot set keepalive: stream is in upgrade transition",
))));
}
#[cfg(test)]
Self::Memory(_) => {
return Err(Error::Io(std::sync::Arc::new(std::io::Error::other(
"keepalive not supported on memory streams",
))));
}
};
result.map_err(|e| Error::Io(std::sync::Arc::new(e)))
}
fn into_tcp(self) -> Option<TcpStream> {
match self {
Self::Plain(s) => Some(s),
Self::Tls(_) | Self::Compressed(_) | Self::Poisoned => Option::None,
#[cfg(test)]
Self::Memory(_) => Option::None,
}
}
}
pub struct ImapConnection {
cmd_tx: tokio::sync::mpsc::Sender<driver::DriverCommand>,
state_rx: tokio::sync::watch::Receiver<driver::ConnectionStateSnapshot>,
events_rx: tokio::sync::Mutex<tokio::sync::mpsc::Receiver<typed_event::TypedEvent>>,
driver_handle: tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
prebuilt_tag_counter: std::sync::atomic::AtomicU32,
host: String,
}
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct NotifyFlags {
pub(crate) list: bool,
pub(crate) status: bool,
pub(crate) metadata: bool,
}
const _: fn() = || {
fn assert_send<T: Send>() {}
assert_send::<ImapConnection>();
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum AppendLiteralKind {
Literal,
Literal8,
Utf8Literal8,
}
impl ImapConnection {
pub(super) fn next_prebuilt_tag(&self) -> String {
let n = self
.prebuilt_tag_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
.wrapping_add(1);
format!("P{n:03}")
}
pub async fn drain_events(&self) -> Vec<typed_event::TypedEvent> {
let mut rx = self.events_rx.lock().await;
let mut out = Vec::new();
while let Ok(ev) = rx.try_recv() {
out.push(ev);
}
out
}
pub async fn next_event(
&self,
timeout: std::time::Duration,
) -> Result<Option<typed_event::TypedEvent>, crate::error::Error> {
let mut rx = self.events_rx.lock().await;
match tokio::time::timeout(timeout, rx.recv()).await {
Ok(Some(ev)) => Ok(Some(ev)),
Ok(None) => Err(crate::error::Error::DriverGone),
Err(_) => Ok(None), }
}
}
impl std::fmt::Debug for ImapConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let snapshot = self.state_rx.borrow();
f.debug_struct("ImapConnection")
.field("state", &snapshot.session_state)
.field("capabilities_count", &snapshot.capabilities.len())
.field("cmd_tx_closed", &self.cmd_tx.is_closed())
.finish_non_exhaustive()
}
}
fn build_default_tls_config() -> Arc<rustls::ClientConfig> {
let _ = rustls::crypto::ring::default_provider().install_default();
let root_store: rustls::RootCertStore =
webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect();
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Arc::new(config)
}
fn find_literal_boundary(buf: &[u8]) -> Option<(usize, usize)> {
let mut i = 0;
while i < buf.len() {
if buf[i] == b'{' {
let start = i + 1;
let mut j = start;
while j < buf.len() && buf[j].is_ascii_digit() {
j += 1;
}
if j > start
&& j + 2 < buf.len()
&& buf[j] == b'}'
&& buf[j + 1] == b'\r'
&& buf[j + 2] == b'\n'
{
let Ok(size_str) = std::str::from_utf8(&buf[start..j]) else {
i += 1;
continue;
};
let Ok(size) = size_str.parse::<usize>() else {
i += 1;
continue;
};
return Some((j + 3, size));
}
}
i += 1;
}
None
}
fn patch_literals_to_plus_with_binary(buf: &[u8], allow_literal8: bool) -> BytesMut {
let mut result = BytesMut::with_capacity(buf.len() + 16);
let mut i = 0;
while i < buf.len() {
if buf[i] == b'{' {
let start = i + 1;
let mut j = start;
while j < buf.len() && buf[j].is_ascii_digit() {
j += 1;
}
if j > start
&& j + 2 < buf.len()
&& buf[j] == b'}'
&& buf[j + 1] == b'\r'
&& buf[j + 2] == b'\n'
{
let Ok(size_str) = std::str::from_utf8(&buf[start..j]) else {
result.extend_from_slice(&buf[i..=i]);
i += 1;
continue;
};
let Ok(size) = size_str.parse::<usize>() else {
result.extend_from_slice(&buf[i..=i]);
i += 1;
continue;
};
let is_literal8 = i > 0 && buf[i - 1] == b'~';
result.extend_from_slice(&buf[i..j]);
if !is_literal8 || allow_literal8 {
result.extend_from_slice(b"+}\r\n");
} else {
result.extend_from_slice(b"}\r\n");
}
let body_start = j + 3;
let body_end = body_start
.checked_add(size)
.map_or(buf.len(), |end| end.min(buf.len()));
result.extend_from_slice(&buf[body_start..body_end]);
i = body_end;
continue;
}
}
result.extend_from_slice(&buf[i..=i]);
i += 1;
}
result
}
fn patch_small_literals_to_plus_with_binary(buf: &[u8], allow_literal8: bool) -> BytesMut {
const LITERAL_MINUS_MAX: usize = 4096;
let mut result = BytesMut::with_capacity(buf.len() + 16);
let mut i = 0;
while i < buf.len() {
if buf[i] == b'{' {
let start = i + 1;
let mut j = start;
while j < buf.len() && buf[j].is_ascii_digit() {
j += 1;
}
if j > start
&& j + 2 < buf.len()
&& buf[j] == b'}'
&& buf[j + 1] == b'\r'
&& buf[j + 2] == b'\n'
{
let Ok(size_str) = std::str::from_utf8(&buf[start..j]) else {
result.extend_from_slice(&buf[i..=i]);
i += 1;
continue;
};
let Ok(size) = size_str.parse::<usize>() else {
result.extend_from_slice(&buf[i..=i]);
i += 1;
continue;
};
let is_literal8 = i > 0 && buf[i - 1] == b'~';
result.extend_from_slice(&buf[i..j]);
if size <= LITERAL_MINUS_MAX && (!is_literal8 || allow_literal8) {
result.extend_from_slice(b"+}\r\n");
} else {
result.extend_from_slice(b"}\r\n");
}
let body_start = j + 3;
let body_end = body_start
.checked_add(size)
.map_or(buf.len(), |end| end.min(buf.len()));
result.extend_from_slice(&buf[body_start..body_end]);
i = body_end;
continue;
}
}
result.extend_from_slice(&buf[i..=i]);
i += 1;
}
result
}
fn filter_store_flags(flags: &[Flag]) -> Vec<Flag> {
flags
.iter()
.filter(|f| !matches!(f, Flag::Recent | Flag::Wildcard))
.cloned()
.collect()
}
fn expand_uid_ranges(ranges: &[UidRange]) -> (Vec<u32>, bool) {
const MAX_EXPANDED_UIDS: usize = 1_000_000;
const STAR_SENTINEL: u32 = u32::MAX;
let mut uids = Vec::new();
let mut truncated = false;
for range in ranges {
if let Some(end) = range.end {
if end == STAR_SENTINEL {
uids.push(range.start);
truncated = true;
continue;
}
let count = (end.saturating_sub(range.start).saturating_add(1)) as usize;
if uids.len().saturating_add(count) > MAX_EXPANDED_UIDS {
warn!(
start = range.start,
end = end,
"UID range too large to expand ({count} UIDs), \
truncating to {MAX_EXPANDED_UIDS} total"
);
let remaining = MAX_EXPANDED_UIDS.saturating_sub(uids.len());
if remaining > 0 {
#[allow(clippy::cast_possible_truncation)]
let last = end.min(range.start.saturating_add((remaining - 1) as u32));
for uid in range.start..=last {
uids.push(uid);
}
}
truncated = true;
break;
}
for uid in range.start..=end {
uids.push(uid);
}
} else {
uids.push(range.start);
}
}
(uids, truncated)
}
fn build_selected_mailbox(
untagged: &[UntaggedResponse],
tagged: &TaggedResponse,
read_only: bool,
) -> SelectedMailbox {
let mut exists = 0;
let mut recent = 0;
let mut uid_validity: Option<u32> = None;
let mut uid_next = Option::None;
let mut flags = Vec::new();
let mut permanent_flags = Vec::new();
let mut highest_mod_seq = Option::None;
let mut no_mod_seq = false;
let mut unseen: Option<u32> = None;
let mut mailbox_id: Option<String> = None;
let mut uid_not_sticky = false;
let mut vanished = Vec::new();
let mut changed_messages = Vec::new();
let effective_responses = selected_mailbox_effective_responses(untagged);
for resp in effective_responses {
match resp {
UntaggedResponse::Exists(n) => exists = *n,
UntaggedResponse::Recent(n) => recent = *n,
UntaggedResponse::Flags(f) => flags.clone_from(f),
UntaggedResponse::Status {
code: Some(code), ..
} => {
extract_selected_code(
code,
&mut uid_validity,
&mut uid_next,
&mut permanent_flags,
&mut highest_mod_seq,
&mut no_mod_seq,
&mut unseen,
&mut mailbox_id,
&mut uid_not_sticky,
);
}
UntaggedResponse::Vanished {
earlier: true,
uids,
} => {
vanished.extend_from_slice(uids);
}
UntaggedResponse::Fetch(fetch) => {
changed_messages.push((**fetch).clone());
}
_ => {}
}
}
if let Some(code) = &tagged.code {
extract_selected_code(
code,
&mut uid_validity,
&mut uid_next,
&mut permanent_flags,
&mut highest_mod_seq,
&mut no_mod_seq,
&mut unseen,
&mut mailbox_id,
&mut uid_not_sticky,
);
}
SelectedMailbox {
exists,
recent,
uid_validity,
uid_next,
flags,
permanent_flags,
highest_mod_seq,
no_mod_seq,
unseen,
mailbox_id,
read_only,
uid_not_sticky,
vanished,
changed_messages,
}
}
fn selected_mailbox_effective_responses(untagged: &[UntaggedResponse]) -> &[UntaggedResponse] {
match untagged.iter().rposition(|r| {
matches!(
r,
UntaggedResponse::Status {
code: Some(ResponseCode::Closed),
..
}
)
}) {
Some(closed_idx) => &untagged[closed_idx + 1..],
None => untagged,
}
}
#[allow(clippy::too_many_arguments)]
fn extract_selected_code(
code: &ResponseCode,
uid_validity: &mut Option<u32>,
uid_next: &mut Option<u32>,
permanent_flags: &mut Vec<Flag>,
highest_mod_seq: &mut Option<u64>,
no_mod_seq: &mut bool,
unseen: &mut Option<u32>,
mailbox_id: &mut Option<String>,
uid_not_sticky: &mut bool,
) {
match code {
ResponseCode::UidValidity(v) => *uid_validity = Some(*v),
ResponseCode::UidNext(v) => *uid_next = Some(*v),
ResponseCode::PermanentFlags(f) => permanent_flags.clone_from(f),
ResponseCode::HighestModSeq(v) => {
if *v == 0 {
*no_mod_seq = true;
} else {
*highest_mod_seq = Some(*v);
}
}
ResponseCode::NoModSeq => *no_mod_seq = true,
ResponseCode::Unseen(v) => *unseen = Some(*v),
ResponseCode::MailboxId(id) => *mailbox_id = Some(id.clone()),
ResponseCode::UidNotSticky => *uid_not_sticky = true,
_ => {}
}
}
pub(super) fn is_notify_selection_mismatch(info: &MailboxInfo, selection_options: &[&str]) -> bool {
let has_recursivematch = selection_options
.iter()
.any(|o| o.eq_ignore_ascii_case("RECURSIVEMATCH"));
for opt in selection_options {
if opt.eq_ignore_ascii_case("SUBSCRIBED") {
let has_subscribed = info
.attributes
.iter()
.any(|a| matches!(a, MailboxAttribute::Subscribed));
let has_childinfo = has_recursivematch && !info.child_info.is_empty();
if !has_subscribed && !has_childinfo {
return true;
}
}
if opt.eq_ignore_ascii_case("REMOTE")
&& !info
.attributes
.iter()
.any(|a| matches!(a, MailboxAttribute::Remote))
{
return true;
}
if opt.eq_ignore_ascii_case("SPECIAL-USE")
&& !info.attributes.iter().any(MailboxAttribute::is_special_use)
{
return true;
}
}
false
}
pub(super) fn is_notify_list_event(info: &MailboxInfo, filter_extended_markers: bool) -> bool {
if info.old_name.is_some() {
return true;
}
if filter_extended_markers
&& info.attributes.iter().any(|a| {
matches!(
a,
MailboxAttribute::NonExistent | MailboxAttribute::NoAccess
)
})
{
return true;
}
false
}