dfdi_core/
context.rs

1use std::{
2    any::{type_name, TypeId},
3    collections::HashMap,
4    marker::PhantomData,
5    ptr::NonNull,
6};
7
8use crate::{BindError, ProvideFn, Provider, Service, UnbindError};
9
10/// A context in which to store providers for services
11pub struct Context<'pcx> {
12    /// Map `Service` `TypeId`s to a type-erased provider
13    //
14    // Note: Unfortunately, https://github.com/rust-lang/rust/issues/10389 is an I-unsound bug to
15    // keep an eye on. TL;DR: TypeId hash collisions are possible and there have been some (obscure)
16    // examples of this in the past.
17    providers: HashMap<TypeId, DynProvider>,
18
19    /// Ensure that this context does not outlive its parent. This is required since we only want to
20    /// drop providers once, on the parent scope.
21    _phantom: PhantomData<&'pcx ()>,
22}
23
24// SAFETY:
25// - All providers must be Send
26unsafe impl Send for Context<'_> {}
27
28// SAFETY:
29// - All providers mustbe Sync
30unsafe impl Sync for Context<'_> {}
31
32impl Context<'_> {
33    /// Create an empty context
34    pub fn new() -> Self {
35        Self {
36            providers: HashMap::new(),
37            _phantom: PhantomData,
38        }
39    }
40
41    /// Create a sub-context
42    ///
43    /// The retuned context will contain the same elements as the parent context and any elements
44    /// added to the sub context will not be visible on the original. However, the underlying
45    /// providers that were added before this call are shared between the two contexts.
46    pub fn scoped(&self) -> Context<'_> {
47        // Notes:
48        // - We are cloning the pointers, not the underlying data
49        // - Provider expects a shared reference
50        // - DynProvider's clone implementation skips the drop function for clones
51        Context {
52            providers: self.providers.clone(),
53            _phantom: PhantomData,
54        }
55    }
56
57    /// Register a new provider for the service `S`
58    ///
59    /// # Panics
60    /// If the service binding fails. See [`try_bind_with`](Self::try_bind_with) for a fallible
61    /// version of this function.
62    #[track_caller]
63    pub fn bind_with<'cx, S: Service>(&'cx mut self, provider: impl Provider<'cx, S>) {
64        if let Err(err) = self.try_bind_with::<S>(provider) {
65            panic!("{}", err)
66        }
67    }
68
69    /// Register a function as a provider for the service `S`
70    ///
71    /// # Panics
72    /// If the service binding fails. See [`try_bind_fn`](Self::try_bind_fn) for a fallible version
73    /// of this function.
74    #[track_caller]
75    pub fn bind_fn<'cx, S: Service>(
76        &'cx mut self,
77        provider_fn: impl Fn(&'cx Context, S::Argument<'_>) -> S::Output<'cx> + Send + Sync + 'cx,
78    ) {
79        if let Err(err) = self.try_bind_fn::<S>(provider_fn) {
80            panic!("{}", err)
81        }
82    }
83
84    /// Bind the provider `P` to the service `S`
85    ///
86    /// # Panics
87    /// If the service binding fails. See [`try_bind`](Self::try_bind) for a fallible version of
88    /// this function.
89    #[track_caller]
90    pub fn bind<'cx, S, P>(&'cx mut self)
91    where
92        S: Service,
93        P: Provider<'cx, S> + Default,
94    {
95        if let Err(err) = self.try_bind::<S, P>() {
96            panic!("{}", err)
97        }
98    }
99
100    /// Delete the provider bound to the service `S`
101    ///
102    /// # Panics
103    /// If the service unbinding fails. See [`try_unbind`](Self::try_unbind) for a fallible version
104    /// of this function.
105    #[track_caller]
106    pub fn unbind<S>(&mut self)
107    where
108        S: Service,
109    {
110        if let Err(err) = self.try_unbind::<S>() {
111            panic!("{}", err)
112        }
113    }
114
115    /// Resolve the service `S` using the default service argument.
116    ///
117    /// # Panics
118    /// If no provider is registered for this service. See [`try_resolve`](Self::try_resolve) for a
119    /// fallible version of this function.
120    #[inline(always)]
121    #[track_caller]
122    pub fn resolve<S>(&self) -> S::Output<'_>
123    where
124        S: Service,
125        S::Argument<'static>: Default,
126    {
127        self.resolve_with::<S>(Default::default())
128    }
129
130    /// Resolve the service `S` given the service argument.
131    ///
132    /// # Panics
133    /// If no provider is registered for this service. See [`try_resolve`](Self::try_resolve) for a
134    /// fallible version of this function.
135    #[track_caller]
136    pub fn resolve_with<S>(&self, arg: S::Argument<'_>) -> S::Output<'_>
137    where
138        S: Service,
139    {
140        match self.try_resolve_with::<S>(arg) {
141            Some(s) => s,
142            None => panic!("no provider for service `{}`", type_name::<S>()),
143        }
144    }
145
146    /// Try to register a new provider for the service `S`
147    ///
148    /// # Fails
149    /// This function will fail if a provider is already bound to the service.
150    ///
151    /// See [`bind_with`](Self::bind_with) for the panicking version of this function.
152    pub fn try_bind_with<'cx, S: Service>(
153        &'cx mut self,
154        provider: impl Provider<'cx, S>,
155    ) -> Result<(), BindError> {
156        use std::collections::hash_map::Entry::*;
157        match self.providers.entry(TypeId::of::<S>()) {
158            Vacant(e) => {
159                // SAFETY:
160                // - Due to the api provided by `Context`, all clones of `DynProvider` _will_ be
161                //   dropped before the original instance is dropped
162                e.insert(unsafe { DynProvider::new(provider) });
163                Ok(())
164            }
165            Occupied(_) => Err(BindError::ServiceBound(std::any::type_name::<S>())),
166        }
167    }
168
169    /// Try to register a function as a provider for the service `S`
170    ///
171    /// # Fails
172    /// This function will fail if a provider is already bound to the service.
173    ///
174    /// See [`bind_fn`](Self::bind_fn) for the panicking version of this function.
175    #[inline(always)]
176    pub fn try_bind_fn<'cx, S: Service>(
177        &'cx mut self,
178        provider_fn: impl Fn(&'cx Context, S::Argument<'_>) -> S::Output<'cx> + Send + Sync + 'cx,
179    ) -> Result<(), BindError> {
180        self.try_bind_with::<S>(provider_fn)
181    }
182
183    /// Try to bind the provider `P` to the service `S`
184    ///
185    /// # Fails
186    /// This function will fail if a provider is already bound to the service.
187    ///
188    /// See [`bind`](Self::bind) for the panicking version of this function.
189    #[inline(always)]
190    pub fn try_bind<'cx, S, P>(&'cx mut self) -> Result<(), BindError>
191    where
192        S: Service,
193        P: Provider<'cx, S> + Default,
194    {
195        self.try_bind_with(P::default())
196    }
197
198    /// Try to delete the provider bound to the service `S`.
199    ///
200    /// # Fails
201    /// This function will fail if no provider is bound to the service.
202    ///
203    /// See [`unbind`](Self::unbind) for the panicking version of this function.
204    pub fn try_unbind<S>(&mut self) -> Result<(), UnbindError>
205    where
206        S: Service,
207    {
208        match self.providers.remove(&TypeId::of::<S>()) {
209            Some(_) => Ok(()),
210            None => Err(UnbindError::ServiceUnbound(type_name::<S>())),
211        }
212    }
213
214    /// Try to resolve the service `S` using the default service argument.
215    ///
216    /// # Fails
217    /// This function will fail if no provider is bound to the service.
218    ///
219    /// See [`resolve`](Self::resolve) for the panicking version of this function.
220    #[inline(always)]
221    pub fn try_resolve<S>(&self) -> Option<S::Output<'_>>
222    where
223        S: Service,
224        S::Argument<'static>: Default,
225    {
226        self.try_resolve_with::<S>(Default::default())
227    }
228
229    /// Try to resolve the service `S` given the service argument.
230    ///
231    /// # Fails
232    /// This function will fail if no provider is bound to the service.
233    ///
234    /// See [`resolve_with`](Self::resolve_with) for the panicking version of this function.
235    pub fn try_resolve_with<S>(&self, arg: S::Argument<'_>) -> Option<S::Output<'_>>
236    where
237        S: Service,
238    {
239        let provider = self.providers.get(&TypeId::of::<S>())?;
240
241        // SAFETY:
242        // - We know that the provider was created for the service `S`, since it came from the
243        //   `self.providers` map
244        Some(unsafe { provider.provide::<S>(self, arg) })
245    }
246}
247
248impl Default for Context<'_> {
249    #[inline(always)]
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255struct DynProvider {
256    /// Type-erased pointer to the underlying provider data
257    this: NonNull<()>,
258
259    /// Type-erased function pointer to the provider's `provide` implementation
260    provide_fn: NonNull<()>,
261
262    /// Pointer to the provider's `drop` implementation
263    //
264    // SAFETY:
265    // - Must only be called with a valid `self.this` pointer
266    drop_fn: Option<unsafe fn(*mut ())>,
267}
268
269impl DynProvider {
270    /// Create a `DynProvider` for the service `S`
271    ///
272    /// SAFETY:
273    /// - This instance must live as long as all of its clones
274    unsafe fn new<'cx, S, P>(provider: P) -> Self
275    where
276        S: Service,
277        P: Provider<'cx, S>,
278    {
279        unsafe fn drop_provider<P>(this: *mut ()) {
280            std::mem::drop(Box::from_raw(this as *mut P));
281        }
282
283        // Create a pointer to a specialized `drop` function and store it.
284        let drop_fn = Some(drop_provider::<P> as _);
285
286        // Get the P::provide function pointer and store a type-erased version of it
287        //
288        // SAFETY:
289        // - fn pointers are always non-null
290        let provide_fn = unsafe { NonNull::new_unchecked(P::provide as fn(_, _, _) -> _ as _) };
291
292        // Create the `this` pointer.
293        //
294        // SAFETY:
295        // - A `Box`'s pointer is always guaranteed to be non-null
296        let this = unsafe { NonNull::new_unchecked(Box::into_raw(Box::new(provider)) as *mut _) };
297
298        Self {
299            this,
300            drop_fn,
301            provide_fn,
302        }
303    }
304
305    /// Run the provider
306    ///
307    /// SAFETY:
308    /// - The `DynProvider` was created for the service `S`
309    unsafe fn provide<'cx, S>(&'cx self, cx: &'cx Context, arg: S::Argument<'_>) -> S::Output<'cx>
310    where
311        S: Service,
312    {
313        let this = self.this.as_ptr() as *const ();
314        let provide_fn: ProvideFn<'cx, S> = std::mem::transmute(self.provide_fn);
315
316        provide_fn(this, cx, arg)
317    }
318}
319
320impl Clone for DynProvider {
321    fn clone(&self) -> Self {
322        Self {
323            this: self.this,
324            provide_fn: self.provide_fn,
325            drop_fn: None, // drop should only run on the original instance
326        }
327    }
328}
329
330impl Drop for DynProvider {
331    fn drop(&mut self) {
332        if let Some(drop_fn) = self.drop_fn {
333            // SAFETY:
334            // - `drop_fn` can only be called with `self.this`, which it is.
335            // - We know drop has not been called because of the safety guarantees on new(), which
336            //   means that `self.this` points to valid memory.
337            unsafe { (drop_fn)(self.this.as_ptr()) }
338        }
339    }
340}