use std::sync::Arc;
use arc_swap::ArcSwap;
pub trait ReloadCallback {
fn invoke(&self);
}
impl ReloadCallback for () {
fn invoke(&self) {
}
}
impl<F: Fn()> ReloadCallback for F {
fn invoke(&self) {
self()
}
}
pub trait ReloadableConfig: Sized {
type Error;
fn build() -> Result<Self, Self::Error>;
fn reloading() -> Result<ReloadingConfig<Self, ()>, Self::Error> {
ReloadingConfig::build()
}
}
#[derive(Debug)]
pub struct ReloadingConfig<T, F> {
config: Arc<ArcSwap<T>>,
on_update: F,
}
impl<T, F> Clone for ReloadingConfig<T, F>
where
F: Clone,
{
fn clone(&self) -> Self {
ReloadingConfig {
config: Arc::clone(&self.config),
on_update: self.on_update.clone(),
}
}
}
impl<T> ReloadingConfig<T, ()>
where
T: ReloadableConfig,
{
pub fn build() -> Result<Self, <T as ReloadableConfig>::Error> {
Ok(ReloadingConfig {
config: Arc::new(ArcSwap::new(Arc::new(T::build()?))),
on_update: (),
})
}
}
impl<T, F> ReloadingConfig<T, F> {
#[must_use]
pub fn with_on_update<U>(self, new: U) -> ReloadingConfig<T, U> {
ReloadingConfig {
config: self.config,
on_update: new,
}
}
#[must_use]
pub fn load(&self) -> Arc<T> {
self.config.load_full()
}
}
impl<T, F> ReloadingConfig<T, F>
where
T: ReloadableConfig,
F: ReloadCallback,
{
pub fn reload(&self) -> Result<(), <T as ReloadableConfig>::Error> {
let config = T::build()?;
self.config.store(Arc::new(config));
self.on_update.invoke();
Ok(())
}
}
#[cfg(feature = "signal")]
impl<T, F> ReloadingConfig<T, F>
where
T: ReloadableConfig + Send + Sync + 'static,
F: ReloadCallback + Clone + Send + Sync + 'static,
{
pub fn spawn_signal_handler(&self) -> Result<std::thread::JoinHandle<()>, std::io::Error>
where
<T as ReloadableConfig>::Error: std::fmt::Display,
{
use signal_hook::{consts::SIGHUP, iterator::Signals};
let mut signals = Signals::new([SIGHUP])?;
let config = self.clone();
Ok(std::thread::spawn(move || {
for signal in &mut signals {
if signal == SIGHUP {
if let Err(err) = config.reload() {
#[cfg(feature = "tracing")]
tracing::error!(%err, "Failed to reload configuration");
#[cfg(not(feature = "tracing"))]
{
let _ = err;
}
}
}
}
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Configuration;
#[derive(Debug, Clone, PartialEq, Configuration)]
struct TestConfig {
value: u32,
}
impl ReloadableConfig for TestConfig {
type Error = &'static str;
fn build() -> Result<Self, Self::Error> {
Ok(TestConfig { value: 42 })
}
}
#[test]
fn test_build_and_load() {
let config = ReloadingConfig::<TestConfig, _>::build().unwrap();
let current = config.load();
assert_eq!(current.value, 42);
}
#[test]
fn test_reload_without_callback() {
let config = TestConfig::reloading().unwrap();
config.reload().unwrap();
let current = config.load();
assert_eq!(current.value, 42);
}
#[test]
fn test_reload_with_callback() {
use std::sync::atomic::{AtomicBool, Ordering};
let called = Arc::new(AtomicBool::new(false));
let called_clone = Arc::clone(&called);
let config = TestConfig::reloading().unwrap().with_on_update(move || {
called_clone.store(true, Ordering::SeqCst);
});
assert!(!called.load(Ordering::SeqCst));
config.reload().unwrap();
assert!(called.load(Ordering::SeqCst));
}
#[test]
fn test_reload_updates_all_clones() {
use std::sync::atomic::{AtomicU32, Ordering};
static COUNTER: AtomicU32 = AtomicU32::new(0);
#[derive(Debug, serde::Deserialize, Configuration)]
struct CountingConfig {
id: u32,
}
impl ReloadableConfig for CountingConfig {
type Error = std::convert::Infallible;
fn build() -> Result<Self, Self::Error> {
Ok(CountingConfig {
id: COUNTER.fetch_add(1, Ordering::SeqCst),
})
}
}
let config1 = CountingConfig::reloading().unwrap();
let config2 = config1.clone();
assert_eq!(config1.load().id, 0);
assert_eq!(config2.load().id, 0);
config1.reload().unwrap();
assert_eq!(config1.load().id, 1);
assert_eq!(config2.load().id, 1);
}
#[test]
fn test_reload_error_leaves_config_unchanged() {
use std::sync::atomic::{AtomicBool, Ordering};
static SHOULD_FAIL: AtomicBool = AtomicBool::new(false);
#[derive(Debug, serde::Deserialize, Configuration)]
struct FallibleConfig {
value: u32,
}
impl ReloadableConfig for FallibleConfig {
type Error = &'static str;
fn build() -> Result<Self, Self::Error> {
if SHOULD_FAIL.load(Ordering::SeqCst) {
Err("Build failed")
} else {
Ok(FallibleConfig { value: 42 })
}
}
}
let config = FallibleConfig::reloading().unwrap();
assert_eq!(config.load().value, 42);
SHOULD_FAIL.store(true, Ordering::SeqCst);
let result = config.reload();
assert!(result.is_err());
assert_eq!(config.load().value, 42);
SHOULD_FAIL.store(false, Ordering::SeqCst);
config.reload().unwrap();
assert_eq!(config.load().value, 42);
}
#[test]
fn test_callback_not_invoked_on_reload_error() {
use std::sync::atomic::{AtomicBool, Ordering};
static SHOULD_FAIL: AtomicBool = AtomicBool::new(false);
#[derive(Debug, serde::Deserialize, Configuration)]
struct FallibleConfig {
value: u32,
}
impl ReloadableConfig for FallibleConfig {
type Error = &'static str;
fn build() -> Result<Self, Self::Error> {
if SHOULD_FAIL.load(Ordering::SeqCst) {
Err("Build failed")
} else {
Ok(FallibleConfig { value: 100 })
}
}
}
let callback_called = Arc::new(AtomicBool::new(false));
let callback_called_clone = Arc::clone(&callback_called);
let config = FallibleConfig::reloading()
.unwrap()
.with_on_update(move || {
callback_called_clone.store(true, Ordering::SeqCst);
});
assert_eq!(config.load().value, 100);
config.reload().unwrap();
assert!(callback_called.load(Ordering::SeqCst));
callback_called.store(false, Ordering::SeqCst);
SHOULD_FAIL.store(true, Ordering::SeqCst);
let result = config.reload();
assert!(result.is_err());
assert!(!callback_called.load(Ordering::SeqCst));
}
}