use tokio_rustls::rustls;
use vqueue::GenericQueueManager;
use vsmtp_common::{
status::Status, Address, ContextFinished, Domain, Reply, Stage, TransactionType,
};
use vsmtp_config::Config;
use vsmtp_delivery::Deliver;
use vsmtp_mail_parser::{MailParser, MessageBody};
use vsmtp_protocol::{
AcceptArgs, AuthArgs, AuthError, CallbackWrap, EhloArgs, Error, HeloArgs, MailFromArgs,
RcptToArgs, ReceiverContext,
};
use vsmtp_rule_engine::{ExecutionStage, RuleEngine, RuleState};
use crate::scheduler;
pub struct Handler<Parser, ParserFactory>
where
Parser: MailParser + Send + Sync,
ParserFactory: Fn() -> Parser + Send + Sync,
{
pub(super) state: std::sync::Arc<RuleState>,
pub(super) state_internal: Option<std::sync::Arc<RuleState>>,
pub(super) skipped: Option<Status>,
pub(super) config: std::sync::Arc<Config>,
pub(super) rustls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
pub(super) rule_engine: std::sync::Arc<RuleEngine>,
pub(super) queue_manager: std::sync::Arc<dyn GenericQueueManager>,
pub(super) message_parser_factory: ParserFactory,
pub(super) emitter: std::sync::Arc<scheduler::Emitter>,
}
impl<Parser, ParserFactory> Handler<Parser, ParserFactory>
where
Parser: MailParser + Send + Sync,
ParserFactory: Fn() -> Parser + Send + Sync,
{
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn new(
config: std::sync::Arc<Config>,
rustls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
rule_engine: std::sync::Arc<RuleEngine>,
queue_manager: std::sync::Arc<dyn GenericQueueManager>,
message_parser_factory: ParserFactory,
emitter: std::sync::Arc<scheduler::Emitter>,
client_addr: std::net::SocketAddr,
server_addr: std::net::SocketAddr,
server_name: Domain,
timestamp: time::OffsetDateTime,
uuid: uuid::Uuid,
) -> Self {
Self {
state: rule_engine.spawn_at_connect(
client_addr,
server_addr,
server_name,
timestamp,
uuid,
),
state_internal: None,
skipped: None,
config,
rustls_config,
rule_engine,
queue_manager,
message_parser_factory,
emitter,
}
}
}
#[async_trait::async_trait]
impl<Parser: MailParser + Send + Sync, ParserFactory: Fn() -> Parser + Send + Sync>
vsmtp_protocol::ReceiverHandler for Handler<Parser, ParserFactory>
{
fn generate_sasl_callback(&self) -> CallbackWrap {
self.generate_sasl_callback_inner()
}
async fn on_accept(&mut self, ctx: &mut ReceiverContext, args: AcceptArgs) -> Reply {
self.on_accept_inner(ctx, &args)
}
async fn on_post_tls_handshake(
&mut self,
sni: Option<String>,
protocol_version: rustls::ProtocolVersion,
cipher_suite: rustls::CipherSuite,
peer_certificates: Option<Vec<rustls::Certificate>>,
alpn_protocol: Option<Vec<u8>>,
) -> Reply {
self.on_post_tls_handshake_inner(
sni,
protocol_version,
cipher_suite,
peer_certificates,
alpn_protocol,
)
}
async fn on_starttls(&mut self, ctx: &mut ReceiverContext) -> Reply {
self.on_starttls_inner(ctx)
}
async fn on_auth(&mut self, ctx: &mut ReceiverContext, args: AuthArgs) -> Option<Reply> {
self.on_auth_inner(ctx, args)
}
async fn on_post_auth(
&mut self,
ctx: &mut ReceiverContext,
result: Result<(), AuthError>,
) -> Reply {
self.on_post_auth_inner(ctx, result)
}
async fn on_helo(&mut self, ctx: &mut ReceiverContext, args: HeloArgs) -> Reply {
self.on_helo_inner(ctx, args)
}
async fn on_ehlo(&mut self, ctx: &mut ReceiverContext, args: EhloArgs) -> Reply {
self.on_ehlo_inner(ctx, args)
}
async fn on_mail_from(&mut self, ctx: &mut ReceiverContext, args: MailFromArgs) -> Reply {
let reverse_path = args
.reverse_path
.map(|reverse_path| reverse_path.parse().expect("handle invalid mailbox"));
self.state
.context()
.write()
.expect("state poisoned")
.to_mail_from(reverse_path)
.expect("bad state");
match self
.rule_engine
.run_when(&self.state, &mut self.skipped, ExecutionStage::MailFrom)
{
Status::Faccept(reply) | Status::Accept(reply) => reply,
Status::Quarantine(_) | Status::Next | Status::DelegationResult => {
"250 Ok\r\n".parse::<Reply>().unwrap()
}
Status::Deny(reply) => {
ctx.deny();
reply
}
Status::Delegated(_) => unreachable!(),
}
}
#[allow(clippy::too_many_lines)]
async fn on_rcpt_to(&mut self, ctx: &mut ReceiverContext, args: RcptToArgs) -> Reply {
if self
.state
.context()
.read()
.expect("state poisoned")
.forward_paths()
.map_or(0, Vec::len)
>= self.config.server.smtp.rcpt_count_max
{
return "452 Requested action not taken: too many recipients\r\n"
.parse::<Reply>()
.unwrap();
}
let forward_path = args
.forward_path
.parse::<Address>()
.expect("todo: handle invalid mailbox");
let is_internal = {
let ctx = self.state.context();
let mut ctx = ctx.write().expect("state poisoned");
let reverse_path = ctx.reverse_path().expect("bad state").clone();
let reverse_path_domain = reverse_path.as_ref().map(Address::domain);
let (is_outgoing, is_handled) = (
reverse_path.as_ref().map_or(false, |reverse_path| {
self.rule_engine.is_handled_domain(&reverse_path.domain())
}),
self.rule_engine.is_handled_domain(&forward_path.domain()),
);
match (is_outgoing, is_handled) {
(true, true) if Some(forward_path.domain()) == reverse_path_domain => {
tracing::debug!(
"INTERNAL: forward and reverse path domain are both: {}",
forward_path.domain()
);
if self.state_internal.is_none() {
tracing::debug!("No previous `internal_state`. Copying...");
let mut ctx_internal = ctx.clone();
ctx_internal.generate_message_id().expect("bad state");
if let Ok(rcpt) = ctx_internal.forward_paths_mut() {
rcpt.clear();
}
self.state_internal = Some(
self.rule_engine.spawn_finished(
ctx_internal,
self.state
.message()
.read()
.expect("message poisoned")
.clone(),
),
);
}
let internal_ctx = self
.state_internal
.as_ref()
.expect("has been set above")
.context();
let mut internal_guard = internal_ctx.write().expect("state poisoned");
internal_guard
.add_forward_path(
forward_path,
std::sync::Arc::new(Deliver::new(
self.rule_engine.srv().resolvers.get_resolver_root(),
self.config.clone(),
)),
)
.expect("bad state");
internal_guard
.set_transaction_type(TransactionType::Internal)
.expect("bad state");
ctx.set_transaction_type(TransactionType::Outgoing {
domain: reverse_path.expect("none-null reverse path").domain(),
})
.expect("bad state");
true
}
(true, _) => {
tracing::debug!(
"OUTGOING: reverse:${} => forward:${}",
reverse_path_domain.map_or("none".to_string(), |d| d.to_string()),
forward_path.domain()
);
ctx.add_forward_path(
forward_path,
std::sync::Arc::new(Deliver::new(
self.rule_engine.srv().resolvers.get_resolver_root(),
self.config.clone(),
)),
)
.expect("bad state");
ctx.set_transaction_type(reverse_path.as_ref().map_or(
TransactionType::Incoming(None),
|reverse_path| TransactionType::Outgoing {
domain: reverse_path.domain(),
},
))
.expect("bad state");
false
}
(false, forward_path_is_handled) => {
tracing::debug!(
"INCOMING: reverse:${:?} => forward:${}",
reverse_path,
forward_path.domain()
);
ctx.set_transaction_type(TransactionType::Incoming(
if forward_path_is_handled {
Some(forward_path.domain())
} else {
None
},
))
.expect("bad state");
ctx.add_forward_path(
forward_path,
std::sync::Arc::new(Deliver::new(
self.rule_engine.srv().resolvers.get_resolver_root(),
self.config.clone(),
)),
)
.expect("bad state");
false
}
}
};
let state = match self.state_internal.as_mut() {
Some(state_internal) if is_internal => state_internal,
_ => &mut self.state,
};
match self
.rule_engine
.run_when(state, &mut self.skipped, ExecutionStage::RcptTo)
{
Status::Faccept(reply) | Status::Accept(reply) => reply,
Status::Quarantine(_) | Status::Next | Status::DelegationResult => {
"250 Ok\r\n".parse::<Reply>().unwrap()
}
Status::Deny(reply) => {
ctx.deny();
reply
}
Status::Delegated(_) => unreachable!(),
}
}
async fn on_rset(&mut self) -> Reply {
self.state
.context()
.write()
.expect("state poisoned")
.reset();
self.state_internal = None;
"250 Ok\r\n".parse::<Reply>().unwrap()
}
async fn on_message(
&mut self,
ctx: &mut ReceiverContext,
stream: impl tokio_stream::Stream<Item = Result<Vec<u8>, Error>> + Send + Unpin,
) -> (Reply, Option<Vec<(ContextFinished, MessageBody)>>) {
self.on_message_inner(ctx, stream).await
}
async fn on_message_completed(
&mut self,
ctx: ContextFinished,
msg: MessageBody,
) -> Option<Reply> {
self.on_message_completed_inner(ctx, msg).await
}
async fn on_hard_error(&mut self, ctx: &mut ReceiverContext, reply: Reply) -> Reply {
ctx.deny();
reply.extended(
&"451 Too many errors from the client\r\n"
.parse::<Reply>()
.unwrap(),
)
}
async fn on_soft_error(&mut self, _: &mut ReceiverContext, reply: Reply) -> Reply {
tokio::time::sleep(self.config.server.smtp.error.delay).await;
reply
}
fn get_stage(&self) -> Stage {
self.state
.context()
.write()
.expect("state poisoned")
.stage()
}
}