Skip to main content

injectium_core/
container.rs

1use std::any::{TypeId, type_name};
2
3use cfg_block::cfg_block;
4
5use crate::provider::Provider;
6use crate::types::{AnyDyn, ErasedProvider, SyncBounds};
7
8cfg_block! {
9    #[cfg(feature = "validation")] {
10        /// A dependency declaration collected at link time via [`inventory`].
11        ///
12        /// Available when the `validation` feature is enabled.
13        ///
14        /// Each `#[derive(Injectable)]` struct registers one `DeclaredDependency`
15        /// entry per field type. [`Container::validate`] iterates all collected
16        /// entries at startup to confirm every required type is present.
17        ///
18        /// You rarely construct this manually; use `declare_dependency!` instead.
19        pub struct DeclaredDependency {
20            /// A function pointer returning the [`TypeId`] of the required type.
21            /// Stored as a fn pointer rather than a value because [`TypeId`] is not
22            /// usable in `const` contexts on stable/nightly without a feature flag.
23            pub type_id: fn() -> TypeId,
24            /// Human-readable name of the type, produced by `stringify!`.
25            pub type_name: &'static str,
26        }
27
28        inventory::collect!(DeclaredDependency);
29    }
30}
31
32struct RegistrationEntry {
33    type_id: TypeId,
34    provider: Box<ErasedProvider>,
35}
36
37#[inline]
38fn registrations_contain_type_id<'a>(
39    registrations: impl IntoIterator<Item = &'a RegistrationEntry>,
40    type_id: TypeId,
41) -> bool {
42    registrations
43        .into_iter()
44        .any(|entry| entry.type_id == type_id)
45}
46
47/// A runtime dependency-injection container.
48///
49/// `Container` stores one provider per type.
50///
51/// A provider can be a closure that constructs a fresh value, or an [`Arc`]
52/// that clones and returns shared state. Consumers use [`get`](Container::get)
53/// to retrieve an owned value regardless of how that provider is implemented.
54///
55/// Containers are built through [`ContainerBuilder`] (or the
56/// [`container!`](crate::container!) macro) and are typically wrapped in an
57/// [`Arc`](std::sync::Arc) for sharing across threads.
58///
59/// # Example
60///
61/// ```
62/// use injectium_core::{Container, container};
63///
64/// let c = container! {
65///     providers: [|_: &Container| 42_u32, |_: &Container| "hello"],
66/// };
67///
68/// assert_eq!(c.get::<u32>(), 42);
69/// assert_eq!(c.get::<&str>(), "hello");
70/// ```
71pub struct Container {
72    registrations: Box<[RegistrationEntry]>,
73}
74
75cfg_block! {
76    #[cfg(feature = "debug")] {
77        use std::fmt;
78
79        impl fmt::Debug for Container {
80            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81                f.debug_struct("Container")
82                    .field("providers", &self.provider_count())
83                    .finish()
84            }
85        }
86
87        impl fmt::Display for Container {
88            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89                write!(
90                    f,
91                    "Container ({} providers)",
92                    self.provider_count()
93                )
94            }
95        }
96    }
97}
98
99/// A builder for [`Container`].
100///
101/// Obtain one via [`Container::builder`] or the
102/// [`container!`](crate::container!) macro. All methods take `self` by value
103/// and return `Self`, enabling a fluent builder chain. Call
104/// [`build`](ContainerBuilder::build) to finalise.
105///
106/// # Example
107///
108/// ```
109/// use std::sync::Arc;
110///
111/// use injectium_core::Container;
112///
113/// let c = Container::builder()
114///     .provider(Arc::new(42_u32))
115///     .provider(|_: &Container| "hello")
116///     .build();
117///
118/// assert_eq!(c.get::<Arc<u32>>().as_ref(), &42);
119/// ```
120pub struct ContainerBuilder {
121    registrations: Vec<RegistrationEntry>,
122}
123
124impl Default for ContainerBuilder {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130impl ContainerBuilder {
131    /// Creates an empty builder with no registrations.
132    #[must_use]
133    pub fn new() -> Self {
134        Self::with_capacity(0)
135    }
136
137    /// Creates an empty builder with preallocated storage capacity for the
138    /// expected number of providers.
139    ///
140    /// These values are only allocation hints. Duplicate registrations are not
141    /// allowed, so the final number of stored registrations may still be
142    /// smaller than the requested capacity.
143    #[must_use]
144    pub fn with_capacity(capacity: usize) -> Self {
145        Self {
146            registrations: Vec::with_capacity(capacity),
147        }
148    }
149
150    #[inline]
151    fn contains_type_id(&self, type_id: TypeId) -> bool {
152        registrations_contain_type_id(self.registrations.iter(), type_id)
153    }
154
155    /// Returns `true` if a provider producing `T` is already registered in this
156    /// builder.
157    #[must_use]
158    pub fn contains<T: 'static>(&self) -> bool {
159        self.contains_type_id(TypeId::of::<T>())
160    }
161
162    /// Returns `true` if a provider of type `P` would conflict with an existing
163    /// registration in this builder.
164    #[must_use]
165    pub fn contains_provider<P>(&self) -> bool
166    where
167        P: Provider + SyncBounds,
168    {
169        self.contains::<P::Output>()
170    }
171
172    /// Registers a provider for values of type `T`.
173    ///
174    /// Providers are automatically supported for closures and [`Arc`] values.
175    ///
176    /// # Panics
177    ///
178    /// Panics if a provider for the same output type has already been
179    /// registered in this builder.
180    #[must_use]
181    pub fn provider<P>(self, provider: P) -> Self
182    where
183        P: Provider + SyncBounds,
184    {
185        let type_id = TypeId::of::<P::Output>();
186        assert!(
187            !self.contains_type_id(type_id),
188            "provider already registered for `{}`",
189            type_name::<P::Output>()
190        );
191
192        let provider =
193            move |container: &Container| -> Box<AnyDyn> { Box::new(provider.provide(container)) };
194
195        let mut registrations = self.registrations;
196        registrations.push(RegistrationEntry {
197            type_id,
198            provider: Box::new(provider),
199        });
200
201        Self { registrations }
202    }
203
204    /// Consumes the builder and returns the finished [`Container`].
205    #[must_use]
206    pub fn build(self) -> Container {
207        Container {
208            registrations: self.registrations.into_boxed_slice(),
209        }
210    }
211}
212
213impl Container {
214    #[inline]
215    fn provider_for(&self, type_id: TypeId) -> Option<&ErasedProvider> {
216        self.registrations
217            .iter()
218            .rev()
219            .find(|entry| entry.type_id == type_id)
220            .map(|entry| entry.provider.as_ref())
221    }
222
223    #[inline]
224    fn contains_type_id(&self, type_id: TypeId) -> bool {
225        registrations_contain_type_id(self.registrations.iter(), type_id)
226    }
227
228    #[inline]
229    fn cast_owned_unchecked<T: SyncBounds>(owned_erased: Box<AnyDyn>) -> T {
230        debug_assert!((*owned_erased).is::<T>());
231        let ptr = Box::into_raw(owned_erased).cast::<T>();
232
233        unsafe {
234            // SAFETY: `provider::<T>` stores providers under `TypeId::of::<T>()` and
235            // each stored provider returns `Box::new(provider.provide(c))` where the
236            // concrete value is exactly `T`. `try_get::<T>` uses the same key, so
237            // this cast is valid and ownership is preserved when reconstructing the box.
238            *Box::from_raw(ptr)
239        }
240    }
241
242    /// Returns a new [`ContainerBuilder`].
243    ///
244    /// Equivalent to [`ContainerBuilder::new`].
245    #[must_use]
246    pub fn builder() -> ContainerBuilder {
247        ContainerBuilder::new()
248    }
249
250    /// Returns a new [`ContainerBuilder`] with preallocated storage capacity.
251    #[must_use]
252    pub fn builder_with_capacity(capacity: usize) -> ContainerBuilder {
253        ContainerBuilder::with_capacity(capacity)
254    }
255
256    /// Returns the current value for type `T`.
257    ///
258    /// # Panics
259    ///
260    /// Panics if no provider of type `T` has been registered. Use
261    /// [`try_get`](Container::try_get) for a non-panicking alternative.
262    #[must_use]
263    pub fn get<T: SyncBounds>(&self) -> T {
264        self.try_get().expect("dependency not registered")
265    }
266
267    /// Returns the current value for type `T`, or `None` if no provider is
268    /// registered.
269    #[must_use]
270    pub fn try_get<T: SyncBounds>(&self) -> Option<T> {
271        let boxed = (self.provider_for(TypeId::of::<T>())?)(self);
272
273        Some(Self::cast_owned_unchecked::<T>(boxed))
274    }
275
276    /// Returns `true` if a provider is registered for `T`.
277    #[must_use]
278    pub fn contains<T: 'static>(&self) -> bool {
279        self.contains_type_id(TypeId::of::<T>())
280    }
281
282    /// Validates that every dependency declared via `declare_dependency!` is
283    /// registered in this container.
284    ///
285    /// Available when the `validation` feature is enabled.
286    ///
287    /// Intended to be called once at application startup, immediately after
288    /// the container is built. If any declared dependency is absent, the
289    /// method panics with a message that lists every missing type by name,
290    /// making misconfiguration easy to diagnose.
291    ///
292    /// `#[derive(Injectable)]` automatically calls `declare_dependency!` for
293    /// each field type, so this check covers all structs that use the derive
294    /// macro without any manual bookkeeping.
295    ///
296    /// # Panics
297    ///
298    /// Panics if one or more declared dependencies are not registered,
299    /// printing the names of all missing types.
300    #[cfg(feature = "validation")]
301    pub fn validate(&self) {
302        let mut missing: Vec<&'static str> = Vec::new();
303
304        for dep in inventory::iter::<DeclaredDependency> {
305            let type_id = (dep.type_id)();
306            let registered = self.contains_type_id(type_id);
307
308            if !registered {
309                missing.push(dep.type_name);
310            }
311        }
312
313        if !missing.is_empty() {
314            panic!(
315                "Container is missing {} declared dependenc{}: [{}]",
316                missing.len(),
317                if missing.len() == 1 { "y" } else { "ies" },
318                missing.join(", ")
319            );
320        }
321    }
322
323    /// Returns the number of registered providers.
324    #[must_use]
325    pub fn provider_count(&self) -> usize {
326        self.registrations.len()
327    }
328}