dfdi_core/context.rs
1use std::{
2 any::{type_name, TypeId},
3 collections::HashMap,
4 marker::PhantomData,
5 ptr::NonNull,
6};
7
8use crate::{BindError, ProvideFn, Provider, Service, UnbindError};
9
10/// A context in which to store providers for services
11pub struct Context<'pcx> {
12 /// Map `Service` `TypeId`s to a type-erased provider
13 //
14 // Note: Unfortunately, https://github.com/rust-lang/rust/issues/10389 is an I-unsound bug to
15 // keep an eye on. TL;DR: TypeId hash collisions are possible and there have been some (obscure)
16 // examples of this in the past.
17 providers: HashMap<TypeId, DynProvider>,
18
19 /// Ensure that this context does not outlive its parent. This is required since we only want to
20 /// drop providers once, on the parent scope.
21 _phantom: PhantomData<&'pcx ()>,
22}
23
24// SAFETY:
25// - All providers must be Send
26unsafe impl Send for Context<'_> {}
27
28// SAFETY:
29// - All providers mustbe Sync
30unsafe impl Sync for Context<'_> {}
31
32impl Context<'_> {
33 /// Create an empty context
34 pub fn new() -> Self {
35 Self {
36 providers: HashMap::new(),
37 _phantom: PhantomData,
38 }
39 }
40
41 /// Create a sub-context
42 ///
43 /// The retuned context will contain the same elements as the parent context and any elements
44 /// added to the sub context will not be visible on the original. However, the underlying
45 /// providers that were added before this call are shared between the two contexts.
46 pub fn scoped(&self) -> Context<'_> {
47 // Notes:
48 // - We are cloning the pointers, not the underlying data
49 // - Provider expects a shared reference
50 // - DynProvider's clone implementation skips the drop function for clones
51 Context {
52 providers: self.providers.clone(),
53 _phantom: PhantomData,
54 }
55 }
56
57 /// Register a new provider for the service `S`
58 ///
59 /// # Panics
60 /// If the service binding fails. See [`try_bind_with`](Self::try_bind_with) for a fallible
61 /// version of this function.
62 #[track_caller]
63 pub fn bind_with<'cx, S: Service>(&'cx mut self, provider: impl Provider<'cx, S>) {
64 if let Err(err) = self.try_bind_with::<S>(provider) {
65 panic!("{}", err)
66 }
67 }
68
69 /// Register a function as a provider for the service `S`
70 ///
71 /// # Panics
72 /// If the service binding fails. See [`try_bind_fn`](Self::try_bind_fn) for a fallible version
73 /// of this function.
74 #[track_caller]
75 pub fn bind_fn<'cx, S: Service>(
76 &'cx mut self,
77 provider_fn: impl Fn(&'cx Context, S::Argument<'_>) -> S::Output<'cx> + Send + Sync + 'cx,
78 ) {
79 if let Err(err) = self.try_bind_fn::<S>(provider_fn) {
80 panic!("{}", err)
81 }
82 }
83
84 /// Bind the provider `P` to the service `S`
85 ///
86 /// # Panics
87 /// If the service binding fails. See [`try_bind`](Self::try_bind) for a fallible version of
88 /// this function.
89 #[track_caller]
90 pub fn bind<'cx, S, P>(&'cx mut self)
91 where
92 S: Service,
93 P: Provider<'cx, S> + Default,
94 {
95 if let Err(err) = self.try_bind::<S, P>() {
96 panic!("{}", err)
97 }
98 }
99
100 /// Delete the provider bound to the service `S`
101 ///
102 /// # Panics
103 /// If the service unbinding fails. See [`try_unbind`](Self::try_unbind) for a fallible version
104 /// of this function.
105 #[track_caller]
106 pub fn unbind<S>(&mut self)
107 where
108 S: Service,
109 {
110 if let Err(err) = self.try_unbind::<S>() {
111 panic!("{}", err)
112 }
113 }
114
115 /// Resolve the service `S` using the default service argument.
116 ///
117 /// # Panics
118 /// If no provider is registered for this service. See [`try_resolve`](Self::try_resolve) for a
119 /// fallible version of this function.
120 #[inline(always)]
121 #[track_caller]
122 pub fn resolve<S>(&self) -> S::Output<'_>
123 where
124 S: Service,
125 S::Argument<'static>: Default,
126 {
127 self.resolve_with::<S>(Default::default())
128 }
129
130 /// Resolve the service `S` given the service argument.
131 ///
132 /// # Panics
133 /// If no provider is registered for this service. See [`try_resolve`](Self::try_resolve) for a
134 /// fallible version of this function.
135 #[track_caller]
136 pub fn resolve_with<S>(&self, arg: S::Argument<'_>) -> S::Output<'_>
137 where
138 S: Service,
139 {
140 match self.try_resolve_with::<S>(arg) {
141 Some(s) => s,
142 None => panic!("no provider for service `{}`", type_name::<S>()),
143 }
144 }
145
146 /// Try to register a new provider for the service `S`
147 ///
148 /// # Fails
149 /// This function will fail if a provider is already bound to the service.
150 ///
151 /// See [`bind_with`](Self::bind_with) for the panicking version of this function.
152 pub fn try_bind_with<'cx, S: Service>(
153 &'cx mut self,
154 provider: impl Provider<'cx, S>,
155 ) -> Result<(), BindError> {
156 use std::collections::hash_map::Entry::*;
157 match self.providers.entry(TypeId::of::<S>()) {
158 Vacant(e) => {
159 // SAFETY:
160 // - Due to the api provided by `Context`, all clones of `DynProvider` _will_ be
161 // dropped before the original instance is dropped
162 e.insert(unsafe { DynProvider::new(provider) });
163 Ok(())
164 }
165 Occupied(_) => Err(BindError::ServiceBound(std::any::type_name::<S>())),
166 }
167 }
168
169 /// Try to register a function as a provider for the service `S`
170 ///
171 /// # Fails
172 /// This function will fail if a provider is already bound to the service.
173 ///
174 /// See [`bind_fn`](Self::bind_fn) for the panicking version of this function.
175 #[inline(always)]
176 pub fn try_bind_fn<'cx, S: Service>(
177 &'cx mut self,
178 provider_fn: impl Fn(&'cx Context, S::Argument<'_>) -> S::Output<'cx> + Send + Sync + 'cx,
179 ) -> Result<(), BindError> {
180 self.try_bind_with::<S>(provider_fn)
181 }
182
183 /// Try to bind the provider `P` to the service `S`
184 ///
185 /// # Fails
186 /// This function will fail if a provider is already bound to the service.
187 ///
188 /// See [`bind`](Self::bind) for the panicking version of this function.
189 #[inline(always)]
190 pub fn try_bind<'cx, S, P>(&'cx mut self) -> Result<(), BindError>
191 where
192 S: Service,
193 P: Provider<'cx, S> + Default,
194 {
195 self.try_bind_with(P::default())
196 }
197
198 /// Try to delete the provider bound to the service `S`.
199 ///
200 /// # Fails
201 /// This function will fail if no provider is bound to the service.
202 ///
203 /// See [`unbind`](Self::unbind) for the panicking version of this function.
204 pub fn try_unbind<S>(&mut self) -> Result<(), UnbindError>
205 where
206 S: Service,
207 {
208 match self.providers.remove(&TypeId::of::<S>()) {
209 Some(_) => Ok(()),
210 None => Err(UnbindError::ServiceUnbound(type_name::<S>())),
211 }
212 }
213
214 /// Try to resolve the service `S` using the default service argument.
215 ///
216 /// # Fails
217 /// This function will fail if no provider is bound to the service.
218 ///
219 /// See [`resolve`](Self::resolve) for the panicking version of this function.
220 #[inline(always)]
221 pub fn try_resolve<S>(&self) -> Option<S::Output<'_>>
222 where
223 S: Service,
224 S::Argument<'static>: Default,
225 {
226 self.try_resolve_with::<S>(Default::default())
227 }
228
229 /// Try to resolve the service `S` given the service argument.
230 ///
231 /// # Fails
232 /// This function will fail if no provider is bound to the service.
233 ///
234 /// See [`resolve_with`](Self::resolve_with) for the panicking version of this function.
235 pub fn try_resolve_with<S>(&self, arg: S::Argument<'_>) -> Option<S::Output<'_>>
236 where
237 S: Service,
238 {
239 let provider = self.providers.get(&TypeId::of::<S>())?;
240
241 // SAFETY:
242 // - We know that the provider was created for the service `S`, since it came from the
243 // `self.providers` map
244 Some(unsafe { provider.provide::<S>(self, arg) })
245 }
246}
247
248impl Default for Context<'_> {
249 #[inline(always)]
250 fn default() -> Self {
251 Self::new()
252 }
253}
254
255struct DynProvider {
256 /// Type-erased pointer to the underlying provider data
257 this: NonNull<()>,
258
259 /// Type-erased function pointer to the provider's `provide` implementation
260 provide_fn: NonNull<()>,
261
262 /// Pointer to the provider's `drop` implementation
263 //
264 // SAFETY:
265 // - Must only be called with a valid `self.this` pointer
266 drop_fn: Option<unsafe fn(*mut ())>,
267}
268
269impl DynProvider {
270 /// Create a `DynProvider` for the service `S`
271 ///
272 /// SAFETY:
273 /// - This instance must live as long as all of its clones
274 unsafe fn new<'cx, S, P>(provider: P) -> Self
275 where
276 S: Service,
277 P: Provider<'cx, S>,
278 {
279 unsafe fn drop_provider<P>(this: *mut ()) {
280 std::mem::drop(Box::from_raw(this as *mut P));
281 }
282
283 // Create a pointer to a specialized `drop` function and store it.
284 let drop_fn = Some(drop_provider::<P> as _);
285
286 // Get the P::provide function pointer and store a type-erased version of it
287 //
288 // SAFETY:
289 // - fn pointers are always non-null
290 let provide_fn = unsafe { NonNull::new_unchecked(P::provide as fn(_, _, _) -> _ as _) };
291
292 // Create the `this` pointer.
293 //
294 // SAFETY:
295 // - A `Box`'s pointer is always guaranteed to be non-null
296 let this = unsafe { NonNull::new_unchecked(Box::into_raw(Box::new(provider)) as *mut _) };
297
298 Self {
299 this,
300 drop_fn,
301 provide_fn,
302 }
303 }
304
305 /// Run the provider
306 ///
307 /// SAFETY:
308 /// - The `DynProvider` was created for the service `S`
309 unsafe fn provide<'cx, S>(&'cx self, cx: &'cx Context, arg: S::Argument<'_>) -> S::Output<'cx>
310 where
311 S: Service,
312 {
313 let this = self.this.as_ptr() as *const ();
314 let provide_fn: ProvideFn<'cx, S> = std::mem::transmute(self.provide_fn);
315
316 provide_fn(this, cx, arg)
317 }
318}
319
320impl Clone for DynProvider {
321 fn clone(&self) -> Self {
322 Self {
323 this: self.this,
324 provide_fn: self.provide_fn,
325 drop_fn: None, // drop should only run on the original instance
326 }
327 }
328}
329
330impl Drop for DynProvider {
331 fn drop(&mut self) {
332 if let Some(drop_fn) = self.drop_fn {
333 // SAFETY:
334 // - `drop_fn` can only be called with `self.this`, which it is.
335 // - We know drop has not been called because of the safety guarantees on new(), which
336 // means that `self.this` points to valid memory.
337 unsafe { (drop_fn)(self.this.as_ptr()) }
338 }
339 }
340}