use crate::error::{InvalidInputError, ProtocolError};
pub const MAX_REPLY_LINE_LEN: usize = 998;
pub const MAX_REPLY_LINES: usize = 128;
pub const MAX_ADDRESS_LEN: usize = 254;
pub const MAX_LOCAL_PART_LEN: usize = 64;
pub const MAX_DOMAIN_LEN: usize = 255;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ReplyLine<'a> {
pub code: u16,
pub is_last: bool,
pub text: &'a [u8],
}
pub fn parse_reply_line(line: &[u8]) -> Result<ReplyLine<'_>, ProtocolError> {
if line.len() < 3 {
return Err(malformed(line));
}
let d0 = ascii_digit_value(line[0]).ok_or_else(|| malformed(line))?;
let d1 = ascii_digit_value(line[1]).ok_or_else(|| malformed(line))?;
let d2 = ascii_digit_value(line[2]).ok_or_else(|| malformed(line))?;
let code = u16::from(d0) * 100 + u16::from(d1) * 10 + u16::from(d2);
if line.len() == 3 {
return Ok(ReplyLine {
code,
is_last: true,
text: &[],
});
}
let (is_last, text) = match line[3] {
b' ' => (true, &line[4..]),
b'-' => (false, &line[4..]),
_ => return Err(malformed(line)),
};
Ok(ReplyLine {
code,
is_last,
text,
})
}
fn ascii_digit_value(b: u8) -> Option<u8> {
if b.is_ascii_digit() {
Some(b - b'0')
} else {
None
}
}
fn malformed(line: &[u8]) -> ProtocolError {
ProtocolError::Malformed(String::from_utf8_lossy(line).into_owned())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EnhancedStatus {
pub class: u8,
pub subject: u16,
pub detail: u16,
}
impl EnhancedStatus {
#[must_use]
pub fn to_dotted(&self) -> String {
format!("{}.{}.{}", self.class, self.subject, self.detail)
}
}
impl core::fmt::Display for EnhancedStatus {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}.{}.{}", self.class, self.subject, self.detail)
}
}
fn parse_enhanced_status_prefix(text: &str) -> Option<(EnhancedStatus, usize)> {
let bytes = text.as_bytes();
if bytes.len() < 5 {
return None;
}
let class_byte = bytes[0];
if !matches!(class_byte, b'2' | b'4' | b'5') || bytes[1] != b'.' {
return None;
}
let mut i = 2;
let subj_start = i;
while i < bytes.len() && bytes[i].is_ascii_digit() {
i += 1;
}
if i == subj_start || i >= bytes.len() || bytes[i] != b'.' {
return None;
}
let subject: u16 = text[subj_start..i].parse().ok()?;
i += 1;
let det_start = i;
while i < bytes.len() && bytes[i].is_ascii_digit() {
i += 1;
}
if i == det_start {
return None;
}
let detail: u16 = text[det_start..i].parse().ok()?;
let prefix_len = if i < bytes.len() && (bytes[i] == b' ' || bytes[i] == b'\t') {
i + 1
} else {
i
};
Some((
EnhancedStatus {
class: class_byte - b'0',
subject,
detail,
},
prefix_len,
))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Reply {
pub code: u16,
pub lines: Vec<String>,
enhanced: Option<EnhancedStatus>,
}
impl Reply {
#[must_use]
pub fn new(code: u16, lines: Vec<String>) -> Self {
Self {
code,
lines,
enhanced: None,
}
}
pub fn class(&self) -> u8 {
u8::try_from(self.code / 100).unwrap_or(0)
}
pub fn joined_text(&self) -> String {
self.lines.join("\n")
}
pub fn message_text(&self) -> String {
if self.enhanced.is_none() {
return self.joined_text();
}
let stripped: Vec<&str> = self
.lines
.iter()
.map(|line| match parse_enhanced_status_prefix(line) {
Some((_, prefix_len)) => &line[prefix_len..],
None => line.as_str(),
})
.collect();
stripped.join("\n")
}
#[must_use]
pub fn enhanced(&self) -> Option<EnhancedStatus> {
self.enhanced
}
pub(crate) fn attach_enhanced_status(&mut self, status: EnhancedStatus) {
self.enhanced = Some(status);
}
pub fn iter_lines(&self) -> impl Iterator<Item = &str> {
self.lines.iter().map(String::as_str)
}
#[must_use]
pub fn try_parse_enhanced(&self) -> Option<EnhancedStatus> {
self.lines
.first()
.and_then(|line| parse_enhanced_status_prefix(line).map(|(s, _)| s))
}
}
pub fn format_command(verb: &str) -> Vec<u8> {
let mut buf = Vec::with_capacity(verb.len() + 2);
buf.extend_from_slice(verb.as_bytes());
buf.extend_from_slice(b"\r\n");
buf
}
pub fn format_command_arg(verb: &str, arg: &str) -> Vec<u8> {
let mut buf = Vec::with_capacity(verb.len() + 1 + arg.len() + 2);
buf.extend_from_slice(verb.as_bytes());
buf.push(b' ');
buf.extend_from_slice(arg.as_bytes());
buf.extend_from_slice(b"\r\n");
buf
}
pub fn format_mail_from(addr: &str) -> Vec<u8> {
let mut buf = Vec::with_capacity(13 + addr.len() + 2);
buf.extend_from_slice(b"MAIL FROM:<");
buf.extend_from_slice(addr.as_bytes());
buf.extend_from_slice(b">\r\n");
buf
}
pub fn format_rcpt_to(addr: &str) -> Vec<u8> {
let mut buf = Vec::with_capacity(11 + addr.len() + 2);
buf.extend_from_slice(b"RCPT TO:<");
buf.extend_from_slice(addr.as_bytes());
buf.extend_from_slice(b">\r\n");
buf
}
pub fn dot_stuff_and_terminate(body: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(body.len() + 8);
let mut at_line_start = true;
let mut prev: u8 = 0;
for &b in body {
if at_line_start && b == b'.' {
out.push(b'.');
}
out.push(b);
at_line_start = prev == b'\r' && b == b'\n';
prev = b;
}
if !out.ends_with(b"\r\n") {
out.extend_from_slice(b"\r\n");
}
out.extend_from_slice(b".\r\n");
out
}
#[derive(Debug, Clone)]
pub struct DotStufferState {
at_line_start: bool,
prev: u8,
prev_prev: u8,
empty: bool,
}
impl DotStufferState {
#[must_use]
pub const fn new() -> Self {
Self {
at_line_start: true,
prev: 0,
prev_prev: 0,
empty: true,
}
}
pub fn process_chunk(&mut self, chunk: &[u8]) -> Vec<u8> {
if chunk.is_empty() {
return Vec::new();
}
let mut out = Vec::with_capacity(chunk.len() + 4);
for &b in chunk {
if self.at_line_start && b == b'.' {
out.push(b'.');
}
out.push(b);
let new_at_line_start = self.prev == b'\r' && b == b'\n';
self.prev_prev = self.prev;
self.prev = b;
self.at_line_start = new_at_line_start;
}
self.empty = false;
out
}
#[must_use]
pub fn finish(self) -> Vec<u8> {
let ends_with_crlf =
!self.empty && self.prev_prev == b'\r' && self.prev == b'\n';
let mut out = Vec::with_capacity(5);
if !ends_with_crlf {
out.extend_from_slice(b"\r\n");
}
out.extend_from_slice(b".\r\n");
out
}
}
impl Default for DotStufferState {
fn default() -> Self {
Self::new()
}
}
const BASE64_ALPHABET: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
pub fn base64_encode(input: &[u8]) -> String {
let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
let chunks = input.chunks_exact(3);
let rem = chunks.remainder();
for chunk in chunks {
let n = (u32::from(chunk[0]) << 16) | (u32::from(chunk[1]) << 8) | u32::from(chunk[2]);
push_b64(&mut out, n, 4);
}
match rem.len() {
0 => {}
1 => {
let n = u32::from(rem[0]) << 16;
push_b64(&mut out, n, 2);
out.push_str("==");
}
2 => {
let n = (u32::from(rem[0]) << 16) | (u32::from(rem[1]) << 8);
push_b64(&mut out, n, 3);
out.push('=');
}
_ => unreachable!(),
}
out
}
fn push_b64(out: &mut String, n: u32, count: u8) {
for i in 0..count {
let shift = 18 - 6 * i;
let idx = ((n >> shift) & 0x3F) as usize;
out.push(char::from(BASE64_ALPHABET[idx]));
}
}
pub fn base64_decode(input: &str) -> Result<Vec<u8>, &'static str> {
let bytes = input.as_bytes();
if bytes.is_empty() {
return Ok(Vec::new());
}
if bytes.len() % 4 != 0 {
return Err("invalid base64");
}
let mut out = Vec::with_capacity(bytes.len() / 4 * 3);
for (chunk_idx, chunk) in bytes.chunks_exact(4).enumerate() {
let is_last = chunk_idx == (bytes.len() / 4) - 1;
let mut buf = [0u8; 4];
let mut pad = 0usize;
for (i, &c) in chunk.iter().enumerate() {
buf[i] = match c {
b'A'..=b'Z' => c - b'A',
b'a'..=b'z' => c - b'a' + 26,
b'0'..=b'9' => c - b'0' + 52,
b'+' => 62,
b'/' => 63,
b'=' => {
pad += 1;
0
}
_ => return Err("invalid base64"),
};
}
if pad > 0 && !is_last {
return Err("invalid base64");
}
let n = (u32::from(buf[0]) << 18)
| (u32::from(buf[1]) << 12)
| (u32::from(buf[2]) << 6)
| u32::from(buf[3]);
out.push(((n >> 16) & 0xff) as u8);
if pad < 2 {
out.push(((n >> 8) & 0xff) as u8);
}
if pad < 1 {
out.push((n & 0xff) as u8);
}
}
Ok(out)
}
pub fn validate_address(addr: &str) -> Result<(), InvalidInputError> {
if addr.is_empty() {
return Err(InvalidInputError::new("mail address must not be empty"));
}
if !addr.is_ascii() {
return Err(InvalidInputError::new(
"mail address must be ASCII (SMTPUTF8 is not supported)",
));
}
if addr.len() > MAX_ADDRESS_LEN {
return Err(InvalidInputError::new(
"mail address exceeds RFC 5321 §4.5.3.1.3 length limit (254 octets)",
));
}
if let Some(at_pos) = addr.rfind('@') {
let (local, domain) = addr.split_at(at_pos);
let domain = &domain[1..];
if local.len() > MAX_LOCAL_PART_LEN {
return Err(InvalidInputError::new(
"mail address local-part exceeds RFC 5321 §4.5.3.1.1 length limit (64 octets)",
));
}
if domain.len() > MAX_DOMAIN_LEN {
return Err(InvalidInputError::new(
"mail address domain exceeds RFC 5321 §4.5.3.1.2 length limit (255 octets)",
));
}
}
for b in addr.bytes() {
match b {
b'\r' | b'\n' => {
return Err(InvalidInputError::new(
"mail address must not contain CR or LF",
));
}
0 => {
return Err(InvalidInputError::new(
"mail address must not contain a NUL byte",
));
}
b'<' | b'>' => {
return Err(InvalidInputError::new(
"mail address must not contain '<' or '>'",
));
}
b' ' | b'\t' => {
return Err(InvalidInputError::new(
"mail address must not contain whitespace",
));
}
_ => {}
}
}
Ok(())
}
pub fn validate_ehlo_domain(domain: &str) -> Result<(), InvalidInputError> {
if domain.is_empty() {
return Err(InvalidInputError::new("EHLO domain must not be empty"));
}
if !domain.is_ascii() {
return Err(InvalidInputError::new("EHLO domain must be ASCII"));
}
if domain.bytes().any(|b| !(0x21..=0x7E).contains(&b)) {
return Err(InvalidInputError::new(
"EHLO domain must contain only printable ASCII characters",
));
}
Ok(())
}
pub fn validate_login_username(user: &str) -> Result<(), InvalidInputError> {
validate_plain_username(user)
}
pub fn validate_login_password(pass: &str) -> Result<(), InvalidInputError> {
validate_plain_password(pass)
}
pub fn ehlo_advertises_auth<S: AsRef<str>>(capability_lines: &[S], mechanism: &str) -> bool {
for line in capability_lines {
let mut parts = line.as_ref().split_ascii_whitespace();
let Some(head) = parts.next() else { continue };
if !head.eq_ignore_ascii_case("AUTH") {
continue;
}
for mech in parts {
if mech.eq_ignore_ascii_case(mechanism) {
return true;
}
}
}
false
}
pub fn ehlo_advertises_starttls<S: AsRef<str>>(capability_lines: &[S]) -> bool {
for line in capability_lines {
if let Some(head) = line.as_ref().split_ascii_whitespace().next()
&& head.eq_ignore_ascii_case("STARTTLS")
{
return true;
}
}
false
}
pub fn ehlo_advertises_enhanced_status_codes<S: AsRef<str>>(capability_lines: &[S]) -> bool {
for line in capability_lines {
if let Some(head) = line.as_ref().split_ascii_whitespace().next()
&& head.eq_ignore_ascii_case("ENHANCEDSTATUSCODES")
{
return true;
}
}
false
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum AuthMechanism {
Plain,
Login,
XOAuth2,
OAuthBearer,
ScramSha256,
}
impl AuthMechanism {
#[must_use]
pub const fn name(self) -> &'static str {
match self {
Self::Plain => "PLAIN",
Self::Login => "LOGIN",
Self::XOAuth2 => "XOAUTH2",
Self::OAuthBearer => "OAUTHBEARER",
Self::ScramSha256 => "SCRAM-SHA-256",
}
}
}
impl core::fmt::Display for AuthMechanism {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(self.name())
}
}
pub fn select_auth_mechanism<S: AsRef<str>>(capability_lines: &[S]) -> Option<AuthMechanism> {
#[cfg(feature = "scram-sha-256")]
if ehlo_advertises_auth(capability_lines, "SCRAM-SHA-256") {
return Some(AuthMechanism::ScramSha256);
}
if ehlo_advertises_auth(capability_lines, "PLAIN") {
Some(AuthMechanism::Plain)
} else if ehlo_advertises_auth(capability_lines, "LOGIN") {
Some(AuthMechanism::Login)
} else {
None
}
}
#[must_use]
pub fn build_auth_plain_initial_response(user: &str, pass: &str) -> String {
let mut payload = Vec::with_capacity(2 + user.len() + pass.len());
payload.push(0u8); payload.extend_from_slice(user.as_bytes());
payload.push(0u8);
payload.extend_from_slice(pass.as_bytes());
base64_encode(&payload)
}
pub fn validate_plain_username(user: &str) -> Result<(), InvalidInputError> {
if user.is_empty() {
return Err(InvalidInputError::new("AUTH username must not be empty"));
}
if user.bytes().any(|b| b == 0) {
return Err(InvalidInputError::new(
"AUTH username must not contain a NUL byte",
));
}
Ok(())
}
pub fn validate_plain_password(pass: &str) -> Result<(), InvalidInputError> {
if pass.is_empty() {
return Err(InvalidInputError::new("AUTH password must not be empty"));
}
if pass.bytes().any(|b| b == 0) {
return Err(InvalidInputError::new(
"AUTH password must not contain a NUL byte",
));
}
Ok(())
}
#[cfg(feature = "xoauth2")]
#[must_use]
pub fn build_xoauth2_initial_response(user: &str, token: &str) -> String {
let mut payload = Vec::with_capacity(19 + user.len() + token.len());
payload.extend_from_slice(b"user=");
payload.extend_from_slice(user.as_bytes());
payload.push(0x01);
payload.extend_from_slice(b"auth=Bearer ");
payload.extend_from_slice(token.as_bytes());
payload.push(0x01);
payload.push(0x01);
base64_encode(&payload)
}
#[cfg(feature = "xoauth2")]
pub fn validate_xoauth2_user(user: &str) -> Result<(), InvalidInputError> {
if user.is_empty() {
return Err(InvalidInputError::new("XOAUTH2 user must not be empty"));
}
if user.bytes().any(|b| matches!(b, 0 | b'\r' | b'\n' | 0x01)) {
return Err(InvalidInputError::new(
"XOAUTH2 user must not contain NUL, CR, LF, or SOH",
));
}
Ok(())
}
#[cfg(feature = "xoauth2")]
pub fn validate_oauth2_token(token: &str) -> Result<(), InvalidInputError> {
if token.is_empty() {
return Err(InvalidInputError::new(
"OAuth2 access token must not be empty",
));
}
for b in token.bytes() {
if !(0x21..=0x7E).contains(&b) {
return Err(InvalidInputError::new(
"OAuth2 access token must contain only printable ASCII (no whitespace or control bytes)",
));
}
}
Ok(())
}
#[cfg(feature = "oauthbearer")]
#[must_use]
pub fn build_oauthbearer_initial_response(user: &str, token: &str) -> String {
let mut payload = Vec::with_capacity(16 + user.len() + token.len());
payload.extend_from_slice(b"n,a=");
payload.extend_from_slice(user.as_bytes());
payload.push(b',');
payload.push(0x01);
payload.extend_from_slice(b"auth=Bearer ");
payload.extend_from_slice(token.as_bytes());
payload.push(0x01);
payload.push(0x01);
base64_encode(&payload)
}
#[cfg(feature = "pipelining")]
#[must_use]
pub fn ehlo_advertises_pipelining(caps: &[String]) -> bool {
caps.iter().any(|c| c.eq_ignore_ascii_case("PIPELINING"))
}
#[cfg(feature = "smtputf8")]
pub fn ehlo_advertises_smtputf8<S: AsRef<str>>(capability_lines: &[S]) -> bool {
for line in capability_lines {
if let Some(head) = line.as_ref().split_ascii_whitespace().next()
&& head.eq_ignore_ascii_case("SMTPUTF8")
{
return true;
}
}
false
}
#[cfg(feature = "smtputf8")]
pub fn validate_address_utf8(addr: &str) -> Result<(), InvalidInputError> {
if addr.is_empty() {
return Err(InvalidInputError::new("mail address must not be empty"));
}
if addr.len() > MAX_ADDRESS_LEN {
return Err(InvalidInputError::new(
"mail address exceeds RFC 5321 §4.5.3.1.3 length limit (254 octets)",
));
}
if let Some(at_pos) = addr.rfind('@') {
let (local, domain) = addr.split_at(at_pos);
let domain = &domain[1..];
if local.len() > MAX_LOCAL_PART_LEN {
return Err(InvalidInputError::new(
"mail address local-part exceeds RFC 5321 §4.5.3.1.1 length limit (64 octets)",
));
}
if domain.len() > MAX_DOMAIN_LEN {
return Err(InvalidInputError::new(
"mail address domain exceeds RFC 5321 §4.5.3.1.2 length limit (255 octets)",
));
}
}
for ch in addr.chars() {
match ch {
'\r' | '\n' => {
return Err(InvalidInputError::new(
"mail address must not contain CR or LF",
));
}
'\0' => {
return Err(InvalidInputError::new(
"mail address must not contain a NUL byte",
));
}
'<' | '>' => {
return Err(InvalidInputError::new(
"mail address must not contain ASCII < or >",
));
}
' ' | '\t' => {
return Err(InvalidInputError::new(
"mail address must not contain ASCII whitespace",
));
}
c if (c as u32) < 0x20 || (c as u32) == 0x7F => {
return Err(InvalidInputError::new(
"mail address must not contain ASCII control characters",
));
}
c if (0x80..=0x9F).contains(&(c as u32)) => {
return Err(InvalidInputError::new(
"mail address must not contain C1 control characters",
));
}
_ => {}
}
}
Ok(())
}
#[cfg(feature = "smtputf8")]
#[must_use]
pub fn format_mail_from_smtputf8(addr: &str) -> Vec<u8> {
let mut out = Vec::with_capacity(23 + addr.len());
out.extend_from_slice(b"MAIL FROM:<");
out.extend_from_slice(addr.as_bytes());
out.extend_from_slice(b"> SMTPUTF8\r\n");
out
}