Skip to main content

nidus_core/container/
request_scope.rs

1use std::{
2    any::{Any, TypeId, type_name},
3    collections::HashMap,
4    sync::{Arc, Condvar, Mutex, MutexGuard},
5};
6
7use crate::{
8    Container, Inject, NidusError, Optional, ProviderLifetime, Result, Scoped, resolution,
9};
10
11use super::downcast;
12
13/// Per-request dependency scope.
14pub struct RequestScope<'a> {
15    container: RequestScopeContainer<'a>,
16    request_instances: Mutex<HashMap<TypeId, RequestInstanceState>>,
17    request_instance_ready: Condvar,
18}
19
20enum RequestInstanceState {
21    Initializing,
22    Ready(Arc<dyn Any + Send + Sync>),
23}
24
25/// Shared request scope handle suitable for HTTP request extensions.
26pub type SharedRequestScope = Arc<RequestScope<'static>>;
27
28enum RequestScopeContainer<'a> {
29    Borrowed(&'a Container),
30    Shared(Arc<Container>),
31}
32
33impl RequestScopeContainer<'_> {
34    fn as_ref(&self) -> &Container {
35        match self {
36            Self::Borrowed(container) => container,
37            Self::Shared(container) => container,
38        }
39    }
40}
41
42impl<'a> RequestScope<'a> {
43    pub(super) fn borrowed(container: &'a Container) -> Self {
44        Self {
45            container: RequestScopeContainer::Borrowed(container),
46            request_instances: Mutex::new(HashMap::new()),
47            request_instance_ready: Condvar::new(),
48        }
49    }
50}
51
52impl RequestScope<'_> {
53    pub(crate) fn container(&self) -> &Container {
54        self.container.as_ref()
55    }
56
57    /// Resolves a dependency in this request scope.
58    pub fn resolve<T>(&self) -> Result<Arc<T>>
59    where
60        T: Send + Sync + 'static,
61    {
62        let entry = self.container().entry::<T>()?;
63        let erased = match entry.lifetime() {
64            ProviderLifetime::Request => {
65                let type_id = TypeId::of::<T>();
66                self.resolve_request_instance(type_id, type_name::<T>(), || {
67                    entry.resolve_erased_in_scope(self)
68                })?
69            }
70            ProviderLifetime::Singleton | ProviderLifetime::Transient => {
71                entry.resolve_erased(self.container())?
72            }
73        };
74
75        downcast::<T>(erased)
76    }
77
78    /// Resolves a typed dependency reference in this request scope.
79    pub fn inject<T>(&self) -> Result<Inject<T>>
80    where
81        T: Send + Sync + 'static,
82    {
83        self.resolve::<T>().map(Inject::new)
84    }
85
86    /// Resolves an optional typed dependency reference in this request scope.
87    ///
88    /// Missing providers become `Optional::new(None)`, while registered providers
89    /// that fail to construct still return their original error.
90    pub fn optional<T>(&self) -> Result<Optional<T>>
91    where
92        T: Send + Sync + 'static,
93    {
94        match self.inject::<T>() {
95            Ok(value) => Ok(Optional::new(Some(value))),
96            Err(NidusError::MissingProvider { .. }) => Ok(Optional::new(None)),
97            Err(error) => Err(error),
98        }
99    }
100
101    /// Resolves a request-scoped dependency wrapper in this request scope.
102    pub fn scoped<T>(&self) -> Result<Scoped<T>>
103    where
104        T: Send + Sync + 'static,
105    {
106        self.inject::<T>().map(Scoped::new)
107    }
108
109    fn resolve_request_instance(
110        &self,
111        type_id: TypeId,
112        type_name: &'static str,
113        create: impl FnOnce() -> Result<Arc<dyn Any + Send + Sync>>,
114    ) -> Result<Arc<dyn Any + Send + Sync>> {
115        let mut create = Some(create);
116        loop {
117            let mut instances = lock_unpoisoned(&self.request_instances);
118            match instances.get(&type_id) {
119                Some(RequestInstanceState::Ready(instance)) => return Ok(Arc::clone(instance)),
120                Some(RequestInstanceState::Initializing) => {
121                    if resolution::is_active(type_id) {
122                        return Err(NidusError::CircularProviderResolution { type_name });
123                    }
124                    drop(wait_unpoisoned(&self.request_instance_ready, instances));
125                }
126                None => {
127                    let _guard = resolution::enter(type_id, type_name)?;
128                    instances.insert(type_id, RequestInstanceState::Initializing);
129                    drop(instances);
130
131                    let initializer = create
132                        .take()
133                        .expect("request instance factory can only be used by initializer");
134                    let instance = initializer();
135                    let mut instances = lock_unpoisoned(&self.request_instances);
136                    match instance {
137                        Ok(instance) => {
138                            instances.insert(
139                                type_id,
140                                RequestInstanceState::Ready(Arc::clone(&instance)),
141                            );
142                            self.request_instance_ready.notify_all();
143                            return Ok(instance);
144                        }
145                        Err(error) => {
146                            instances.remove(&type_id);
147                            self.request_instance_ready.notify_all();
148                            return Err(error);
149                        }
150                    }
151                }
152            }
153        }
154    }
155}
156
157impl RequestScope<'static> {
158    /// Creates a request scope that owns a shared container handle.
159    pub fn from_shared_container(container: Arc<Container>) -> Self {
160        Self {
161            container: RequestScopeContainer::Shared(container),
162            request_instances: Mutex::new(HashMap::new()),
163            request_instance_ready: Condvar::new(),
164        }
165    }
166}
167
168fn lock_unpoisoned<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
169    mutex
170        .lock()
171        .unwrap_or_else(|poisoned| poisoned.into_inner())
172}
173
174fn wait_unpoisoned<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
175    condvar
176        .wait(guard)
177        .unwrap_or_else(|poisoned| poisoned.into_inner())
178}
179
180#[cfg(test)]
181mod tests {
182    use std::{sync::Arc, thread};
183
184    use super::RequestScope;
185    use crate::Container;
186
187    #[derive(Debug, Eq, PartialEq)]
188    struct RequestValue(u64);
189
190    #[test]
191    fn request_scope_recovers_from_poisoned_instance_cache() {
192        let mut container = Container::new();
193        container
194            .register_request_scoped::<RequestValue, _>(|_scope| Ok(RequestValue(42)))
195            .unwrap();
196        let scope = Arc::new(RequestScope::from_shared_container(Arc::new(container)));
197        let poisoned_scope = Arc::clone(&scope);
198
199        let panic = thread::spawn(move || {
200            let _instances = poisoned_scope.request_instances.lock().unwrap();
201            panic!("poison request scope cache");
202        });
203        assert!(panic.join().is_err());
204
205        assert_eq!(*scope.resolve::<RequestValue>().unwrap(), RequestValue(42));
206    }
207}