Skip to main content

ruest/di/
container.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use super::{DiError, ProviderDescriptor, Scope};
6
7type InstanceMap = HashMap<TypeId, Arc<dyn Any + Send + Sync>>;
8
9/// Thread-safe DI container with **monomorphized** resolution via `get::<T>()`.
10///
11/// Prefer [`register_singleton`](Self::register_singleton) for services — no `dyn` factory on the hot path.
12#[derive(Clone, Default)]
13pub struct Container {
14    providers: Arc<RwLock<HashMap<TypeId, ProviderDescriptor>>>,
15    singletons: Arc<RwLock<InstanceMap>>,
16    resolving: Arc<RwLock<Vec<TypeId>>>,
17}
18
19impl Container {
20    pub fn new() -> Self {
21        Self::default()
22    }
23
24    /// Register a provider descriptor (advanced / async factories).
25    pub fn register<T: Send + Sync + 'static>(&self, descriptor: ProviderDescriptor) {
26        let type_id = TypeId::of::<T>();
27        self.providers
28            .write()
29            .expect("container providers lock poisoned")
30            .insert(type_id, descriptor);
31    }
32
33    /// Register a singleton directly (compile-time friendly, no factory trait object).
34    pub fn register_singleton<T: Send + Sync + 'static>(&self, instance: Arc<T>) {
35        let type_id = TypeId::of::<T>();
36        self.singletons
37            .write()
38            .expect("container singletons lock poisoned")
39            .insert(type_id, instance);
40    }
41
42    /// Register `T: Default` as a singleton (used by `#[service]` macro).
43    pub fn register_default<T: Default + Send + Sync + 'static>(&self) {
44        self.register_singleton(Arc::new(T::default()));
45    }
46
47    /// Register a pre-built singleton instance (alias).
48    pub fn register_instance<T: Send + Sync + 'static>(&self, instance: Arc<T>) {
49        self.register_singleton(instance);
50    }
51
52    /// Resolve type `T` — monomorphized at each call site (no runtime type name).
53    pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, DiError> {
54        let type_id = TypeId::of::<T>();
55
56        if let Some(existing) = self.singletons.read().expect("lock").get(&type_id) {
57            return downcast_arc::<T>(existing.clone());
58        }
59
60        let descriptor = self
61            .providers
62            .read()
63            .expect("lock")
64            .get(&type_id)
65            .cloned()
66            .ok_or(DiError::not_found::<T>())?;
67
68        self.guard_circular(type_id)?;
69
70        if descriptor.scope == Scope::Singleton {
71            if let Some(existing) = self.singletons.read().expect("lock").get(&type_id) {
72                self.resolving.write().expect("lock").pop();
73                return downcast_arc::<T>(existing.clone());
74            }
75        }
76
77        let instance = self.resolve_descriptor(&descriptor)?;
78        self.resolving.write().expect("lock").pop();
79
80        let arc = downcast_arc::<T>(instance)?;
81
82        if descriptor.scope == Scope::Singleton {
83            self.singletons
84                .write()
85                .expect("lock")
86                .insert(type_id, arc.clone());
87        }
88
89        Ok(arc)
90    }
91
92    /// Async resolution for providers with async factories.
93    pub async fn get_async<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, DiError> {
94        let type_id = TypeId::of::<T>();
95
96        if let Some(existing) = self.singletons.read().expect("lock").get(&type_id) {
97            return downcast_arc::<T>(existing.clone());
98        }
99
100        let descriptor = self
101            .providers
102            .read()
103            .expect("lock")
104            .get(&type_id)
105            .cloned()
106            .ok_or(DiError::not_found::<T>())?;
107
108        self.guard_circular(type_id)?;
109
110        let instance = if let Some(fut) = descriptor.factory.create_async(self) {
111            fut.await
112        } else if let Some(sync) = descriptor.factory.create_sync(self) {
113            sync
114        } else {
115            self.resolving.write().expect("lock").pop();
116            return Err(DiError::ResolutionFailed {
117                type_name: descriptor.type_name,
118                reason: "async factory required".into(),
119            });
120        };
121
122        self.resolving.write().expect("lock").pop();
123
124        let arc = downcast_arc::<T>(instance)?;
125
126        if descriptor.scope == Scope::Singleton {
127            self.singletons
128                .write()
129                .expect("lock")
130                .insert(type_id, arc.clone());
131        }
132
133        Ok(arc)
134    }
135
136    pub fn request_scope(&self) -> RequestScope<'_> {
137        RequestScope {
138            parent: self,
139            request_instances: RwLock::new(HashMap::new()),
140        }
141    }
142
143    fn guard_circular(&self, type_id: TypeId) -> Result<(), DiError> {
144        let mut resolving = self.resolving.write().expect("lock");
145        if resolving.contains(&type_id) {
146            return Err(DiError::CircularDependency(format!("{type_id:?}")));
147        }
148        resolving.push(type_id);
149        Ok(())
150    }
151
152    fn resolve_descriptor(
153        &self,
154        descriptor: &ProviderDescriptor,
155    ) -> Result<Arc<dyn Any + Send + Sync>, DiError> {
156        if let Some(instance) = descriptor.factory.create_sync(self) {
157            return Ok(instance);
158        }
159        Err(DiError::ResolutionFailed {
160            type_name: descriptor.type_name,
161            reason: "sync factory required; use get_async for async providers".into(),
162        })
163    }
164}
165
166/// Request-scoped resolution context.
167pub struct RequestScope<'a> {
168    parent: &'a Container,
169    request_instances: RwLock<InstanceMap>,
170}
171
172impl<'a> RequestScope<'a> {
173    pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, DiError> {
174        let type_id = TypeId::of::<T>();
175
176        if let Some(existing) = self.request_instances.read().expect("lock").get(&type_id) {
177            return downcast_arc::<T>(existing.clone());
178        }
179
180        let instance = self.parent.get::<T>()?;
181        self.request_instances
182            .write()
183            .expect("lock")
184            .insert(type_id, instance.clone());
185        downcast_arc::<T>(instance)
186    }
187}
188
189fn downcast_arc<T: Send + Sync + 'static>(
190    value: Arc<dyn Any + Send + Sync>,
191) -> Result<Arc<T>, DiError> {
192    Arc::downcast::<T>(value)
193        .map_err(|_| DiError::ResolutionFailed {
194            type_name: std::any::type_name::<T>(),
195            reason: "type mismatch in DI container".into(),
196        })
197}