use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use async_trait::async_trait;
use tracing::{Instrument, error, info, warn};
pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug)]
pub enum Error {
Boot {
name: &'static str,
source: BoxError
},
Validate {
name: &'static str,
source: BoxError
},
Reload {
name: &'static str,
source: BoxError
},
Run {
name: &'static str,
source: BoxError
},
Recoverable {
name: &'static str,
source: BoxError
},
Other(BoxError)
}
impl std::fmt::Display for Error {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>
) -> std::fmt::Result {
match self {
Error::Boot { name, source } => {
write!(f, "provider '{name}' failed during boot: {source}")
}
Error::Validate { name, source } => {
write!(f, "provider '{name}' failed during validate: {source}")
}
Error::Reload { name, source } => {
write!(f, "reload of '{name}' failed: {source}")
}
Error::Run { name, source } => {
write!(f, "runnable '{name}' failed: {source}")
}
Error::Recoverable { name, source } => {
write!(f, "runnable '{name}' failed (recoverable): {source}")
}
Error::Other(e) => std::fmt::Display::fmt(e, f)
}
}
}
impl<E> From<E> for Error
where
E: std::error::Error + Send + Sync + 'static
{
fn from(e: E) -> Self {
Error::Other(Box::new(e))
}
}
impl Error {
pub fn msg(s: impl Into<String>) -> Self {
#[derive(Debug)]
struct MsgErr(String);
impl std::fmt::Display for MsgErr {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>
) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for MsgErr {}
Error::Other(Box::new(MsgErr(s.into())))
}
fn into_boot(
self,
name: &'static str
) -> Self {
match self {
Error::Other(source) => Error::Boot { name, source },
other => other
}
}
fn into_validate(
self,
name: &'static str
) -> Self {
match self {
Error::Other(source) => Error::Validate { name, source },
other => other
}
}
fn into_reload(
self,
name: &'static str
) -> Self {
match self {
Error::Other(source) => Error::Reload { name, source },
other => other
}
}
fn into_run(
self,
name: &'static str
) -> Self {
match self {
Error::Other(source) => Error::Run { name, source },
Error::Recoverable { name: "", source } => Error::Recoverable { name, source },
other => other
}
}
pub fn recoverable(s: impl Into<String>) -> Self {
#[derive(Debug)]
struct MsgErr(String);
impl std::fmt::Display for MsgErr {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>
) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for MsgErr {}
Error::Recoverable { name: "", source: Box::new(MsgErr(s.into())) }
}
}
pub mod priority {
pub const FIRST: u8 = 0;
pub const EARLY: u8 = 50;
pub const NORMAL: u8 = 100;
pub const LATE: u8 = 150;
pub const LAST: u8 = u8::MAX;
}
#[async_trait]
pub trait ReloadState: Send + Sync + Sized + 'static {
async fn reload(&self) -> Result<()>;
}
#[async_trait]
pub trait Reloadable<S>: Send + Sync + 'static {
fn priority(&self) -> Option<u8> {
None
}
async fn reload(
&self,
state: &S
) -> Result<()>;
}
pub type TaskFuture = Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>;
pub trait Runnable<S>: Send + Sync + 'static {
fn run(
&self,
state: S
) -> TaskFuture;
}
#[async_trait]
pub trait Provider<S>: Any + Send + Sync + 'static {
fn name(&self) -> &'static str {
"provider"
}
fn boot_priority(&self) -> Option<u8> {
None
}
fn run_priority(&self) -> Option<u8> {
None
}
async fn boot(
&self,
_state: &S
) -> Result<()> {
Ok(())
}
async fn shutdown(
&self,
_state: &S
) -> Result<()> {
Ok(())
}
fn validate(
&self,
_state: &S
) -> Result<()> {
Ok(())
}
fn as_any(&self) -> &dyn Any
where
Self: Sized
{
self
}
fn as_reloadable(&self) -> Option<&dyn Reloadable<S>> {
None
}
fn as_runnable(&self) -> Option<&dyn Runnable<S>> {
None
}
}
pub struct Registry<S> {
providers: RwLock<HashMap<TypeId, Arc<dyn Provider<S>>>>,
by_type: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>
}
impl<S: 'static> Registry<S> {
pub fn new() -> Self {
Self { providers: RwLock::new(HashMap::new()), by_type: RwLock::new(HashMap::new()) }
}
pub fn insert<C>(
&self,
item: Arc<C>
) -> &Self
where
C: Provider<S> + 'static
{
let type_id = TypeId::of::<C>();
let any: Arc<dyn Any + Send + Sync> = item.clone();
let mut by_type = self.by_type.write().expect("registry by_type lock poisoned");
if by_type.contains_key(&type_id) {
warn!(
"⚠️ duplicate provider type '{}' — skipping registration",
std::any::type_name::<C>()
);
return self;
}
by_type.insert(type_id, any);
drop(by_type);
let it: Arc<dyn Provider<S>> = item;
self.providers.write().expect("registry providers lock poisoned").insert(type_id, it);
self
}
pub fn with_typed<T, R>(
&self,
f: impl FnOnce(&T) -> R
) -> Option<R>
where
T: Provider<S> + 'static
{
let typed = self.resolve::<T>()?;
Some(f(typed.as_ref()))
}
pub fn resolve<T>(&self) -> Option<Arc<T>>
where
T: Provider<S> + 'static
{
let any = self
.by_type
.read()
.expect("registry by_type lock poisoned")
.get(&TypeId::of::<T>())?
.clone();
Arc::downcast::<T>(any).ok()
}
#[allow(unused)]
pub fn providers(&self) -> Vec<Arc<dyn Provider<S>>> {
self.providers.read().expect("registry providers lock poisoned").values().cloned().collect()
}
#[allow(unused)]
pub fn list_names(&self) -> Vec<&'static str> {
self.providers().iter().map(|c| c.name()).collect()
}
pub fn run_all(
&self,
state: S,
join_set: &mut tokio::task::JoinSet<Result<()>>
) -> usize
where
S: Clone + Send + 'static
{
let mut spawned = 0usize;
let mut providers = self.providers();
providers.sort_by_key(|provider| {
(provider.run_priority().unwrap_or(priority::NORMAL), provider.name())
});
for provider in providers {
let Some(runnable) = provider.as_runnable() else {
continue;
};
let name = provider.name();
let fut = runnable
.run(state.clone())
.instrument(tracing::debug_span!("provider", provider = %name));
join_set.spawn(async move { fut.await.map_err(|e| e.into_run(name)) });
spawned += 1;
}
spawned
}
pub fn validate_all(
&self,
state: &S
) -> Result<()> {
for provider in self.providers() {
let name = provider.name();
provider.validate(state).map_err(|e| e.into_validate(name))?;
}
Ok(())
}
pub async fn boot_all(
&self,
state: &S
) -> Result<()> {
let mut providers = self.providers();
providers.sort_by_key(|provider| {
(provider.boot_priority().unwrap_or(priority::NORMAL), provider.name())
});
for provider in providers {
let name = provider.name();
if let Err(e) = provider.boot(state).await {
error!("❌ boot provider '{}' failed: {}", name, e);
return Err(e.into_boot(name));
}
}
Ok(())
}
pub async fn shutdown_all(
&self,
state: &S
) -> Result<()> {
let mut providers = self.providers();
providers.sort_by_key(|provider| {
(provider.boot_priority().unwrap_or(priority::NORMAL), provider.name())
});
providers.reverse();
for provider in providers {
let name = provider.name();
if let Err(e) = provider.shutdown(state).await {
warn!("shutdown of provider '{}' failed: {}", name, e);
}
}
Ok(())
}
pub async fn reload_one(
&self,
name: &str,
state: &S
) -> Result<()> {
let Some(provider) = self.providers().into_iter().find(|provider| provider.name() == name)
else {
return Err(Error::msg(format!(
"reload_by_name: no provider registered with name '{}'",
name
)));
};
let Some(reloadable) = provider.as_reloadable() else {
return Err(Error::msg(format!(
"reload_by_name: provider '{}' is not reloadable",
name
)));
};
info!("♻️ reloading service '{}'", name);
match reloadable.reload(state).await {
Ok(()) => {
info!("♻️ {} reloaded", name);
Ok(())
}
Err(e) => {
warn!("❌ reload of {} failed: {e}", name);
let static_name = provider.name();
Err(e.into_reload(static_name))
}
}
}
}
impl<S> Registry<S>
where
S: ReloadState + 'static
{
pub async fn reload_all(
&self,
state: &S
) -> Result<()> {
state.reload().await?;
info!("✅ state reloaded");
let mut list: Vec<(u8, &'static str, Arc<dyn Provider<S>>)> = self
.providers()
.into_iter()
.filter_map(|provider| {
let reloadable = provider.as_reloadable()?;
Some((reloadable.priority().unwrap_or(priority::NORMAL), provider.name(), provider))
})
.collect();
list.sort_by_key(|(priority, name, _)| (*priority, *name));
for (_, name, provider) in list {
if let Some(reloadable) = provider.as_reloadable() {
if let Err(e) = reloadable.reload(state).await {
warn!("❌ reload of {} failed: {e}", name);
} else {
info!("♻️ {} reloaded", name);
}
}
}
Ok(())
}
}
impl<S: 'static> Default for Registry<S> {
fn default() -> Self {
Self::new()
}
}