nidus_core/container/
request_scope.rs1use std::{
2 any::{Any, TypeId, type_name},
3 collections::HashMap,
4 sync::{Arc, Condvar, Mutex, MutexGuard},
5};
6
7use crate::{
8 Container, Inject, NidusError, Optional, ProviderLifetime, Result, Scoped, resolution,
9};
10
11use super::downcast;
12
13pub struct RequestScope<'a> {
15 container: RequestScopeContainer<'a>,
16 request_instances: Mutex<HashMap<TypeId, RequestInstanceState>>,
17 request_instance_ready: Condvar,
18}
19
20enum RequestInstanceState {
21 Initializing,
22 Ready(Arc<dyn Any + Send + Sync>),
23}
24
25pub type SharedRequestScope = Arc<RequestScope<'static>>;
27
28enum RequestScopeContainer<'a> {
29 Borrowed(&'a Container),
30 Shared(Arc<Container>),
31}
32
33impl RequestScopeContainer<'_> {
34 fn as_ref(&self) -> &Container {
35 match self {
36 Self::Borrowed(container) => container,
37 Self::Shared(container) => container,
38 }
39 }
40}
41
42impl<'a> RequestScope<'a> {
43 pub(super) fn borrowed(container: &'a Container) -> Self {
44 Self {
45 container: RequestScopeContainer::Borrowed(container),
46 request_instances: Mutex::new(HashMap::new()),
47 request_instance_ready: Condvar::new(),
48 }
49 }
50}
51
52impl RequestScope<'_> {
53 pub(crate) fn container(&self) -> &Container {
54 self.container.as_ref()
55 }
56
57 pub fn resolve<T>(&self) -> Result<Arc<T>>
59 where
60 T: Send + Sync + 'static,
61 {
62 let entry = self.container().entry::<T>()?;
63 let erased = match entry.lifetime() {
64 ProviderLifetime::Request => {
65 let type_id = TypeId::of::<T>();
66 self.resolve_request_instance(type_id, type_name::<T>(), || {
67 entry.resolve_erased_in_scope(self)
68 })?
69 }
70 ProviderLifetime::Singleton | ProviderLifetime::Transient => {
71 entry.resolve_erased(self.container())?
72 }
73 };
74
75 downcast::<T>(erased)
76 }
77
78 pub fn inject<T>(&self) -> Result<Inject<T>>
80 where
81 T: Send + Sync + 'static,
82 {
83 self.resolve::<T>().map(Inject::new)
84 }
85
86 pub fn optional<T>(&self) -> Result<Optional<T>>
91 where
92 T: Send + Sync + 'static,
93 {
94 match self.inject::<T>() {
95 Ok(value) => Ok(Optional::new(Some(value))),
96 Err(NidusError::MissingProvider { .. }) => Ok(Optional::new(None)),
97 Err(error) => Err(error),
98 }
99 }
100
101 pub fn scoped<T>(&self) -> Result<Scoped<T>>
103 where
104 T: Send + Sync + 'static,
105 {
106 self.inject::<T>().map(Scoped::new)
107 }
108
109 fn resolve_request_instance(
110 &self,
111 type_id: TypeId,
112 type_name: &'static str,
113 create: impl FnOnce() -> Result<Arc<dyn Any + Send + Sync>>,
114 ) -> Result<Arc<dyn Any + Send + Sync>> {
115 let mut create = Some(create);
116 loop {
117 let mut instances = lock_unpoisoned(&self.request_instances);
118 match instances.get(&type_id) {
119 Some(RequestInstanceState::Ready(instance)) => return Ok(Arc::clone(instance)),
120 Some(RequestInstanceState::Initializing) => {
121 if resolution::is_active(type_id) {
122 return Err(NidusError::CircularProviderResolution { type_name });
123 }
124 drop(wait_unpoisoned(&self.request_instance_ready, instances));
125 }
126 None => {
127 let _guard = resolution::enter(type_id, type_name)?;
128 instances.insert(type_id, RequestInstanceState::Initializing);
129 drop(instances);
130
131 let initializer = create
132 .take()
133 .expect("request instance factory can only be used by initializer");
134 let instance = initializer();
135 let mut instances = lock_unpoisoned(&self.request_instances);
136 match instance {
137 Ok(instance) => {
138 instances.insert(
139 type_id,
140 RequestInstanceState::Ready(Arc::clone(&instance)),
141 );
142 self.request_instance_ready.notify_all();
143 return Ok(instance);
144 }
145 Err(error) => {
146 instances.remove(&type_id);
147 self.request_instance_ready.notify_all();
148 return Err(error);
149 }
150 }
151 }
152 }
153 }
154 }
155}
156
157impl RequestScope<'static> {
158 pub fn from_shared_container(container: Arc<Container>) -> Self {
160 Self {
161 container: RequestScopeContainer::Shared(container),
162 request_instances: Mutex::new(HashMap::new()),
163 request_instance_ready: Condvar::new(),
164 }
165 }
166}
167
168fn lock_unpoisoned<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
169 mutex
170 .lock()
171 .unwrap_or_else(|poisoned| poisoned.into_inner())
172}
173
174fn wait_unpoisoned<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
175 condvar
176 .wait(guard)
177 .unwrap_or_else(|poisoned| poisoned.into_inner())
178}
179
180#[cfg(test)]
181mod tests {
182 use std::{sync::Arc, thread};
183
184 use super::RequestScope;
185 use crate::Container;
186
187 #[derive(Debug, Eq, PartialEq)]
188 struct RequestValue(u64);
189
190 #[test]
191 fn request_scope_recovers_from_poisoned_instance_cache() {
192 let mut container = Container::new();
193 container
194 .register_request_scoped::<RequestValue, _>(|_scope| Ok(RequestValue(42)))
195 .unwrap();
196 let scope = Arc::new(RequestScope::from_shared_container(Arc::new(container)));
197 let poisoned_scope = Arc::clone(&scope);
198
199 let panic = thread::spawn(move || {
200 let _instances = poisoned_scope.request_instances.lock().unwrap();
201 panic!("poison request scope cache");
202 });
203 assert!(panic.join().is_err());
204
205 assert_eq!(*scope.resolve::<RequestValue>().unwrap(), RequestValue(42));
206 }
207}