Skip to main content

hive_router_plan_executor/plugins/
plugin_context.rs

1use std::{
2    any::{Any, TypeId},
3    ops::{Deref, DerefMut},
4    sync::Arc,
5};
6
7use dashmap::{
8    mapref::one::{Ref, RefMut},
9    DashMap,
10};
11use http::Uri;
12use ntex::router::Path;
13use ntex::{http::HeaderMap, web::HttpRequest};
14
15use crate::plugin_trait::RouterPluginBoxed;
16
17pub struct RouterHttpRequest<'req> {
18    pub uri: &'req Uri,
19    pub method: &'req http::Method,
20    pub version: http::Version,
21    pub headers: &'req HeaderMap,
22    pub path: &'req str,
23    pub query_string: &'req str,
24    pub match_info: &'req Path<Uri>,
25}
26
27impl<'a> From<&'a HttpRequest> for RouterHttpRequest<'a> {
28    fn from(req: &'a HttpRequest) -> Self {
29        RouterHttpRequest {
30            uri: req.uri(),
31            method: req.method(),
32            version: req.version(),
33            headers: req.headers(),
34            match_info: req.match_info(),
35            query_string: req.query_string(),
36            path: req.path(),
37        }
38    }
39}
40
41#[derive(Default)]
42pub struct PluginContext {
43    inner: DashMap<TypeId, Box<dyn Any + Send + Sync>>,
44}
45
46pub struct PluginContextRefEntry<'a, T> {
47    pub entry: Ref<'a, TypeId, Box<dyn Any + Send + Sync>>,
48    phantom: std::marker::PhantomData<T>,
49}
50
51impl<'a, T: Any + Send + Sync> AsRef<T> for PluginContextRefEntry<'a, T> {
52    fn as_ref(&self) -> &T {
53        let boxed_any = self.entry.value();
54        boxed_any
55            .downcast_ref::<T>()
56            .expect("type mismatch in PluginContextRefEntry")
57    }
58}
59
60impl<'a, T: Any + Send + Sync> Deref for PluginContextRefEntry<'a, T> {
61    type Target = T;
62    fn deref(&self) -> &Self::Target {
63        self.as_ref()
64    }
65}
66
67pub struct PluginContextMutEntry<'a, T> {
68    pub entry: RefMut<'a, TypeId, Box<dyn Any + Send + Sync>>,
69    phantom: std::marker::PhantomData<T>,
70}
71
72impl<'a, T: Any + Send + Sync> AsRef<T> for PluginContextMutEntry<'a, T> {
73    fn as_ref(&self) -> &T {
74        let boxed_any = self.entry.value();
75        boxed_any
76            .downcast_ref::<T>()
77            .expect("type mismatch in PluginContextMutEntry")
78    }
79}
80
81impl<'a, T: Any + Send + Sync> Deref for PluginContextMutEntry<'a, T> {
82    type Target = T;
83    fn deref(&self) -> &Self::Target {
84        self.as_ref()
85    }
86}
87
88impl<'a, T: Any + Send + Sync> AsMut<T> for PluginContextMutEntry<'a, T> {
89    fn as_mut(&mut self) -> &mut T {
90        let boxed_any = self.entry.value_mut();
91        boxed_any
92            .downcast_mut::<T>()
93            .expect("type mismatch in PluginContextMutEntry")
94    }
95}
96
97impl<'a, T: Any + Send + Sync> DerefMut for PluginContextMutEntry<'a, T> {
98    fn deref_mut(&mut self) -> &mut T {
99        self.as_mut()
100    }
101}
102
103impl PluginContext {
104    /// Check if the context contains an entry of type T.
105    ///
106    /// This can be used by plugins to check for the presence of other plugins' context entries before trying to access them.
107    ///
108    /// Example:
109    /// ```
110    /// struct ContextData {
111    ///     pub greetings: String
112    /// }
113    ///
114    /// if payload.context.contains::<ContextData>() {
115    ///     /// safe to access ContextData entry
116    /// }
117    /// ```
118    pub fn contains<T: Any + Send + Sync>(&self) -> bool {
119        let type_id = TypeId::of::<T>();
120        self.inner.contains_key(&type_id)
121    }
122    /// Insert a value of type T into the context.
123    /// If an entry of that type already exists, it will be replaced and the old value will be returned.
124    ///
125    /// Example:
126    /// ```
127    /// struct ContextData {
128    ///     pub greetings: String
129    /// }
130    ///
131    /// let context_data = ContextData {
132    ///     greetings: "Hello from context!".to_string()
133    /// };
134    ///
135    /// payload.context.insert(context_data);
136    ///
137    /// ```
138    pub fn insert<T: Any + Send + Sync>(&self, value: T) -> Option<Box<T>> {
139        let type_id = TypeId::of::<T>();
140        self.inner
141            .insert(type_id, Box::new(value))
142            .and_then(|boxed_any| boxed_any.downcast::<T>().ok())
143    }
144    /// Get an immutable reference to the entry of type T in the context, if it exists.
145    /// If no entry of that type exists, None is returned.
146    ///
147    /// Example:
148    /// ```
149    /// struct ContextData {
150    ///     pub greetings: String
151    /// }
152    ///
153    /// let context_data = ContextData {
154    ///     greetings: "Hello from context!".to_string()
155    /// };
156    ///
157    /// payload.context.insert(context_data);
158    ///
159    /// let context_data_entry = payload.context.get_ref::<ContextData>();
160    /// if let Some(ref context_data) = context_data_entry {
161    ///    println!("{}", context_data.greetings); // prints "Hello from context!"
162    /// }
163    /// ```
164    pub fn get_ref<'a, T: Any + Send + Sync>(&'a self) -> Option<PluginContextRefEntry<'a, T>> {
165        let type_id = TypeId::of::<T>();
166        self.inner.get(&type_id).map(|entry| PluginContextRefEntry {
167            entry,
168            phantom: std::marker::PhantomData,
169        })
170    }
171    /// Get a mutable reference to the entry of type T in the context, if it exists.
172    /// If no entry of that type exists, None is returned.
173    ///
174    /// Example:
175    /// ```
176    /// struct ContextData {
177    ///   pub greetings: String
178    /// }
179    ///
180    /// let context_data = ContextData {
181    ///   greetings: "Hello from context!".to_string()
182    /// };
183    ///
184    /// payload.context.insert(context_data);
185    ///
186    /// if let Some(mut context_data_entry) = payload.context.get_mut::<ContextData>() {
187    ///    context_data_entry.greetings = "Hello from mutable reference!".to_string();
188    /// }
189    /// ```
190    pub fn get_mut<'a, T: Any + Send + Sync>(&'a self) -> Option<PluginContextMutEntry<'a, T>> {
191        let type_id = TypeId::of::<T>();
192        self.inner
193            .get_mut(&type_id)
194            .map(|entry| PluginContextMutEntry {
195                entry,
196                phantom: std::marker::PhantomData,
197            })
198    }
199}
200
201pub struct PluginRequestState<'req> {
202    pub plugins: Arc<Vec<RouterPluginBoxed>>,
203    pub router_http_request: RouterHttpRequest<'req>,
204    pub context: Arc<PluginContext>,
205}
206
207#[cfg(test)]
208mod tests {
209    #[test]
210    fn inserts_and_gets_immut_ref() {
211        use super::PluginContext;
212
213        struct TestCtx {
214            pub value: u32,
215        }
216
217        let ctx = PluginContext::default();
218        ctx.insert(TestCtx { value: 42 });
219
220        let ctx_ref: &TestCtx = &ctx.get_ref().unwrap();
221        assert_eq!(ctx_ref.value, 42);
222    }
223    #[test]
224    fn inserts_and_mutates_with_mut_ref() {
225        use super::PluginContext;
226
227        struct TestCtx {
228            pub value: u32,
229        }
230
231        let ctx = PluginContext::default();
232        ctx.insert(TestCtx { value: 42 });
233
234        {
235            let ctx_mut: &mut TestCtx = &mut ctx.get_mut().unwrap();
236            ctx_mut.value = 100;
237        }
238
239        let ctx_ref: &TestCtx = &ctx.get_ref().unwrap();
240        assert_eq!(ctx_ref.value, 100);
241    }
242}