use std::any::{Any, TypeId};
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use async_trait::async_trait;
use tracing::{Instrument, error, info, warn};
pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug)]
pub enum Error {
Boot {
name: &'static str,
source: BoxError,
},
Validate {
name: &'static str,
source: BoxError,
},
Reload {
name: &'static str,
source: BoxError,
},
Run {
name: &'static str,
source: BoxError,
},
Recoverable {
name: &'static str,
source: BoxError,
},
Other(BoxError),
}
impl std::fmt::Display for Error {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match self {
Error::Boot { name, source } => {
write!(f, "provider '{name}' failed during boot: {source}")
}
Error::Validate { name, source } => {
write!(f, "provider '{name}' failed during validate: {source}")
}
Error::Reload { name, source } => {
write!(f, "reload of '{name}' failed: {source}")
}
Error::Run { name, source } => {
write!(f, "runnable '{name}' failed: {source}")
}
Error::Recoverable { name, source } => {
write!(f, "runnable '{name}' failed (recoverable): {source}")
}
Error::Other(e) => std::fmt::Display::fmt(e, f),
}
}
}
impl<E> From<E> for Error
where
E: std::error::Error + Send + Sync + 'static,
{
fn from(e: E) -> Self {
Error::Other(Box::new(e))
}
}
impl Error {
pub fn msg(s: impl Into<String>) -> Self {
#[derive(Debug)]
struct MsgErr(String);
impl std::fmt::Display for MsgErr {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for MsgErr {}
Error::Other(Box::new(MsgErr(s.into())))
}
fn into_boot(
self,
name: &'static str,
) -> Self {
match self {
Error::Other(source) => Error::Boot { name, source },
other => other,
}
}
fn into_validate(
self,
name: &'static str,
) -> Self {
match self {
Error::Other(source) => Error::Validate { name, source },
other => other,
}
}
fn into_reload(
self,
name: &'static str,
) -> Self {
match self {
Error::Other(source) => Error::Reload { name, source },
other => other,
}
}
fn into_run(
self,
name: &'static str,
) -> Self {
match self {
Error::Other(source) => Error::Run { name, source },
Error::Recoverable { name: "", source } => Error::Recoverable { name, source },
other => other,
}
}
pub fn recoverable(s: impl Into<String>) -> Self {
#[derive(Debug)]
struct MsgErr(String);
impl std::fmt::Display for MsgErr {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for MsgErr {}
Error::Recoverable { name: "", source: Box::new(MsgErr(s.into())) }
}
}
pub mod priority {
#[doc(hidden)]
pub const FIRST: u8 = 0;
pub const EARLY: u8 = 50;
pub const NORMAL: u8 = 100;
pub const LATE: u8 = 150;
#[doc(hidden)]
pub const LAST: u8 = u8::MAX;
}
#[derive(Clone, Debug, Default)]
pub struct ProviderOrder {
before: Vec<TypeId>,
after: Vec<TypeId>,
}
impl ProviderOrder {
pub fn new() -> Self {
Self::default()
}
pub fn before<T: 'static>(mut self) -> Self {
self.before.push(TypeId::of::<T>());
self
}
pub fn after<T: 'static>(mut self) -> Self {
self.after.push(TypeId::of::<T>());
self
}
pub fn before_types(&self) -> &[TypeId] {
&self.before
}
pub fn after_types(&self) -> &[TypeId] {
&self.after
}
}
#[async_trait]
pub trait ReloadState: Send + Sync + Sized + 'static {
async fn reload(&self) -> Result<()>;
}
#[async_trait]
pub trait Reloadable<S>: Send + Sync + 'static {
fn priority(&self) -> Option<u8> {
None
}
async fn reload(
&self,
state: &S,
) -> Result<()>;
}
#[async_trait]
pub trait Runnable<S>: Send + Sync + 'static {
async fn run(
self: Arc<Self>,
state: S,
) -> Result<()>;
}
#[async_trait]
pub trait Provider<S>: Any + Send + Sync + 'static {
fn name(&self) -> &'static str {
"provider"
}
fn boot_priority(&self) -> Option<u8> {
None
}
fn run_priority(&self) -> Option<u8> {
None
}
fn order(&self) -> ProviderOrder {
ProviderOrder::default()
}
async fn boot(
&self,
_state: &S,
) -> Result<()> {
Ok(())
}
async fn shutdown(
&self,
_state: &S,
) -> Result<()> {
Ok(())
}
fn validate(
&self,
_state: &S,
) -> Result<()> {
Ok(())
}
fn as_any(&self) -> &dyn Any
where
Self: Sized,
{
self
}
fn as_reloadable(&self) -> Option<&dyn Reloadable<S>> {
None
}
fn as_runnable(self: Arc<Self>) -> Option<Arc<dyn Runnable<S>>> {
None
}
}
pub struct Registry<S> {
providers: RwLock<HashMap<TypeId, Arc<dyn Provider<S>>>>,
by_type: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
registration_order: RwLock<Vec<TypeId>>,
lifecycle_order: RwLock<Option<Vec<TypeId>>>,
}
impl<S: 'static> Registry<S> {
pub fn new() -> Self {
Self {
providers: RwLock::new(HashMap::new()),
by_type: RwLock::new(HashMap::new()),
registration_order: RwLock::new(Vec::new()),
lifecycle_order: RwLock::new(None),
}
}
pub fn insert<C>(
&self,
item: Arc<C>,
) -> &Self
where
C: Provider<S> + 'static,
{
let type_id = TypeId::of::<C>();
let any: Arc<dyn Any + Send + Sync> = item.clone();
let mut by_type = self.by_type.write().expect("registry by_type lock poisoned");
if by_type.contains_key(&type_id) {
warn!(
"⚠️ duplicate provider type '{}' — skipping registration",
std::any::type_name::<C>()
);
return self;
}
by_type.insert(type_id, any);
drop(by_type);
let it: Arc<dyn Provider<S>> = item;
self.providers.write().expect("registry providers lock poisoned").insert(type_id, it);
self.registration_order.write().expect("registry order lock poisoned").push(type_id);
*self.lifecycle_order.write().expect("registry lifecycle order lock poisoned") = None;
self
}
pub fn with_typed<T, R>(
&self,
f: impl FnOnce(&T) -> R,
) -> Option<R>
where
T: Provider<S> + 'static,
{
let typed = self.resolve::<T>()?;
Some(f(typed.as_ref()))
}
pub fn resolve<T>(&self) -> Option<Arc<T>>
where
T: Provider<S> + 'static,
{
let any = self
.by_type
.read()
.expect("registry by_type lock poisoned")
.get(&TypeId::of::<T>())?
.clone();
Arc::downcast::<T>(any).ok()
}
#[allow(unused)]
pub fn providers(&self) -> Vec<Arc<dyn Provider<S>>> {
self.providers.read().expect("registry providers lock poisoned").values().cloned().collect()
}
fn provider_entries_snapshot(&self) -> Vec<ProviderEntry<S>> {
let providers = self.providers.read().expect("registry providers lock poisoned");
self.registration_order
.read()
.expect("registry order lock poisoned")
.iter()
.enumerate()
.filter_map(|(index, type_id)| {
providers.get(type_id).cloned().map(|provider| ProviderEntry {
type_id: *type_id,
index,
provider,
})
})
.collect()
}
fn lifecycle_plan(&self) -> Result<Vec<Arc<dyn Provider<S>>>> {
if let Some(type_ids) = self
.lifecycle_order
.read()
.expect("registry lifecycle order lock poisoned")
.as_ref()
.cloned()
{
return Ok(self.providers_from_type_ids(&type_ids));
}
let ordered = order_provider_entries(self.provider_entries_snapshot())?;
let type_ids = ordered.iter().map(|entry| entry.type_id).collect::<Vec<_>>();
let providers = ordered.iter().map(|entry| entry.provider.clone()).collect::<Vec<_>>();
#[cfg(debug_assertions)]
tracing::debug!(
providers = ?providers.iter().map(|provider| provider.name()).collect::<Vec<_>>(),
"provider lifecycle order"
);
*self.lifecycle_order.write().expect("registry lifecycle order lock poisoned") =
Some(type_ids);
Ok(providers)
}
fn providers_from_type_ids(
&self,
type_ids: &[TypeId],
) -> Vec<Arc<dyn Provider<S>>> {
let providers = self.providers.read().expect("registry providers lock poisoned");
type_ids.iter().filter_map(|type_id| providers.get(type_id).cloned()).collect()
}
#[allow(unused)]
pub fn list_names(&self) -> Vec<&'static str> {
self.providers().iter().map(|c| c.name()).collect()
}
pub fn lifecycle_names(&self) -> Result<Vec<&'static str>> {
Ok(self.lifecycle_plan()?.iter().map(|provider| provider.name()).collect())
}
pub fn run_all(
&self,
state: S,
join_set: &mut tokio::task::JoinSet<Result<()>>,
) -> usize
where
S: Clone + Send + 'static,
{
let mut spawned = 0usize;
let mut providers = self.providers();
providers.sort_by_key(|provider| {
(provider.run_priority().unwrap_or(priority::NORMAL), provider.name())
});
for provider in providers {
let Some(runnable) = provider.clone().as_runnable() else { continue };
let name = provider.name();
let state = state.clone();
join_set.spawn(
async move { runnable.run(state).await.map_err(|e| e.into_run(name)) }
.instrument(tracing::debug_span!("provider", provider = %name)),
);
spawned += 1;
}
spawned
}
pub fn validate_all(
&self,
state: &S,
) -> Result<()> {
for provider in self.lifecycle_plan()? {
let name = provider.name();
provider.validate(state).map_err(|e| e.into_validate(name))?;
}
Ok(())
}
pub async fn boot_all(
&self,
state: &S,
) -> Result<()> {
for provider in self.lifecycle_plan()? {
let name = provider.name();
if let Err(e) = provider.boot(state).await {
error!("❌ boot provider '{}' failed: {}", name, e);
return Err(e.into_boot(name));
}
}
Ok(())
}
pub async fn shutdown_all(
&self,
state: &S,
) -> Result<()> {
let mut providers = self.lifecycle_plan()?;
providers.reverse();
for provider in providers {
let name = provider.name();
if let Err(e) = provider.shutdown(state).await {
warn!("shutdown of provider '{}' failed: {}", name, e);
}
}
Ok(())
}
pub async fn reload_one(
&self,
name: &str,
state: &S,
) -> Result<()> {
let Some(provider) = self.providers().into_iter().find(|provider| provider.name() == name)
else {
return Err(Error::msg(format!(
"reload_by_name: no provider registered with name '{}'",
name
)));
};
let Some(reloadable) = provider.as_reloadable() else {
return Err(Error::msg(format!(
"reload_by_name: provider '{}' is not reloadable",
name
)));
};
info!("♻️ reloading service '{}'", name);
match reloadable.reload(state).await {
Ok(()) => {
info!("♻️ {} reloaded", name);
Ok(())
}
Err(e) => {
warn!("❌ reload of {} failed: {e}", name);
let static_name = provider.name();
Err(e.into_reload(static_name))
}
}
}
}
impl<S> Registry<S>
where
S: ReloadState + 'static,
{
pub async fn reload_all(
&self,
state: &S,
) -> Result<()> {
state.reload().await?;
info!("✅ state reloaded");
for provider in self.lifecycle_plan()? {
let name = provider.name();
if let Some(reloadable) = provider.as_reloadable() {
if let Err(e) = reloadable.reload(state).await {
warn!("❌ reload of {} failed: {e}", name);
} else {
info!("♻️ {} reloaded", name);
}
}
}
Ok(())
}
}
struct ProviderEntry<S> {
type_id: TypeId,
index: usize,
provider: Arc<dyn Provider<S>>,
}
impl<S> Clone for ProviderEntry<S> {
fn clone(&self) -> Self {
Self { type_id: self.type_id, index: self.index, provider: self.provider.clone() }
}
}
fn order_provider_entries<S: 'static>(
entries: Vec<ProviderEntry<S>>
) -> Result<Vec<ProviderEntry<S>>> {
let len = entries.len();
let positions: HashMap<TypeId, usize> =
entries.iter().enumerate().map(|(idx, entry)| (entry.type_id, idx)).collect();
let priorities: Vec<u8> =
entries.iter().map(|entry| lifecycle_priority(&entry.provider)).collect();
let mut outgoing: Vec<HashSet<usize>> = (0..len).map(|_| HashSet::new()).collect();
let mut indegree = vec![0usize; len];
let mut add_edge = |from: usize, to: usize| {
if from != to && outgoing[from].insert(to) {
indegree[to] += 1;
}
};
for (idx, entry) in entries.iter().enumerate() {
let order = entry.provider.order();
for target in order.before_types() {
if let Some(&target_idx) = positions.get(target) {
add_edge(idx, target_idx);
}
}
for target in order.after_types() {
if let Some(&target_idx) = positions.get(target) {
add_edge(target_idx, idx);
}
}
}
let mut ready: Vec<usize> = indegree
.iter()
.enumerate()
.filter_map(|(idx, degree)| (*degree == 0).then_some(idx))
.collect();
let mut ordered = Vec::with_capacity(len);
while !ready.is_empty() {
ready.sort_by_key(|idx| {
(priorities[*idx], entries[*idx].index, entries[*idx].provider.name())
});
let idx = ready.remove(0);
ordered.push(idx);
let next: Vec<_> = outgoing[idx].iter().copied().collect();
for target in next {
indegree[target] -= 1;
if indegree[target] == 0 {
ready.push(target);
}
}
}
if ordered.len() != len {
let blocked = indegree
.iter()
.enumerate()
.filter_map(|(idx, degree)| (*degree > 0).then_some(entries[idx].provider.name()))
.collect::<Vec<_>>()
.join(", ");
return Err(Error::msg(format!("provider lifecycle order cycle detected: {blocked}")));
}
Ok(ordered.into_iter().map(|idx| entries[idx].clone()).collect())
}
fn lifecycle_priority<S: 'static>(provider: &Arc<dyn Provider<S>>) -> u8 {
provider
.boot_priority()
.or_else(|| provider.as_reloadable().and_then(|reloadable| reloadable.priority()))
.unwrap_or(priority::NORMAL)
}
impl<S: 'static> Default for Registry<S> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
#[derive(Clone, Default)]
struct TestState {
seen: Arc<Mutex<Vec<&'static str>>>,
}
struct DbProvider;
struct CacheProvider;
struct ApiProvider;
#[async_trait]
impl Provider<TestState> for DbProvider {
fn name(&self) -> &'static str {
"db"
}
fn validate(
&self,
state: &TestState,
) -> Result<()> {
state.seen.lock().expect("test log poisoned").push("db");
Ok(())
}
}
#[async_trait]
impl Provider<TestState> for CacheProvider {
fn name(&self) -> &'static str {
"cache"
}
fn order(&self) -> ProviderOrder {
ProviderOrder::new().after::<DbProvider>()
}
fn validate(
&self,
state: &TestState,
) -> Result<()> {
state.seen.lock().expect("test log poisoned").push("cache");
Ok(())
}
}
#[async_trait]
impl Provider<TestState> for ApiProvider {
fn name(&self) -> &'static str {
"api"
}
fn order(&self) -> ProviderOrder {
ProviderOrder::new().after::<CacheProvider>()
}
fn validate(
&self,
state: &TestState,
) -> Result<()> {
state.seen.lock().expect("test log poisoned").push("api");
Ok(())
}
}
#[test]
fn lifecycle_order_uses_type_dependencies() {
let state = TestState::default();
let registry = Registry::<TestState>::new();
registry
.insert(Arc::new(ApiProvider))
.insert(Arc::new(CacheProvider))
.insert(Arc::new(DbProvider));
registry.validate_all(&state).expect("validation should succeed");
let seen = state.seen.lock().expect("test log poisoned").clone();
assert_eq!(seen, vec!["db", "cache", "api"]);
}
struct CycleA;
struct CycleB;
#[async_trait]
impl Provider<TestState> for CycleA {
fn name(&self) -> &'static str {
"cycle-a"
}
fn order(&self) -> ProviderOrder {
ProviderOrder::new().after::<CycleB>()
}
}
#[async_trait]
impl Provider<TestState> for CycleB {
fn name(&self) -> &'static str {
"cycle-b"
}
fn order(&self) -> ProviderOrder {
ProviderOrder::new().after::<CycleA>()
}
}
#[test]
fn lifecycle_order_rejects_cycles() {
let state = TestState::default();
let registry = Registry::<TestState>::new();
registry.insert(Arc::new(CycleA)).insert(Arc::new(CycleB));
let err = registry.validate_all(&state).expect_err("cycle must be rejected");
assert!(err.to_string().contains("provider lifecycle order cycle detected"));
}
}