async_injector/
lib.rs

1//! [<img alt="github" src="https://img.shields.io/badge/github-udoprog/async--injector-8da0cb?style=for-the-badge&logo=github" height="20">](https://github.com/udoprog/async-injector)
2//! [<img alt="crates.io" src="https://img.shields.io/crates/v/async-injector.svg?style=for-the-badge&color=fc8d62&logo=rust" height="20">](https://crates.io/crates/async-injector)
3//! [<img alt="docs.rs" src="https://img.shields.io/badge/docs.rs-async--injector-66c2a5?style=for-the-badge&logoColor=white&logo=" height="20">](https://docs.rs/async-injector)
4//!
5//! Asynchronous dependency injection for Rust.
6//!
7//! This library provides the glue which allows for building robust decoupled
8//! applications that can be reconfigured dynamically while they are running.
9//!
10//! For a real world example of how this is used, see [`OxidizeBot`] for which
11//! it was written.
12//!
13//! <br>
14//!
15//! ## Usage
16//!
17//! Add `async-injector` to your `Cargo.toml`.
18//!
19//! ```toml
20//! [dependencies]
21//! async-injector = "0.19.4"
22//! ```
23//!
24//! <br>
25//!
26//! ## Example
27//!
28//! In the following we'll showcase the injection of a *fake* `Database`. The
29//! idea here would be that if something about the database connection changes,
30//! a new instance of `Database` would be created and cause the application to
31//! reconfigure itself.
32//!
33//! ```rust
34//! use async_injector::{Key, Injector, Provider};
35//!
36//! #[derive(Debug, Clone)]
37//! struct Database;
38//!
39//! #[derive(Debug, Provider)]
40//! struct Service {
41//!     #[dependency]
42//!     database: Database,
43//! }
44//!
45//! async fn service(injector: Injector) -> Result<(), Box<dyn std::error::Error>> {
46//!     let mut provider = Service::provider(&injector).await?;
47//!
48//!     let Service { database } = provider.wait().await;
49//!     println!("Service got initial database {database:?}!");
50//!
51//!     let Service { database } = provider.wait().await;
52//!     println!("Service got new database {database:?}!");
53//!
54//!     Ok(())
55//! }
56//! ```
57//!
58//! > **Note:** This is available as the `database` example:
59//! > ```sh
60//! > cargo run --example database
61//! > ```
62//!
63//! The [`Injector`] above provides a structured broadcasting system that allows
64//! for configuration updates to be cleanly integrated into asynchronous
65//! contexts. The update itself is triggered by some other component that is
66//! responsible for constructing the `Database` instance.
67//!
68//! Building up the components of your application like this means that it can
69//! be reconfigured without restarting it. Providing a much richer user
70//! experience.
71//!
72//! <br>
73//!
74//! ## Injecting multiple things of the same type
75//!
76//! In the previous section you might've noticed that the injected value was
77//! solely discriminated by its type: `Database`. In this example we'll show how
78//! [`Key`] can be used to *tag* values of the same type with different names to
79//! discriminate them. This can be useful when dealing with overly generic types
80//! like [`String`].
81//!
82//! The tag used must be serializable with [`serde`]. It must also not use any
83//! components which [cannot be hashed], like `f32` and `f64`.
84//!
85//! <br>
86//!
87//! ### A simple greeter
88//!
89//! The following example showcases the use of `Key` to injector two different
90//! values into an asynchronous `greeter`.
91//!
92//! ```rust,no_run
93//! use async_injector::{Key, Injector};
94//!
95//! async fn greeter(injector: Injector) -> Result<(), Box<dyn std::error::Error>> {
96//!     let name = Key::<String>::tagged("name")?;
97//!     let fun = Key::<String>::tagged("fun")?;
98//!
99//!     let (mut name_stream, mut name) = injector.stream_key(name).await;
100//!     let (mut fun_stream, mut fun) = injector.stream_key(fun).await;
101//!
102//!     loop {
103//!         tokio::select! {
104//!             update = name_stream.recv() => {
105//!                 name = update;
106//!             }
107//!             update = fun_stream.recv() => {
108//!                 fun = update;
109//!             }
110//!         }
111//!
112//!         let (Some(name), Some(fun)) = (&name, &fun) else {
113//!             continue;
114//!         };
115//!
116//!         println!("Hi {name}! I see you do \"{fun}\" for fun!");
117//!         return Ok(());
118//!     }
119//! }
120//! ```
121//!
122//! > **Note:** you can run this using:
123//! > ```sh
124//! > cargo run --example greeter
125//! > ```
126//!
127//! The loop above can be implemented more easily using the [`Provider`] derive,
128//! so let's do that.
129//!
130//! ```rust,no_run
131//! use async_injector::{Injector, Provider};
132//!
133//! #[derive(Provider)]
134//! struct Dependencies {
135//!     #[dependency(tag = "name")]
136//!     name: String,
137//!     #[dependency(tag = "fun")]
138//!     fun: String,
139//! }
140//!
141//! async fn greeter(injector: Injector) -> Result<(), Box<dyn std::error::Error>> {
142//!     let mut provider = Dependencies::provider(&injector).await?;
143//!     let Dependencies { name, fun } = provider.wait().await;
144//!     println!("Hi {name}! I see you do \"{fun}\" for fun!");
145//!     Ok(())
146//! }
147//! ```
148//!
149//! > **Note:** you can run this using:
150//! > ```sh
151//! > cargo run --example greeter_provider
152//! > ```
153//!
154//! <br>
155//!
156//! ## The `Provider` derive
157//!
158//! The [`Provider`] derive can be used to conveniently implement the mechanism
159//! necessary to wait for a specific set of dependencies to become available.
160//!
161//! It builds a companion structure next to the type being provided called
162//! `<name>Provider` which in turn implements the following set of methods:
163//!
164//! ```rust,no_run
165//! use async_injector::{Error, Injector};
166//!
167//! # struct Dependencies {}
168//! impl Dependencies {
169//!     /// Construct a new provider.
170//!     async fn provider(injector: &Injector) -> Result<DependenciesProvider, Error>
171//!     # { todo!() }
172//! }
173//!
174//! struct DependenciesProvider {
175//!     /* private fields */
176//! }
177//!
178//! impl DependenciesProvider {
179//!     /// Try to construct the current value. Returns [None] unless all
180//!     /// required dependencies are available.
181//!     fn build(&mut self) -> Option<Dependencies>
182//!     # { todo!() }
183//!
184//!     /// Wait until we can successfully build the complete provided
185//!     /// value.
186//!     async fn wait(&mut self) -> Dependencies
187//!     # { todo!() }
188//!
189//!     /// Wait until the provided value has changed. Either some
190//!     /// dependencies are no longer available at which it returns `None`,
191//!     /// or all dependencies are available after which we return the
192//!     /// build value.
193//!     async fn wait_for_update(&mut self) -> Option<Dependencies>
194//!     # { todo!() }
195//! }
196//! ```
197//!
198//! <br>
199//!
200//! ### Fixed arguments to `Provider`
201//!
202//! Any arguments which do not have the `#[dependency]` attribute are known as
203//! "fixed" arguments. These must be passed in when calling the `provider`
204//! constructor. They can also be used during tag construction.
205//!
206//! ```rust,no_run
207//! use async_injector::{Injector, Key, Provider};
208//!
209//! #[derive(Provider)]
210//! struct Dependencies {
211//!     name_tag: &'static str,
212//!     #[dependency(tag = name_tag)]
213//!     name: String,
214//! }
215//!
216//! async fn greeter(injector: Injector) -> Result<(), Box<dyn std::error::Error>> {
217//!     let mut provider = Dependencies::provider(&injector, "name").await?;
218//!     let Dependencies { name, .. } = provider.wait().await;
219//!     println!("Hi {name}!");
220//!     Ok(())
221//! }
222//! ```
223//!
224//! [`OxidizeBot`]: https://github.com/udoprog/OxidizeBot
225//! [cannot be hashed]: https://internals.rust-lang.org/t/f32-f64-should-implement-hash/5436
226//! [`Injector`]: https://docs.rs/async-injector/0/async_injector/struct.Injector.html
227//! [`Key`]: https://docs.rs/async-injector/0/async_injector/struct.Key.html
228//! [`Provider`]: https://docs.rs/async-injector/0/async_injector/derive.Provider.html
229//! [`serde`]: https://serde.rs
230//! [`Stream`]: https://docs.rs/futures-core/0/futures_core/stream/trait.Stream.html
231//! [`String`]: https://doc.rust-lang.org/std/string/struct.String.html
232
233#![deny(missing_docs)]
234
235use hashbrown::HashMap;
236use serde_hashkey as hashkey;
237use std::any::{Any, TypeId};
238use std::cmp;
239use std::error;
240use std::fmt;
241use std::hash;
242use std::marker;
243use std::mem;
244use std::ptr;
245use std::sync::Arc;
246use tokio::sync::{broadcast, RwLock};
247
248/// The read guard produced by [Ref::read].
249pub type RefReadGuard<'a, T> = tokio::sync::RwLockReadGuard<'a, T>;
250
251/// re-exports for the Provider derive.
252#[doc(hidden)]
253pub mod derive {
254    pub use tokio::select;
255}
256
257/// Helper derive to implement a "provider".
258///
259/// The `Provider` derive can only be used on structs. Each field designates a
260/// value that must either be injected, or provided during construction.
261///
262/// ```rust
263/// use async_injector::Provider;
264/// use serde::Serialize;
265///
266/// #[derive(Serialize)]
267/// enum Tag {
268///     Table,
269///     Url,
270/// }
271///
272/// #[derive(Provider)]
273/// struct Deps {
274///     fixed: String,
275///     #[dependency(optional, tag = Tag::Table)]
276///     table: Option<String>,
277///     #[dependency(tag = Tag::Url)]
278///     url: String,
279///     #[dependency]
280///     connection_limit: u32,
281/// }
282/// ```
283///
284/// This generates another struct named `DepsProvider`, with the following api:
285///
286/// ```rust,no_run
287/// use async_injector::{Error, Injector};
288///
289/// # struct Deps {}
290/// impl Deps {
291///     /// Construct a new provider.
292///     async fn provider(injector: &Injector, fixed: String) -> Result<DepsProvider, Error>
293///     # { todo!() }
294/// }
295///
296/// struct DepsProvider {
297///     /* private fields */
298/// }
299///
300/// impl DepsProvider {
301///     /// Try to construct the current value. Returns [None] unless all
302///     /// required dependencies are available.
303///     fn build(&mut self) -> Option<Deps>
304///     # { todo!() }
305///
306///     /// Wait until we can successfully build the complete provided
307///     /// value.
308///     async fn wait(&mut self) -> Deps
309///     # { todo!() }
310///
311///     /// Wait until the provided value has changed. Either some
312///     /// dependencies are no longer available at which it returns `None`,
313///     /// or all dependencies are available after which we return the
314///     /// build value.
315///     async fn wait_for_update(&mut self) -> Option<Deps>
316///     # { todo!() }
317/// }
318/// ```
319///
320/// The `provider` associated function takes the reference to an injector as its
321/// first argument and any fields which are not marked as a `#[dependency]`.
322/// These are called fixed fields.
323///
324/// <br>
325///
326/// # The `#[dependency]` field attribute
327///
328/// The `#[dependency]` attribute can be used to mark fields which need to be
329/// injected. It takes an optional `#[dependency(tag = ...)]`, which allows you
330/// to specify the tag to use when constructing the injected [Key].
331///
332/// ```rust
333/// use async_injector::Provider;
334/// use serde::Serialize;
335///
336/// #[derive(Serialize)]
337/// enum Tag {
338///     First,
339/// }
340///
341/// #[derive(Provider)]
342/// struct Params {
343///     #[dependency(tag = Tag::First)]
344///     tagged: String,
345///     #[dependency]
346///     number: u32,
347/// }
348/// ```
349///
350/// Optional fields use the [Option] type and must be marked with the `optional`
351/// meta attribute.
352///
353/// ```rust
354/// use async_injector::Provider;
355///
356/// #[derive(Provider)]
357/// struct Params {
358///     #[dependency(optional)]
359///     table: Option<String>,
360/// }
361/// ```
362///
363/// [Key]: https://docs.rs/async-injector/0/async_injector/struct.Key.html
364#[doc(inline)]
365pub use async_injector_derive::Provider;
366
367/// Errors that can be raised by various functions in the [`Injector`].
368#[derive(Debug)]
369pub enum Error {
370    /// Error when serializing key.
371    SerializationError(hashkey::Error),
372}
373
374impl fmt::Display for Error {
375    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
376        match *self {
377            Self::SerializationError(..) => "serialization error".fmt(fmt),
378        }
379    }
380}
381
382impl error::Error for Error {
383    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
384        match self {
385            Self::SerializationError(e) => Some(e),
386        }
387    }
388}
389
390impl From<hashkey::Error> for Error {
391    fn from(value: hashkey::Error) -> Self {
392        Error::SerializationError(value)
393    }
394}
395
396/// A stream of updates to a value in the [`Injector`].
397///
398/// This is created using [Injector::stream] or [Injector::stream_key] and can
399/// be used to make sure that an asynchronous process has access to the most
400/// up-to-date value from the injector.
401pub struct Stream<T> {
402    rx: broadcast::Receiver<Option<Value>>,
403    marker: marker::PhantomData<T>,
404}
405
406impl<T> Stream<T> {
407    /// Receive the next injected element from the stream.
408    pub async fn recv(&mut self) -> Option<T> {
409        let value = loop {
410            let value = self.rx.recv().await;
411
412            match value {
413                Ok(value) => break value,
414                Err(broadcast::error::RecvError::Lagged { .. }) => continue,
415                _ => return None,
416            };
417        };
418
419        let value = match value {
420            Some(value) => value,
421            _ => return None,
422        };
423
424        // Safety: The expected type parameter is encoded and maintained in the
425        // Stream<T> type.
426        Some(unsafe { value.downcast::<T>() })
427    }
428}
429
430/// An opaque value holder, similar to Any, but can be cloned and relies
431/// entirely on external type information.
432struct Value {
433    data: *const (),
434    // clone function, to use when cloning the value.
435    value_clone_fn: unsafe fn(*const ()) -> *const (),
436    // drop function, to use when dropping the value.
437    value_drop_fn: unsafe fn(*const ()),
438}
439
440impl Value {
441    /// Construct a new opaque value.
442    pub(crate) fn new<T>(data: T) -> Self
443    where
444        T: 'static + Clone + Send + Sync,
445    {
446        return Self {
447            data: Box::into_raw(Box::new(data)) as *const (),
448            value_clone_fn: value_clone_fn::<T>,
449            value_drop_fn: value_drop_fn::<T>,
450        };
451
452        /// Clone implementation for a given value.
453        unsafe fn value_clone_fn<T>(data: *const ()) -> *const ()
454        where
455            T: Clone,
456        {
457            let data = T::clone(&*(data as *const _));
458            Box::into_raw(Box::new(data)) as *const ()
459        }
460
461        /// Drop implementation for a given value.
462        unsafe fn value_drop_fn<T>(value: *const ()) {
463            ptr::drop_in_place(value as *mut () as *mut T)
464        }
465    }
466
467    /// Downcast the given value reference.
468    ///
469    /// # Safety
470    ///
471    /// Assumes that we know the type of the underlying value.
472    pub(crate) unsafe fn downcast_ref<T>(&self) -> &T {
473        &*(self.data as *const T)
474    }
475
476    /// Downcast the given value to a mutable reference.
477    ///
478    /// # Safety
479    ///
480    /// Assumes that we know the type of the underlying value.
481    pub(crate) unsafe fn downcast_mut<T>(&mut self) -> &mut T {
482        &mut *(self.data as *const T as *mut T)
483    }
484
485    /// Downcast the given value.
486    ///
487    /// # Safety
488    ///
489    /// Assumes that we know the correct, underlying type of the value.
490    pub(crate) unsafe fn downcast<T>(self) -> T {
491        let value = Box::from_raw(self.data as *const T as *mut T);
492        mem::forget(self);
493        *value
494    }
495}
496
497/// Safety: Send + Sync bound is enforced in all constructors of `Value`.
498unsafe impl Send for Value {}
499unsafe impl Sync for Value {}
500
501impl Clone for Value {
502    fn clone(&self) -> Self {
503        let data = unsafe { (self.value_clone_fn)(self.data as *const _) };
504
505        Self {
506            data,
507            value_clone_fn: self.value_clone_fn,
508            value_drop_fn: self.value_drop_fn,
509        }
510    }
511}
512
513impl Drop for Value {
514    fn drop(&mut self) {
515        unsafe {
516            (self.value_drop_fn)(self.data);
517        }
518    }
519}
520
521struct Storage {
522    value: Arc<RwLock<Option<Value>>>,
523    tx: broadcast::Sender<Option<Value>>,
524}
525
526impl Default for Storage {
527    fn default() -> Self {
528        let (tx, _) = broadcast::channel(1);
529        Self {
530            value: Arc::new(RwLock::new(None)),
531            tx,
532        }
533    }
534}
535
536struct Inner {
537    storage: RwLock<HashMap<RawKey, Storage>>,
538}
539
540/// An injector of dependencies.
541///
542/// Injectors are defined in hierarchies where an injector is either the root
543/// injector as created using [Injector::new].
544#[derive(Clone)]
545pub struct Injector {
546    inner: Arc<Inner>,
547}
548
549impl Injector {
550    /// Construct and use an [`Injector`].
551    ///
552    /// # Example
553    ///
554    /// ```
555    /// use async_injector::Injector;
556    ///
557    /// # #[tokio::main] async fn main() {
558    /// let injector = Injector::new();
559    ///
560    /// assert_eq!(None, injector.get::<u32>().await);
561    /// injector.update(1u32).await;
562    /// assert_eq!(Some(1u32), injector.get::<u32>().await);
563    /// assert!(injector.clear::<u32>().await.is_some());
564    /// assert_eq!(None, injector.get::<u32>().await);
565    /// # }
566    /// ```
567    ///
568    /// Example using a [`Stream`].
569    ///
570    /// ```
571    /// use async_injector::Injector;
572    ///
573    /// #[derive(Clone)]
574    /// struct Database;
575    ///
576    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
577    /// let injector = Injector::new();
578    ///
579    /// let database = injector.var::<Database>().await;
580    ///
581    /// assert!(database.read().await.is_none());
582    /// injector.update(Database).await;
583    /// assert!(database.read().await.is_some());
584    /// # Ok(()) }
585    /// ```
586    pub fn new() -> Self {
587        Self::default()
588    }
589
590    /// Get a value from the injector.
591    ///
592    /// This will cause the clear to be propagated to all streams set up using
593    /// [`stream`]. And for future calls to [`get`] to return the updated value.
594    ///
595    /// [`stream`]: Injector::stream
596    /// [`get`]: Injector::get
597    ///
598    /// # Examples
599    ///
600    /// ```
601    /// use async_injector::Injector;
602    ///
603    /// # #[tokio::main] async fn main() {
604    /// let injector = Injector::new();
605    ///
606    /// assert_eq!(None, injector.get::<u32>().await);
607    /// injector.update(1u32).await;
608    /// assert_eq!(Some(1u32), injector.get::<u32>().await);
609    /// assert!(injector.clear::<u32>().await.is_some());
610    /// assert_eq!(None, injector.get::<u32>().await);
611    /// # }
612    /// ```
613    pub async fn clear<T>(&self) -> Option<T>
614    where
615        T: Clone + Any + Send + Sync,
616    {
617        self.clear_key(Key::<T>::of()).await
618    }
619
620    /// Clear the given value with the given key.
621    ///
622    /// This will cause the clear to be propagated to all streams set up using
623    /// [`stream`]. And for future calls to [`get`] to return the updated value.
624    ///
625    /// [`stream`]: Injector::stream
626    /// [`get`]: Injector::get
627    ///
628    /// # Examples
629    ///
630    /// ```
631    /// use async_injector::{Key, Injector};
632    ///
633    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
634    /// let injector = Injector::new();
635    /// let k = Key::<u32>::tagged("first")?;
636    ///
637    /// assert_eq!(None, injector.get_key(&k).await);
638    /// injector.update_key(&k, 1u32).await;
639    /// assert_eq!(Some(1u32), injector.get_key(&k).await);
640    /// assert!(injector.clear_key(&k).await.is_some());
641    /// assert_eq!(None, injector.get_key(&k).await);
642    /// # Ok(()) }
643    /// ```
644    pub async fn clear_key<T>(&self, key: impl AsRef<Key<T>>) -> Option<T>
645    where
646        T: Clone + Any + Send + Sync,
647    {
648        let key = key.as_ref().as_raw_key();
649
650        let storage = self.inner.storage.read().await;
651        let storage = storage.get(key)?;
652        let value = storage.value.write().await.take()?;
653        let _ = storage.tx.send(None);
654        Some(unsafe { value.downcast() })
655    }
656
657    /// Set the given value and notify any subscribers.
658    ///
659    /// This will cause the update to be propagated to all streams set up using
660    /// [`stream`]. And for future calls to [`get`] to return the updated value.
661    ///
662    /// [`stream`]: Injector::stream
663    /// [`get`]: Injector::get
664    ///
665    /// # Examples
666    ///
667    /// ```
668    /// use async_injector::Injector;
669    ///
670    /// # #[tokio::main]
671    /// # async fn main() {
672    /// let injector = Injector::new();
673    ///
674    /// assert_eq!(None, injector.get::<u32>().await);
675    /// injector.update(1u32).await;
676    /// assert_eq!(Some(1u32), injector.get::<u32>().await);
677    /// # }
678    /// ```
679    pub async fn update<T>(&self, value: T) -> Option<T>
680    where
681        T: Clone + Any + Send + Sync,
682    {
683        self.update_key(Key::<T>::of(), value).await
684    }
685
686    /// Update the value associated with the given key.
687    ///
688    /// This will cause the update to be propagated to all streams set up using
689    /// [`stream`]. And for future calls to [`get`] to return the updated value.
690    ///
691    /// [`stream`]: Injector::stream
692    /// [`get`]: Injector::get
693    ///
694    /// # Examples
695    ///
696    /// ```
697    /// use async_injector::{Key, Injector};
698    ///
699    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
700    /// let injector = Injector::new();
701    /// let k = Key::<u32>::tagged("first")?;
702    ///
703    /// assert_eq!(None, injector.get_key(&k).await);
704    /// injector.update_key(&k, 1u32).await;
705    /// assert_eq!(Some(1u32), injector.get_key(&k).await);
706    /// # Ok(()) }
707    /// ```
708    pub async fn update_key<T>(&self, key: impl AsRef<Key<T>>, value: T) -> Option<T>
709    where
710        T: Clone + Any + Send + Sync,
711    {
712        let key = key.as_ref().as_raw_key();
713        let value = Value::new(value);
714
715        let mut storage = self.inner.storage.write().await;
716        let storage = storage.entry(key.clone()).or_default();
717        let _ = storage.tx.send(Some(value.clone()));
718        let old = storage.value.write().await.replace(value)?;
719        Some(unsafe { old.downcast() })
720    }
721
722    /// Test if a given value exists by type.
723    ///
724    /// # Examples
725    ///
726    /// ```
727    /// use async_injector::{Key, Injector};
728    ///
729    /// # #[tokio::main] async fn main() {
730    /// let injector = Injector::new();
731    ///
732    /// assert_eq!(false, injector.exists::<u32>().await);
733    /// injector.update(1u32).await;
734    /// assert_eq!(true, injector.exists::<u32>().await);
735    /// # }
736    /// ```
737    pub async fn exists<T>(&self) -> bool
738    where
739        T: Clone + Any + Send + Sync,
740    {
741        self.exists_key(Key::<T>::of()).await
742    }
743
744    /// Test if a given value exists by key.
745    ///
746    /// # Examples
747    ///
748    /// ```
749    /// use async_injector::{Key, Injector};
750    ///
751    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
752    /// let injector = Injector::new();
753    /// let k = Key::<u32>::tagged("first")?;
754    ///
755    /// assert_eq!(false, injector.exists_key(&k).await);
756    /// injector.update_key(&k, 1u32).await;
757    /// assert_eq!(true, injector.exists_key(&k).await);
758    /// # Ok(()) }
759    /// ```
760    pub async fn exists_key<T>(&self, key: impl AsRef<Key<T>>) -> bool
761    where
762        T: Clone + Any + Send + Sync,
763    {
764        let key = key.as_ref().as_raw_key();
765        let storage = self.inner.storage.read().await;
766
767        if let Some(s) = storage.get(key) {
768            s.value.read().await.is_some()
769        } else {
770            false
771        }
772    }
773
774    /// Mutate the given value by type.
775    ///
776    /// # Examples
777    ///
778    /// ```
779    /// use async_injector::Injector;
780    ///
781    /// # #[tokio::main] async fn main() {
782    /// let injector = Injector::new();
783    ///
784    /// injector.update(1u32).await;
785    ///
786    /// let old = injector.mutate(|value: &mut u32| {
787    ///     let old = *value;
788    ///     *value += 1;
789    ///     old
790    /// }).await;
791    ///
792    /// assert_eq!(Some(1u32), old);
793    /// # }
794    /// ```
795    pub async fn mutate<T, M, R>(&self, mutator: M) -> Option<R>
796    where
797        T: Clone + Any + Send + Sync,
798        M: FnMut(&mut T) -> R,
799    {
800        self.mutate_key(Key::<T>::of(), mutator).await
801    }
802
803    /// Mutate the given value by key.
804    ///
805    /// # Examples
806    ///
807    /// ```
808    /// use async_injector::{Key, Injector};
809    ///
810    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
811    /// let injector = Injector::new();
812    /// let k = Key::<u32>::tagged("first")?;
813    ///
814    /// injector.update_key(&k, 1u32).await;
815    ///
816    /// let old = injector.mutate_key(&k, |value| {
817    ///     let old = *value;
818    ///     *value += 1;
819    ///     old
820    /// }).await;
821    ///
822    /// assert_eq!(Some(1u32), old);
823    /// # Ok(()) }
824    /// ```
825    pub async fn mutate_key<T, M, R>(&self, key: impl AsRef<Key<T>>, mut mutator: M) -> Option<R>
826    where
827        T: Clone + Any + Send + Sync,
828        M: FnMut(&mut T) -> R,
829    {
830        let key = key.as_ref().as_raw_key();
831        let storage = self.inner.storage.read().await;
832
833        let storage = match storage.get(key) {
834            Some(s) => s,
835            None => return None,
836        };
837
838        let mut value = storage.value.write().await;
839
840        if let Some(value) = &mut *value {
841            let output = mutator(unsafe { value.downcast_mut() });
842            let value = value.clone();
843            let _ = storage.tx.send(Some(value));
844            return Some(output);
845        }
846
847        None
848    }
849
850    /// Get a value from the injector.
851    ///
852    /// # Examples
853    ///
854    /// ```
855    /// use async_injector::Injector;
856    ///
857    /// # #[tokio::main] async fn main() {
858    /// let injector = Injector::new();
859    ///
860    /// assert_eq!(None, injector.get::<u32>().await);
861    /// injector.update(1u32).await;
862    /// assert_eq!(Some(1u32), injector.get::<u32>().await);
863    /// # }
864    /// ```
865    pub async fn get<T>(&self) -> Option<T>
866    where
867        T: Clone + Any + Send + Sync,
868    {
869        self.get_key(Key::<T>::of()).await
870    }
871
872    /// Get a value from the injector with the given key.
873    ///
874    /// # Examples
875    ///
876    /// ```
877    /// use async_injector::{Key, Injector};
878    ///
879    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
880    /// let k1 = Key::<u32>::tagged("first")?;
881    /// let k2 = Key::<u32>::tagged("second")?;
882    ///
883    /// let injector = Injector::new();
884    ///
885    /// assert_eq!(None, injector.get_key(&k1).await);
886    /// assert_eq!(None, injector.get_key(&k2).await);
887    ///
888    /// injector.update_key(&k1, 1u32).await;
889    ///
890    /// assert_eq!(Some(1u32), injector.get_key(&k1).await);
891    /// assert_eq!(None, injector.get_key(&k2).await);
892    /// # Ok(()) }
893    /// ```
894    pub async fn get_key<T>(&self, key: impl AsRef<Key<T>>) -> Option<T>
895    where
896        T: Clone + Any + Send + Sync,
897    {
898        let key = key.as_ref().as_raw_key();
899        let storage = self.inner.storage.read().await;
900
901        let storage = match storage.get(key) {
902            Some(storage) => storage,
903            None => return None,
904        };
905
906        let value = storage.value.read().await;
907
908        value.as_ref().map(|value| {
909            // Safety: Ref<T> instances can only be produced by checked fns.
910            unsafe { value.downcast_ref::<T>().clone() }
911        })
912    }
913
914    /// Wait for a value to become available.
915    ///
916    /// Note that this could potentially wait forever if the value is never
917    /// injected.
918    ///
919    /// # Examples
920    ///
921    /// ```
922    /// use async_injector::Injector;
923    ///
924    /// # #[tokio::main]
925    /// # async fn main() {
926    /// let injector = Injector::new();
927    ///
928    /// injector.update(1u32).await;
929    /// assert_eq!(1u32, injector.wait::<u32>().await);
930    /// # }
931    /// ```
932    #[inline]
933    pub async fn wait<T>(&self) -> T
934    where
935        T: Clone + Any + Send + Sync,
936    {
937        self.wait_key(Key::<T>::of()).await
938    }
939
940    /// Wait for a value associated with the given key to become available.
941    ///
942    /// Note that this could potentially wait forever if the value is never
943    /// injected.
944    ///
945    /// # Examples
946    ///
947    /// ```
948    /// use async_injector::{Key, Injector};
949    ///
950    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
951    /// let injector = Injector::new();
952    /// let tag = Key::<u32>::tagged("first")?;
953    ///
954    /// injector.update_key(&tag, 1u32).await;
955    /// assert_eq!(1u32, injector.wait_key(tag).await);
956    /// # Ok(()) }
957    /// ```
958    pub async fn wait_key<T>(&self, key: impl AsRef<Key<T>>) -> T
959    where
960        T: Clone + Any + Send + Sync,
961    {
962        let (mut stream, value) = self.stream_key(key).await;
963
964        if let Some(value) = value {
965            return value;
966        }
967
968        loop {
969            if let Some(value) = stream.recv().await {
970                return value;
971            }
972        }
973    }
974
975    /// Get an existing value and setup a stream for updates at the same time.
976    ///
977    /// # Examples
978    ///
979    /// ```
980    /// use async_injector::Injector;
981    ///
982    /// #[derive(Debug, Clone, PartialEq, Eq)]
983    /// struct Database;
984    ///
985    /// # #[tokio::main] async fn main() {
986    /// let injector = Injector::new();
987    /// let (mut database_stream, mut database) = injector.stream::<Database>().await;
988    ///
989    /// // Update the key somewhere else.
990    /// tokio::spawn({
991    ///     let injector = injector.clone();
992    ///
993    ///     async move {
994    ///         injector.update(Database).await;
995    ///     }
996    /// });
997    ///
998    /// let database = loop {
999    ///     if let Some(update) = database_stream.recv().await {
1000    ///         break update;
1001    ///     }
1002    /// };
1003    ///
1004    /// assert_eq!(database, Database);
1005    /// # }
1006    /// ```
1007    pub async fn stream<T>(&self) -> (Stream<T>, Option<T>)
1008    where
1009        T: Clone + Any + Send + Sync,
1010    {
1011        self.stream_key(Key::<T>::of()).await
1012    }
1013
1014    /// Get an existing value and setup a stream for updates at the same time.
1015    ///
1016    /// # Examples
1017    ///
1018    /// ```
1019    /// use async_injector::{Key, Injector};
1020    ///
1021    /// #[derive(Debug, Clone, PartialEq, Eq)]
1022    /// struct Database;
1023    ///
1024    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
1025    /// let injector = Injector::new();
1026    /// let db = Key::<Database>::tagged("first")?;
1027    /// let (mut database_stream, mut database) = injector.stream_key(&db).await;
1028    ///
1029    /// // Update the key somewhere else.
1030    /// tokio::spawn({
1031    ///     let db = db.clone();
1032    ///     let injector = injector.clone();
1033    ///
1034    ///     async move {
1035    ///         injector.update_key(&db, Database).await;
1036    ///     }
1037    /// });
1038    ///
1039    /// let database = loop {
1040    ///     if let Some(update) = database_stream.recv().await {
1041    ///         break update;
1042    ///     }
1043    /// };
1044    ///
1045    /// assert_eq!(database, Database);
1046    /// # Ok(()) }
1047    /// ```
1048    pub async fn stream_key<T>(&self, key: impl AsRef<Key<T>>) -> (Stream<T>, Option<T>)
1049    where
1050        T: Clone + Any + Send + Sync,
1051    {
1052        let key = key.as_ref().as_raw_key();
1053
1054        let mut storage = self.inner.storage.write().await;
1055        let storage = storage.entry(key.clone()).or_default();
1056
1057        let rx = storage.tx.subscribe();
1058
1059        let value = storage.value.read().await;
1060
1061        let value = value.as_ref().map(|value| {
1062            // Safety: The expected type parameter is encoded and maintained
1063            // in the Stream type.
1064            unsafe { value.downcast_ref::<T>().clone() }
1065        });
1066
1067        let stream = Stream {
1068            rx,
1069            marker: marker::PhantomData,
1070        };
1071
1072        (stream, value)
1073    }
1074
1075    /// Get a synchronized reference for the given configuration key.
1076    ///
1077    /// # Examples
1078    ///
1079    /// ```
1080    /// use async_injector::Injector;
1081    ///
1082    /// #[derive(Clone)]
1083    /// struct Database;
1084    ///
1085    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
1086    /// let injector = Injector::new();
1087    ///
1088    /// let database = injector.var::<Database>().await;
1089    ///
1090    /// assert!(database.read().await.is_none());
1091    /// injector.update(Database).await;
1092    /// assert!(database.read().await.is_some());
1093    /// # Ok(()) }
1094    /// ```
1095    pub async fn var<T>(&self) -> Ref<T>
1096    where
1097        T: Clone + Any + Send + Sync + Unpin,
1098    {
1099        self.var_key(Key::<T>::of()).await
1100    }
1101
1102    /// Get a synchronized reference for the given configuration key.
1103    ///
1104    /// # Examples
1105    ///
1106    /// ```
1107    /// use async_injector::{Key, Injector};
1108    /// use std::error::Error;
1109    ///
1110    /// #[derive(Clone)]
1111    /// struct Database;
1112    ///
1113    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn Error>> {
1114    /// let injector = Injector::new();
1115    /// let db = Key::<Database>::tagged("first")?;
1116    ///
1117    /// let database = injector.var_key(&db).await;
1118    ///
1119    /// assert!(database.read().await.is_none());
1120    /// injector.update_key(&db, Database).await;
1121    /// assert!(database.read().await.is_some());
1122    /// # Ok(()) }
1123    /// ```
1124    pub async fn var_key<T>(&self, key: impl AsRef<Key<T>>) -> Ref<T>
1125    where
1126        T: Clone + Any + Send + Sync + Unpin,
1127    {
1128        let key = key.as_ref().as_raw_key();
1129
1130        let mut storage = self.inner.storage.write().await;
1131        let storage = storage.entry(key.clone()).or_default();
1132
1133        Ref {
1134            value: storage.value.clone(),
1135            _m: marker::PhantomData,
1136        }
1137    }
1138}
1139
1140impl Default for Injector {
1141    fn default() -> Self {
1142        Self {
1143            inner: Arc::new(Inner {
1144                storage: Default::default(),
1145            }),
1146        }
1147    }
1148}
1149
1150#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq, Hash)]
1151struct RawKey {
1152    type_id: TypeId,
1153    tag_type_id: TypeId,
1154    tag: hashkey::Key,
1155}
1156
1157impl RawKey {
1158    /// Construct a new raw key.
1159    fn new<T, K>(tag: hashkey::Key) -> Self
1160    where
1161        T: Any,
1162        K: Any,
1163    {
1164        Self {
1165            type_id: TypeId::of::<T>(),
1166            tag_type_id: TypeId::of::<K>(),
1167            tag,
1168        }
1169    }
1170}
1171
1172/// A key used to discriminate a value in the [`Injector`].
1173#[derive(Clone)]
1174pub struct Key<T>
1175where
1176    T: Any,
1177{
1178    raw_key: RawKey,
1179    _marker: marker::PhantomData<T>,
1180}
1181
1182impl<T> fmt::Debug for Key<T>
1183where
1184    T: Any,
1185{
1186    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1187        fmt::Debug::fmt(&self.raw_key, fmt)
1188    }
1189}
1190
1191impl<T> cmp::PartialEq for Key<T>
1192where
1193    T: Any,
1194{
1195    fn eq(&self, other: &Self) -> bool {
1196        self.as_raw_key().eq(other.as_raw_key())
1197    }
1198}
1199
1200impl<T> cmp::Eq for Key<T> where T: Any {}
1201
1202impl<T> cmp::PartialOrd for Key<T>
1203where
1204    T: Any,
1205{
1206    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
1207        Some(self.cmp(other))
1208    }
1209}
1210
1211impl<T> cmp::Ord for Key<T>
1212where
1213    T: Any,
1214{
1215    fn cmp(&self, other: &Self) -> cmp::Ordering {
1216        self.as_raw_key().cmp(other.as_raw_key())
1217    }
1218}
1219
1220impl<T> hash::Hash for Key<T>
1221where
1222    T: Any,
1223{
1224    fn hash<H>(&self, state: &mut H)
1225    where
1226        H: hash::Hasher,
1227    {
1228        self.as_raw_key().hash(state);
1229    }
1230}
1231
1232impl<T> Key<T>
1233where
1234    T: Any,
1235{
1236    /// Construct a new key without a tag.
1237    ///
1238    /// # Examples
1239    ///
1240    /// ```
1241    /// use async_injector::Key;
1242    ///
1243    /// struct Foo;
1244    ///
1245    /// assert_eq!(Key::<Foo>::of(), Key::<Foo>::of());
1246    /// ```
1247    pub fn of() -> Self {
1248        Self {
1249            raw_key: RawKey::new::<T, ()>(hashkey::Key::Unit),
1250            _marker: marker::PhantomData,
1251        }
1252    }
1253
1254    /// Construct a new key.
1255    ///
1256    /// # Examples
1257    ///
1258    /// ```
1259    /// use serde::Serialize;
1260    /// use async_injector::Key;
1261    ///
1262    /// struct Foo;
1263    ///
1264    /// #[derive(Serialize)]
1265    /// enum Tag {
1266    ///     One,
1267    ///     Two,
1268    /// }
1269    ///
1270    /// #[derive(Serialize)]
1271    /// enum Tag2 {
1272    ///     One,
1273    ///     Two,
1274    /// }
1275    ///
1276    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1277    /// assert_eq!(Key::<Foo>::tagged(Tag::One)?, Key::<Foo>::tagged(Tag::One)?);
1278    /// assert_ne!(Key::<Foo>::tagged(Tag::One)?, Key::<Foo>::tagged(Tag::Two)?);
1279    /// assert_ne!(Key::<Foo>::tagged(Tag::One)?, Key::<Foo>::tagged(Tag2::One)?);
1280    /// # Ok(()) }
1281    /// ```
1282    pub fn tagged<K>(tag: K) -> Result<Self, Error>
1283    where
1284        K: Any + serde::Serialize,
1285    {
1286        let tag = hashkey::to_key(&tag)?;
1287
1288        Ok(Self {
1289            raw_key: RawKey::new::<T, K>(tag),
1290            _marker: marker::PhantomData,
1291        })
1292    }
1293
1294    /// Convert into a raw key.
1295    fn as_raw_key(&self) -> &RawKey {
1296        &self.raw_key
1297    }
1298}
1299
1300impl<T> AsRef<Key<T>> for Key<T>
1301where
1302    T: 'static,
1303{
1304    fn as_ref(&self) -> &Self {
1305        self
1306    }
1307}
1308
1309/// A variable allowing for the synchronized reading of avalue in the
1310/// [`Injector`].
1311///
1312/// This can be created through [Injector::var] or [Injector::var_key].
1313#[derive(Clone)]
1314pub struct Ref<T>
1315where
1316    T: Clone + Any + Send + Sync,
1317{
1318    value: Arc<RwLock<Option<Value>>>,
1319    _m: marker::PhantomData<T>,
1320}
1321
1322impl<T> Ref<T>
1323where
1324    T: Clone + Any + Send + Sync,
1325{
1326    /// Read the synchronized variable.
1327    ///
1328    /// # Examples
1329    ///
1330    /// ```
1331    /// use async_injector::Injector;
1332    ///
1333    /// #[derive(Clone)]
1334    /// struct Database;
1335    ///
1336    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
1337    /// let injector = Injector::new();
1338    ///
1339    /// let database = injector.var::<Database>().await;
1340    ///
1341    /// assert!(database.read().await.is_none());
1342    /// injector.update(Database).await;
1343    /// assert!(database.read().await.is_some());
1344    /// # Ok(()) }
1345    /// ```
1346    pub async fn read(&self) -> Option<RefReadGuard<'_, T>> {
1347        let value = self.value.read().await;
1348
1349        let result = RefReadGuard::try_map(value, |value| {
1350            value.as_ref().map(|value| {
1351                // Safety: The expected type parameter is encoded and
1352                // maintained in the Stream type.
1353                unsafe { value.downcast_ref::<T>() }
1354            })
1355        });
1356
1357        result.ok()
1358    }
1359
1360    /// Load the synchronized variable. This clones the underlying value if it
1361    /// has been set.
1362    ///
1363    /// # Examples
1364    ///
1365    /// ```
1366    /// use async_injector::Injector;
1367    ///
1368    /// #[derive(Clone)]
1369    /// struct Database;
1370    ///
1371    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
1372    /// let injector = Injector::new();
1373    ///
1374    /// let database = injector.var::<Database>().await;
1375    ///
1376    /// assert!(database.load().await.is_none());
1377    /// injector.update(Database).await;
1378    /// assert!(database.load().await.is_some());
1379    /// # Ok(()) }
1380    /// ```
1381    pub async fn load(&self) -> Option<T> {
1382        let value = self.value.read().await;
1383
1384        value.as_ref().map(|value| {
1385            // Safety: The expected type parameter is encoded and maintained
1386            // in the Stream type.
1387            unsafe { value.downcast_ref::<T>().clone() }
1388        })
1389    }
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394    use super::Value;
1395
1396    #[test]
1397    fn test_clone() {
1398        use std::sync::{
1399            atomic::{AtomicUsize, Ordering},
1400            Arc,
1401        };
1402
1403        let count = Arc::new(AtomicUsize::new(0));
1404
1405        let value = Value::new(Foo(count.clone()));
1406        assert_eq!(0, count.load(Ordering::SeqCst));
1407        drop(value.clone());
1408        assert_eq!(1, count.load(Ordering::SeqCst));
1409        drop(value);
1410        assert_eq!(2, count.load(Ordering::SeqCst));
1411
1412        #[derive(Clone)]
1413        struct Foo(Arc<AtomicUsize>);
1414
1415        impl Drop for Foo {
1416            fn drop(&mut self) {
1417                self.0.fetch_add(1, Ordering::SeqCst);
1418            }
1419        }
1420    }
1421}