async_injector/lib.rs
1//! [<img alt="github" src="https://img.shields.io/badge/github-udoprog/async--injector-8da0cb?style=for-the-badge&logo=github" height="20">](https://github.com/udoprog/async-injector)
2//! [<img alt="crates.io" src="https://img.shields.io/crates/v/async-injector.svg?style=for-the-badge&color=fc8d62&logo=rust" height="20">](https://crates.io/crates/async-injector)
3//! [<img alt="docs.rs" src="https://img.shields.io/badge/docs.rs-async--injector-66c2a5?style=for-the-badge&logoColor=white&logo=" height="20">](https://docs.rs/async-injector)
4//!
5//! Asynchronous dependency injection for Rust.
6//!
7//! This library provides the glue which allows for building robust decoupled
8//! applications that can be reconfigured dynamically while they are running.
9//!
10//! For a real world example of how this is used, see [`OxidizeBot`] for which
11//! it was written.
12//!
13//! <br>
14//!
15//! ## Usage
16//!
17//! Add `async-injector` to your `Cargo.toml`.
18//!
19//! ```toml
20//! [dependencies]
21//! async-injector = "0.19.4"
22//! ```
23//!
24//! <br>
25//!
26//! ## Example
27//!
28//! In the following we'll showcase the injection of a *fake* `Database`. The
29//! idea here would be that if something about the database connection changes,
30//! a new instance of `Database` would be created and cause the application to
31//! reconfigure itself.
32//!
33//! ```rust
34//! use async_injector::{Key, Injector, Provider};
35//!
36//! #[derive(Debug, Clone)]
37//! struct Database;
38//!
39//! #[derive(Debug, Provider)]
40//! struct Service {
41//! #[dependency]
42//! database: Database,
43//! }
44//!
45//! async fn service(injector: Injector) -> Result<(), Box<dyn std::error::Error>> {
46//! let mut provider = Service::provider(&injector).await?;
47//!
48//! let Service { database } = provider.wait().await;
49//! println!("Service got initial database {database:?}!");
50//!
51//! let Service { database } = provider.wait().await;
52//! println!("Service got new database {database:?}!");
53//!
54//! Ok(())
55//! }
56//! ```
57//!
58//! > **Note:** This is available as the `database` example:
59//! > ```sh
60//! > cargo run --example database
61//! > ```
62//!
63//! The [`Injector`] above provides a structured broadcasting system that allows
64//! for configuration updates to be cleanly integrated into asynchronous
65//! contexts. The update itself is triggered by some other component that is
66//! responsible for constructing the `Database` instance.
67//!
68//! Building up the components of your application like this means that it can
69//! be reconfigured without restarting it. Providing a much richer user
70//! experience.
71//!
72//! <br>
73//!
74//! ## Injecting multiple things of the same type
75//!
76//! In the previous section you might've noticed that the injected value was
77//! solely discriminated by its type: `Database`. In this example we'll show how
78//! [`Key`] can be used to *tag* values of the same type with different names to
79//! discriminate them. This can be useful when dealing with overly generic types
80//! like [`String`].
81//!
82//! The tag used must be serializable with [`serde`]. It must also not use any
83//! components which [cannot be hashed], like `f32` and `f64`.
84//!
85//! <br>
86//!
87//! ### A simple greeter
88//!
89//! The following example showcases the use of `Key` to injector two different
90//! values into an asynchronous `greeter`.
91//!
92//! ```rust,no_run
93//! use async_injector::{Key, Injector};
94//!
95//! async fn greeter(injector: Injector) -> Result<(), Box<dyn std::error::Error>> {
96//! let name = Key::<String>::tagged("name")?;
97//! let fun = Key::<String>::tagged("fun")?;
98//!
99//! let (mut name_stream, mut name) = injector.stream_key(name).await;
100//! let (mut fun_stream, mut fun) = injector.stream_key(fun).await;
101//!
102//! loop {
103//! tokio::select! {
104//! update = name_stream.recv() => {
105//! name = update;
106//! }
107//! update = fun_stream.recv() => {
108//! fun = update;
109//! }
110//! }
111//!
112//! let (Some(name), Some(fun)) = (&name, &fun) else {
113//! continue;
114//! };
115//!
116//! println!("Hi {name}! I see you do \"{fun}\" for fun!");
117//! return Ok(());
118//! }
119//! }
120//! ```
121//!
122//! > **Note:** you can run this using:
123//! > ```sh
124//! > cargo run --example greeter
125//! > ```
126//!
127//! The loop above can be implemented more easily using the [`Provider`] derive,
128//! so let's do that.
129//!
130//! ```rust,no_run
131//! use async_injector::{Injector, Provider};
132//!
133//! #[derive(Provider)]
134//! struct Dependencies {
135//! #[dependency(tag = "name")]
136//! name: String,
137//! #[dependency(tag = "fun")]
138//! fun: String,
139//! }
140//!
141//! async fn greeter(injector: Injector) -> Result<(), Box<dyn std::error::Error>> {
142//! let mut provider = Dependencies::provider(&injector).await?;
143//! let Dependencies { name, fun } = provider.wait().await;
144//! println!("Hi {name}! I see you do \"{fun}\" for fun!");
145//! Ok(())
146//! }
147//! ```
148//!
149//! > **Note:** you can run this using:
150//! > ```sh
151//! > cargo run --example greeter_provider
152//! > ```
153//!
154//! <br>
155//!
156//! ## The `Provider` derive
157//!
158//! The [`Provider`] derive can be used to conveniently implement the mechanism
159//! necessary to wait for a specific set of dependencies to become available.
160//!
161//! It builds a companion structure next to the type being provided called
162//! `<name>Provider` which in turn implements the following set of methods:
163//!
164//! ```rust,no_run
165//! use async_injector::{Error, Injector};
166//!
167//! # struct Dependencies {}
168//! impl Dependencies {
169//! /// Construct a new provider.
170//! async fn provider(injector: &Injector) -> Result<DependenciesProvider, Error>
171//! # { todo!() }
172//! }
173//!
174//! struct DependenciesProvider {
175//! /* private fields */
176//! }
177//!
178//! impl DependenciesProvider {
179//! /// Try to construct the current value. Returns [None] unless all
180//! /// required dependencies are available.
181//! fn build(&mut self) -> Option<Dependencies>
182//! # { todo!() }
183//!
184//! /// Wait until we can successfully build the complete provided
185//! /// value.
186//! async fn wait(&mut self) -> Dependencies
187//! # { todo!() }
188//!
189//! /// Wait until the provided value has changed. Either some
190//! /// dependencies are no longer available at which it returns `None`,
191//! /// or all dependencies are available after which we return the
192//! /// build value.
193//! async fn wait_for_update(&mut self) -> Option<Dependencies>
194//! # { todo!() }
195//! }
196//! ```
197//!
198//! <br>
199//!
200//! ### Fixed arguments to `Provider`
201//!
202//! Any arguments which do not have the `#[dependency]` attribute are known as
203//! "fixed" arguments. These must be passed in when calling the `provider`
204//! constructor. They can also be used during tag construction.
205//!
206//! ```rust,no_run
207//! use async_injector::{Injector, Key, Provider};
208//!
209//! #[derive(Provider)]
210//! struct Dependencies {
211//! name_tag: &'static str,
212//! #[dependency(tag = name_tag)]
213//! name: String,
214//! }
215//!
216//! async fn greeter(injector: Injector) -> Result<(), Box<dyn std::error::Error>> {
217//! let mut provider = Dependencies::provider(&injector, "name").await?;
218//! let Dependencies { name, .. } = provider.wait().await;
219//! println!("Hi {name}!");
220//! Ok(())
221//! }
222//! ```
223//!
224//! [`OxidizeBot`]: https://github.com/udoprog/OxidizeBot
225//! [cannot be hashed]: https://internals.rust-lang.org/t/f32-f64-should-implement-hash/5436
226//! [`Injector`]: https://docs.rs/async-injector/0/async_injector/struct.Injector.html
227//! [`Key`]: https://docs.rs/async-injector/0/async_injector/struct.Key.html
228//! [`Provider`]: https://docs.rs/async-injector/0/async_injector/derive.Provider.html
229//! [`serde`]: https://serde.rs
230//! [`Stream`]: https://docs.rs/futures-core/0/futures_core/stream/trait.Stream.html
231//! [`String`]: https://doc.rust-lang.org/std/string/struct.String.html
232
233#![deny(missing_docs)]
234
235use hashbrown::HashMap;
236use serde_hashkey as hashkey;
237use std::any::{Any, TypeId};
238use std::cmp;
239use std::error;
240use std::fmt;
241use std::hash;
242use std::marker;
243use std::mem;
244use std::ptr;
245use std::sync::Arc;
246use tokio::sync::{broadcast, RwLock};
247
248/// The read guard produced by [Ref::read].
249pub type RefReadGuard<'a, T> = tokio::sync::RwLockReadGuard<'a, T>;
250
251/// re-exports for the Provider derive.
252#[doc(hidden)]
253pub mod derive {
254 pub use tokio::select;
255}
256
257/// Helper derive to implement a "provider".
258///
259/// The `Provider` derive can only be used on structs. Each field designates a
260/// value that must either be injected, or provided during construction.
261///
262/// ```rust
263/// use async_injector::Provider;
264/// use serde::Serialize;
265///
266/// #[derive(Serialize)]
267/// enum Tag {
268/// Table,
269/// Url,
270/// }
271///
272/// #[derive(Provider)]
273/// struct Deps {
274/// fixed: String,
275/// #[dependency(optional, tag = Tag::Table)]
276/// table: Option<String>,
277/// #[dependency(tag = Tag::Url)]
278/// url: String,
279/// #[dependency]
280/// connection_limit: u32,
281/// }
282/// ```
283///
284/// This generates another struct named `DepsProvider`, with the following api:
285///
286/// ```rust,no_run
287/// use async_injector::{Error, Injector};
288///
289/// # struct Deps {}
290/// impl Deps {
291/// /// Construct a new provider.
292/// async fn provider(injector: &Injector, fixed: String) -> Result<DepsProvider, Error>
293/// # { todo!() }
294/// }
295///
296/// struct DepsProvider {
297/// /* private fields */
298/// }
299///
300/// impl DepsProvider {
301/// /// Try to construct the current value. Returns [None] unless all
302/// /// required dependencies are available.
303/// fn build(&mut self) -> Option<Deps>
304/// # { todo!() }
305///
306/// /// Wait until we can successfully build the complete provided
307/// /// value.
308/// async fn wait(&mut self) -> Deps
309/// # { todo!() }
310///
311/// /// Wait until the provided value has changed. Either some
312/// /// dependencies are no longer available at which it returns `None`,
313/// /// or all dependencies are available after which we return the
314/// /// build value.
315/// async fn wait_for_update(&mut self) -> Option<Deps>
316/// # { todo!() }
317/// }
318/// ```
319///
320/// The `provider` associated function takes the reference to an injector as its
321/// first argument and any fields which are not marked as a `#[dependency]`.
322/// These are called fixed fields.
323///
324/// <br>
325///
326/// # The `#[dependency]` field attribute
327///
328/// The `#[dependency]` attribute can be used to mark fields which need to be
329/// injected. It takes an optional `#[dependency(tag = ...)]`, which allows you
330/// to specify the tag to use when constructing the injected [Key].
331///
332/// ```rust
333/// use async_injector::Provider;
334/// use serde::Serialize;
335///
336/// #[derive(Serialize)]
337/// enum Tag {
338/// First,
339/// }
340///
341/// #[derive(Provider)]
342/// struct Params {
343/// #[dependency(tag = Tag::First)]
344/// tagged: String,
345/// #[dependency]
346/// number: u32,
347/// }
348/// ```
349///
350/// Optional fields use the [Option] type and must be marked with the `optional`
351/// meta attribute.
352///
353/// ```rust
354/// use async_injector::Provider;
355///
356/// #[derive(Provider)]
357/// struct Params {
358/// #[dependency(optional)]
359/// table: Option<String>,
360/// }
361/// ```
362///
363/// [Key]: https://docs.rs/async-injector/0/async_injector/struct.Key.html
364#[doc(inline)]
365pub use async_injector_derive::Provider;
366
367/// Errors that can be raised by various functions in the [`Injector`].
368#[derive(Debug)]
369pub enum Error {
370 /// Error when serializing key.
371 SerializationError(hashkey::Error),
372}
373
374impl fmt::Display for Error {
375 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
376 match *self {
377 Self::SerializationError(..) => "serialization error".fmt(fmt),
378 }
379 }
380}
381
382impl error::Error for Error {
383 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
384 match self {
385 Self::SerializationError(e) => Some(e),
386 }
387 }
388}
389
390impl From<hashkey::Error> for Error {
391 fn from(value: hashkey::Error) -> Self {
392 Error::SerializationError(value)
393 }
394}
395
396/// A stream of updates to a value in the [`Injector`].
397///
398/// This is created using [Injector::stream] or [Injector::stream_key] and can
399/// be used to make sure that an asynchronous process has access to the most
400/// up-to-date value from the injector.
401pub struct Stream<T> {
402 rx: broadcast::Receiver<Option<Value>>,
403 marker: marker::PhantomData<T>,
404}
405
406impl<T> Stream<T> {
407 /// Receive the next injected element from the stream.
408 pub async fn recv(&mut self) -> Option<T> {
409 let value = loop {
410 let value = self.rx.recv().await;
411
412 match value {
413 Ok(value) => break value,
414 Err(broadcast::error::RecvError::Lagged { .. }) => continue,
415 _ => return None,
416 };
417 };
418
419 let value = match value {
420 Some(value) => value,
421 _ => return None,
422 };
423
424 // Safety: The expected type parameter is encoded and maintained in the
425 // Stream<T> type.
426 Some(unsafe { value.downcast::<T>() })
427 }
428}
429
430/// An opaque value holder, similar to Any, but can be cloned and relies
431/// entirely on external type information.
432struct Value {
433 data: *const (),
434 // clone function, to use when cloning the value.
435 value_clone_fn: unsafe fn(*const ()) -> *const (),
436 // drop function, to use when dropping the value.
437 value_drop_fn: unsafe fn(*const ()),
438}
439
440impl Value {
441 /// Construct a new opaque value.
442 pub(crate) fn new<T>(data: T) -> Self
443 where
444 T: 'static + Clone + Send + Sync,
445 {
446 return Self {
447 data: Box::into_raw(Box::new(data)) as *const (),
448 value_clone_fn: value_clone_fn::<T>,
449 value_drop_fn: value_drop_fn::<T>,
450 };
451
452 /// Clone implementation for a given value.
453 unsafe fn value_clone_fn<T>(data: *const ()) -> *const ()
454 where
455 T: Clone,
456 {
457 let data = T::clone(&*(data as *const _));
458 Box::into_raw(Box::new(data)) as *const ()
459 }
460
461 /// Drop implementation for a given value.
462 unsafe fn value_drop_fn<T>(value: *const ()) {
463 ptr::drop_in_place(value as *mut () as *mut T)
464 }
465 }
466
467 /// Downcast the given value reference.
468 ///
469 /// # Safety
470 ///
471 /// Assumes that we know the type of the underlying value.
472 pub(crate) unsafe fn downcast_ref<T>(&self) -> &T {
473 &*(self.data as *const T)
474 }
475
476 /// Downcast the given value to a mutable reference.
477 ///
478 /// # Safety
479 ///
480 /// Assumes that we know the type of the underlying value.
481 pub(crate) unsafe fn downcast_mut<T>(&mut self) -> &mut T {
482 &mut *(self.data as *const T as *mut T)
483 }
484
485 /// Downcast the given value.
486 ///
487 /// # Safety
488 ///
489 /// Assumes that we know the correct, underlying type of the value.
490 pub(crate) unsafe fn downcast<T>(self) -> T {
491 let value = Box::from_raw(self.data as *const T as *mut T);
492 mem::forget(self);
493 *value
494 }
495}
496
497/// Safety: Send + Sync bound is enforced in all constructors of `Value`.
498unsafe impl Send for Value {}
499unsafe impl Sync for Value {}
500
501impl Clone for Value {
502 fn clone(&self) -> Self {
503 let data = unsafe { (self.value_clone_fn)(self.data as *const _) };
504
505 Self {
506 data,
507 value_clone_fn: self.value_clone_fn,
508 value_drop_fn: self.value_drop_fn,
509 }
510 }
511}
512
513impl Drop for Value {
514 fn drop(&mut self) {
515 unsafe {
516 (self.value_drop_fn)(self.data);
517 }
518 }
519}
520
521struct Storage {
522 value: Arc<RwLock<Option<Value>>>,
523 tx: broadcast::Sender<Option<Value>>,
524}
525
526impl Default for Storage {
527 fn default() -> Self {
528 let (tx, _) = broadcast::channel(1);
529 Self {
530 value: Arc::new(RwLock::new(None)),
531 tx,
532 }
533 }
534}
535
536struct Inner {
537 storage: RwLock<HashMap<RawKey, Storage>>,
538}
539
540/// An injector of dependencies.
541///
542/// Injectors are defined in hierarchies where an injector is either the root
543/// injector as created using [Injector::new].
544#[derive(Clone)]
545pub struct Injector {
546 inner: Arc<Inner>,
547}
548
549impl Injector {
550 /// Construct and use an [`Injector`].
551 ///
552 /// # Example
553 ///
554 /// ```
555 /// use async_injector::Injector;
556 ///
557 /// # #[tokio::main] async fn main() {
558 /// let injector = Injector::new();
559 ///
560 /// assert_eq!(None, injector.get::<u32>().await);
561 /// injector.update(1u32).await;
562 /// assert_eq!(Some(1u32), injector.get::<u32>().await);
563 /// assert!(injector.clear::<u32>().await.is_some());
564 /// assert_eq!(None, injector.get::<u32>().await);
565 /// # }
566 /// ```
567 ///
568 /// Example using a [`Stream`].
569 ///
570 /// ```
571 /// use async_injector::Injector;
572 ///
573 /// #[derive(Clone)]
574 /// struct Database;
575 ///
576 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
577 /// let injector = Injector::new();
578 ///
579 /// let database = injector.var::<Database>().await;
580 ///
581 /// assert!(database.read().await.is_none());
582 /// injector.update(Database).await;
583 /// assert!(database.read().await.is_some());
584 /// # Ok(()) }
585 /// ```
586 pub fn new() -> Self {
587 Self::default()
588 }
589
590 /// Get a value from the injector.
591 ///
592 /// This will cause the clear to be propagated to all streams set up using
593 /// [`stream`]. And for future calls to [`get`] to return the updated value.
594 ///
595 /// [`stream`]: Injector::stream
596 /// [`get`]: Injector::get
597 ///
598 /// # Examples
599 ///
600 /// ```
601 /// use async_injector::Injector;
602 ///
603 /// # #[tokio::main] async fn main() {
604 /// let injector = Injector::new();
605 ///
606 /// assert_eq!(None, injector.get::<u32>().await);
607 /// injector.update(1u32).await;
608 /// assert_eq!(Some(1u32), injector.get::<u32>().await);
609 /// assert!(injector.clear::<u32>().await.is_some());
610 /// assert_eq!(None, injector.get::<u32>().await);
611 /// # }
612 /// ```
613 pub async fn clear<T>(&self) -> Option<T>
614 where
615 T: Clone + Any + Send + Sync,
616 {
617 self.clear_key(Key::<T>::of()).await
618 }
619
620 /// Clear the given value with the given key.
621 ///
622 /// This will cause the clear to be propagated to all streams set up using
623 /// [`stream`]. And for future calls to [`get`] to return the updated value.
624 ///
625 /// [`stream`]: Injector::stream
626 /// [`get`]: Injector::get
627 ///
628 /// # Examples
629 ///
630 /// ```
631 /// use async_injector::{Key, Injector};
632 ///
633 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
634 /// let injector = Injector::new();
635 /// let k = Key::<u32>::tagged("first")?;
636 ///
637 /// assert_eq!(None, injector.get_key(&k).await);
638 /// injector.update_key(&k, 1u32).await;
639 /// assert_eq!(Some(1u32), injector.get_key(&k).await);
640 /// assert!(injector.clear_key(&k).await.is_some());
641 /// assert_eq!(None, injector.get_key(&k).await);
642 /// # Ok(()) }
643 /// ```
644 pub async fn clear_key<T>(&self, key: impl AsRef<Key<T>>) -> Option<T>
645 where
646 T: Clone + Any + Send + Sync,
647 {
648 let key = key.as_ref().as_raw_key();
649
650 let storage = self.inner.storage.read().await;
651 let storage = storage.get(key)?;
652 let value = storage.value.write().await.take()?;
653 let _ = storage.tx.send(None);
654 Some(unsafe { value.downcast() })
655 }
656
657 /// Set the given value and notify any subscribers.
658 ///
659 /// This will cause the update to be propagated to all streams set up using
660 /// [`stream`]. And for future calls to [`get`] to return the updated value.
661 ///
662 /// [`stream`]: Injector::stream
663 /// [`get`]: Injector::get
664 ///
665 /// # Examples
666 ///
667 /// ```
668 /// use async_injector::Injector;
669 ///
670 /// # #[tokio::main]
671 /// # async fn main() {
672 /// let injector = Injector::new();
673 ///
674 /// assert_eq!(None, injector.get::<u32>().await);
675 /// injector.update(1u32).await;
676 /// assert_eq!(Some(1u32), injector.get::<u32>().await);
677 /// # }
678 /// ```
679 pub async fn update<T>(&self, value: T) -> Option<T>
680 where
681 T: Clone + Any + Send + Sync,
682 {
683 self.update_key(Key::<T>::of(), value).await
684 }
685
686 /// Update the value associated with the given key.
687 ///
688 /// This will cause the update to be propagated to all streams set up using
689 /// [`stream`]. And for future calls to [`get`] to return the updated value.
690 ///
691 /// [`stream`]: Injector::stream
692 /// [`get`]: Injector::get
693 ///
694 /// # Examples
695 ///
696 /// ```
697 /// use async_injector::{Key, Injector};
698 ///
699 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
700 /// let injector = Injector::new();
701 /// let k = Key::<u32>::tagged("first")?;
702 ///
703 /// assert_eq!(None, injector.get_key(&k).await);
704 /// injector.update_key(&k, 1u32).await;
705 /// assert_eq!(Some(1u32), injector.get_key(&k).await);
706 /// # Ok(()) }
707 /// ```
708 pub async fn update_key<T>(&self, key: impl AsRef<Key<T>>, value: T) -> Option<T>
709 where
710 T: Clone + Any + Send + Sync,
711 {
712 let key = key.as_ref().as_raw_key();
713 let value = Value::new(value);
714
715 let mut storage = self.inner.storage.write().await;
716 let storage = storage.entry(key.clone()).or_default();
717 let _ = storage.tx.send(Some(value.clone()));
718 let old = storage.value.write().await.replace(value)?;
719 Some(unsafe { old.downcast() })
720 }
721
722 /// Test if a given value exists by type.
723 ///
724 /// # Examples
725 ///
726 /// ```
727 /// use async_injector::{Key, Injector};
728 ///
729 /// # #[tokio::main] async fn main() {
730 /// let injector = Injector::new();
731 ///
732 /// assert_eq!(false, injector.exists::<u32>().await);
733 /// injector.update(1u32).await;
734 /// assert_eq!(true, injector.exists::<u32>().await);
735 /// # }
736 /// ```
737 pub async fn exists<T>(&self) -> bool
738 where
739 T: Clone + Any + Send + Sync,
740 {
741 self.exists_key(Key::<T>::of()).await
742 }
743
744 /// Test if a given value exists by key.
745 ///
746 /// # Examples
747 ///
748 /// ```
749 /// use async_injector::{Key, Injector};
750 ///
751 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
752 /// let injector = Injector::new();
753 /// let k = Key::<u32>::tagged("first")?;
754 ///
755 /// assert_eq!(false, injector.exists_key(&k).await);
756 /// injector.update_key(&k, 1u32).await;
757 /// assert_eq!(true, injector.exists_key(&k).await);
758 /// # Ok(()) }
759 /// ```
760 pub async fn exists_key<T>(&self, key: impl AsRef<Key<T>>) -> bool
761 where
762 T: Clone + Any + Send + Sync,
763 {
764 let key = key.as_ref().as_raw_key();
765 let storage = self.inner.storage.read().await;
766
767 if let Some(s) = storage.get(key) {
768 s.value.read().await.is_some()
769 } else {
770 false
771 }
772 }
773
774 /// Mutate the given value by type.
775 ///
776 /// # Examples
777 ///
778 /// ```
779 /// use async_injector::Injector;
780 ///
781 /// # #[tokio::main] async fn main() {
782 /// let injector = Injector::new();
783 ///
784 /// injector.update(1u32).await;
785 ///
786 /// let old = injector.mutate(|value: &mut u32| {
787 /// let old = *value;
788 /// *value += 1;
789 /// old
790 /// }).await;
791 ///
792 /// assert_eq!(Some(1u32), old);
793 /// # }
794 /// ```
795 pub async fn mutate<T, M, R>(&self, mutator: M) -> Option<R>
796 where
797 T: Clone + Any + Send + Sync,
798 M: FnMut(&mut T) -> R,
799 {
800 self.mutate_key(Key::<T>::of(), mutator).await
801 }
802
803 /// Mutate the given value by key.
804 ///
805 /// # Examples
806 ///
807 /// ```
808 /// use async_injector::{Key, Injector};
809 ///
810 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
811 /// let injector = Injector::new();
812 /// let k = Key::<u32>::tagged("first")?;
813 ///
814 /// injector.update_key(&k, 1u32).await;
815 ///
816 /// let old = injector.mutate_key(&k, |value| {
817 /// let old = *value;
818 /// *value += 1;
819 /// old
820 /// }).await;
821 ///
822 /// assert_eq!(Some(1u32), old);
823 /// # Ok(()) }
824 /// ```
825 pub async fn mutate_key<T, M, R>(&self, key: impl AsRef<Key<T>>, mut mutator: M) -> Option<R>
826 where
827 T: Clone + Any + Send + Sync,
828 M: FnMut(&mut T) -> R,
829 {
830 let key = key.as_ref().as_raw_key();
831 let storage = self.inner.storage.read().await;
832
833 let storage = match storage.get(key) {
834 Some(s) => s,
835 None => return None,
836 };
837
838 let mut value = storage.value.write().await;
839
840 if let Some(value) = &mut *value {
841 let output = mutator(unsafe { value.downcast_mut() });
842 let value = value.clone();
843 let _ = storage.tx.send(Some(value));
844 return Some(output);
845 }
846
847 None
848 }
849
850 /// Get a value from the injector.
851 ///
852 /// # Examples
853 ///
854 /// ```
855 /// use async_injector::Injector;
856 ///
857 /// # #[tokio::main] async fn main() {
858 /// let injector = Injector::new();
859 ///
860 /// assert_eq!(None, injector.get::<u32>().await);
861 /// injector.update(1u32).await;
862 /// assert_eq!(Some(1u32), injector.get::<u32>().await);
863 /// # }
864 /// ```
865 pub async fn get<T>(&self) -> Option<T>
866 where
867 T: Clone + Any + Send + Sync,
868 {
869 self.get_key(Key::<T>::of()).await
870 }
871
872 /// Get a value from the injector with the given key.
873 ///
874 /// # Examples
875 ///
876 /// ```
877 /// use async_injector::{Key, Injector};
878 ///
879 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
880 /// let k1 = Key::<u32>::tagged("first")?;
881 /// let k2 = Key::<u32>::tagged("second")?;
882 ///
883 /// let injector = Injector::new();
884 ///
885 /// assert_eq!(None, injector.get_key(&k1).await);
886 /// assert_eq!(None, injector.get_key(&k2).await);
887 ///
888 /// injector.update_key(&k1, 1u32).await;
889 ///
890 /// assert_eq!(Some(1u32), injector.get_key(&k1).await);
891 /// assert_eq!(None, injector.get_key(&k2).await);
892 /// # Ok(()) }
893 /// ```
894 pub async fn get_key<T>(&self, key: impl AsRef<Key<T>>) -> Option<T>
895 where
896 T: Clone + Any + Send + Sync,
897 {
898 let key = key.as_ref().as_raw_key();
899 let storage = self.inner.storage.read().await;
900
901 let storage = match storage.get(key) {
902 Some(storage) => storage,
903 None => return None,
904 };
905
906 let value = storage.value.read().await;
907
908 value.as_ref().map(|value| {
909 // Safety: Ref<T> instances can only be produced by checked fns.
910 unsafe { value.downcast_ref::<T>().clone() }
911 })
912 }
913
914 /// Wait for a value to become available.
915 ///
916 /// Note that this could potentially wait forever if the value is never
917 /// injected.
918 ///
919 /// # Examples
920 ///
921 /// ```
922 /// use async_injector::Injector;
923 ///
924 /// # #[tokio::main]
925 /// # async fn main() {
926 /// let injector = Injector::new();
927 ///
928 /// injector.update(1u32).await;
929 /// assert_eq!(1u32, injector.wait::<u32>().await);
930 /// # }
931 /// ```
932 #[inline]
933 pub async fn wait<T>(&self) -> T
934 where
935 T: Clone + Any + Send + Sync,
936 {
937 self.wait_key(Key::<T>::of()).await
938 }
939
940 /// Wait for a value associated with the given key to become available.
941 ///
942 /// Note that this could potentially wait forever if the value is never
943 /// injected.
944 ///
945 /// # Examples
946 ///
947 /// ```
948 /// use async_injector::{Key, Injector};
949 ///
950 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
951 /// let injector = Injector::new();
952 /// let tag = Key::<u32>::tagged("first")?;
953 ///
954 /// injector.update_key(&tag, 1u32).await;
955 /// assert_eq!(1u32, injector.wait_key(tag).await);
956 /// # Ok(()) }
957 /// ```
958 pub async fn wait_key<T>(&self, key: impl AsRef<Key<T>>) -> T
959 where
960 T: Clone + Any + Send + Sync,
961 {
962 let (mut stream, value) = self.stream_key(key).await;
963
964 if let Some(value) = value {
965 return value;
966 }
967
968 loop {
969 if let Some(value) = stream.recv().await {
970 return value;
971 }
972 }
973 }
974
975 /// Get an existing value and setup a stream for updates at the same time.
976 ///
977 /// # Examples
978 ///
979 /// ```
980 /// use async_injector::Injector;
981 ///
982 /// #[derive(Debug, Clone, PartialEq, Eq)]
983 /// struct Database;
984 ///
985 /// # #[tokio::main] async fn main() {
986 /// let injector = Injector::new();
987 /// let (mut database_stream, mut database) = injector.stream::<Database>().await;
988 ///
989 /// // Update the key somewhere else.
990 /// tokio::spawn({
991 /// let injector = injector.clone();
992 ///
993 /// async move {
994 /// injector.update(Database).await;
995 /// }
996 /// });
997 ///
998 /// let database = loop {
999 /// if let Some(update) = database_stream.recv().await {
1000 /// break update;
1001 /// }
1002 /// };
1003 ///
1004 /// assert_eq!(database, Database);
1005 /// # }
1006 /// ```
1007 pub async fn stream<T>(&self) -> (Stream<T>, Option<T>)
1008 where
1009 T: Clone + Any + Send + Sync,
1010 {
1011 self.stream_key(Key::<T>::of()).await
1012 }
1013
1014 /// Get an existing value and setup a stream for updates at the same time.
1015 ///
1016 /// # Examples
1017 ///
1018 /// ```
1019 /// use async_injector::{Key, Injector};
1020 ///
1021 /// #[derive(Debug, Clone, PartialEq, Eq)]
1022 /// struct Database;
1023 ///
1024 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
1025 /// let injector = Injector::new();
1026 /// let db = Key::<Database>::tagged("first")?;
1027 /// let (mut database_stream, mut database) = injector.stream_key(&db).await;
1028 ///
1029 /// // Update the key somewhere else.
1030 /// tokio::spawn({
1031 /// let db = db.clone();
1032 /// let injector = injector.clone();
1033 ///
1034 /// async move {
1035 /// injector.update_key(&db, Database).await;
1036 /// }
1037 /// });
1038 ///
1039 /// let database = loop {
1040 /// if let Some(update) = database_stream.recv().await {
1041 /// break update;
1042 /// }
1043 /// };
1044 ///
1045 /// assert_eq!(database, Database);
1046 /// # Ok(()) }
1047 /// ```
1048 pub async fn stream_key<T>(&self, key: impl AsRef<Key<T>>) -> (Stream<T>, Option<T>)
1049 where
1050 T: Clone + Any + Send + Sync,
1051 {
1052 let key = key.as_ref().as_raw_key();
1053
1054 let mut storage = self.inner.storage.write().await;
1055 let storage = storage.entry(key.clone()).or_default();
1056
1057 let rx = storage.tx.subscribe();
1058
1059 let value = storage.value.read().await;
1060
1061 let value = value.as_ref().map(|value| {
1062 // Safety: The expected type parameter is encoded and maintained
1063 // in the Stream type.
1064 unsafe { value.downcast_ref::<T>().clone() }
1065 });
1066
1067 let stream = Stream {
1068 rx,
1069 marker: marker::PhantomData,
1070 };
1071
1072 (stream, value)
1073 }
1074
1075 /// Get a synchronized reference for the given configuration key.
1076 ///
1077 /// # Examples
1078 ///
1079 /// ```
1080 /// use async_injector::Injector;
1081 ///
1082 /// #[derive(Clone)]
1083 /// struct Database;
1084 ///
1085 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
1086 /// let injector = Injector::new();
1087 ///
1088 /// let database = injector.var::<Database>().await;
1089 ///
1090 /// assert!(database.read().await.is_none());
1091 /// injector.update(Database).await;
1092 /// assert!(database.read().await.is_some());
1093 /// # Ok(()) }
1094 /// ```
1095 pub async fn var<T>(&self) -> Ref<T>
1096 where
1097 T: Clone + Any + Send + Sync + Unpin,
1098 {
1099 self.var_key(Key::<T>::of()).await
1100 }
1101
1102 /// Get a synchronized reference for the given configuration key.
1103 ///
1104 /// # Examples
1105 ///
1106 /// ```
1107 /// use async_injector::{Key, Injector};
1108 /// use std::error::Error;
1109 ///
1110 /// #[derive(Clone)]
1111 /// struct Database;
1112 ///
1113 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn Error>> {
1114 /// let injector = Injector::new();
1115 /// let db = Key::<Database>::tagged("first")?;
1116 ///
1117 /// let database = injector.var_key(&db).await;
1118 ///
1119 /// assert!(database.read().await.is_none());
1120 /// injector.update_key(&db, Database).await;
1121 /// assert!(database.read().await.is_some());
1122 /// # Ok(()) }
1123 /// ```
1124 pub async fn var_key<T>(&self, key: impl AsRef<Key<T>>) -> Ref<T>
1125 where
1126 T: Clone + Any + Send + Sync + Unpin,
1127 {
1128 let key = key.as_ref().as_raw_key();
1129
1130 let mut storage = self.inner.storage.write().await;
1131 let storage = storage.entry(key.clone()).or_default();
1132
1133 Ref {
1134 value: storage.value.clone(),
1135 _m: marker::PhantomData,
1136 }
1137 }
1138}
1139
1140impl Default for Injector {
1141 fn default() -> Self {
1142 Self {
1143 inner: Arc::new(Inner {
1144 storage: Default::default(),
1145 }),
1146 }
1147 }
1148}
1149
1150#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq, Hash)]
1151struct RawKey {
1152 type_id: TypeId,
1153 tag_type_id: TypeId,
1154 tag: hashkey::Key,
1155}
1156
1157impl RawKey {
1158 /// Construct a new raw key.
1159 fn new<T, K>(tag: hashkey::Key) -> Self
1160 where
1161 T: Any,
1162 K: Any,
1163 {
1164 Self {
1165 type_id: TypeId::of::<T>(),
1166 tag_type_id: TypeId::of::<K>(),
1167 tag,
1168 }
1169 }
1170}
1171
1172/// A key used to discriminate a value in the [`Injector`].
1173#[derive(Clone)]
1174pub struct Key<T>
1175where
1176 T: Any,
1177{
1178 raw_key: RawKey,
1179 _marker: marker::PhantomData<T>,
1180}
1181
1182impl<T> fmt::Debug for Key<T>
1183where
1184 T: Any,
1185{
1186 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1187 fmt::Debug::fmt(&self.raw_key, fmt)
1188 }
1189}
1190
1191impl<T> cmp::PartialEq for Key<T>
1192where
1193 T: Any,
1194{
1195 fn eq(&self, other: &Self) -> bool {
1196 self.as_raw_key().eq(other.as_raw_key())
1197 }
1198}
1199
1200impl<T> cmp::Eq for Key<T> where T: Any {}
1201
1202impl<T> cmp::PartialOrd for Key<T>
1203where
1204 T: Any,
1205{
1206 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
1207 Some(self.cmp(other))
1208 }
1209}
1210
1211impl<T> cmp::Ord for Key<T>
1212where
1213 T: Any,
1214{
1215 fn cmp(&self, other: &Self) -> cmp::Ordering {
1216 self.as_raw_key().cmp(other.as_raw_key())
1217 }
1218}
1219
1220impl<T> hash::Hash for Key<T>
1221where
1222 T: Any,
1223{
1224 fn hash<H>(&self, state: &mut H)
1225 where
1226 H: hash::Hasher,
1227 {
1228 self.as_raw_key().hash(state);
1229 }
1230}
1231
1232impl<T> Key<T>
1233where
1234 T: Any,
1235{
1236 /// Construct a new key without a tag.
1237 ///
1238 /// # Examples
1239 ///
1240 /// ```
1241 /// use async_injector::Key;
1242 ///
1243 /// struct Foo;
1244 ///
1245 /// assert_eq!(Key::<Foo>::of(), Key::<Foo>::of());
1246 /// ```
1247 pub fn of() -> Self {
1248 Self {
1249 raw_key: RawKey::new::<T, ()>(hashkey::Key::Unit),
1250 _marker: marker::PhantomData,
1251 }
1252 }
1253
1254 /// Construct a new key.
1255 ///
1256 /// # Examples
1257 ///
1258 /// ```
1259 /// use serde::Serialize;
1260 /// use async_injector::Key;
1261 ///
1262 /// struct Foo;
1263 ///
1264 /// #[derive(Serialize)]
1265 /// enum Tag {
1266 /// One,
1267 /// Two,
1268 /// }
1269 ///
1270 /// #[derive(Serialize)]
1271 /// enum Tag2 {
1272 /// One,
1273 /// Two,
1274 /// }
1275 ///
1276 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1277 /// assert_eq!(Key::<Foo>::tagged(Tag::One)?, Key::<Foo>::tagged(Tag::One)?);
1278 /// assert_ne!(Key::<Foo>::tagged(Tag::One)?, Key::<Foo>::tagged(Tag::Two)?);
1279 /// assert_ne!(Key::<Foo>::tagged(Tag::One)?, Key::<Foo>::tagged(Tag2::One)?);
1280 /// # Ok(()) }
1281 /// ```
1282 pub fn tagged<K>(tag: K) -> Result<Self, Error>
1283 where
1284 K: Any + serde::Serialize,
1285 {
1286 let tag = hashkey::to_key(&tag)?;
1287
1288 Ok(Self {
1289 raw_key: RawKey::new::<T, K>(tag),
1290 _marker: marker::PhantomData,
1291 })
1292 }
1293
1294 /// Convert into a raw key.
1295 fn as_raw_key(&self) -> &RawKey {
1296 &self.raw_key
1297 }
1298}
1299
1300impl<T> AsRef<Key<T>> for Key<T>
1301where
1302 T: 'static,
1303{
1304 fn as_ref(&self) -> &Self {
1305 self
1306 }
1307}
1308
1309/// A variable allowing for the synchronized reading of avalue in the
1310/// [`Injector`].
1311///
1312/// This can be created through [Injector::var] or [Injector::var_key].
1313#[derive(Clone)]
1314pub struct Ref<T>
1315where
1316 T: Clone + Any + Send + Sync,
1317{
1318 value: Arc<RwLock<Option<Value>>>,
1319 _m: marker::PhantomData<T>,
1320}
1321
1322impl<T> Ref<T>
1323where
1324 T: Clone + Any + Send + Sync,
1325{
1326 /// Read the synchronized variable.
1327 ///
1328 /// # Examples
1329 ///
1330 /// ```
1331 /// use async_injector::Injector;
1332 ///
1333 /// #[derive(Clone)]
1334 /// struct Database;
1335 ///
1336 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
1337 /// let injector = Injector::new();
1338 ///
1339 /// let database = injector.var::<Database>().await;
1340 ///
1341 /// assert!(database.read().await.is_none());
1342 /// injector.update(Database).await;
1343 /// assert!(database.read().await.is_some());
1344 /// # Ok(()) }
1345 /// ```
1346 pub async fn read(&self) -> Option<RefReadGuard<'_, T>> {
1347 let value = self.value.read().await;
1348
1349 let result = RefReadGuard::try_map(value, |value| {
1350 value.as_ref().map(|value| {
1351 // Safety: The expected type parameter is encoded and
1352 // maintained in the Stream type.
1353 unsafe { value.downcast_ref::<T>() }
1354 })
1355 });
1356
1357 result.ok()
1358 }
1359
1360 /// Load the synchronized variable. This clones the underlying value if it
1361 /// has been set.
1362 ///
1363 /// # Examples
1364 ///
1365 /// ```
1366 /// use async_injector::Injector;
1367 ///
1368 /// #[derive(Clone)]
1369 /// struct Database;
1370 ///
1371 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
1372 /// let injector = Injector::new();
1373 ///
1374 /// let database = injector.var::<Database>().await;
1375 ///
1376 /// assert!(database.load().await.is_none());
1377 /// injector.update(Database).await;
1378 /// assert!(database.load().await.is_some());
1379 /// # Ok(()) }
1380 /// ```
1381 pub async fn load(&self) -> Option<T> {
1382 let value = self.value.read().await;
1383
1384 value.as_ref().map(|value| {
1385 // Safety: The expected type parameter is encoded and maintained
1386 // in the Stream type.
1387 unsafe { value.downcast_ref::<T>().clone() }
1388 })
1389 }
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394 use super::Value;
1395
1396 #[test]
1397 fn test_clone() {
1398 use std::sync::{
1399 atomic::{AtomicUsize, Ordering},
1400 Arc,
1401 };
1402
1403 let count = Arc::new(AtomicUsize::new(0));
1404
1405 let value = Value::new(Foo(count.clone()));
1406 assert_eq!(0, count.load(Ordering::SeqCst));
1407 drop(value.clone());
1408 assert_eq!(1, count.load(Ordering::SeqCst));
1409 drop(value);
1410 assert_eq!(2, count.load(Ordering::SeqCst));
1411
1412 #[derive(Clone)]
1413 struct Foo(Arc<AtomicUsize>);
1414
1415 impl Drop for Foo {
1416 fn drop(&mut self) {
1417 self.0.fetch_add(1, Ordering::SeqCst);
1418 }
1419 }
1420 }
1421}