thruster_jab/
lib.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4};
5
6#[macro_export]
7macro_rules! provide {
8    ($jab_state:expr, dyn $trait:tt, $value:expr) => {
9        let _temp: std::boxed::Box<dyn $trait + Send + Sync> = Box::new($value);
10
11        $jab_state.put(_temp);
12    };
13    ($jab_state:expr, $trait:ty, $value:expr) => {
14        let _temp: std::boxed::Box<$trait> = Box::new($value);
15
16        $jab_state.put(_temp);
17    };
18    ($jab_state:expr, $value:expr) => {
19        $jab_state.put(Box::new($value));
20    };
21}
22
23#[macro_export]
24macro_rules! fetch {
25    ($jab_state:expr, dyn $trait:tt) => {
26        $jab_state.get::<Box<dyn $trait + Send + Sync>>()
27    };
28    ($jab_state:expr, $trait:ty) => {
29        $jab_state.get::<Box<$trait>>()
30    };
31}
32
33trait JabStateWithDI {
34    fn get_mut<'a>(&'a mut self) -> &'a mut JabDI;
35    fn get<'a>(&'a self) -> &'a JabDI;
36}
37
38#[derive(Debug, Default)]
39pub struct JabDI {
40    dep_map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
41}
42
43impl JabDI {
44    pub fn put<T: 'static + Send + Sync>(&mut self, val: T) {
45        self.dep_map
46            .insert(TypeId::of::<T>(), Box::new(Box::new(val)));
47    }
48
49    pub fn get<T: 'static + ?Sized>(&self) -> &T {
50        if let Some(v) = self.try_get() {
51            v
52        } else {
53            panic!("Could not find requested type");
54        }
55    }
56
57    pub fn try_get<T: 'static + ?Sized>(&self) -> Option<&T> {
58        if let Some(dep) = self.dep_map.get(&TypeId::of::<T>()) {
59            if let Some(val) = dep.downcast_ref::<Box<T>>() {
60                return Some(val);
61            }
62        }
63
64        None
65    }
66
67    pub fn get_mut<T: 'static>(&mut self) -> &mut T {
68        if let Some(v) = self.try_get_mut() {
69            v
70        } else {
71            panic!("Could not find requested type");
72        }
73    }
74
75    pub fn try_get_mut<T: 'static>(&mut self) -> Option<&mut T> {
76        if let Some(dep) = self.dep_map.get_mut(&TypeId::of::<T>()) {
77            if let Some(val) = Box::new(dep).downcast_mut::<T>() {
78                return Some(val);
79            }
80        }
81
82        None
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use crate::JabDI;
89
90    #[derive(Debug, PartialEq)]
91    struct A(i32);
92
93    #[derive(Debug, PartialEq)]
94    struct B(i32);
95
96    trait C {
97        fn valc(&self) -> i32;
98    }
99    trait D {
100        fn vald(&self) -> i32;
101    }
102
103    impl C for A {
104        fn valc(&self) -> i32 {
105            self.0
106        }
107    }
108
109    impl D for B {
110        fn vald(&self) -> i32 {
111            self.0
112        }
113    }
114
115    #[test]
116    fn test_get_struct() {
117        let mut jab = JabDI::default();
118
119        let a = A(0);
120        let b = B(1);
121
122        provide!(jab, a);
123        provide!(jab, b);
124
125        assert_eq!(
126            0,
127            fetch!(jab, A).0,
128            "it should correctly find struct A for struct A"
129        );
130
131        assert_eq!(
132            1,
133            fetch!(jab, B).0,
134            "it should correctly find struct B for struct B"
135        );
136    }
137
138    #[test]
139    fn test_get_trait() {
140        let mut jab = JabDI::default();
141
142        let a = A(0);
143        let b = B(1);
144
145        provide!(jab, dyn C, a);
146        provide!(jab, dyn D, b);
147
148        assert_eq!(
149            0,
150            fetch!(jab, dyn C).valc(),
151            "it should correctly find struct A for trait C"
152        );
153
154        assert_eq!(
155            1,
156            fetch!(jab, dyn D).vald(),
157            "it should correctly find struct B for trait D"
158        );
159    }
160}