use std::collections::VecDeque;
use tracing::warn;
use super::{NotifyFlags, SessionState};
use crate::error::Error;
use crate::types::response::{
Capability, GreetingResponse, GreetingStatus, ResponseCode, TaggedResponse, UntaggedResponse,
UntaggedStatus,
};
use crate::types::validated::MailboxName;
use crate::types::Response;
#[derive(Debug)]
#[allow(clippy::struct_excessive_bools)] pub(crate) struct ProtocolState {
state: SessionState,
capabilities: Vec<Capability>,
notify: NotifyFlags,
selected: Option<MailboxName>,
notify_history: VecDeque<NotifyFlags>,
enabled: Vec<String>,
in_logout: bool,
in_auth: bool,
in_select: Option<MailboxName>,
in_close: bool,
in_notify_set: Option<NotifyFlags>,
in_unauthenticate: bool,
}
impl ProtocolState {
pub(crate) fn new() -> Self {
Self {
state: SessionState::NotAuthenticated,
capabilities: Vec::new(),
notify: NotifyFlags::default(),
selected: None,
notify_history: VecDeque::new(),
enabled: Vec::new(),
in_logout: false,
in_auth: false,
in_select: None,
in_close: false,
in_notify_set: None,
in_unauthenticate: false,
}
}
pub(crate) fn session_state(&self) -> SessionState {
self.state
}
pub(crate) fn capabilities(&self) -> &[Capability] {
&self.capabilities
}
pub(crate) fn notify(&self) -> NotifyFlags {
self.notify
}
pub(crate) fn enabled(&self) -> &[String] {
&self.enabled
}
pub(in crate::connection) fn snapshot(&self) -> super::driver::ConnectionStateSnapshot {
super::driver::ConnectionStateSnapshot {
session_state: self.state,
capabilities: self.capabilities.clone(),
enabled: self.enabled.clone(),
}
}
pub(in crate::connection) fn apply_greeting(
&mut self,
g: &GreetingResponse,
) -> Result<Option<String>, Error> {
match g.status {
GreetingStatus::Ok => {
self.state = SessionState::NotAuthenticated;
}
GreetingStatus::PreAuth => {
self.state = SessionState::Authenticated;
}
GreetingStatus::Bye => {
self.state = SessionState::Logout;
return Err(Error::bye_with_code(g.text.clone(), g.code.clone()));
}
}
if let Some(ResponseCode::Capability(caps)) = &g.code {
self.capabilities.clone_from(caps);
}
if g.code == Some(ResponseCode::Alert) {
return Ok(Some(g.text.clone()));
}
Ok(None)
}
pub(in crate::connection) fn apply_infrastructure_failure(&mut self) {
self.state = SessionState::Logout;
}
pub(in crate::connection) fn apply_capability_fetch(&mut self, caps: Vec<Capability>) {
self.capabilities = caps;
}
pub(crate) fn set_in_logout(&mut self, val: bool) {
self.in_logout = val;
}
pub(crate) fn set_in_auth(&mut self, val: bool) {
self.in_auth = val;
}
pub(crate) fn set_in_select(&mut self, mailbox: Option<MailboxName>) {
self.in_select = mailbox;
}
pub(crate) fn set_in_close(&mut self, val: bool) {
self.in_close = val;
}
pub(crate) fn set_in_notify_set(&mut self, flags: Option<NotifyFlags>) {
self.in_notify_set = flags;
}
pub(crate) fn set_in_unauthenticate(&mut self, val: bool) {
self.in_unauthenticate = val;
}
pub(crate) fn apply_side_effects(&mut self, resp: &Response) -> SideEffectDigest {
let mut digest = SideEffectDigest::default();
let notify_snapshot = self.notify;
match resp {
Response::Untagged(u) => {
self.apply_untagged(u, &mut digest);
self.notify_history.push_back(notify_snapshot);
if self.notify_history.len() > 1024 {
self.notify_history.pop_front();
}
}
Response::Tagged(t) => {
self.apply_tagged(t, &mut digest);
}
Response::Continuation(_) | Response::Greeting(_) => {}
}
digest
}
fn apply_untagged(&mut self, u: &UntaggedResponse, digest: &mut SideEffectDigest) {
match u {
UntaggedResponse::Capability(caps)
| UntaggedResponse::Status {
code: Some(ResponseCode::Capability(caps)),
..
} => {
self.capabilities.clone_from(caps);
}
UntaggedResponse::Status {
status: UntaggedStatus::Bye,
..
} => {
self.state = SessionState::Logout;
digest.had_bye = true;
}
UntaggedResponse::Status {
code: Some(ResponseCode::NotificationOverflow(_)),
..
} => {
warn!(
"server sent NOTIFICATIONOVERFLOW — NOTIFY registration \
cleared (RFC 5465 §5.8)"
);
self.notify = NotifyFlags::default();
self.in_notify_set = None;
digest.had_notification_overflow = true;
}
UntaggedResponse::Enabled(ref exts) => {
for ext in exts {
if !self.enabled.iter().any(|e| e.eq_ignore_ascii_case(ext)) {
self.enabled.push(ext.clone());
}
}
}
_ => {}
}
}
fn apply_tagged(&mut self, t: &TaggedResponse, digest: &mut SideEffectDigest) {
match &t.code {
Some(ResponseCode::Capability(caps)) => {
self.capabilities.clone_from(caps);
}
Some(ResponseCode::NotificationOverflow(_)) => {
warn!(
"tagged NOTIFICATIONOVERFLOW — NOTIFY registration \
cleared (RFC 5465 §5.8)"
);
self.notify = NotifyFlags::default();
self.in_notify_set = None;
digest.had_notification_overflow = true;
}
_ => {}
}
if self.in_logout && t.status == crate::types::response::StatusKind::Ok {
self.state = SessionState::Logout;
self.in_logout = false;
}
if self.in_auth && t.status == crate::types::response::StatusKind::Ok {
self.state = SessionState::Authenticated;
self.in_auth = false;
} else if self.in_auth {
self.in_auth = false;
}
if let Some(mailbox) = self.in_select.take() {
if t.status == crate::types::response::StatusKind::Ok {
self.state = SessionState::Selected;
self.selected = Some(mailbox);
} else if t.status == crate::types::response::StatusKind::No {
self.state = SessionState::Authenticated;
self.selected = None;
}
}
if self.in_close && t.status == crate::types::response::StatusKind::Ok {
self.state = SessionState::Authenticated;
self.selected = None;
self.in_close = false;
} else if self.in_close {
self.in_close = false;
}
if self.in_unauthenticate && t.status == crate::types::response::StatusKind::Ok {
self.state = SessionState::NotAuthenticated;
self.selected = None;
self.notify = NotifyFlags::default();
self.notify_history.clear();
self.enabled.clear();
self.in_notify_set = None;
self.in_unauthenticate = false;
} else if self.in_unauthenticate {
self.in_unauthenticate = false;
}
if let Some(pending_flags) = self.in_notify_set.take() {
if t.status == crate::types::response::StatusKind::Ok {
self.notify = pending_flags;
}
}
}
}
#[derive(Debug, Default)]
pub(crate) struct SideEffectDigest {
pub(crate) had_notification_overflow: bool,
pub(crate) had_bye: bool,
}
#[cfg(test)]
#[path = "state_tests.rs"]
mod tests;