use serde::Deserialize;
use crate::{
auth::AuthContext,
policy,
sanitize,
config::AppConfig,
error::AppError,
sanitize::{contains_control_chars, contains_header_injection},
};
#[derive(Debug, Clone)]
pub struct Recipients(pub Vec<String>);
impl<'de> serde::Deserialize<'de> for Recipients {
fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum OneOrMany {
One(String),
Many(Vec<String>),
}
match OneOrMany::deserialize(de)? {
OneOrMany::One(s) => Ok(Recipients(vec![s])),
OneOrMany::Many(v) => Ok(Recipients(v)),
}
}
}
#[derive(Debug, Deserialize)]
pub struct AttachmentSpec {
pub filename: String,
pub content_type: String,
pub data: String,
}
#[derive(Debug, Clone)]
pub struct ValidatedAttachment {
pub filename: String,
pub content_type: String,
pub decoded: Vec<u8>,
}
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct MailRequest {
pub to: Recipients,
pub subject: String,
pub body: String,
pub from_name: Option<String>,
pub reply_to: Option<Recipients>,
pub body_html: Option<String>,
pub cc: Option<Recipients>,
pub attachments: Option<Vec<AttachmentSpec>>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug)]
pub struct ValidatedMailRequest {
pub to: Vec<String>,
pub subject: String,
pub body: String,
pub from_name: Option<String>,
pub reply_to: Vec<String>,
pub body_html: Option<String>,
pub cc: Vec<String>,
pub attachments: Vec<ValidatedAttachment>,
pub client_request_id: Option<String>,
}
pub fn validate_mail_request(
req: MailRequest,
config: &AppConfig,
auth: &AuthContext,
) -> Result<ValidatedMailRequest, AppError> {
let mail_cfg = &config.mail;
{
let recipients = &req.to.0;
if recipients.is_empty() {
return Err(AppError::Validation("to: at least one recipient is required".into()));
}
if recipients.len() > config.mail.max_recipients {
return Err(AppError::Validation(format!(
"to: too many recipients (max {})",
config.mail.max_recipients
)));
}
for addr in recipients {
validate_email_address(addr, "to")?;
sanitize::reject_header_crlf("to", addr)?;
check_recipient_domain_or_address(addr, config, auth)?;
}
}
let to = req.to.0;
let cc: Vec<String> = if let Some(cc_recipients) = req.cc {
let cc_addrs = cc_recipients.0;
let total = to.len() + cc_addrs.len();
if total > config.mail.max_recipients {
return Err(AppError::Validation(format!(
"to + cc: too many recipients (max {})",
config.mail.max_recipients
)));
}
for addr in &cc_addrs {
validate_email_address(addr, "cc")?;
sanitize::reject_header_crlf("cc", addr)?;
check_recipient_domain_or_address(addr, config, auth)?;
}
cc_addrs
} else {
vec![]
};
let subject = validate_subject(&req.subject, mail_cfg.max_subject_chars)?;
let body = validate_body(&req.body, mail_cfg.max_body_bytes)?;
if let Some(ref html) = req.body_html {
if html.contains('\0') {
return Err(AppError::Validation("body_html: contains NUL character".into()));
}
if html.len() > mail_cfg.max_body_bytes {
return Err(AppError::Validation(format!(
"body_html: exceeds maximum of {} bytes",
mail_cfg.max_body_bytes
)));
}
}
let from_name = req
.from_name
.as_deref()
.map(|n| validate_display_name(n, "from_name"))
.transpose()?;
let reply_to: Vec<String> = if let Some(recipients) = req.reply_to {
let addrs = recipients.0;
for addr in &addrs {
validate_email_address(addr, "reply_to")?;
sanitize::reject_header_crlf("reply_to", addr)?;
}
addrs
} else {
vec![]
};
let client_request_id = req
.metadata
.as_ref()
.and_then(|m| m.get("request_id"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let attachments: Vec<ValidatedAttachment> = {
use base64::Engine as _;
let specs = req.attachments.unwrap_or_default();
if specs.len() > mail_cfg.max_attachments {
return Err(AppError::Validation(format!(
"attachments: too many (max {})", mail_cfg.max_attachments
)));
}
let mut validated = Vec::with_capacity(specs.len());
for spec in specs {
if spec.filename.is_empty() || spec.filename.len() > 255 {
return Err(AppError::Validation("attachments[].filename: must be 1–255 chars".into()));
}
if spec.filename.contains('/') || spec.filename.contains('\\') || spec.filename.contains('