use std::{
any::{Any, TypeId},
collections::HashMap,
future::Future,
sync::Arc,
};
use crate::error::Result;
pub trait Injectable: Send + Sync + 'static {}
type ServiceBox = Arc<dyn Any + Send + Sync>;
#[derive(Clone, Default)]
pub struct Container {
services: HashMap<TypeId, ServiceBox>,
}
impl Container {
pub fn new() -> Self {
Self {
services: HashMap::new(),
}
}
pub fn register<T: Injectable>(&mut self, service: Arc<T>) {
let type_id = self.get_type_id::<T>();
self.insert_service(type_id, service);
}
fn get_type_id<T: Injectable>(&self) -> TypeId {
TypeId::of::<T>()
}
fn insert_service<T: Injectable>(&mut self, type_id: TypeId, service: Arc<T>) {
self.services.insert(type_id, service as ServiceBox);
}
pub fn register_factory<T: Injectable, F>(&mut self, factory: F)
where
F: FnOnce() -> T,
{
let service = self.create_service(factory);
self.register(service);
}
pub async fn register_async_factory<T, F, Fut>(&mut self, factory: F) -> Result<()>
where
T: Injectable,
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T>>,
{
let service = factory().await?;
self.register(Arc::new(service));
Ok(())
}
fn create_service<T: Injectable, F>(&self, factory: F) -> Arc<T>
where
F: FnOnce() -> T,
{
Arc::new(factory())
}
pub fn resolve<T: Injectable>(&self) -> Option<Arc<T>> {
let type_id = self.get_type_id::<T>();
self.lookup_service(type_id)
}
fn lookup_service<T: Injectable>(&self, type_id: TypeId) -> Option<Arc<T>> {
self.services
.get(&type_id)
.and_then(|boxed| self.downcast_service(boxed))
}
fn downcast_service<T: Injectable>(&self, boxed: &ServiceBox) -> Option<Arc<T>> {
boxed.clone().downcast::<T>().ok()
}
pub fn resolve_or_panic<T: Injectable>(&self) -> Arc<T> {
self.resolve()
.unwrap_or_else(|| panic!("Service {} not registered", std::any::type_name::<T>()))
}
pub fn contains<T: Injectable>(&self) -> bool {
let type_id = TypeId::of::<T>();
self.services.contains_key(&type_id)
}
pub fn len(&self) -> usize {
self.services.len()
}
pub fn is_empty(&self) -> bool {
self.services.is_empty()
}
pub fn clear(&mut self) {
self.services.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockDatabase {
connection_string: String,
}
impl Injectable for MockDatabase {}
impl MockDatabase {
fn new(conn: &str) -> Self {
Self {
connection_string: conn.to_string(),
}
}
}
struct MockUserService {
db: Arc<MockDatabase>,
}
impl Injectable for MockUserService {}
impl MockUserService {
fn new(db: Arc<MockDatabase>) -> Self {
Self { db }
}
}
#[test]
fn test_register_and_resolve() {
let mut container = Container::new();
let db = Arc::new(MockDatabase::new("postgres://localhost"));
container.register(db.clone());
let resolved: Arc<MockDatabase> = container.resolve().unwrap();
assert_eq!(resolved.connection_string, "postgres://localhost");
}
#[test]
fn test_register_factory() {
let mut container = Container::new();
container.register_factory(|| MockDatabase::new("sqlite::memory"));
let resolved: Arc<MockDatabase> = container.resolve().unwrap();
assert_eq!(resolved.connection_string, "sqlite::memory");
}
#[test]
fn test_resolve_missing_service() {
let container = Container::new();
let result: Option<Arc<MockDatabase>> = container.resolve();
assert!(result.is_none());
}
#[test]
#[should_panic(expected = "Service")]
fn test_resolve_or_panic() {
let container = Container::new();
let _: Arc<MockDatabase> = container.resolve_or_panic();
}
#[test]
fn test_dependency_chain() {
let mut container = Container::new();
let db = Arc::new(MockDatabase::new("postgres://localhost"));
container.register(db.clone());
let user_service = Arc::new(MockUserService::new(db));
container.register(user_service);
let resolved_db: Arc<MockDatabase> = container.resolve().unwrap();
let resolved_service: Arc<MockUserService> = container.resolve().unwrap();
assert_eq!(resolved_db.connection_string, "postgres://localhost");
assert_eq!(
resolved_service.db.connection_string,
"postgres://localhost"
);
}
#[test]
fn test_contains() {
let mut container = Container::new();
assert!(!container.contains::<MockDatabase>());
container.register_factory(|| MockDatabase::new("test"));
assert!(container.contains::<MockDatabase>());
}
#[test]
fn test_len_and_clear() {
let mut container = Container::new();
assert_eq!(container.len(), 0);
assert!(container.is_empty());
container.register_factory(|| MockDatabase::new("test"));
assert_eq!(container.len(), 1);
assert!(!container.is_empty());
container.clear();
assert_eq!(container.len(), 0);
assert!(container.is_empty());
}
}