beid_api/
di.rs

1//! Inversion of control.
2
3use std::fmt;
4use std::marker::PhantomData;
5use std::ops::{Deref, DerefMut};
6use std::panic::Location;
7use std::sync::{Arc, OnceLock};
8
9use http::Extensions;
10use thiserror::Error;
11use tracing::warn;
12
13/// Wrapper type for managing dependency.
14#[derive(Debug)]
15pub struct Dep<T: 'static + ?Sized>(DepInner<T>);
16
17// implementation detail of Dep
18#[derive(Debug, strum::IntoStaticStr)]
19enum DepInner<T: 'static + ?Sized> {
20    Arc(Arc<T>),
21    LazyArc(OnceLock<Arc<T>>),
22}
23
24impl<T: Sized> Dep<T> {
25    /// Create a new dependency.
26    pub fn new(val: T) -> Self {
27        Self::new_arc(Arc::new(val))
28    }
29}
30
31impl<T: ?Sized> Dep<T> {
32    pub fn new_arc(arc: Arc<T>) -> Self {
33        Self(DepInner::Arc(arc))
34    }
35
36    pub fn lazy() -> Self {
37        Self(DepInner::LazyArc(OnceLock::new()))
38    }
39
40    pub fn try_as_ref(this: &Self) -> Result<&T, AsRefError<T>> {
41        match &this.0 {
42            DepInner::Arc(arc) => Ok(arc),
43            DepInner::LazyArc(cell) => cell.get().map(Arc::as_ref).ok_or_else(AsRefError::new),
44        }
45    }
46
47    #[track_caller]
48    pub fn bind(src: &Self, dst: &Self) {
49        if let Err(err) = Self::try_bind(src, dst) {
50            handle_bind_error::<T>(err)
51        }
52    }
53
54    pub fn try_bind(src: &Self, dst: &Self) -> Result<(), BindError> {
55        use BindError::*;
56        match (&src.0, &dst.0) {
57            (DepInner::LazyArc(src_cell), DepInner::LazyArc(dst_cell)) => {
58                let src_arc = src_cell.get().ok_or(UninitializedSourceCell)?.clone();
59                dst_cell
60                    .set(src_arc)
61                    .map_err(|_| InitializedDestinationCell)?;
62            }
63            (DepInner::Arc(src_arc), DepInner::LazyArc(dst_cell)) => {
64                dst_cell
65                    .set(src_arc.clone())
66                    .map_err(|_| InitializedDestinationCell)?;
67            }
68            _ => {
69                return Err(IncompatibleVariant {
70                    src: From::from(&src.0),
71                    dst: From::from(&dst.0),
72                })
73            }
74        }
75        Ok(())
76    }
77
78    /// Returns `true` if `this` is initialized.
79    pub fn is_initialized(this: &Self) -> bool {
80        match &this.0 {
81            DepInner::Arc(..) => true,
82            DepInner::LazyArc(cell) => cell.get().is_some(),
83        }
84    }
85
86    pub fn assert_initialized(this: &Self) {
87        assert!(Self::is_initialized(this), "cell is uninitialized")
88    }
89
90    pub fn as_arc(this: &Self) -> Option<&Arc<T>> {
91        let arc = match &this.0 {
92            DepInner::Arc(arc) => arc,
93            DepInner::LazyArc(cell) => cell.get()?,
94        };
95        Some(arc)
96    }
97
98    pub fn try_with<F, R>(&self, f: F) -> Result<R, ()>
99    where
100        F: FnOnce(&T) -> R,
101    {
102        Dep::try_as_ref(self).map(f).map_err(|_| ())
103    }
104}
105
106#[track_caller]
107fn handle_bind_error<T: ?Sized>(err: BindError) {
108    match err {
109        BindError::InitializedDestinationCell => {
110            let caller = Location::caller();
111            warn!(
112                "Bind already initialized instance of {} at {file}:{line}",
113                std::any::type_name::<T>(),
114                file = caller.file(),
115                line = caller.line(),
116            )
117        }
118        err => {
119            panic!("BindError: {}", err);
120        }
121    }
122}
123
124impl<T: ?Sized> Clone for DepInner<T> {
125    fn clone(&self) -> Self {
126        match self {
127            DepInner::Arc(arc) => DepInner::Arc(arc.clone()),
128            DepInner::LazyArc(cell) => DepInner::LazyArc(cell.clone()),
129        }
130    }
131}
132
133impl<T: ?Sized> Clone for Dep<T> {
134    fn clone(&self) -> Self {
135        Self(self.0.clone())
136    }
137}
138
139impl<T> From<T> for Dep<T> {
140    fn from(val: T) -> Self {
141        Self(DepInner::Arc(Arc::new(val)))
142    }
143}
144
145impl<T: ?Sized> From<Arc<T>> for Dep<T> {
146    fn from(val: Arc<T>) -> Self {
147        Self(DepInner::Arc(val))
148    }
149}
150
151impl<T: ?Sized> Deref for Dep<T> {
152    type Target = T;
153
154    fn deref(&self) -> &Self::Target {
155        Dep::try_as_ref(self).expect("initialized dependency")
156    }
157}
158
159#[derive(Error)]
160pub enum BindError {
161    #[error("destination cell is already initialized")]
162    InitializedDestinationCell,
163    #[error("source cell is uninitialized")]
164    UninitializedSourceCell,
165    #[error("incompatible variant, src variant: {src}, dst variant: {dst}")]
166    IncompatibleVariant {
167        src: &'static str,
168        dst: &'static str,
169    },
170}
171
172impl fmt::Debug for BindError {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        fmt::Display::fmt(self, f)
175    }
176}
177
178#[derive(Error)]
179#[error("Dependency of type {} is uninitialized", std::any::type_name::<T>())]
180pub struct AsRefError<T: ?Sized>(PhantomData<T>);
181
182impl<T: ?Sized> AsRefError<T> {
183    fn new() -> Self {
184        Self(PhantomData)
185    }
186}
187
188impl<T: ?Sized> fmt::Debug for AsRefError<T> {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        fmt::Display::fmt(self, f)
191    }
192}
193
194/// Support late dependency binding at runtime.
195pub trait BindDep {
196    fn bind_dep(&self, map: &TypeMap);
197}
198
199/// A type map of dependencies.
200#[derive(Default)]
201pub struct TypeMap(Extensions);
202
203impl TypeMap {
204    pub fn new() -> Self {
205        Default::default()
206    }
207
208    /// Get a reference to a type previously inserted on this Map.
209    ///
210    /// panic if an instance of type doesn't exist.
211    pub fn get_instance<T: Send + Sync + 'static>(&self) -> &T {
212        self.0.get().unwrap_or_else(|| {
213            panic!(
214                r##"Not found type: "{}" in TypeMap"##,
215                std::any::type_name::<T>()
216            );
217        })
218    }
219
220    #[track_caller]
221    pub fn bind_instance<T: Send + Sync + 'static>(&self, target: &Dep<T>) {
222        let source: &Dep<T> = self.get_instance();
223        if let Err(err) = Dep::try_bind(source, target) {
224            handle_bind_error::<T>(err);
225        }
226    }
227
228    /// Get a reference to inner extensions.
229    pub fn extensions(&self) -> &Extensions {
230        &self.0
231    }
232}
233
234impl From<Extensions> for TypeMap {
235    fn from(ext: Extensions) -> Self {
236        Self(ext)
237    }
238}
239
240impl Deref for TypeMap {
241    type Target = Extensions;
242
243    fn deref(&self) -> &Self::Target {
244        &self.0
245    }
246}
247
248impl DerefMut for TypeMap {
249    fn deref_mut(&mut self) -> &mut Self::Target {
250        &mut self.0
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    #[should_panic]
260    fn test_assert_initialized_lazy_arc() {
261        let a = Dep::<()>::lazy();
262        Dep::assert_initialized(&a);
263    }
264
265    #[test]
266    fn test_cyclic_dependency() {
267        struct Foo {
268            bar: Dep<Bar>,
269        }
270
271        impl BindDep for Foo {
272            fn bind_dep(&self, map: &TypeMap) {
273                map.bind_instance(&self.bar);
274            }
275        }
276
277        struct Bar {
278            foo: Dep<Foo>,
279        }
280
281        let foo = Dep::new(Foo { bar: Dep::lazy() });
282        let bar = Dep::new(Bar { foo: foo.clone() });
283
284        let mut map = TypeMap::new();
285        map.insert(bar.clone());
286        foo.bind_dep(&map);
287        Dep::assert_initialized(&foo.bar);
288        Dep::assert_initialized(&bar.foo);
289    }
290}