use std::num::{ParseIntError, TryFromIntError};
#[cfg(feature = "bounded-static")]
use bounded_static::{IntoBoundedStatic, ToStatic};
use imap_types::{
auth::AuthenticateData,
command::Command,
core::{LiteralMode, Tag},
extensions::idle::IdleDone,
response::{Greeting, Response},
};
use nom::error::{ErrorKind, FromExternalError, ParseError};
use crate::{
auth::authenticate_data,
command::command,
extensions::idle::idle_done,
response::{greeting, response},
AuthenticateDataCodec, CommandCodec, GreetingCodec, IdleDoneCodec, ResponseCodec,
};
pub(crate) type IMAPResult<'a, I, O> = Result<(I, O), nom::Err<IMAPParseError<'a, I>>>;
#[derive(Debug)]
pub(crate) struct IMAPParseError<'a, I> {
#[allow(unused)]
pub input: I,
pub kind: IMAPErrorKind<'a>,
}
#[derive(Debug)]
pub(crate) enum IMAPErrorKind<'a> {
Literal {
tag: Option<Tag<'a>>,
length: u32,
mode: LiteralMode,
},
BadNumber,
BadBase64,
BadDateTime,
LiteralContainsNull,
RecursionLimitExceeded,
Nom(ErrorKind),
}
impl<'a, I> ParseError<I> for IMAPParseError<'a, I> {
fn from_error_kind(input: I, kind: ErrorKind) -> Self {
Self {
input,
kind: IMAPErrorKind::Nom(kind),
}
}
fn append(input: I, kind: ErrorKind, _: Self) -> Self {
Self {
input,
kind: IMAPErrorKind::Nom(kind),
}
}
}
impl<'a, I> FromExternalError<I, ParseIntError> for IMAPParseError<'a, I> {
fn from_external_error(input: I, _: ErrorKind, _: ParseIntError) -> Self {
Self {
input,
kind: IMAPErrorKind::BadNumber,
}
}
}
impl<'a, I> FromExternalError<I, TryFromIntError> for IMAPParseError<'a, I> {
fn from_external_error(input: I, _: ErrorKind, _: TryFromIntError) -> Self {
Self {
input,
kind: IMAPErrorKind::BadNumber,
}
}
}
impl<'a, I> FromExternalError<I, base64::DecodeError> for IMAPParseError<'a, I> {
fn from_external_error(input: I, _: ErrorKind, _: base64::DecodeError) -> Self {
Self {
input,
kind: IMAPErrorKind::BadBase64,
}
}
}
pub trait Decoder {
type Message<'a>: Sized;
type Error<'a>;
fn decode<'a>(&self, input: &'a [u8])
-> Result<(&'a [u8], Self::Message<'a>), Self::Error<'a>>;
#[cfg(feature = "bounded-static")]
#[cfg_attr(docsrs, doc(cfg(feature = "bounded-static")))]
fn decode_static<'a>(
&self,
input: &'a [u8],
) -> Result<(&'a [u8], Self::Message<'static>), Self::Error<'static>>
where
Self::Message<'a>: IntoBoundedStatic<Static = Self::Message<'static>>,
Self::Error<'a>: IntoBoundedStatic<Static = Self::Error<'static>>,
{
let (remaining, value) = self.decode(input).map_err(IntoBoundedStatic::into_static)?;
Ok((remaining, value.into_static()))
}
}
#[cfg_attr(feature = "bounded-static", derive(ToStatic))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum GreetingDecodeError {
Incomplete,
Failed,
}
#[cfg_attr(feature = "bounded-static", derive(ToStatic))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum CommandDecodeError<'a> {
Incomplete,
LiteralFound {
tag: Tag<'a>,
length: u32,
mode: LiteralMode,
},
Failed,
}
#[cfg_attr(feature = "bounded-static", derive(ToStatic))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum AuthenticateDataDecodeError {
Incomplete,
Failed,
}
#[cfg_attr(feature = "bounded-static", derive(ToStatic))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ResponseDecodeError {
Incomplete,
LiteralFound {
length: u32,
},
Failed,
}
#[cfg_attr(feature = "bounded-static", derive(ToStatic))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum IdleDoneDecodeError {
Incomplete,
Failed,
}
impl Decoder for GreetingCodec {
type Message<'a> = Greeting<'a>;
type Error<'a> = GreetingDecodeError;
fn decode<'a>(
&self,
input: &'a [u8],
) -> Result<(&'a [u8], Self::Message<'a>), Self::Error<'static>> {
match greeting(input) {
Ok((rem, grt)) => Ok((rem, grt)),
Err(nom::Err::Incomplete(_)) => Err(GreetingDecodeError::Incomplete),
Err(nom::Err::Failure(_)) | Err(nom::Err::Error(_)) => Err(GreetingDecodeError::Failed),
}
}
}
impl Decoder for CommandCodec {
type Message<'a> = Command<'a>;
type Error<'a> = CommandDecodeError<'a>;
fn decode<'a>(
&self,
input: &'a [u8],
) -> Result<(&'a [u8], Self::Message<'a>), Self::Error<'a>> {
match command(input) {
Ok((rem, cmd)) => Ok((rem, cmd)),
Err(nom::Err::Incomplete(_)) => Err(CommandDecodeError::Incomplete),
Err(nom::Err::Failure(error)) => match error {
IMAPParseError {
input: _,
kind: IMAPErrorKind::Literal { tag, length, mode },
} => Err(CommandDecodeError::LiteralFound {
tag: tag.expect("Expected `Some(tag)` in `IMAPErrorKind::Literal`, got `None`"),
length,
mode,
}),
_ => Err(CommandDecodeError::Failed),
},
Err(nom::Err::Error(_)) => Err(CommandDecodeError::Failed),
}
}
}
impl Decoder for ResponseCodec {
type Message<'a> = Response<'a>;
type Error<'a> = ResponseDecodeError;
fn decode<'a>(
&self,
input: &'a [u8],
) -> Result<(&'a [u8], Self::Message<'a>), Self::Error<'static>> {
match response(input) {
Ok((rem, rsp)) => Ok((rem, rsp)),
Err(nom::Err::Incomplete(_)) => Err(ResponseDecodeError::Incomplete),
Err(nom::Err::Error(error) | nom::Err::Failure(error)) => match error {
IMAPParseError {
kind: IMAPErrorKind::Literal { length, .. },
..
} => Err(ResponseDecodeError::LiteralFound { length }),
_ => Err(ResponseDecodeError::Failed),
},
}
}
}
impl Decoder for AuthenticateDataCodec {
type Message<'a> = AuthenticateData;
type Error<'a> = AuthenticateDataDecodeError;
fn decode<'a>(
&self,
input: &'a [u8],
) -> Result<(&'a [u8], Self::Message<'a>), Self::Error<'static>> {
match authenticate_data(input) {
Ok((rem, rsp)) => Ok((rem, rsp)),
Err(nom::Err::Incomplete(_)) => Err(AuthenticateDataDecodeError::Incomplete),
Err(nom::Err::Failure(_)) | Err(nom::Err::Error(_)) => {
Err(AuthenticateDataDecodeError::Failed)
}
}
}
}
impl Decoder for IdleDoneCodec {
type Message<'a> = IdleDone;
type Error<'a> = IdleDoneDecodeError;
fn decode<'a>(
&self,
input: &'a [u8],
) -> Result<(&'a [u8], Self::Message<'a>), Self::Error<'static>> {
match idle_done(input) {
Ok((rem, rsp)) => Ok((rem, rsp)),
Err(nom::Err::Incomplete(_)) => Err(IdleDoneDecodeError::Incomplete),
Err(nom::Err::Failure(_)) | Err(nom::Err::Error(_)) => Err(IdleDoneDecodeError::Failed),
}
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroU32;
use imap_types::{
command::{Command, CommandBody},
core::{IString, Literal, NString, NonEmptyVec},
extensions::idle::IdleDone,
fetch::MessageDataItem,
mailbox::Mailbox,
response::{Data, Greeting, GreetingKind, Response},
secret::Secret,
};
use super::*;
#[test]
fn test_decode_greeting() {
let tests = [
(
b"* OK ...\r\n".as_ref(),
Ok((
b"".as_ref(),
Greeting::new(GreetingKind::Ok, None, "...").unwrap(),
)),
),
(
b"* ByE .\r\n???".as_ref(),
Ok((
b"???".as_ref(),
Greeting::new(GreetingKind::Bye, None, ".").unwrap(),
)),
),
(
b"* preaUth x\r\n?".as_ref(),
Ok((
b"?".as_ref(),
Greeting::new(GreetingKind::PreAuth, None, "x").unwrap(),
)),
),
(b"*".as_ref(), Err(GreetingDecodeError::Incomplete)),
(b"* ".as_ref(), Err(GreetingDecodeError::Incomplete)),
(b"* O".as_ref(), Err(GreetingDecodeError::Incomplete)),
(b"* OK".as_ref(), Err(GreetingDecodeError::Incomplete)),
(b"* OK ".as_ref(), Err(GreetingDecodeError::Incomplete)),
(b"* OK .".as_ref(), Err(GreetingDecodeError::Incomplete)),
(b"* OK .\r".as_ref(), Err(GreetingDecodeError::Incomplete)),
(b"**".as_ref(), Err(GreetingDecodeError::Failed)),
(b"* NO x\r\n".as_ref(), Err(GreetingDecodeError::Failed)),
];
for (test, expected) in tests {
let got = GreetingCodec::default().decode(test);
dbg!((std::str::from_utf8(test).unwrap(), &expected, &got));
assert_eq!(expected, got);
#[cfg(feature = "bounded-static")]
{
let got = GreetingCodec::default().decode_static(test);
assert_eq!(expected, got);
}
}
}
#[test]
fn test_decode_command() {
let tests = [
(
b"a noop\r\n".as_ref(),
Ok((b"".as_ref(), Command::new("a", CommandBody::Noop).unwrap())),
),
(
b"a noop\r\n???".as_ref(),
Ok((
b"???".as_ref(),
Command::new("a", CommandBody::Noop).unwrap(),
)),
),
(
b"a select {5}\r\ninbox\r\n".as_ref(),
Ok((
b"".as_ref(),
Command::new(
"a",
CommandBody::Select {
mailbox: Mailbox::Inbox,
},
)
.unwrap(),
)),
),
(
b"a select {5}\r\ninbox\r\nxxx".as_ref(),
Ok((
b"xxx".as_ref(),
Command::new(
"a",
CommandBody::Select {
mailbox: Mailbox::Inbox,
},
)
.unwrap(),
)),
),
(b"a".as_ref(), Err(CommandDecodeError::Incomplete)),
(b"a ".as_ref(), Err(CommandDecodeError::Incomplete)),
(b"a n".as_ref(), Err(CommandDecodeError::Incomplete)),
(b"a no".as_ref(), Err(CommandDecodeError::Incomplete)),
(b"a noo".as_ref(), Err(CommandDecodeError::Incomplete)),
(b"a noop".as_ref(), Err(CommandDecodeError::Incomplete)),
(b"a noop\r".as_ref(), Err(CommandDecodeError::Incomplete)),
(
b"a select {5}\r\n".as_ref(),
Err(CommandDecodeError::LiteralFound {
tag: Tag::try_from("a").unwrap(),
length: 5,
mode: LiteralMode::Sync,
}),
),
(
b"a select {5}\r\nxxx".as_ref(),
Err(CommandDecodeError::Incomplete),
),
(b"* noop\r\n".as_ref(), Err(CommandDecodeError::Failed)),
(b"A noop\r\n".as_ref(), Err(CommandDecodeError::Failed)),
];
for (test, expected) in tests {
let got = CommandCodec::default().decode(test);
dbg!((std::str::from_utf8(test).unwrap(), &expected, &got));
assert_eq!(expected, got);
#[cfg(feature = "bounded-static")]
{
let got = CommandCodec::default().decode_static(test);
assert_eq!(expected, got);
}
}
}
#[test]
fn test_decode_authenticate_data() {
let tests = [
(
b"VGVzdA==\r\n".as_ref(),
Ok((
b"".as_ref(),
AuthenticateData(Secret::new(b"Test".to_vec())),
)),
),
(
b"VGVzdA==\r\nx".as_ref(),
Ok((
b"x".as_ref(),
AuthenticateData(Secret::new(b"Test".to_vec())),
)),
),
(b"V".as_ref(), Err(AuthenticateDataDecodeError::Incomplete)),
(b"VG".as_ref(), Err(AuthenticateDataDecodeError::Incomplete)),
(
b"VGV".as_ref(),
Err(AuthenticateDataDecodeError::Incomplete),
),
(
b"VGVz".as_ref(),
Err(AuthenticateDataDecodeError::Incomplete),
),
(
b"VGVzd".as_ref(),
Err(AuthenticateDataDecodeError::Incomplete),
),
(
b"VGVzdA".as_ref(),
Err(AuthenticateDataDecodeError::Incomplete),
),
(
b"VGVzdA=".as_ref(),
Err(AuthenticateDataDecodeError::Incomplete),
),
(
b"VGVzdA==".as_ref(),
Err(AuthenticateDataDecodeError::Incomplete),
),
(
b"VGVzdA==\r".as_ref(),
Err(AuthenticateDataDecodeError::Incomplete),
),
(
b"VGVzdA==\r\n".as_ref(),
Ok((
b"".as_ref(),
AuthenticateData(Secret::new(b"Test".to_vec())),
)),
),
(
b"VGVzdA== \r\n".as_ref(),
Err(AuthenticateDataDecodeError::Failed),
),
(
b" VGVzdA== \r\n".as_ref(),
Err(AuthenticateDataDecodeError::Failed),
),
(
b" V GVzdA== \r\n".as_ref(),
Err(AuthenticateDataDecodeError::Failed),
),
(
b" V GVzdA= \r\n".as_ref(),
Err(AuthenticateDataDecodeError::Failed),
),
];
for (test, expected) in tests {
let got = AuthenticateDataCodec::default().decode(test);
dbg!((std::str::from_utf8(test).unwrap(), &expected, &got));
assert_eq!(expected, got);
#[cfg(feature = "bounded-static")]
{
let got = AuthenticateDataCodec::default().decode_static(test);
assert_eq!(expected, got);
}
}
}
#[test]
fn test_decode_idle_done() {
let tests = [
(b"done\r\n".as_ref(), Ok((b"".as_ref(), IdleDone))),
(b"done\r\n?".as_ref(), Ok((b"?".as_ref(), IdleDone))),
(b"d".as_ref(), Err(IdleDoneDecodeError::Incomplete)),
(b"do".as_ref(), Err(IdleDoneDecodeError::Incomplete)),
(b"don".as_ref(), Err(IdleDoneDecodeError::Incomplete)),
(b"done".as_ref(), Err(IdleDoneDecodeError::Incomplete)),
(b"done\r".as_ref(), Err(IdleDoneDecodeError::Incomplete)),
(b"donee\r\n".as_ref(), Err(IdleDoneDecodeError::Failed)),
(b" done\r\n".as_ref(), Err(IdleDoneDecodeError::Failed)),
(b"done \r\n".as_ref(), Err(IdleDoneDecodeError::Failed)),
(b" done \r\n".as_ref(), Err(IdleDoneDecodeError::Failed)),
];
for (test, expected) in tests {
let got = IdleDoneCodec::default().decode(test);
dbg!((std::str::from_utf8(test).unwrap(), &expected, &got));
assert_eq!(expected, got);
#[cfg(feature = "bounded-static")]
{
let got = IdleDoneCodec::default().decode_static(test);
assert_eq!(expected, got);
}
}
}
#[test]
fn test_decode_response() {
let tests = [
(b"".as_ref(), Err(ResponseDecodeError::Incomplete)),
(b"*".as_ref(), Err(ResponseDecodeError::Incomplete)),
(b"* ".as_ref(), Err(ResponseDecodeError::Incomplete)),
(b"* S".as_ref(), Err(ResponseDecodeError::Incomplete)),
(b"* SE".as_ref(), Err(ResponseDecodeError::Incomplete)),
(b"* SEA".as_ref(), Err(ResponseDecodeError::Incomplete)),
(b"* SEAR".as_ref(), Err(ResponseDecodeError::Incomplete)),
(b"* SEARC".as_ref(), Err(ResponseDecodeError::Incomplete)),
(b"* SEARCH".as_ref(), Err(ResponseDecodeError::Incomplete)),
(b"* SEARCH ".as_ref(), Err(ResponseDecodeError::Incomplete)),
(b"* SEARCH 1".as_ref(), Err(ResponseDecodeError::Incomplete)),
(
b"* SEARCH 1\r".as_ref(),
Err(ResponseDecodeError::Incomplete),
),
(
b"* SEARCH 1\r\n".as_ref(),
Ok((
b"".as_ref(),
Response::Data(Data::Search(vec![NonZeroU32::new(1).unwrap()])),
)),
),
(
b"* SEARCH 1\r\n???".as_ref(),
Ok((
b"???".as_ref(),
Response::Data(Data::Search(vec![NonZeroU32::new(1).unwrap()])),
)),
),
(
b"* 1 FETCH (RFC822 {5}\r\nhello)\r\n".as_ref(),
Ok((
b"".as_ref(),
Response::Data(Data::Fetch {
seq: NonZeroU32::new(1).unwrap(),
items: NonEmptyVec::from(MessageDataItem::Rfc822(NString(Some(
IString::Literal(Literal::try_from(b"hello".as_ref()).unwrap()),
)))),
}),
)),
),
(
b"* 1 FETCH (RFC822 {5}\r\n".as_ref(),
Err(ResponseDecodeError::LiteralFound { length: 5 }),
),
(
b"* search 1 2 3\r\n".as_ref(),
Err(ResponseDecodeError::Failed),
),
(b"A search\r\n".as_ref(), Err(ResponseDecodeError::Failed)),
];
for (test, expected) in tests {
let got = ResponseCodec::default().decode(test);
dbg!((std::str::from_utf8(test).unwrap(), &expected, &got));
assert_eq!(expected, got);
#[cfg(feature = "bounded-static")]
{
let got = ResponseCodec::default().decode_static(test);
assert_eq!(expected, got);
}
}
}
}