use std::{net::SocketAddr, sync::Arc};
use tokio::io::AsyncWriteExt;
use tokio_rustls::server::TlsStream;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info};
use crate::{
network::{AsyncReadWrite, tls_handshake},
smtp::inbound::{
Session, SessionConfig, SessionData, SessionError, SessionResult, SessionTlsMode,
SessionUpgrade,
},
};
pub struct SessionManager;
impl SessionManager {
pub async fn handle_connection<S: AsyncReadWrite>(
stream: S,
remote_addr: SocketAddr,
params: Arc<SessionConfig>,
shutdown_token: CancellationToken,
) {
let mut session = Session::new(remote_addr.ip(), stream, params);
match session.handle(shutdown_token.child_token()).await {
Ok(v) => match v {
SessionUpgrade::No => {
session.stream.shutdown().await.ok();
}
SessionUpgrade::StartTls => {
Self::starttls(session, shutdown_token.child_token()).await
}
},
Err(e) => {
if !matches!(e, SessionError::Quit) {
info!("{session}: error: {e:#}");
}
if let Err(e) = session.shutdown().await {
debug!("{session}: error closing connection: {e:#}");
};
}
}
}
async fn starttls<S: AsyncReadWrite>(session: Session<S>, shutdown_token: CancellationToken) {
let session_name = session.to_string();
match session.into_tls().await {
Ok(mut session) => {
if let Err(e) = session.handle(shutdown_token.child_token()).await {
if !matches!(e, SessionError::Quit) {
info!("{session}: error: {e:#}");
}
if let Err(e) = session.shutdown().await {
debug!("{session}: error closing connection: {e:#}");
};
}
}
Err(e) => {
info!("{session_name}: TLS handshake failed: {e:#}");
}
};
}
}
impl<S: AsyncReadWrite> Session<S> {
pub async fn into_tls(self) -> SessionResult<Session<TlsStream<S>>> {
let tls_config = match &self.cfg.tls_mode {
SessionTlsMode::Allowed(v) | SessionTlsMode::Required(v) => v.clone(),
SessionTlsMode::Disabled => {
unreachable!("Session::into_tls() called with TLS disabled")
}
};
let (stream, tls_info) = tls_handshake(tls_config, self.stream).await?;
Ok(Session {
id: self.id,
remote_ip: self.remote_ip,
stream,
state: self.state,
data: SessionData::default(),
counters: self.counters,
cfg: self.cfg,
tls_info: Some(tls_info),
})
}
}