use std::any::{type_name, Any, TypeId};
use std::collections::HashMap;
use std::fmt;
pub trait Bean: Clone + Send + Sync + 'static {
fn dependencies() -> Vec<(TypeId, &'static str)>;
fn build(ctx: &BeanContext) -> Self;
}
pub trait BeanState: Clone + Send + Sync + 'static {
fn from_context(ctx: &BeanContext) -> Self;
}
pub struct BeanContext {
entries: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
}
impl fmt::Debug for BeanContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BeanContext")
.field("entry_count", &self.entries.len())
.finish()
}
}
impl BeanContext {
pub fn get<T: Clone + 'static>(&self) -> T {
self.entries
.get(&TypeId::of::<T>())
.and_then(|v| v.downcast_ref::<T>())
.unwrap_or_else(|| {
panic!(
"Bean of type `{}` not found in context",
type_name::<T>()
)
})
.clone()
}
pub fn try_get<T: Clone + 'static>(&self) -> Option<T> {
self.entries
.get(&TypeId::of::<T>())
.and_then(|v| v.downcast_ref::<T>())
.cloned()
}
}
type Factory = Box<dyn FnOnce(&BeanContext) -> Box<dyn Any + Send + Sync> + Send>;
pub struct BeanRegistry {
beans: Vec<BeanRegistration>,
provided: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
}
struct BeanRegistration {
type_id: TypeId,
type_name: &'static str,
dependencies: Vec<(TypeId, &'static str)>,
factory: Factory,
}
#[derive(Debug)]
pub enum BeanError {
CyclicDependency { cycle: Vec<String> },
MissingDependency { bean: String, dependency: String },
DuplicateBean { type_name: String },
}
impl fmt::Display for BeanError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BeanError::CyclicDependency { cycle } => {
write!(
f,
"Circular dependency detected: {}",
cycle.join(" -> ")
)
}
BeanError::MissingDependency { bean, dependency } => {
write!(
f,
"Missing dependency for bean '{}': type '{}' is not registered. \
Use .provide(instance) or .with_bean::<Type>()",
bean, dependency
)
}
BeanError::DuplicateBean { type_name } => {
write!(f, "Bean of type '{}' registered twice", type_name)
}
}
}
}
impl std::error::Error for BeanError {}
impl BeanRegistry {
pub fn new() -> Self {
Self {
beans: Vec::new(),
provided: HashMap::new(),
}
}
pub fn provide<T: Clone + Send + Sync + 'static>(&mut self, value: T) -> &mut Self {
self.provided.insert(TypeId::of::<T>(), Box::new(value));
self
}
pub fn register<T: Bean>(&mut self) -> &mut Self {
self.beans.push(BeanRegistration {
type_id: TypeId::of::<T>(),
type_name: type_name::<T>(),
dependencies: T::dependencies(),
factory: Box::new(|ctx| Box::new(T::build(ctx))),
});
self
}
pub fn resolve(self) -> Result<BeanContext, BeanError> {
let mut entries: HashMap<TypeId, Box<dyn Any + Send + Sync>> = HashMap::new();
for (tid, value) in self.provided {
entries.insert(tid, value);
}
let bean_count = self.beans.len();
if bean_count == 0 {
return Ok(BeanContext { entries });
}
let mut seen: HashMap<TypeId, &str> = HashMap::new();
for reg in &self.beans {
if entries.contains_key(®.type_id) {
return Err(BeanError::DuplicateBean {
type_name: reg.type_name.to_string(),
});
}
if seen.insert(reg.type_id, reg.type_name).is_some() {
return Err(BeanError::DuplicateBean {
type_name: reg.type_name.to_string(),
});
}
}
let id_to_idx: HashMap<TypeId, usize> = self
.beans
.iter()
.enumerate()
.map(|(i, r)| (r.type_id, i))
.collect();
for reg in &self.beans {
for (dep_id, dep_name) in ®.dependencies {
if !entries.contains_key(dep_id) && !id_to_idx.contains_key(dep_id) {
return Err(BeanError::MissingDependency {
bean: reg.type_name.to_string(),
dependency: dep_name.to_string(),
});
}
}
}
let mut in_degree: Vec<usize> = Vec::with_capacity(bean_count);
for reg in &self.beans {
let deg = reg
.dependencies
.iter()
.filter(|(d, _)| id_to_idx.contains_key(d))
.count();
in_degree.push(deg);
}
let mut dependents: Vec<Vec<usize>> = vec![Vec::new(); bean_count];
for (i, reg) in self.beans.iter().enumerate() {
for (dep_id, _) in ®.dependencies {
if let Some(&dep_idx) = id_to_idx.get(dep_id) {
dependents[dep_idx].push(i);
}
}
}
let mut queue: Vec<usize> = (0..bean_count)
.filter(|&i| in_degree[i] == 0)
.collect();
let mut sorted_order: Vec<usize> = Vec::with_capacity(bean_count);
while let Some(idx) = queue.pop() {
sorted_order.push(idx);
for &dep_idx in &dependents[idx] {
in_degree[dep_idx] -= 1;
if in_degree[dep_idx] == 0 {
queue.push(dep_idx);
}
}
}
if sorted_order.len() != bean_count {
let cycle: Vec<String> = (0..bean_count)
.filter(|i| in_degree[*i] > 0)
.map(|i| self.beans[i].type_name.to_string())
.collect();
return Err(BeanError::CyclicDependency { cycle });
}
let mut bean_data: Vec<Option<(TypeId, Factory)>> = self
.beans
.into_iter()
.map(|r| Some((r.type_id, r.factory)))
.collect();
for idx in sorted_order {
let (type_id, factory) = bean_data[idx].take().unwrap();
let ctx = BeanContext { entries };
let bean_value = factory(&ctx);
entries = ctx.entries;
entries.insert(type_id, bean_value);
}
Ok(BeanContext { entries })
}
}
impl Default for BeanRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct Dep {
value: i32,
}
#[derive(Clone)]
struct ServiceA {
dep: Dep,
}
impl Bean for ServiceA {
fn dependencies() -> Vec<(TypeId, &'static str)> {
vec![(TypeId::of::<Dep>(), type_name::<Dep>())]
}
fn build(ctx: &BeanContext) -> Self {
Self {
dep: ctx.get::<Dep>(),
}
}
}
#[derive(Clone)]
struct ServiceB {
a: ServiceA,
dep: Dep,
}
impl Bean for ServiceB {
fn dependencies() -> Vec<(TypeId, &'static str)> {
vec![
(TypeId::of::<ServiceA>(), type_name::<ServiceA>()),
(TypeId::of::<Dep>(), type_name::<Dep>()),
]
}
fn build(ctx: &BeanContext) -> Self {
Self {
a: ctx.get::<ServiceA>(),
dep: ctx.get::<Dep>(),
}
}
}
#[test]
fn resolve_simple_graph() {
let mut reg = BeanRegistry::new();
reg.provide(Dep { value: 42 });
reg.register::<ServiceA>();
reg.register::<ServiceB>();
let ctx = reg.resolve().unwrap();
let b: ServiceB = ctx.get();
assert_eq!(b.dep.value, 42);
assert_eq!(b.a.dep.value, 42);
}
#[test]
fn missing_dependency() {
let mut reg = BeanRegistry::new();
reg.register::<ServiceA>();
let err = reg.resolve().unwrap_err();
match &err {
BeanError::MissingDependency { dependency, .. } => {
assert!(dependency.contains("Dep"), "error should name the missing type: {}", err);
}
_ => panic!("expected MissingDependency, got {:?}", err),
}
}
#[test]
fn duplicate_bean_registered_twice() {
let mut reg = BeanRegistry::new();
reg.provide(Dep { value: 1 });
reg.register::<ServiceA>();
reg.register::<ServiceA>();
let err = reg.resolve().unwrap_err();
assert!(matches!(err, BeanError::DuplicateBean { .. }));
}
#[test]
fn duplicate_provided_and_bean() {
let mut reg = BeanRegistry::new();
reg.provide(Dep { value: 1 });
reg.provide(ServiceA {
dep: Dep { value: 2 },
});
reg.register::<ServiceA>();
let err = reg.resolve().unwrap_err();
assert!(matches!(err, BeanError::DuplicateBean { .. }));
}
#[derive(Clone)]
struct CycleA;
#[derive(Clone)]
struct CycleB;
impl Bean for CycleA {
fn dependencies() -> Vec<(TypeId, &'static str)> {
vec![(TypeId::of::<CycleB>(), type_name::<CycleB>())]
}
fn build(ctx: &BeanContext) -> Self {
let _ = ctx.get::<CycleB>();
Self
}
}
impl Bean for CycleB {
fn dependencies() -> Vec<(TypeId, &'static str)> {
vec![(TypeId::of::<CycleA>(), type_name::<CycleA>())]
}
fn build(ctx: &BeanContext) -> Self {
let _ = ctx.get::<CycleA>();
Self
}
}
#[test]
fn cyclic_dependency() {
let mut reg = BeanRegistry::new();
reg.register::<CycleA>();
reg.register::<CycleB>();
let err = reg.resolve().unwrap_err();
assert!(matches!(err, BeanError::CyclicDependency { .. }));
}
#[test]
fn provided_only() {
let mut reg = BeanRegistry::new();
reg.provide(Dep { value: 7 });
let ctx = reg.resolve().unwrap();
let d: Dep = ctx.get();
assert_eq!(d.value, 7);
}
#[test]
fn try_get_none() {
let reg = BeanRegistry::new();
let ctx = reg.resolve().unwrap();
assert!(ctx.try_get::<Dep>().is_none());
}
#[test]
fn empty_registry() {
let reg = BeanRegistry::new();
let ctx = reg.resolve().unwrap();
assert!(ctx.try_get::<i32>().is_none());
}
}