1use std::fmt;
4use std::marker::PhantomData;
5use std::ops::{Deref, DerefMut};
6use std::panic::Location;
7use std::sync::{Arc, OnceLock};
8
9use http::Extensions;
10use thiserror::Error;
11use tracing::warn;
12
13#[derive(Debug)]
15pub struct Dep<T: 'static + ?Sized>(DepInner<T>);
16
17#[derive(Debug, strum::IntoStaticStr)]
19enum DepInner<T: 'static + ?Sized> {
20 Arc(Arc<T>),
21 LazyArc(OnceLock<Arc<T>>),
22}
23
24impl<T: Sized> Dep<T> {
25 pub fn new(val: T) -> Self {
27 Self::new_arc(Arc::new(val))
28 }
29}
30
31impl<T: ?Sized> Dep<T> {
32 pub fn new_arc(arc: Arc<T>) -> Self {
33 Self(DepInner::Arc(arc))
34 }
35
36 pub fn lazy() -> Self {
37 Self(DepInner::LazyArc(OnceLock::new()))
38 }
39
40 pub fn try_as_ref(this: &Self) -> Result<&T, AsRefError<T>> {
41 match &this.0 {
42 DepInner::Arc(arc) => Ok(arc),
43 DepInner::LazyArc(cell) => cell.get().map(Arc::as_ref).ok_or_else(AsRefError::new),
44 }
45 }
46
47 #[track_caller]
48 pub fn bind(src: &Self, dst: &Self) {
49 if let Err(err) = Self::try_bind(src, dst) {
50 handle_bind_error::<T>(err)
51 }
52 }
53
54 pub fn try_bind(src: &Self, dst: &Self) -> Result<(), BindError> {
55 use BindError::*;
56 match (&src.0, &dst.0) {
57 (DepInner::LazyArc(src_cell), DepInner::LazyArc(dst_cell)) => {
58 let src_arc = src_cell.get().ok_or(UninitializedSourceCell)?.clone();
59 dst_cell
60 .set(src_arc)
61 .map_err(|_| InitializedDestinationCell)?;
62 }
63 (DepInner::Arc(src_arc), DepInner::LazyArc(dst_cell)) => {
64 dst_cell
65 .set(src_arc.clone())
66 .map_err(|_| InitializedDestinationCell)?;
67 }
68 _ => {
69 return Err(IncompatibleVariant {
70 src: From::from(&src.0),
71 dst: From::from(&dst.0),
72 })
73 }
74 }
75 Ok(())
76 }
77
78 pub fn is_initialized(this: &Self) -> bool {
80 match &this.0 {
81 DepInner::Arc(..) => true,
82 DepInner::LazyArc(cell) => cell.get().is_some(),
83 }
84 }
85
86 pub fn assert_initialized(this: &Self) {
87 assert!(Self::is_initialized(this), "cell is uninitialized")
88 }
89
90 pub fn as_arc(this: &Self) -> Option<&Arc<T>> {
91 let arc = match &this.0 {
92 DepInner::Arc(arc) => arc,
93 DepInner::LazyArc(cell) => cell.get()?,
94 };
95 Some(arc)
96 }
97
98 #[allow(clippy::result_unit_err)]
99 pub fn try_with<F, R>(&self, f: F) -> Result<R, ()>
100 where
101 F: FnOnce(&T) -> R,
102 {
103 Dep::try_as_ref(self).map(f).map_err(|_| ())
104 }
105}
106
107#[track_caller]
108fn handle_bind_error<T: ?Sized>(err: BindError) {
109 match err {
110 BindError::InitializedDestinationCell => {
111 let caller = Location::caller();
112 warn!(
113 "Bind already initialized instance of {} at {file}:{line}",
114 std::any::type_name::<T>(),
115 file = caller.file(),
116 line = caller.line(),
117 )
118 }
119 err => {
120 panic!("BindError: {}", err);
121 }
122 }
123}
124
125impl<T: ?Sized> Clone for DepInner<T> {
126 fn clone(&self) -> Self {
127 match self {
128 DepInner::Arc(arc) => DepInner::Arc(arc.clone()),
129 DepInner::LazyArc(cell) => DepInner::LazyArc(cell.clone()),
130 }
131 }
132}
133
134impl<T: ?Sized> Clone for Dep<T> {
135 fn clone(&self) -> Self {
136 Self(self.0.clone())
137 }
138}
139
140impl<T> From<T> for Dep<T> {
141 fn from(val: T) -> Self {
142 Self(DepInner::Arc(Arc::new(val)))
143 }
144}
145
146impl<T: ?Sized> From<Arc<T>> for Dep<T> {
147 fn from(val: Arc<T>) -> Self {
148 Self(DepInner::Arc(val))
149 }
150}
151
152impl<T: ?Sized> Deref for Dep<T> {
153 type Target = T;
154
155 fn deref(&self) -> &Self::Target {
156 Dep::try_as_ref(self).expect("initialized dependency")
157 }
158}
159
160#[derive(Error)]
161pub enum BindError {
162 #[error("destination cell is already initialized")]
163 InitializedDestinationCell,
164 #[error("source cell is uninitialized")]
165 UninitializedSourceCell,
166 #[error("incompatible variant, src variant: {src}, dst variant: {dst}")]
167 IncompatibleVariant {
168 src: &'static str,
169 dst: &'static str,
170 },
171}
172
173impl fmt::Debug for BindError {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 fmt::Display::fmt(self, f)
176 }
177}
178
179#[derive(Error)]
180#[error("Dependency of type {} is uninitialized", std::any::type_name::<T>())]
181pub struct AsRefError<T: ?Sized>(PhantomData<T>);
182
183impl<T: ?Sized> AsRefError<T> {
184 fn new() -> Self {
185 Self(PhantomData)
186 }
187}
188
189impl<T: ?Sized> fmt::Debug for AsRefError<T> {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 fmt::Display::fmt(self, f)
192 }
193}
194
195pub trait BindDep {
197 fn bind_dep(&self, map: &TypeMap);
198}
199
200#[derive(Default)]
202pub struct TypeMap(Extensions);
203
204impl TypeMap {
205 pub fn new() -> Self {
206 Default::default()
207 }
208
209 pub fn get_instance<T: Send + Sync + 'static>(&self) -> &T {
213 self.0.get().unwrap_or_else(|| {
214 panic!(
215 r##"Not found type: "{}" in TypeMap"##,
216 std::any::type_name::<T>()
217 );
218 })
219 }
220
221 #[track_caller]
222 pub fn bind_instance<T: Send + Sync + 'static>(&self, target: &Dep<T>) {
223 let source: &Dep<T> = self.get_instance();
224 if let Err(err) = Dep::try_bind(source, target) {
225 handle_bind_error::<T>(err);
226 }
227 }
228
229 pub fn extensions(&self) -> &Extensions {
231 &self.0
232 }
233}
234
235impl From<Extensions> for TypeMap {
236 fn from(ext: Extensions) -> Self {
237 Self(ext)
238 }
239}
240
241impl Deref for TypeMap {
242 type Target = Extensions;
243
244 fn deref(&self) -> &Self::Target {
245 &self.0
246 }
247}
248
249impl DerefMut for TypeMap {
250 fn deref_mut(&mut self) -> &mut Self::Target {
251 &mut self.0
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 #[should_panic]
261 fn test_assert_initialized_lazy_arc() {
262 let a = Dep::<()>::lazy();
263 Dep::assert_initialized(&a);
264 }
265
266 #[test]
267 fn test_cyclic_dependency() {
268 struct Foo {
269 bar: Dep<Bar>,
270 }
271
272 impl BindDep for Foo {
273 fn bind_dep(&self, map: &TypeMap) {
274 map.bind_instance(&self.bar);
275 }
276 }
277
278 struct Bar {
279 foo: Dep<Foo>,
280 }
281
282 let foo = Dep::new(Foo { bar: Dep::lazy() });
283 let bar = Dep::new(Bar { foo: foo.clone() });
284
285 let mut map = TypeMap::new();
286 map.insert(bar.clone());
287 foo.bind_dep(&map);
288 Dep::assert_initialized(&foo.bar);
289 Dep::assert_initialized(&bar.foo);
290 }
291}