1use std::fmt;
4use std::marker::PhantomData;
5use std::ops::{Deref, DerefMut};
6use std::panic::Location;
7use std::sync::{Arc, OnceLock};
8
9use http::Extensions;
10use thiserror::Error;
11use tracing::warn;
12
13#[derive(Debug)]
15pub struct Dep<T: 'static + ?Sized>(DepInner<T>);
16
17#[derive(Debug, strum::IntoStaticStr)]
19enum DepInner<T: 'static + ?Sized> {
20 Arc(Arc<T>),
21 LazyArc(OnceLock<Arc<T>>),
22}
23
24impl<T: Sized> Dep<T> {
25 pub fn new(val: T) -> Self {
27 Self::new_arc(Arc::new(val))
28 }
29}
30
31impl<T: ?Sized> Dep<T> {
32 pub fn new_arc(arc: Arc<T>) -> Self {
33 Self(DepInner::Arc(arc))
34 }
35
36 pub fn lazy() -> Self {
37 Self(DepInner::LazyArc(OnceLock::new()))
38 }
39
40 pub fn try_as_ref(this: &Self) -> Result<&T, AsRefError<T>> {
41 match &this.0 {
42 DepInner::Arc(arc) => Ok(arc),
43 DepInner::LazyArc(cell) => cell.get().map(Arc::as_ref).ok_or_else(AsRefError::new),
44 }
45 }
46
47 #[track_caller]
48 pub fn bind(src: &Self, dst: &Self) {
49 if let Err(err) = Self::try_bind(src, dst) {
50 handle_bind_error::<T>(err)
51 }
52 }
53
54 pub fn try_bind(src: &Self, dst: &Self) -> Result<(), BindError> {
55 use BindError::*;
56 match (&src.0, &dst.0) {
57 (DepInner::LazyArc(src_cell), DepInner::LazyArc(dst_cell)) => {
58 let src_arc = src_cell.get().ok_or(UninitializedSourceCell)?.clone();
59 dst_cell
60 .set(src_arc)
61 .map_err(|_| InitializedDestinationCell)?;
62 }
63 (DepInner::Arc(src_arc), DepInner::LazyArc(dst_cell)) => {
64 dst_cell
65 .set(src_arc.clone())
66 .map_err(|_| InitializedDestinationCell)?;
67 }
68 _ => {
69 return Err(IncompatibleVariant {
70 src: From::from(&src.0),
71 dst: From::from(&dst.0),
72 })
73 }
74 }
75 Ok(())
76 }
77
78 pub fn is_initialized(this: &Self) -> bool {
80 match &this.0 {
81 DepInner::Arc(..) => true,
82 DepInner::LazyArc(cell) => cell.get().is_some(),
83 }
84 }
85
86 pub fn assert_initialized(this: &Self) {
87 assert!(Self::is_initialized(this), "cell is uninitialized")
88 }
89
90 pub fn as_arc(this: &Self) -> Option<&Arc<T>> {
91 let arc = match &this.0 {
92 DepInner::Arc(arc) => arc,
93 DepInner::LazyArc(cell) => cell.get()?,
94 };
95 Some(arc)
96 }
97
98 pub fn try_with<F, R>(&self, f: F) -> Result<R, ()>
99 where
100 F: FnOnce(&T) -> R,
101 {
102 Dep::try_as_ref(self).map(f).map_err(|_| ())
103 }
104}
105
106#[track_caller]
107fn handle_bind_error<T: ?Sized>(err: BindError) {
108 match err {
109 BindError::InitializedDestinationCell => {
110 let caller = Location::caller();
111 warn!(
112 "Bind already initialized instance of {} at {file}:{line}",
113 std::any::type_name::<T>(),
114 file = caller.file(),
115 line = caller.line(),
116 )
117 }
118 err => {
119 panic!("BindError: {}", err);
120 }
121 }
122}
123
124impl<T: ?Sized> Clone for DepInner<T> {
125 fn clone(&self) -> Self {
126 match self {
127 DepInner::Arc(arc) => DepInner::Arc(arc.clone()),
128 DepInner::LazyArc(cell) => DepInner::LazyArc(cell.clone()),
129 }
130 }
131}
132
133impl<T: ?Sized> Clone for Dep<T> {
134 fn clone(&self) -> Self {
135 Self(self.0.clone())
136 }
137}
138
139impl<T> From<T> for Dep<T> {
140 fn from(val: T) -> Self {
141 Self(DepInner::Arc(Arc::new(val)))
142 }
143}
144
145impl<T: ?Sized> From<Arc<T>> for Dep<T> {
146 fn from(val: Arc<T>) -> Self {
147 Self(DepInner::Arc(val))
148 }
149}
150
151impl<T: ?Sized> Deref for Dep<T> {
152 type Target = T;
153
154 fn deref(&self) -> &Self::Target {
155 Dep::try_as_ref(self).expect("initialized dependency")
156 }
157}
158
159#[derive(Error)]
160pub enum BindError {
161 #[error("destination cell is already initialized")]
162 InitializedDestinationCell,
163 #[error("source cell is uninitialized")]
164 UninitializedSourceCell,
165 #[error("incompatible variant, src variant: {src}, dst variant: {dst}")]
166 IncompatibleVariant {
167 src: &'static str,
168 dst: &'static str,
169 },
170}
171
172impl fmt::Debug for BindError {
173 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174 fmt::Display::fmt(self, f)
175 }
176}
177
178#[derive(Error)]
179#[error("Dependency of type {} is uninitialized", std::any::type_name::<T>())]
180pub struct AsRefError<T: ?Sized>(PhantomData<T>);
181
182impl<T: ?Sized> AsRefError<T> {
183 fn new() -> Self {
184 Self(PhantomData)
185 }
186}
187
188impl<T: ?Sized> fmt::Debug for AsRefError<T> {
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 fmt::Display::fmt(self, f)
191 }
192}
193
194pub trait BindDep {
196 fn bind_dep(&self, map: &TypeMap);
197}
198
199#[derive(Default)]
201pub struct TypeMap(Extensions);
202
203impl TypeMap {
204 pub fn new() -> Self {
205 Default::default()
206 }
207
208 pub fn get_instance<T: Send + Sync + 'static>(&self) -> &T {
212 self.0.get().unwrap_or_else(|| {
213 panic!(
214 r##"Not found type: "{}" in TypeMap"##,
215 std::any::type_name::<T>()
216 );
217 })
218 }
219
220 #[track_caller]
221 pub fn bind_instance<T: Send + Sync + 'static>(&self, target: &Dep<T>) {
222 let source: &Dep<T> = self.get_instance();
223 if let Err(err) = Dep::try_bind(source, target) {
224 handle_bind_error::<T>(err);
225 }
226 }
227
228 pub fn extensions(&self) -> &Extensions {
230 &self.0
231 }
232}
233
234impl From<Extensions> for TypeMap {
235 fn from(ext: Extensions) -> Self {
236 Self(ext)
237 }
238}
239
240impl Deref for TypeMap {
241 type Target = Extensions;
242
243 fn deref(&self) -> &Self::Target {
244 &self.0
245 }
246}
247
248impl DerefMut for TypeMap {
249 fn deref_mut(&mut self) -> &mut Self::Target {
250 &mut self.0
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 #[should_panic]
260 fn test_assert_initialized_lazy_arc() {
261 let a = Dep::<()>::lazy();
262 Dep::assert_initialized(&a);
263 }
264
265 #[test]
266 fn test_cyclic_dependency() {
267 struct Foo {
268 bar: Dep<Bar>,
269 }
270
271 impl BindDep for Foo {
272 fn bind_dep(&self, map: &TypeMap) {
273 map.bind_instance(&self.bar);
274 }
275 }
276
277 struct Bar {
278 foo: Dep<Foo>,
279 }
280
281 let foo = Dep::new(Foo { bar: Dep::lazy() });
282 let bar = Dep::new(Bar { foo: foo.clone() });
283
284 let mut map = TypeMap::new();
285 map.insert(bar.clone());
286 foo.bind_dep(&map);
287 Dep::assert_initialized(&foo.bar);
288 Dep::assert_initialized(&bar.foo);
289 }
290}