use crate::{delivery, processing, ProcessMessage, Server};
use anyhow::Context;
use vsmtp_config::{Config, DnsResolvers};
use vsmtp_delivery::Sender;
use vsmtp_rule_engine::RuleEngine;
fn init_runtime<F>(
sender: tokio::sync::mpsc::Sender<()>,
name: impl Into<String>,
worker_thread_count: usize,
future: F,
timeout: Option<std::time::Duration>,
) -> anyhow::Result<std::thread::JoinHandle<anyhow::Result<()>>>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let name = name.into();
let name_park = name.clone();
let name_unpark = name.clone();
let name_start = name.clone();
let name_stop = name.clone();
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(worker_thread_count)
.enable_all()
.thread_name(format!("{name}-child"))
.on_thread_park(move || tracing::trace!("{}-child goes idle", name_park))
.on_thread_unpark(move || tracing::trace!("{}-child starts executing tasks", name_unpark))
.on_thread_start(move || tracing::trace!("{}-child start", name_start))
.on_thread_stop(move || tracing::trace!("{}-child stop", name_stop))
.build()?;
std::thread::Builder::new()
.name(format!("{name}-main"))
.spawn(move || {
let name_rt = name.clone();
runtime.block_on(async move {
tracing::info!(name = name_rt, "Runtime started successfully.");
match timeout {
Some(duration) => {
tokio::time::timeout(duration, future).await.unwrap_err();
}
None => future.await,
}
});
sender.blocking_send(())?;
Ok(())
})
.map_err(anyhow::Error::new)
}
#[allow(clippy::module_name_repetitions)]
pub fn start_runtime(
config: Config,
sockets: (
Vec<std::net::TcpListener>,
Vec<std::net::TcpListener>,
Vec<std::net::TcpListener>,
),
timeout: Option<std::time::Duration>,
) -> anyhow::Result<()> {
let config = std::sync::Arc::new(config);
let mut error_handler = tokio::sync::mpsc::channel::<()>(3);
let (delivery_channel, working_channel) = (
tokio::sync::mpsc::channel::<ProcessMessage>(config.server.queues.delivery.channel_size),
tokio::sync::mpsc::channel::<ProcessMessage>(config.server.queues.working.channel_size),
);
let queue_manager =
<vqueue::fs::QueueManager as vqueue::GenericQueueManager>::init(config.clone())?;
let resolvers = std::sync::Arc::new(
DnsResolvers::from_config(&config).context("could not initialize dns")?,
);
let rule_engine = std::sync::Arc::new(RuleEngine::new(
config.clone(),
resolvers.clone(),
queue_manager.clone(),
)?);
let sender = std::sync::Arc::new(Sender::default());
let _tasks_delivery = init_runtime(
error_handler.0.clone(),
"delivery",
config.server.system.thread_pool.delivery,
delivery::start(
config.clone(),
rule_engine.clone(),
resolvers,
queue_manager.clone(),
delivery_channel.1,
sender,
),
timeout,
)?;
let _tasks_processing = init_runtime(
error_handler.0.clone(),
"processing",
config.server.system.thread_pool.processing,
processing::start(
rule_engine.clone(),
queue_manager.clone(),
working_channel.1,
delivery_channel.0.clone(),
),
timeout,
)?;
let _tasks_receiver = init_runtime(
error_handler.0.clone(),
"receiver",
config.server.system.thread_pool.receiver,
async move {
let server = match Server::new(
config.clone(),
rule_engine.clone(),
queue_manager.clone(),
working_channel.0.clone(),
delivery_channel.0.clone(),
) {
Ok(server) => server,
Err(error) => {
tracing::error!(%error, "Receiver build failure.");
return;
}
};
if let Err(error) = server.listen_and_serve(sockets).await {
tracing::error!(%error, "Receiver failure.");
}
},
timeout,
);
let error_handler_sig = error_handler.0.clone();
let mut signals = signal_hook::iterator::Signals::new([
signal_hook::consts::SIGTERM,
signal_hook::consts::SIGINT,
])?;
let _signal_handler = std::thread::spawn(move || {
for sig in signals.forever() {
tracing::warn!(signal = sig, "Stopping vSMTP server.");
error_handler_sig
.blocking_send(())
.expect("failed to send terminating instruction");
}
});
error_handler.1.blocking_recv();
Ok(())
}
#[cfg(test)]
mod tests {
use vsmtp_test::config;
use super::*;
#[test]
fn basic() {
start_runtime(
config::local_test(),
(
vec![std::net::TcpListener::bind("0.0.0.0:22001").unwrap()],
vec![std::net::TcpListener::bind("0.0.0.0:22002").unwrap()],
vec![std::net::TcpListener::bind("0.0.0.0:22003").unwrap()],
),
Some(std::time::Duration::from_millis(100)),
)
.unwrap();
}
}