use std::sync::Arc;
use axum::{extract::State, http::StatusCode, Json};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::{
api::{send::MailRequest, ExtractRequestId},
auth::AuthContext,
error::AppError,
mail, smtp, validation,
request_id::RequestId,
status::{
recipient_domains_from, ErrorCode, StatusUpdate, SubmissionStatus,
SubmissionStatusRecord,
},
AppState,
};
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct BulkMailRequest {
pub messages: Vec<MailRequest>,
}
#[derive(Debug, Serialize)]
pub struct MessageResult {
pub index: usize,
pub request_id: String,
pub status: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
pub async fn send_bulk(
State(state): State<Arc<AppState>>,
ExtractRequestId(bulk_req_id): ExtractRequestId,
auth: AuthContext,
Json(payload): Json<BulkMailRequest>,
) -> Result<(StatusCode, Json<Value>), AppError> {
let cfg = state.config();
if payload.messages.is_empty() {
return Err(AppError::BadRequest);
}
if payload.messages.len() > cfg.mail.max_bulk_messages {
tracing::warn!(
event = "bulk_too_large",
key_id = %auth.key_id,
count = payload.messages.len(),
max = cfg.mail.max_bulk_messages,
);
return Err(AppError::PayloadTooLarge);
}
let total = payload.messages.len();
let mut results: Vec<MessageResult> = Vec::with_capacity(total);
let mut accepted = 0usize;
let mut rejected = 0usize;
tracing::info!(
event = "bulk_started",
key_id = %auth.key_id,
client_ip = %auth.client_ip,
total = total,
bulk_request_id = %bulk_req_id,
);
for (index, raw_msg) in payload.messages.into_iter().enumerate() {
let req_id = RequestId::new();
let result = process_one(
index, raw_msg, req_id.clone(), &auth, &state, &cfg,
).await;
match result {
Ok(()) => {
accepted += 1;
results.push(MessageResult {
index,
request_id: req_id.to_string(),
status: "accepted",
code: None,
message: None,
});
}
Err((code, msg)) => {
rejected += 1;
results.push(MessageResult {
index,
request_id: req_id.to_string(),
status: "rejected",
code: Some(error_code_str(&code)),
message: Some(msg),
});
}
}
}
state.metrics.inc_request("2xx");
tracing::info!(
event = "bulk_completed",
key_id = %auth.key_id,
total = total,
accepted = accepted,
rejected = rejected,
bulk_request_id = %bulk_req_id,
);
Ok((
StatusCode::ACCEPTED,
Json(json!({
"bulk_request_id": bulk_req_id.to_string(),
"total": total,
"accepted": accepted,
"rejected": rejected,
"results": results,
})),
))
}
async fn process_one(
index: usize,
raw_msg: MailRequest,
req_id: RequestId,
auth: &AuthContext,
state: &Arc<AppState>,
cfg: &Arc<crate::config::AppConfig>,
) -> Result<(), (ErrorCode, String)> {
if let Err(e) = state.rate_limiter.check_global() {
state.metrics.inc_rate_limited("global");
tracing::warn!(event = "rate_limited", tier = "global", index, retry_after = e.retry_after_secs);
record_rejected(&state, &req_id, &auth.key_id, ErrorCode::RateLimited,
"Global rate limit exceeded.", cfg.status.ttl_seconds);
return Err((ErrorCode::RateLimited, "Global rate limit exceeded.".into()));
}
if let Err(_e) = state.rate_limiter.check_ip(auth.client_ip) {
state.metrics.inc_rate_limited("ip");
tracing::warn!(event = "rate_limited", tier = "ip", index, client_ip = %auth.client_ip);
record_rejected(&state, &req_id, &auth.key_id, ErrorCode::RateLimited,
"IP rate limit exceeded.", cfg.status.ttl_seconds);
return Err((ErrorCode::RateLimited, "IP rate limit exceeded.".into()));
}
if let Err(_e) = state.rate_limiter.check_key(&auth.key_id, auth.key_rate_limit_per_min, auth.key_burst) {
state.metrics.inc_rate_limited("key");
tracing::warn!(event = "rate_limited", tier = "key", index, key_id = %auth.key_id);
record_rejected(&state, &req_id, &auth.key_id, ErrorCode::RateLimited,
"Per-key rate limit exceeded.", cfg.status.ttl_seconds);
return Err((ErrorCode::RateLimited, "Per-key rate limit exceeded.".into()));
}
state.status_store.put(SubmissionStatusRecord::new(
req_id.clone(), auth.key_id.clone(), vec![], 0, cfg.status.ttl_seconds,
));
let validated = match validation::validate_mail_request(raw_msg, cfg, auth) {
Ok(v) => v,
Err(e) => {
state.metrics.inc_validation_failure("request");
tracing::warn!(event = "validation_failure", index, key_id = %auth.key_id, error = %e);
let code = app_error_to_code(&e);
state.status_store.update_status(&req_id, &auth.key_id, StatusUpdate {
status: SubmissionStatus::Rejected,
code: Some(code.clone()),
message: Some("Request rejected during validation.".into()),
});
return Err((code, e.to_string()));
}
};
let domains = recipient_domains_from(&validated.to, &validated.cc);
let count = (validated.to.len() + validated.cc.len()) as u32;
state.status_store.put(SubmissionStatusRecord {
request_id: req_id.clone(),
key_id: auth.key_id.clone(),
status: SubmissionStatus::Received,
code: None,
message: Some("Validated.".into()),
recipient_domains: domains,
recipient_count: count,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
expires_at: chrono::Utc::now()
+ chrono::Duration::seconds(cfg.status.ttl_seconds as i64),
});
let message = match mail::build_message(&validated, cfg) {
Ok(m) => m,
Err(e) => {
state.status_store.update_status(&req_id, &auth.key_id, StatusUpdate {
status: SubmissionStatus::Rejected,
code: Some(ErrorCode::InternalError),
message: Some("Failed to build mail message.".into()),
});
return Err((ErrorCode::InternalError, e.to_string()));
}
};
state.status_store.update_status(&req_id, &auth.key_id, StatusUpdate {
status: SubmissionStatus::SmtpSubmissionStarted,
code: None,
message: Some("Submitting to SMTP.".into()),
});
let recipient_domain = validated.to.first()
.and_then(|a| a.split('@').nth(1)).unwrap_or("unknown");
let effective_mask = cfg.security.api_keys.iter()
.find(|k| k.id == auth.key_id)
.and_then(|k| k.mask_recipient)
.unwrap_or(cfg.logging.mask_recipient);
let logged_domain = if effective_mask { "***" } else { recipient_domain };
let smtp_start = std::time::Instant::now();
let smtp_res = if cfg.smtp.mode == "pipe" {
smtp::submit_pipe(message, &cfg.smtp.pipe_command, cfg.smtp.submission_timeout_seconds).await
} else {
smtp::submit(&state.smtp, message, cfg.smtp.submission_timeout_seconds).await
};
state.metrics.observe_smtp_duration(smtp_start.elapsed().as_secs_f64());
match smtp_res {
Ok(()) => {
state.status_store.update_status(&req_id, &auth.key_id, StatusUpdate {
status: SubmissionStatus::SmtpAccepted,
code: None,
message: Some("Accepted by SMTP server.".into()),
});
state.metrics.inc_smtp_ok();
tracing::info!(
event = "smtp_submitted",
key_id = %auth.key_id,
recipient_domain = logged_domain,
index,
request_id = %req_id,
);
Ok(())
}
Err(e) => {
state.metrics.inc_smtp_error();
tracing::error!(
event = "smtp_failure",
key_id = %auth.key_id,
recipient_domain = logged_domain,
index,
error = %e,
);
state.status_store.update_status(&req_id, &auth.key_id, StatusUpdate {
status: SubmissionStatus::SmtpFailed,
code: Some(ErrorCode::SmtpUnavailable),
message: Some("SMTP server unavailable.".into()),
});
Err((ErrorCode::SmtpUnavailable, "SMTP server unavailable.".into()))
}
}
}
fn record_rejected(
state: &Arc<AppState>,
req_id: &RequestId,
key_id: &str,
code: ErrorCode,
msg: &str,
ttl: u64,
) {
state.status_store.put(SubmissionStatusRecord {
request_id: req_id.clone(),
key_id: key_id.to_string(),
status: SubmissionStatus::Rejected,
code: Some(code),
message: Some(msg.to_string()),
recipient_domains: vec![],
recipient_count: 0,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
expires_at: chrono::Utc::now()
+ chrono::Duration::seconds(ttl as i64),
});
}
fn app_error_to_code(e: &AppError) -> ErrorCode {
match e {
AppError::Validation(_) => ErrorCode::ValidationFailed,
AppError::BadRequest => ErrorCode::BadRequest,
AppError::PayloadTooLarge => ErrorCode::PayloadTooLarge,
AppError::UnsupportedMediaType => ErrorCode::UnsupportedMediaType,
AppError::RateLimited { .. } => ErrorCode::RateLimited,
AppError::Forbidden => ErrorCode::Forbidden,
AppError::SmtpUnavailable => ErrorCode::SmtpUnavailable,
_ => ErrorCode::InternalError,
}
}
fn error_code_str(c: &ErrorCode) -> String {
serde_json::to_value(c)
.ok()
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_default()
}