use std::{borrow::Cow, slice::Iter};
use crate::{Error, Request};
pub const MAX_LINE_LENGTH: usize = 4096;
#[derive(Default)]
pub struct RequestReceiver {
buf: Vec<u8>,
buf_used: bool,
}
pub struct DataReceiver {
crlf_dot: bool,
last_ch: u8,
prev_last_ch: u8,
}
pub struct BdatReceiver {
pub is_last: bool,
bytes_left: usize,
}
pub struct DummyDataReceiver {
is_bdat: bool,
bdat_bytes_left: usize,
crlf_dot: bool,
last_ch: u8,
prev_last_ch: u8,
}
#[derive(Default)]
pub struct DummyLineReceiver {}
#[derive(Default)]
pub struct LineReceiver<T> {
pub buf: Vec<u8>,
pub state: T,
}
impl RequestReceiver {
pub fn buf(&mut self) -> &mut Vec<u8> {
if self.buf_used {
self.buf.clear();
self.buf_used = false;
}
&mut self.buf
}
pub fn ingest<'this, 'bytes, 'out>(
&'this mut self,
bytes: &mut Iter<'bytes, u8>,
) -> Result<Request<Cow<'out, str>>, Error>
where
'this: 'out,
'bytes: 'out,
{
self.buf();
if self.buf.is_empty() {
let buf = bytes.as_slice();
match Request::parse(bytes) {
Err(Error::NeedsMoreData { bytes_left }) => {
if bytes_left > 0 {
if bytes_left < MAX_LINE_LENGTH {
self.buf = buf[buf.len().saturating_sub(bytes_left)..].to_vec();
} else {
return Err(Error::ResponseTooLong);
}
}
}
result => return result,
}
} else {
for &ch in bytes {
self.buf.push(ch);
if ch == b'\n' {
self.buf_used = true;
return Request::parse(&mut self.buf.iter());
} else if self.buf.len() == MAX_LINE_LENGTH {
self.buf.clear();
return Err(Error::ResponseTooLong);
}
}
}
Err(Error::NeedsMoreData { bytes_left: 0 })
}
}
impl DataReceiver {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
crlf_dot: false,
last_ch: 0,
prev_last_ch: 0,
}
}
pub fn ingest(&mut self, bytes: &mut Iter<'_, u8>, buf: &mut Vec<u8>) -> bool {
for &ch in bytes {
match ch {
b'.' if self.last_ch == b'\n' && self.prev_last_ch == b'\r' => {
self.crlf_dot = true;
}
b'\n' if self.crlf_dot && self.last_ch == b'\r' => {
buf.truncate(buf.len() - 3);
return true;
}
b'\r' => {
buf.push(ch);
}
_ => {
buf.push(ch);
self.crlf_dot = false;
}
}
self.prev_last_ch = self.last_ch;
self.last_ch = ch;
}
false
}
}
impl BdatReceiver {
pub fn new(chunk_size: usize, is_last: bool) -> Self {
Self {
bytes_left: chunk_size,
is_last,
}
}
pub fn ingest(&mut self, bytes: &mut Iter<'_, u8>, buf: &mut Vec<u8>) -> bool {
while self.bytes_left > 0 {
if let Some(&ch) = bytes.next() {
buf.push(ch);
self.bytes_left -= 1;
} else {
return false;
}
}
true
}
}
impl DummyDataReceiver {
pub fn new_bdat(chunk_size: usize) -> Self {
Self {
bdat_bytes_left: chunk_size,
is_bdat: true,
crlf_dot: false,
last_ch: 0,
prev_last_ch: 0,
}
}
pub fn new_data(data: &DataReceiver) -> Self {
Self {
is_bdat: false,
bdat_bytes_left: 0,
crlf_dot: data.crlf_dot,
last_ch: data.last_ch,
prev_last_ch: data.prev_last_ch,
}
}
pub fn ingest(&mut self, bytes: &mut Iter<'_, u8>) -> bool {
if !self.is_bdat {
for &ch in bytes {
match ch {
b'.' if self.last_ch == b'\n' && self.prev_last_ch == b'\r' => {
self.crlf_dot = true;
}
b'\n' if self.crlf_dot && self.last_ch == b'\r' => {
return true;
}
b'\r' => {}
_ => {
self.crlf_dot = false;
}
}
self.prev_last_ch = self.last_ch;
self.last_ch = ch;
}
false
} else {
while self.bdat_bytes_left > 0 {
if bytes.next().is_some() {
self.bdat_bytes_left -= 1;
} else {
return false;
}
}
true
}
}
}
impl<T> LineReceiver<T> {
pub fn new(state: T) -> Self {
Self {
buf: Vec::with_capacity(32),
state,
}
}
pub fn ingest(&mut self, bytes: &mut Iter<'_, u8>) -> bool {
for &ch in bytes {
match ch {
b'\n' => return true,
b'\r' => (),
_ => {
if self.buf.len() < MAX_LINE_LENGTH {
self.buf.push(ch);
}
}
}
}
false
}
}
impl DummyLineReceiver {
pub fn ingest(&mut self, bytes: &mut Iter<'_, u8>) -> bool {
for &ch in bytes {
if ch == b'\n' {
return true;
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::DataReceiver;
use crate::{Error, MailFrom, RcptTo, Request, request::receiver::RequestReceiver};
#[test]
fn data_receiver() {
'outer: for (data, message) in [
(
vec!["hi\r\n", "..\r\n", ".a\r\n", "\r\n.\r\n"],
"hi\r\n.\r\na\r\n",
),
(
vec!["\r\na\rb\nc\r\n.d\r\n..\r\n", "\r\n.\r\n"],
"\r\na\rb\nc\r\nd\r\n.\r\n",
),
(
vec![
"\n.\r\n",
"MAIL FROM:<hello@world.com>\r\n",
"RCPT TO:<test@domain.com\r\n",
"DATA\r\n",
"\r\n.\r\n",
],
concat!(
"\n.\r\n",
"MAIL FROM:<hello@world.com>\r\n",
"RCPT TO:<test@domain.com\r\n",
"DATA\r\n",
),
),
(
vec![
"\n.\n",
"MAIL FROM:<hello@world.com>\r\n",
"RCPT TO:<test@domain.com\r\n",
"DATA\r\n",
"\r\n.\r\n",
],
concat!(
"\n.\n",
"MAIL FROM:<hello@world.com>\r\n",
"RCPT TO:<test@domain.com\r\n",
"DATA\r\n",
),
),
(
vec![
"\r.\r\n",
"MAIL FROM:<hello@world.com>\r\n",
"RCPT TO:<test@domain.com\r\n",
"DATA\r\n",
"\r\n.\r\n",
],
concat!(
"\r.\r\n",
"MAIL FROM:<hello@world.com>\r\n",
"RCPT TO:<test@domain.com\r\n",
"DATA\r\n",
),
),
(
vec![
"\r.\r",
"MAIL FROM:<hello@world.com>\r\n",
"RCPT TO:<test@domain.com\r\n",
"DATA\r\n",
"\r\n.\r\n",
],
concat!(
"\r.\r",
"MAIL FROM:<hello@world.com>\r\n",
"RCPT TO:<test@domain.com\r\n",
"DATA\r\n",
),
),
] {
let mut r = DataReceiver::new();
let mut buf = Vec::new();
for data in &data {
if r.ingest(&mut data.as_bytes().iter(), &mut buf) {
assert_eq!(message, String::from_utf8(buf).unwrap());
continue 'outer;
}
}
panic!("Failed for {data:?}");
}
}
#[test]
fn request_receiver() {
for (data, expected_requests) in [
(
vec![
"data\n",
"start",
"tls\n",
"quit\nnoop",
" hello\nehlo test\nvrfy name\n",
"mail from:<hello",
"@world.com>\nrcpt to:<",
"test@domain.com>\n",
],
vec![
Request::Data,
Request::StartTls,
Request::Quit,
Request::Noop {
value: "hello".to_string(),
},
Request::Ehlo {
host: "test".to_string(),
},
Request::Vrfy {
value: "name".to_string(),
},
Request::Mail {
from: MailFrom {
address: "hello@world.com".to_string(),
flags: 0,
size: 0,
trans_id: None,
by: 0,
env_id: None,
solicit: None,
mtrk: None,
auth: None,
hold_for: 0,
hold_until: 0,
mt_priority: 0,
},
},
Request::Rcpt {
to: RcptTo {
address: "test@domain.com".to_string(),
orcpt: None,
rrvs: 0,
flags: 0,
},
},
],
),
(
vec!["d", "a", "t", "a", "\n", "quit", "\n"],
vec![Request::Data, Request::Quit],
),
] {
let mut requests = Vec::new();
let mut r = RequestReceiver::default();
for data in &data {
let mut bytes = data.as_bytes().iter();
loop {
match r.ingest(&mut bytes) {
Ok(request) => {
requests.push(request.into_owned());
continue;
}
Err(Error::NeedsMoreData { .. }) => {
break;
}
err => panic!("Unexpected error for {data:?}: {err:?}"),
}
}
}
assert_eq!(expected_requests, requests);
}
}
}