use std::sync::Arc;
use bytes::BytesMut;
use tokio::sync::{mpsc, oneshot, watch};
use tracing::{debug, trace, warn};
use crate::codec::classification::{self, ClassificationContext, SolicitationRule};
use crate::codec::encode::{encode_command, EncodeOptions, LiteralMode};
use crate::error::Error;
use crate::types::response::{Capability, ResponseCode, StatusKind, UntaggedResponse};
use crate::types::validated::MailboxName;
use crate::types::Command;
use super::dispatch::{
CapabilityConsumer, Consumer, ConsumerContext, ContinuationConsumer, ContinuationReply,
Finalized, TaggedOkConsumer,
};
use super::typed_event::TypedEvent;
use super::NotifyFlags;
pub(super) type PipelineResults = Vec<Result<Box<dyn std::any::Any + Send>, Error>>;
pub(super) enum DriverCommand {
Run {
payload: DriverCommandPayload,
consumer: DriverConsumer,
result_tx: oneshot::Sender<Result<Box<dyn std::any::Any + Send>, Error>>,
},
Upgrade {
payload: UpgradePayload,
result_tx: oneshot::Sender<Result<Box<dyn std::any::Any + Send>, Error>>,
},
Pipeline {
commands: Vec<Command>,
consumers: Vec<Box<dyn ConsumerErased>>,
result_tx: oneshot::Sender<Result<PipelineResults, Error>>,
},
SetKeepalive {
keepalive: super::TcpKeepalive,
result_tx: oneshot::Sender<Result<(), Error>>,
},
Idle {
done_rx: oneshot::Receiver<()>,
result_tx: oneshot::Sender<Result<IdleTermination, Error>>,
},
}
#[derive(Debug)]
pub(super) enum IdleTermination {
ClientDone,
ServerTerminated,
}
pub(super) enum DriverCommandPayload {
Standard(Command),
PreBuilt {
wire_bytes: BytesMut,
tag: String,
cmd_kind: crate::types::CommandKind,
cmd_target: Option<MailboxName>,
},
}
pub(super) enum UpgradePayload {
StartTls {
tls_config: Arc<rustls::ClientConfig>,
server_name: rustls_pki_types::ServerName<'static>,
},
Compress,
}
pub(super) trait ConsumerErased: Send {
fn on_response(
&mut self,
resp: UntaggedResponse,
notify_snapshot: NotifyFlags,
ctx: &ConsumerContext,
);
fn finalize_erased(
self: Box<Self>,
tagged: crate::types::response::TaggedResponse,
ctx: &ConsumerContext,
) -> Result<Finalized<Box<dyn std::any::Any + Send>>, Error>;
}
impl<C: Consumer + 'static> ConsumerErased for C
where
C::Output: 'static,
{
fn on_response(
&mut self,
resp: UntaggedResponse,
notify_snapshot: NotifyFlags,
ctx: &ConsumerContext,
) {
<C as Consumer>::on_response(self, resp, notify_snapshot, ctx);
}
fn finalize_erased(
self: Box<Self>,
tagged: crate::types::response::TaggedResponse,
ctx: &ConsumerContext,
) -> Result<Finalized<Box<dyn std::any::Any + Send>>, Error> {
let finalized = <C as Consumer>::finalize(self, tagged, ctx)?;
Ok(Finalized {
output: Box::new(finalized.output) as Box<dyn std::any::Any + Send>,
reclassified_as_events: finalized.reclassified_as_events,
})
}
}
pub(super) trait ContinuationConsumerErased: ConsumerErased {
fn on_continuation_erased(
&mut self,
cont: crate::types::response::ContinuationRequest,
ctx: &ConsumerContext,
) -> Result<ContinuationReply, Error>;
}
impl<C: ContinuationConsumer + 'static> ContinuationConsumerErased for C
where
C::Output: 'static,
{
fn on_continuation_erased(
&mut self,
cont: crate::types::response::ContinuationRequest,
ctx: &ConsumerContext,
) -> Result<ContinuationReply, Error> {
<C as ContinuationConsumer>::on_continuation(self, cont, ctx)
}
}
pub(super) enum DriverConsumer {
Regular(Box<dyn ConsumerErased>),
WithContinuations(Box<dyn ContinuationConsumerErased>),
}
impl DriverConsumer {
fn on_response(
&mut self,
resp: UntaggedResponse,
notify_snapshot: NotifyFlags,
ctx: &ConsumerContext,
) {
match self {
Self::Regular(c) => c.on_response(resp, notify_snapshot, ctx),
Self::WithContinuations(c) => c.on_response(resp, notify_snapshot, ctx),
}
}
fn finalize_erased(
self,
tagged: crate::types::response::TaggedResponse,
ctx: &ConsumerContext,
) -> Result<Finalized<Box<dyn std::any::Any + Send>>, Error> {
match self {
Self::Regular(c) => c.finalize_erased(tagged, ctx),
Self::WithContinuations(c) => c.finalize_erased(tagged, ctx),
}
}
fn on_continuation(
&mut self,
cont: crate::types::response::ContinuationRequest,
ctx: &ConsumerContext,
) -> Result<ContinuationReply, Error> {
match self {
Self::Regular(_) => Err(Error::Protocol(
"unexpected continuation during command that does not expect one".into(),
)),
Self::WithContinuations(c) => c.on_continuation_erased(cont, ctx),
}
}
}
#[derive(Debug, Clone)]
pub(super) struct ConnectionStateSnapshot {
pub session_state: super::SessionState,
pub capabilities: Vec<Capability>,
pub enabled: Vec<String>,
}
impl Default for ConnectionStateSnapshot {
fn default() -> Self {
Self {
session_state: super::SessionState::NotAuthenticated,
capabilities: Vec::new(),
enabled: Vec::new(),
}
}
}
pub(super) async fn driver_task(
mut wire_reader: super::wire::WireReader,
mut state: super::state::ProtocolState,
mut tag_gen: super::tag::TagGenerator,
mut cmd_rx: mpsc::Receiver<DriverCommand>,
state_tx: watch::Sender<ConnectionStateSnapshot>,
mut event_sink: event_sink::DriverEventSink,
) {
loop {
let _ = event_sink.drain_pending_nonblocking();
tokio::select! {
biased;
maybe_cmd = cmd_rx.recv() => {
let Some(cmd) = maybe_cmd else { break; };
match cmd {
DriverCommand::Run { payload, consumer, result_tx } => {
let result = match payload {
DriverCommandPayload::Standard(command) => {
run_one_command(
&mut wire_reader,
&mut state,
&mut tag_gen,
&mut event_sink,
command,
consumer,
).await
}
DriverCommandPayload::PreBuilt {
wire_bytes, tag, cmd_kind, cmd_target,
} => {
run_prebuilt_command(
&mut wire_reader,
&mut state,
&mut event_sink,
wire_bytes,
&tag,
cmd_kind,
cmd_target,
consumer,
).await
}
};
let _ = result_tx.send(result);
}
DriverCommand::Upgrade { payload, result_tx } => {
let result = run_upgrade(
&mut wire_reader,
&mut state,
&mut tag_gen,
&mut event_sink,
payload,
).await;
let _ = result_tx.send(result.map(|()| {
Box::new(()) as Box<dyn std::any::Any + Send>
}));
}
DriverCommand::Pipeline { commands, consumers, result_tx } => {
let result = run_pipeline(
&mut wire_reader,
&mut state,
&mut tag_gen,
&mut event_sink,
commands,
consumers,
).await;
let _ = result_tx.send(result);
}
DriverCommand::SetKeepalive { keepalive, result_tx } => {
let result = wire_reader.set_keepalive(&keepalive);
let _ = result_tx.send(result);
continue;
}
DriverCommand::Idle { done_rx, result_tx } => {
let result = run_idle(
&mut wire_reader,
&mut state,
&mut tag_gen,
&mut event_sink,
done_rx,
).await;
let _ = result_tx.send(result);
}
}
let _ = state_tx.send_replace(state.snapshot());
if state.session_state() == super::SessionState::Logout {
break;
}
}
}
}
let _ = logout_best_effort(&mut wire_reader, &mut state, &mut tag_gen).await;
}
pub(in crate::connection) async fn run_one_command(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
tag_gen: &mut super::tag::TagGenerator,
event_sink: &mut event_sink::DriverEventSink,
cmd: Command,
mut consumer: DriverConsumer,
) -> Result<Box<dyn std::any::Any + Send>, Error> {
let cmd_kind = cmd.kind();
let cmd_target: Option<MailboxName> = cmd.mailbox_target().cloned();
if matches!(cmd, Command::Logout) {
state.set_in_logout(true);
}
if matches!(cmd, Command::Login { .. } | Command::Authenticate { .. }) {
state.set_in_auth(true);
}
if matches!(cmd, Command::Select { .. } | Command::Examine { .. }) {
state.set_in_select(cmd_target.clone());
}
if matches!(cmd, Command::Close | Command::Unselect) {
state.set_in_close(true);
}
if matches!(cmd, Command::Unauthenticate) {
state.set_in_unauthenticate(true);
}
if let Command::NotifySet(ref params) = cmd {
let (list, status, metadata) = super::extensions::compute_notify_flags(params);
state.set_in_notify_set(Some(super::NotifyFlags {
list,
status,
metadata,
}));
}
if matches!(cmd, Command::NotifyNone) {
state.set_in_notify_set(Some(super::NotifyFlags::default()));
}
let tag = send_command_on_wire(wire_reader, state, tag_gen, event_sink, &cmd).await?;
loop {
let notify_before = state.notify();
let utf8 = utf8_mode(state);
let resp = wire_reader.read_one(utf8).await?;
let _digest = state.apply_side_effects(&resp);
match resp {
crate::types::Response::Tagged(t) if t.tag == tag => {
emit_tagged_response_code_events(&t, event_sink);
let ctx = build_consumer_context(state, cmd_target.as_ref(), &tag);
let finalized = consumer.finalize_erased(t, &ctx)?;
for resp in finalized.reclassified_as_events {
if !has_critical_response_code(&resp) {
let _ = event_sink.emit(resp.into());
}
}
return Ok(finalized.output);
}
crate::types::Response::Tagged(t) => {
return Err(Error::Protocol(format!(
"unexpected tag {:?} (expected {:?})",
t.tag, tag,
)));
}
crate::types::Response::Untagged(u) => {
let code_emitted = emit_untagged_response_code_events(&u, event_sink);
let class_ctx = ClassificationContext {
notify: notify_before,
command_target: cmd_target.as_ref(),
};
let rule = classification::classify(cmd_kind, &u, &class_ctx);
match rule {
SolicitationRule::OnlySolicited | SolicitationRule::Either => {
let ctx = build_consumer_context(state, cmd_target.as_ref(), &tag);
consumer.on_response(*u, notify_before, &ctx);
}
SolicitationRule::OnlyUnsolicited | SolicitationRule::Impossible => {
if !code_emitted {
let _ = event_sink.emit((*u).into());
}
}
}
}
crate::types::Response::Continuation(c) => {
let ctx = build_consumer_context(state, cmd_target.as_ref(), &tag);
let ContinuationReply::Write(bytes) = consumer.on_continuation(c, &ctx)?;
wire_reader.write_all(&bytes).await?;
}
crate::types::Response::Greeting(_) => {
return Err(Error::Protocol("unexpected greeting mid-command".into()));
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub(in crate::connection) async fn run_prebuilt_command(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
event_sink: &mut event_sink::DriverEventSink,
wire_bytes: BytesMut,
tag: &str,
cmd_kind: crate::types::CommandKind,
cmd_target: Option<MailboxName>,
mut consumer: DriverConsumer,
) -> Result<Box<dyn std::any::Any + Send>, Error> {
trace!(tag, ?cmd_kind, "driver: sending pre-built command");
send_with_literal_sync(wire_reader, state, event_sink, &wire_bytes).await?;
loop {
let notify_before = state.notify();
let utf8 = utf8_mode(state);
let resp = wire_reader.read_one(utf8).await?;
let _digest = state.apply_side_effects(&resp);
match resp {
crate::types::Response::Tagged(t) if t.tag == tag => {
emit_tagged_response_code_events(&t, event_sink);
let ctx = build_consumer_context(state, cmd_target.as_ref(), tag);
let finalized = consumer.finalize_erased(t, &ctx)?;
for resp in finalized.reclassified_as_events {
if !has_critical_response_code(&resp) {
let _ = event_sink.emit(resp.into());
}
}
return Ok(finalized.output);
}
crate::types::Response::Tagged(t) => {
return Err(Error::Protocol(format!(
"unexpected tag {:?} (expected {:?})",
t.tag, tag,
)));
}
crate::types::Response::Untagged(u) => {
let code_emitted = emit_untagged_response_code_events(&u, event_sink);
let class_ctx = ClassificationContext {
notify: notify_before,
command_target: cmd_target.as_ref(),
};
let rule = classification::classify(cmd_kind, &u, &class_ctx);
match rule {
SolicitationRule::OnlySolicited | SolicitationRule::Either => {
let ctx = build_consumer_context(state, cmd_target.as_ref(), tag);
consumer.on_response(*u, notify_before, &ctx);
}
SolicitationRule::OnlyUnsolicited | SolicitationRule::Impossible => {
if !code_emitted {
let _ = event_sink.emit((*u).into());
}
}
}
}
crate::types::Response::Continuation(_) => {
return Err(Error::Protocol(
"unexpected continuation after pre-built command fully sent".into(),
));
}
crate::types::Response::Greeting(_) => {
return Err(Error::Protocol("unexpected greeting mid-command".into()));
}
}
}
}
type SubBatchEntry = (usize, Command, Box<dyn ConsumerErased>);
fn group_into_sub_batches(
commands: Vec<Command>,
consumers: Vec<Box<dyn ConsumerErased>>,
) -> Vec<Vec<SubBatchEntry>> {
let mut sub_batches: Vec<(
std::collections::HashSet<crate::types::CommandKind>,
Vec<SubBatchEntry>,
)> = vec![(std::collections::HashSet::new(), Vec::new())];
let mut max_batch: usize = 0;
for (original_idx, (cmd, consumer)) in
commands.into_iter().zip(consumers.into_iter()).enumerate()
{
let kind = cmd.kind();
let batch_idx = sub_batches
.iter()
.enumerate()
.skip(max_batch)
.find_map(|(i, (kinds_seen, _))| (!kinds_seen.contains(&kind)).then_some(i));
let batch_idx = if let Some(i) = batch_idx {
i
} else {
sub_batches.push((std::collections::HashSet::new(), Vec::new()));
sub_batches.len() - 1
};
sub_batches[batch_idx].0.insert(kind);
max_batch = batch_idx;
sub_batches[batch_idx].1.push((original_idx, cmd, consumer));
}
sub_batches
.into_iter()
.map(|(_, entries)| entries)
.collect()
}
async fn run_pipeline(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
tag_gen: &mut super::tag::TagGenerator,
event_sink: &mut event_sink::DriverEventSink,
commands: Vec<Command>,
consumers: Vec<Box<dyn ConsumerErased>>,
) -> Result<PipelineResults, Error> {
let count = commands.len();
if count == 0 {
return Ok(Vec::new());
}
let has_duplicates = {
let mut seen = std::collections::HashSet::with_capacity(count);
commands.iter().any(|cmd| !seen.insert(cmd.kind()))
};
if !has_duplicates {
return run_pipeline_batch(wire_reader, state, tag_gen, event_sink, commands, consumers)
.await;
}
let sub_batches = group_into_sub_batches(commands, consumers);
let num_batches = sub_batches.len();
trace!(
count,
num_batches,
"driver: splitting pipeline into sub-batches"
);
let mut all_results: Vec<Option<Result<Box<dyn std::any::Any + Send>, Error>>> =
(0..count).map(|_| None).collect();
for entries in sub_batches {
let original_indices: Vec<usize> = entries.iter().map(|(idx, _, _)| *idx).collect();
let (batch_cmds, batch_consumers): (Vec<Command>, Vec<Box<dyn ConsumerErased>>) = entries
.into_iter()
.map(|(_, cmd, cons)| (cmd, cons))
.unzip();
let batch_results = run_pipeline_batch(
wire_reader,
state,
tag_gen,
event_sink,
batch_cmds,
batch_consumers,
)
.await?;
for (batch_pos, result) in batch_results.into_iter().enumerate() {
all_results[original_indices[batch_pos]] = Some(result);
}
}
Ok(all_results
.into_iter()
.map(|r| r.unwrap_or_else(|| Err(Error::Internal("missing pipeline result".into()))))
.collect())
}
#[allow(clippy::too_many_lines)]
async fn run_pipeline_batch(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
tag_gen: &mut super::tag::TagGenerator,
event_sink: &mut event_sink::DriverEventSink,
commands: Vec<Command>,
consumers: Vec<Box<dyn ConsumerErased>>,
) -> Result<PipelineResults, Error> {
let count = commands.len();
if count == 0 {
return Ok(Vec::new());
}
let opts = build_encode_options(state);
let allow_literal8 = state.capabilities().contains(&Capability::Binary) && !is_rev2(state);
let mut tags: Vec<String> = Vec::with_capacity(count);
let mut kinds: Vec<crate::types::CommandKind> = Vec::with_capacity(count);
let mut targets: Vec<Option<MailboxName>> = Vec::with_capacity(count);
let mut encoded_commands = Vec::with_capacity(count);
for cmd in &commands {
let tag = tag_gen.next();
let kind = cmd.kind();
let target = cmd.mailbox_target().cloned();
let encoded = encode_command(&tag, cmd, &opts)?;
tags.push(tag);
kinds.push(kind);
targets.push(target);
encoded_commands.push(encoded);
}
trace!(count, "driver: sending pipelined batch");
match opts.literal_mode {
LiteralMode::LiteralPlus => {
let bufs: Vec<BytesMut> = encoded_commands
.into_iter()
.map(|e| {
let flat = e.into_buf();
super::patch_literals_to_plus_with_binary(&flat, allow_literal8)
})
.collect();
let total: usize = bufs.iter().map(BytesMut::len).sum();
let mut batch = BytesMut::with_capacity(total);
for buf in bufs {
batch.extend_from_slice(&buf);
}
wire_reader.write_all(&batch).await?;
}
LiteralMode::LiteralMinus => {
for encoded in encoded_commands {
let flat = encoded.into_buf();
let patched =
super::patch_small_literals_to_plus_with_binary(&flat, allow_literal8);
send_with_literal_sync(wire_reader, state, event_sink, &patched).await?;
}
}
LiteralMode::Synchronizing => {
for encoded in encoded_commands {
send_encoded_segments(wire_reader, state, event_sink, encoded.segments()).await?;
}
}
}
let mut tag_to_idx: std::collections::HashMap<String, usize> =
std::collections::HashMap::with_capacity(count);
for (i, tag) in tags.iter().enumerate() {
tag_to_idx.insert(tag.clone(), i);
}
let mut consumers: Vec<Option<Box<dyn ConsumerErased>>> =
consumers.into_iter().map(Some).collect();
let mut results: Vec<Option<Result<Box<dyn std::any::Any + Send>, Error>>> =
(0..count).map(|_| None).collect();
let mut completed = 0usize;
while completed < count {
let notify_before = state.notify();
let utf8 = utf8_mode(state);
let resp = wire_reader.read_one(utf8).await?;
if let crate::types::Response::Tagged(ref t) = resp {
if let Some(&idx) = tag_to_idx.get(&t.tag) {
match commands[idx] {
Command::NotifySet(ref params) => {
let (list, status, metadata) =
super::extensions::compute_notify_flags(params);
state.set_in_notify_set(Some(super::NotifyFlags {
list,
status,
metadata,
}));
}
Command::NotifyNone => {
state.set_in_notify_set(Some(super::NotifyFlags::default()));
}
_ => {}
}
}
}
let _digest = state.apply_side_effects(&resp);
match resp {
crate::types::Response::Tagged(t) => {
emit_tagged_response_code_events(&t, event_sink);
if let Some(&idx) = tag_to_idx.get(&t.tag) {
if let Some(consumer) = consumers[idx].take() {
let ctx = build_consumer_context(state, targets[idx].as_ref(), &tags[idx]);
match consumer.finalize_erased(t, &ctx) {
Ok(finalized) => {
for ev in finalized.reclassified_as_events {
if !has_critical_response_code(&ev) {
let _ = event_sink.emit(ev.into());
}
}
results[idx] = Some(Ok(finalized.output));
}
Err(e) => {
results[idx] = Some(Err(e));
}
}
completed += 1;
}
} else {
return Err(Error::Protocol(format!(
"unknown tag in pipeline response: {:?}",
t.tag,
)));
}
}
crate::types::Response::Untagged(u) => {
let code_emitted = emit_untagged_response_code_events(&u, event_sink);
let head_idx = consumers.iter().position(Option::is_some);
if let Some(idx) = head_idx {
let class_ctx = ClassificationContext {
notify: notify_before,
command_target: targets[idx].as_ref(),
};
let rule = classification::classify(kinds[idx], &u, &class_ctx);
match rule {
SolicitationRule::OnlySolicited | SolicitationRule::Either => {
let ctx =
build_consumer_context(state, targets[idx].as_ref(), &tags[idx]);
if let Some(ref mut consumer) = consumers[idx] {
consumer.on_response(*u, notify_before, &ctx);
}
}
SolicitationRule::OnlyUnsolicited | SolicitationRule::Impossible => {
let mut u = Some(u);
for later in (idx + 1)..consumers.len() {
if consumers[later].is_none() {
continue;
}
let claimed = match u {
Some(ref inner) => {
let later_ctx = ClassificationContext {
notify: notify_before,
command_target: targets[later].as_ref(),
};
matches!(
classification::classify(
kinds[later],
inner,
&later_ctx,
),
SolicitationRule::OnlySolicited
)
}
None => false,
};
if claimed {
if let Some(taken) = u.take() {
let ctx = build_consumer_context(
state,
targets[later].as_ref(),
&tags[later],
);
if let Some(ref mut consumer) = consumers[later] {
consumer.on_response(*taken, notify_before, &ctx);
}
}
break;
}
}
if let Some(unclaimed) = u {
if !code_emitted {
let _ = event_sink.emit((*unclaimed).into());
}
}
}
}
} else {
if !code_emitted {
let _ = event_sink.emit((*u).into());
}
}
}
crate::types::Response::Continuation(_) => {
return Err(Error::Protocol(
"unexpected continuation in pipeline response loop".into(),
));
}
crate::types::Response::Greeting(_) => {
return Err(Error::Protocol("unexpected greeting mid-pipeline".into()));
}
}
}
Ok(results
.into_iter()
.map(|r| r.unwrap_or_else(|| Err(Error::Internal("missing pipeline result".into()))))
.collect())
}
async fn run_idle(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
tag_gen: &mut super::tag::TagGenerator,
event_sink: &mut event_sink::DriverEventSink,
done_rx: oneshot::Receiver<()>,
) -> Result<IdleTermination, Error> {
let tag = send_command_on_wire(wire_reader, state, tag_gen, event_sink, &Command::Idle).await?;
wait_for_continuation(wire_reader, state, event_sink).await?;
trace!(tag, "driver: entered IDLE mode");
let mut done_rx = done_rx;
loop {
let _ = event_sink.drain_pending_nonblocking();
let done_signaled = tokio::select! {
biased;
_ = &mut done_rx => true,
result = wire_reader.read_one(utf8_mode(state)) => {
let resp = result?;
let digest = state.apply_side_effects(&resp);
match resp {
crate::types::Response::Tagged(t) if t.tag == tag => {
emit_tagged_response_code_events(&t, event_sink);
match t.status {
StatusKind::Ok => {
trace!(tag, "driver: server terminated IDLE");
return Ok(IdleTermination::ServerTerminated);
}
StatusKind::No => {
return Err(Error::no_with_code(t.text, t.code));
}
StatusKind::Bad => {
return Err(Error::bad_with_code(t.text, t.code));
}
}
}
crate::types::Response::Tagged(t) => {
return Err(Error::Protocol(format!(
"unexpected tag {:?} during IDLE (expected {:?})",
t.tag, tag,
)));
}
crate::types::Response::Untagged(u) => {
let code_emitted =
emit_untagged_response_code_events(&u, event_sink);
if digest.had_bye {
let (text, code) = match *u {
UntaggedResponse::Status { text, code, .. } => {
(text, code)
}
_ => (String::new(), None),
};
warn!(text, "received BYE during IDLE");
return Err(Error::bye_with_code(text, code));
}
if !code_emitted {
let _ = event_sink.emit((*u).into());
}
}
crate::types::Response::Continuation(_) => {
debug!(tag, "ignoring unexpected continuation during IDLE");
}
crate::types::Response::Greeting(_) => {
return Err(Error::Protocol(
"unexpected greeting during IDLE".into(),
));
}
}
false
}
};
if done_signaled {
trace!(tag, "driver: sending DONE");
wire_reader.write_all(b"DONE\r\n").await?;
drain_idle_responses(wire_reader, state, event_sink, &tag).await?;
trace!(tag, "driver: exited IDLE mode");
return Ok(IdleTermination::ClientDone);
}
}
}
async fn drain_idle_responses(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
event_sink: &mut event_sink::DriverEventSink,
tag: &str,
) -> Result<(), Error> {
loop {
let utf8 = utf8_mode(state);
let resp = wire_reader.read_one(utf8).await?;
let digest = state.apply_side_effects(&resp);
match resp {
crate::types::Response::Tagged(t) if t.tag == tag => {
emit_tagged_response_code_events(&t, event_sink);
match t.status {
StatusKind::Ok => return Ok(()),
StatusKind::No => return Err(Error::no_with_code(t.text, t.code)),
StatusKind::Bad => return Err(Error::bad_with_code(t.text, t.code)),
}
}
crate::types::Response::Tagged(t) => {
return Err(Error::Protocol(format!(
"unexpected tag {:?} during IDLE drain (expected {:?})",
t.tag, tag,
)));
}
crate::types::Response::Untagged(u) => {
let code_emitted = emit_untagged_response_code_events(&u, event_sink);
if digest.had_bye {
let (text, code) = match *u {
UntaggedResponse::Status { text, code, .. } => (text, code),
_ => (String::new(), None),
};
warn!(text, "received BYE during IDLE drain");
return Err(Error::bye_with_code(text, code));
}
if !code_emitted {
let _ = event_sink.emit((*u).into());
}
}
crate::types::Response::Continuation(_) | crate::types::Response::Greeting(_) => {
}
}
}
}
async fn send_command_on_wire(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
tag_gen: &mut super::tag::TagGenerator,
event_sink: &mut event_sink::DriverEventSink,
cmd: &Command,
) -> Result<String, Error> {
let tag = tag_gen.next();
trace!(tag, ?cmd, "driver: sending IMAP command");
let opts = build_encode_options(state);
let encoded = encode_command(&tag, cmd, &opts)?;
let allow_literal8 = state.capabilities().contains(&Capability::Binary) && !is_rev2(state);
match opts.literal_mode {
LiteralMode::LiteralPlus => {
let flat = encoded.into_buf();
let patched = super::patch_literals_to_plus_with_binary(&flat, allow_literal8);
send_with_literal_sync(wire_reader, state, event_sink, &patched).await?;
}
LiteralMode::LiteralMinus => {
let flat = encoded.into_buf();
let patched = super::patch_small_literals_to_plus_with_binary(&flat, allow_literal8);
send_with_literal_sync(wire_reader, state, event_sink, &patched).await?;
}
LiteralMode::Synchronizing => {
send_encoded_segments(wire_reader, state, event_sink, encoded.segments()).await?;
}
}
Ok(tag)
}
async fn send_with_literal_sync(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
event_sink: &mut event_sink::DriverEventSink,
buf: &[u8],
) -> Result<(), Error> {
let mut pos = 0;
while pos < buf.len() {
if let Some((marker_end, literal_size)) = super::find_literal_boundary(&buf[pos..]) {
let send_end = pos + marker_end;
wire_reader.write_all(&buf[pos..send_end]).await?;
wait_for_continuation(wire_reader, state, event_sink).await?;
wire_reader
.write_all(&buf[send_end..send_end + literal_size])
.await?;
pos = send_end + literal_size;
} else {
wire_reader.write_all(&buf[pos..]).await?;
break;
}
}
Ok(())
}
async fn send_encoded_segments(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
event_sink: &mut event_sink::DriverEventSink,
segments: &[BytesMut],
) -> Result<(), Error> {
for (i, segment) in segments.iter().enumerate() {
wire_reader.write_all(segment).await?;
if i + 1 < segments.len() {
wait_for_continuation(wire_reader, state, event_sink).await?;
}
}
Ok(())
}
async fn wait_for_continuation(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
event_sink: &mut event_sink::DriverEventSink,
) -> Result<(), Error> {
loop {
let utf8 = utf8_mode(state);
let resp = wire_reader.read_one(utf8).await?;
let digest = state.apply_side_effects(&resp);
match resp {
crate::types::Response::Continuation(_) => return Ok(()),
crate::types::Response::Tagged(t) => {
emit_tagged_response_code_events(&t, event_sink);
return match t.status {
StatusKind::No => Err(Error::no_with_code(t.text, t.code)),
StatusKind::Bad => Err(Error::bad_with_code(t.text, t.code)),
StatusKind::Ok => Err(Error::Protocol(
"unexpected OK before literal continuation \
(RFC 3501 §4.3)"
.into(),
)),
};
}
crate::types::Response::Untagged(u) => {
let code_emitted = emit_untagged_response_code_events(&u, event_sink);
if digest.had_bye {
let (text, code) = match *u {
UntaggedResponse::Status { text, code, .. } => (text, code),
_ => (String::new(), None),
};
warn!(text, "received BYE during literal sync");
return Err(Error::bye_with_code(text, code));
}
if !code_emitted {
let _ = event_sink.emit((*u).into());
}
}
crate::types::Response::Greeting(_) => {
return Err(Error::Protocol("unexpected greeting".into()));
}
}
}
}
async fn run_upgrade(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
tag_gen: &mut super::tag::TagGenerator,
event_sink: &mut event_sink::DriverEventSink,
payload: UpgradePayload,
) -> Result<(), Error> {
match payload {
UpgradePayload::StartTls {
tls_config,
server_name,
} => {
run_starttls_upgrade(
wire_reader,
state,
tag_gen,
event_sink,
tls_config,
server_name,
)
.await
}
UpgradePayload::Compress => {
run_compress_upgrade(wire_reader, state, tag_gen, event_sink).await
}
}
}
pub(in crate::connection) async fn run_starttls_upgrade(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
tag_gen: &mut super::tag::TagGenerator,
event_sink: &mut event_sink::DriverEventSink,
tls_config: Arc<rustls::ClientConfig>,
server_name: rustls_pki_types::ServerName<'static>,
) -> Result<(), Error> {
let consumer =
DriverConsumer::Regular(Box::new(TaggedOkConsumer::default()) as Box<dyn ConsumerErased>);
run_one_command(
wire_reader,
state,
tag_gen,
event_sink,
Command::StartTls,
consumer,
)
.await?;
if !wire_reader.buffer_is_empty() {
*wire_reader = super::wire::WireReader::new(super::ImapStream::Poisoned);
state.apply_infrastructure_failure();
return Err(Error::Protocol(
"STARTTLS: unexpected bytes in buffer at upgrade boundary \
(possible MITM — RFC 3501 Section 6.2.1)"
.into(),
));
}
let old_reader = std::mem::replace(
wire_reader,
super::wire::WireReader::new(super::ImapStream::Poisoned),
);
let old_stream = old_reader.into_stream();
let Some(tcp) = old_stream.into_tcp() else {
state.apply_infrastructure_failure();
return Err(Error::Protocol(
"STARTTLS requires a plain TCP stream (already TLS or compressed)".into(),
));
};
let connector = tokio_rustls::TlsConnector::from(tls_config);
let tls_stream = match connector.connect(server_name, tcp).await {
Ok(s) => s,
Err(e) => {
state.apply_infrastructure_failure();
return Err(Error::Io(Arc::new(std::io::Error::other(e))));
}
};
*wire_reader = super::wire::WireReader::new(super::ImapStream::Tls(tls_stream));
state.apply_capability_fetch(Vec::new());
let cap_consumer =
DriverConsumer::Regular(Box::new(CapabilityConsumer::default()) as Box<dyn ConsumerErased>);
let result = run_one_command(
wire_reader,
state,
tag_gen,
event_sink,
Command::Capability,
cap_consumer,
)
.await?;
let caps = result
.downcast::<Vec<Capability>>()
.map_err(|_| Error::Internal("CapabilityConsumer output downcast failed".into()))?;
state.apply_capability_fetch(*caps);
debug!("STARTTLS upgrade complete (RFC 3501 Section 6.2.1)");
Ok(())
}
async fn run_compress_upgrade(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
tag_gen: &mut super::tag::TagGenerator,
event_sink: &mut event_sink::DriverEventSink,
) -> Result<(), Error> {
let consumer =
DriverConsumer::Regular(Box::new(TaggedOkConsumer::default()) as Box<dyn ConsumerErased>);
run_one_command(
wire_reader,
state,
tag_gen,
event_sink,
Command::Compress,
consumer,
)
.await?;
let remaining = wire_reader.take_buffer();
let old_reader = std::mem::replace(
wire_reader,
super::wire::WireReader::new(super::ImapStream::Poisoned),
);
let old_stream = old_reader.into_stream();
let inner = match old_stream {
super::ImapStream::Plain(tcp) => super::InnerStream::Plain(tcp),
super::ImapStream::Tls(tls) => super::InnerStream::Tls(tls),
super::ImapStream::Compressed(_) => {
state.apply_infrastructure_failure();
return Err(Error::Protocol(
"COMPRESS=DEFLATE already active on this connection".into(),
));
}
super::ImapStream::Poisoned => {
state.apply_infrastructure_failure();
return Err(Error::Protocol(
"stream poisoned — connection is dead".into(),
));
}
#[cfg(test)]
super::ImapStream::Memory(_) => {
state.apply_infrastructure_failure();
return Err(Error::Protocol(
"COMPRESS=DEFLATE not supported on in-memory test streams".into(),
));
}
};
let mut compressed = super::CompressedStream::new(inner);
if !remaining.is_empty() {
compressed.raw_read_buf.extend_from_slice(&remaining);
}
*wire_reader = super::wire::WireReader::new(super::ImapStream::Compressed(compressed));
debug!("COMPRESS=DEFLATE activated (RFC 4978)");
Ok(())
}
async fn logout_best_effort(
wire_reader: &mut super::wire::WireReader,
state: &mut super::state::ProtocolState,
tag_gen: &mut super::tag::TagGenerator,
) -> Result<(), Error> {
if state.session_state() == super::SessionState::Logout {
return Ok(());
}
let tag = tag_gen.next();
let logout_line = format!("{tag} LOGOUT\r\n");
wire_reader.write_all(logout_line.as_bytes()).await?;
loop {
let utf8 = utf8_mode(state);
let resp = wire_reader.read_one(utf8).await?;
let _digest = state.apply_side_effects(&resp);
match resp {
crate::types::Response::Tagged(t) if t.tag == tag => break,
crate::types::Response::Tagged(_) => break,
_ => {}
}
}
Ok(())
}
fn is_rev2(state: &super::state::ProtocolState) -> bool {
let has_rev2 = state.capabilities().contains(&Capability::Imap4Rev2);
let has_rev1 = state.capabilities().contains(&Capability::Imap4Rev1);
if has_rev2 && has_rev1 {
state
.enabled()
.iter()
.any(|e| e.eq_ignore_ascii_case("IMAP4rev2"))
} else {
has_rev2
}
}
fn utf8_mode(state: &super::state::ProtocolState) -> bool {
state
.enabled()
.iter()
.any(|e| e.eq_ignore_ascii_case("UTF8=ACCEPT"))
|| is_rev2(state)
}
fn literal_mode(state: &super::state::ProtocolState) -> LiteralMode {
if state.capabilities().contains(&Capability::LiteralPlus) {
LiteralMode::LiteralPlus
} else if state.capabilities().contains(&Capability::LiteralMinus) || is_rev2(state) {
LiteralMode::LiteralMinus
} else {
LiteralMode::Synchronizing
}
}
fn build_encode_options(state: &super::state::ProtocolState) -> EncodeOptions {
EncodeOptions {
utf8_mode: utf8_mode(state),
literal_mode: literal_mode(state),
capabilities: state.capabilities().to_vec(),
}
}
fn build_consumer_context<'a>(
state: &'a super::state::ProtocolState,
command_target: Option<&'a MailboxName>,
command_tag: &'a str,
) -> ConsumerContext<'a> {
ConsumerContext {
capabilities: state.capabilities(),
enabled: state.enabled(),
command_target,
command_tag,
}
}
fn emit_untagged_response_code_events(
u: &UntaggedResponse,
event_sink: &mut event_sink::DriverEventSink,
) -> bool {
match u {
UntaggedResponse::Status {
code: Some(ResponseCode::Alert),
text,
..
} => {
let _ = event_sink.emit(TypedEvent::Alert(text.clone()));
true
}
UntaggedResponse::Status {
code: Some(ResponseCode::NotificationOverflow(detail)),
text,
..
} => {
let _ = event_sink.emit(TypedEvent::NotificationOverflow {
code: detail.clone(),
text: text.clone(),
});
true
}
_ => false,
}
}
fn emit_tagged_response_code_events(
t: &crate::types::response::TaggedResponse,
event_sink: &mut event_sink::DriverEventSink,
) {
match &t.code {
Some(ResponseCode::Alert) => {
let _ = event_sink.emit(TypedEvent::Alert(t.text.clone()));
}
Some(ResponseCode::NotificationOverflow(detail)) => {
let _ = event_sink.emit(TypedEvent::NotificationOverflow {
code: detail.clone(),
text: t.text.clone(),
});
}
_ => {}
}
}
fn has_critical_response_code(u: &UntaggedResponse) -> bool {
matches!(
u,
UntaggedResponse::Status {
status,
code: Some(ResponseCode::Alert),
..
} | UntaggedResponse::Status {
status,
code: Some(ResponseCode::NotificationOverflow(_)),
..
}
if !matches!(status, crate::types::response::UntaggedStatus::Bye)
)
}
pub(super) mod event_sink;
#[cfg(test)]
#[path = "mod_tests.rs"]
mod tests;