use crate::{auth, config, net, State};
use futures::FutureExt;
use std::{fs, process};
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::runtime as rt;
use tokio::signal::unix;
use tokio::sync::{mpsc, Notify};
use tokio::task;
pub enum Command {
UsePlain,
UseTls(Arc<tokio_tls::TlsAcceptor>),
}
struct LoadedBinding<F> {
address: SocketAddr,
acceptor: Option<Arc<tokio_tls::TlsAcceptor>>,
handle: mpsc::Sender<Command>,
future: F,
}
fn load_config(config_path: &str) -> config::Result<(config::Config, Box<dyn auth::Provider>)> {
let cfg = config::Config::from_file(config_path).map_err(|err| {
log::error!("Failed to read {:?}: {}", config_path, err);
err
})?;
let sasl_backend = cfg.sasl_backend;
let auth_provider = auth::choose_provider(sasl_backend, cfg.db_url.clone())
.unwrap_or_else(|err| {
log::warn!("Failed to initialize the {} SASL backend: {}", sasl_backend, err);
Box::new(auth::DummyProvider)
});
Ok((cfg, auth_provider))
}
fn create_runtime(workers: usize) -> rt::Runtime {
let mut builder = rt::Builder::new();
if workers != 0 {
builder.core_threads(workers);
}
builder
.threaded_scheduler()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|err| {
log::error!("Failed to start the tokio runtime: {}", err);
process::exit(1);
})
}
fn load_bindings(bindings: Vec<config::Binding>, shared: &State, stop: &mpsc::Sender<SocketAddr>,
runtime: &mut rt::Runtime) -> Vec<(SocketAddr, mpsc::Sender<Command>)>
{
let mut res = Vec::with_capacity(bindings.len());
let mut store = net::TlsIdentityStore::default();
for config::Binding { address, tls_identity } in bindings {
let (handle, commands) = mpsc::channel(8);
if let Some(identity_path) = tls_identity {
let acceptor = match store.acceptor(identity_path) {
Ok(acceptor) => acceptor,
Err(_) => process::exit(1),
};
let server = net::listen(address, shared.clone(), Some(acceptor),
stop.clone(), commands);
res.push((address, handle));
runtime.spawn(server);
} else {
let server = net::listen(address, shared.clone(), None, stop.clone(), commands);
res.push((address, handle));
runtime.spawn(server);
}
}
res
}
pub struct Control {
config_path: String,
shared: crate::State,
stop: mpsc::Sender<SocketAddr>,
failures: mpsc::Receiver<SocketAddr>,
rehash: Arc<Notify>,
bindings: Vec<(SocketAddr, mpsc::Sender<Command>)>,
}
impl Control {
pub fn new<S>(config_path: S) -> (rt::Runtime, Self)
where S: Into<String>,
{
let config_path = config_path.into();
let (stop, failures) = mpsc::channel(8);
let rehash = Arc::new(Notify::new());
let (cfg, auth_provider) = load_config(&config_path).unwrap_or_else(|_| process::exit(1));
let mut runtime = create_runtime(cfg.workers);
let shared = State::new(cfg.state, auth_provider, rehash.clone());
let bindings = load_bindings(cfg.bindings, &shared, &stop, &mut runtime);
(runtime, Self { config_path, shared, stop, failures, rehash, bindings })
}
pub async fn run(self) {
#[cfg(unix)]
let mut signals = unix::signal(unix::SignalKind::user_defined1()).unwrap_or_else(|err| {
log::error!("Cannot listen for signals to reload the configuration: {}", err);
process::exit(1);
});
#[cfg(not(unix))]
let signals = crate::util::PendingStream;
let Self { config_path, shared, stop, mut failures, rehash, mut bindings } = self;
loop {
futures::select! {
addr = failures.recv().fuse() => match addr {
Some(addr) => for i in 0..bindings.len() {
if bindings[i].0 == addr {
bindings.swap_remove(i);
break;
}
}
None => {
log::error!("No binding left, exiting.");
return;
}
},
_ = rehash.notified().fuse() => {
do_rehash(&config_path, &shared, &stop, &mut bindings).await;
},
_ = signals.recv().fuse() => {
do_rehash(&config_path, &shared, &stop, &mut bindings).await;
},
}
}
}
}
async fn do_rehash(config_path: &str, shared: &State, stop: &mpsc::Sender<SocketAddr>,
bindings: &mut Vec<(SocketAddr, mpsc::Sender<Command>)>)
{
log::info!("Reloading configuration from {:?}", config_path);
let reloaded = task::block_in_place(|| {
reload_config(config_path, shared, stop)
});
let (cfg, auth_provider, new_bindings) = match reloaded {
Some(reloaded) => reloaded,
None => return,
};
let mut i = 0;
while i < bindings.len() {
let old_address = bindings[i].0;
if new_bindings.iter().all(|new_b| old_address != new_b.address) {
bindings.swap_remove(i);
} else {
i += 1;
}
}
for new_b in new_bindings {
if let Some(i) = bindings.iter().position(|old_b| old_b.0 == new_b.address) {
let res = bindings[i].1.send(match new_b.acceptor {
Some(acceptor) => Command::UseTls(acceptor),
None => Command::UsePlain,
}).await;
if res.is_err() {
bindings.swap_remove(i);
tokio::spawn(new_b.future);
bindings.push((new_b.address, new_b.handle));
}
} else {
tokio::spawn(new_b.future);
bindings.push((new_b.address, new_b.handle));
}
}
shared.rehash(cfg.state, auth_provider).await;
log::info!("Configuration reloaded");
}
fn reload_config(config_path: &str, shared: &State, stop: &mpsc::Sender<SocketAddr>)
-> Option<(config::Config, Box<dyn auth::Provider>, Vec<LoadedBinding<impl Future<Output=()>>>)>
{
let (mut cfg, auth_provider) = match load_config(config_path) {
Ok((c, a)) => (c, a),
Err(_) => return None,
};
cfg.state.motd_file = match fs::read_to_string(&cfg.state.motd_file) {
Ok(motd) => motd,
Err(err) => {
log::warn!("Failed to read {:?}: {}", cfg.state.motd_file, err);
String::new()
}
};
let new_bindings = reload_bindings(&cfg.bindings, shared, stop);
Some((cfg, auth_provider, new_bindings))
}
fn reload_bindings(bindings: &[config::Binding], shared: &State, stop: &mpsc::Sender<SocketAddr>)
-> Vec<LoadedBinding<impl Future<Output=()>>>
{
let mut res = Vec::with_capacity(bindings.len());
let mut store = net::TlsIdentityStore::default();
for config::Binding { address, tls_identity } in bindings {
let (handle, commands) = mpsc::channel(8);
if let Some(identity_path) = tls_identity {
let acceptor = match store.acceptor(identity_path) {
Ok(acceptor) => acceptor,
Err(_) => continue,
};
let future = net::listen(*address, shared.clone(), Some(acceptor.clone()),
stop.clone(), commands);
res.push(LoadedBinding {
address: *address,
acceptor: Some(acceptor),
handle,
future,
});
} else {
let future = net::listen(*address, shared.clone(), None, stop.clone(), commands);
res.push(LoadedBinding {
address: *address,
acceptor: None,
handle,
future,
});
}
}
res
}