Skip to main content

ntex_service/
cfg.rs

1//! Shared configuration for services
2#![allow(
3    clippy::should_implement_trait,
4    clippy::new_ret_no_self,
5    clippy::missing_panics_doc
6)]
7use std::any::{Any, TypeId};
8use std::cell::{RefCell, UnsafeCell};
9use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
10use std::{fmt, hash::Hash, hash::Hasher, marker::PhantomData, mem, ops, ptr, rc};
11
12type Key = (usize, TypeId);
13type HashMap<K, V> = std::collections::HashMap<K, V, foldhash::fast::RandomState>;
14
15thread_local! {
16    static DEFAULT_CFG: Arc<Storage> = {
17        let mut st = Arc::new(Storage::new("--", "", false, CfgContext(ptr::null())));
18        let p = Arc::as_ptr(&st);
19        Arc::get_mut(&mut st).unwrap().ctx.update(p);
20        st
21    };
22    static MAPPING: RefCell<HashMap<Key, Arc<dyn Any + Send + Sync>>> = {
23        RefCell::new(HashMap::default())
24    };
25}
26static IDX: AtomicUsize = AtomicUsize::new(0);
27const KIND_ARC: usize = 1;
28const KIND_UNMASK: usize = !KIND_ARC;
29
30pub trait Configuration: Default + Send + Sync + fmt::Debug + 'static {
31    const NAME: &'static str;
32
33    fn ctx(&self) -> &CfgContext;
34
35    fn set_ctx(&mut self, ctx: CfgContext);
36}
37
38#[derive(Debug)]
39struct Storage {
40    id: usize,
41    tag: &'static str,
42    service: &'static str,
43    ctx: CfgContext,
44    building: bool,
45    data: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
46}
47
48impl Storage {
49    fn new(
50        tag: &'static str,
51        service: &'static str,
52        building: bool,
53        ctx: CfgContext,
54    ) -> Self {
55        let id = IDX.fetch_add(1, Ordering::SeqCst);
56        Storage {
57            id,
58            ctx,
59            tag,
60            service,
61            building,
62            data: HashMap::default(),
63        }
64    }
65}
66
67#[derive(Debug)]
68pub struct CfgContext(*const Storage);
69
70unsafe impl Send for CfgContext {}
71unsafe impl Sync for CfgContext {}
72
73impl CfgContext {
74    fn update(&mut self, new_p: *const Storage) {
75        self.0 = new_p;
76    }
77
78    /// Unique id of the context.
79    pub fn id(&self) -> usize {
80        self.get_ref().id
81    }
82
83    #[inline]
84    /// Context tag.
85    pub fn tag(&self) -> &'static str {
86        self.get_ref().tag
87    }
88
89    /// Service name.
90    pub fn service(&self) -> &'static str {
91        self.get_ref().service
92    }
93
94    /// Get a reference to a configuration.
95    pub fn get<T>(&self) -> Cfg<T>
96    where
97        T: Configuration,
98    {
99        let inner: Arc<Storage> = unsafe { Arc::from_raw(self.0) };
100        let cfg = get(&inner);
101        mem::forget(inner);
102        cfg
103    }
104
105    /// Get a shared configuration.
106    pub fn shared(&self) -> SharedCfg {
107        let inner: Arc<Storage> = unsafe { Arc::from_raw(self.0) };
108        let shared = SharedCfg(inner.clone());
109        mem::forget(inner);
110        shared
111    }
112
113    fn get_ref(&self) -> &Storage {
114        unsafe { self.0.as_ref().unwrap() }
115    }
116}
117
118impl Default for CfgContext {
119    #[inline]
120    fn default() -> Self {
121        CfgContext(DEFAULT_CFG.with(Arc::as_ptr))
122    }
123}
124
125#[derive(Debug)]
126pub struct Cfg<T: Configuration>(UnsafeCell<*const T>, PhantomData<rc::Rc<T>>);
127
128impl<T: Configuration> Cfg<T> {
129    fn new(ptr: *const T) -> Self {
130        Self(UnsafeCell::new(ptr), PhantomData)
131    }
132
133    #[inline]
134    /// Unique id of the configuration.
135    pub fn id(&self) -> usize {
136        self.get_ref().ctx().id()
137    }
138
139    #[inline]
140    /// Context tag.
141    pub fn tag(&self) -> &'static str {
142        self.get_ref().ctx().tag()
143    }
144
145    /// Service name.
146    pub fn service(&self) -> &'static str {
147        self.get_ref().ctx().service()
148    }
149
150    /// Get a shared configuration.
151    pub fn shared(&self) -> SharedCfg {
152        self.get_ref().ctx().shared()
153    }
154
155    fn get_ref(&self) -> &T {
156        unsafe {
157            (*self.0.get())
158                .map_addr(|addr| addr & KIND_UNMASK)
159                .as_ref()
160                .unwrap()
161        }
162    }
163
164    #[allow(clippy::needless_pass_by_value)]
165    /// Replaces the inner value.
166    ///
167    /// # Safety
168    ///
169    /// The caller must guarantee that no references to the inner `T` value
170    /// exist at the time this function is called.
171    pub unsafe fn replace(&self, cfg: Cfg<T>) {
172        unsafe {
173            ptr::swap(self.0.get(), cfg.0.get());
174        }
175    }
176
177    #[doc(hidden)]
178    #[deprecated(since = "4.5.0")]
179    #[must_use]
180    pub fn into_static(&self) -> Cfg<T> {
181        self.ctx().get()
182    }
183}
184
185impl<T: Configuration> Drop for Cfg<T> {
186    fn drop(&mut self) {
187        unsafe {
188            let addr = (*self.0.get()).map_addr(|addr| addr & KIND_UNMASK);
189            Arc::decrement_strong_count(addr.as_ref().unwrap().ctx().0);
190
191            if ((*self.0.get()).addr() & KIND_ARC) != 0 {
192                Arc::from_raw(addr);
193            }
194        }
195    }
196}
197
198impl<T: Configuration> Clone for Cfg<T> {
199    #[inline]
200    fn clone(&self) -> Self {
201        self.ctx().get()
202    }
203}
204
205impl<'a, T: Configuration> From<&'a T> for Cfg<T> {
206    #[inline]
207    fn from(cfg: &'a T) -> Self {
208        cfg.ctx().get()
209    }
210}
211
212impl<T: Configuration> ops::Deref for Cfg<T> {
213    type Target = T;
214
215    #[inline]
216    fn deref(&self) -> &T {
217        self.get_ref()
218    }
219}
220
221impl<T: Configuration> Default for Cfg<T> {
222    #[inline]
223    fn default() -> Self {
224        SharedCfg::default().get()
225    }
226}
227
228#[derive(Clone, Debug)]
229/// Shared configuration
230pub struct SharedCfg(Arc<Storage>);
231
232#[derive(Debug)]
233pub struct SharedCfgBuilder {
234    ctx: CfgContext,
235    storage: Arc<Storage>,
236}
237
238impl Eq for SharedCfg {}
239
240impl PartialEq for SharedCfg {
241    fn eq(&self, other: &Self) -> bool {
242        ptr::from_ref(self.0.as_ref()) == ptr::from_ref(other.0.as_ref())
243    }
244}
245
246impl Hash for SharedCfg {
247    fn hash<H: Hasher>(&self, state: &mut H) {
248        self.0.id.hash(state);
249    }
250}
251
252impl SharedCfg {
253    /// Construct new configuration
254    pub fn new(tag: &'static str) -> SharedCfgBuilder {
255        SharedCfgBuilder::new(tag)
256    }
257
258    #[inline]
259    /// Get unique shared cfg id
260    pub fn id(&self) -> usize {
261        self.0.id
262    }
263
264    #[inline]
265    /// Get tag.
266    pub fn tag(&self) -> &'static str {
267        self.0.tag
268    }
269
270    /// Service name.
271    pub fn service(&self) -> &'static str {
272        self.0.service
273    }
274
275    /// Get a reference to a previously inserted on configuration.
276    ///
277    /// # Panics
278    ///
279    /// if shared config is in building stage
280    pub fn get<T>(&self) -> Cfg<T>
281    where
282        T: Configuration,
283    {
284        get(&self.0)
285    }
286}
287
288impl Default for SharedCfg {
289    #[inline]
290    fn default() -> Self {
291        Self(DEFAULT_CFG.with(Clone::clone))
292    }
293}
294
295impl<T: Configuration> From<SharedCfg> for Cfg<T> {
296    #[inline]
297    fn from(cfg: SharedCfg) -> Self {
298        cfg.get()
299    }
300}
301
302impl SharedCfgBuilder {
303    fn new(tag: &'static str) -> SharedCfgBuilder {
304        let mut storage = Arc::new(Storage::new(tag, tag, true, CfgContext::default()));
305        let ctx = CfgContext(Arc::as_ptr(&storage));
306        Arc::get_mut(&mut storage).unwrap().ctx.update(ctx.0);
307
308        SharedCfgBuilder { ctx, storage }
309    }
310
311    #[must_use]
312    /// Set service name.
313    pub fn service(mut self, name: &'static str) -> Self {
314        Arc::get_mut(&mut self.storage).unwrap().service = name;
315        self
316    }
317
318    #[must_use]
319    /// Insert a type into this configuration.
320    ///
321    /// If a config of this type already existed, it will
322    /// be replaced.
323    pub fn add<T: Configuration>(mut self, mut val: T) -> Self {
324        val.set_ctx(CfgContext(self.ctx.0));
325        Arc::get_mut(&mut self.storage)
326            .unwrap()
327            .data
328            .insert(TypeId::of::<T>(), Box::new(val));
329        self
330    }
331
332    #[must_use]
333    /// Build `SharedCfg` instance.
334    pub fn build(self) -> SharedCfg {
335        self.into()
336    }
337}
338
339impl From<SharedCfgBuilder> for SharedCfg {
340    fn from(mut cfg: SharedCfgBuilder) -> SharedCfg {
341        let st = Arc::get_mut(&mut cfg.storage).unwrap();
342        st.building = false;
343        SharedCfg(cfg.storage)
344    }
345}
346
347fn get<T>(st: &Arc<Storage>) -> Cfg<T>
348where
349    T: Configuration,
350{
351    assert!(
352        !st.building,
353        "{}: Cannot access shared config while building",
354        st.tag
355    );
356
357    // increase arc refs for storage instead of actual item
358    // CfgContext and Cfg::shared() relayes on Arc<Storage>
359    mem::forget(st.clone());
360
361    let tp = TypeId::of::<T>();
362    if let Some(arc) = st.data.get(&tp) {
363        Cfg::new(arc.as_ref().downcast_ref::<T>().unwrap())
364    } else {
365        MAPPING.with(|store| {
366            let key = (st.id, tp);
367            if let Some(arc) = store.borrow().get(&key) {
368                Cfg::new(
369                    Arc::into_raw(arc.clone())
370                        .cast::<T>()
371                        .map_addr(|addr| addr ^ KIND_ARC),
372                )
373            } else {
374                log::info!(
375                    "{}: Configuration {:?} does not exist, using default",
376                    st.tag,
377                    T::NAME
378                );
379                let mut val = T::default();
380                val.set_ctx(CfgContext(st.ctx.0));
381                let arc = Arc::new(val);
382                store.borrow_mut().insert(key, arc.clone());
383                Cfg::new(
384                    Arc::into_raw(arc)
385                        .cast::<T>()
386                        .map_addr(|addr| addr ^ KIND_ARC),
387                )
388            }
389        })
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    #[allow(clippy::should_panic_without_expect)]
399    #[should_panic]
400    fn access_cfg_in_building_state() {
401        #[derive(Debug)]
402        struct TestCfg {
403            config: CfgContext,
404        }
405        impl TestCfg {
406            fn new() -> Self {
407                Self {
408                    config: CfgContext::default(),
409                }
410            }
411        }
412        impl Default for TestCfg {
413            fn default() -> Self {
414                panic!()
415            }
416        }
417        impl Configuration for TestCfg {
418            const NAME: &str = "TEST";
419            fn ctx(&self) -> &CfgContext {
420                &self.config
421            }
422            fn set_ctx(&mut self, ctx: CfgContext) {
423                let _ = ctx.shared().get::<TestCfg>();
424                self.config = ctx;
425            }
426        }
427        let _ = TestCfg::new().ctx();
428        let _ = SharedCfg::new("TEST").add(TestCfg::new());
429    }
430
431    #[test]
432    fn shared_cfg() {
433        #[derive(Default, Debug)]
434        struct TestCfg {
435            config: CfgContext,
436        }
437        impl Configuration for TestCfg {
438            const NAME: &str = "TEST";
439            fn ctx(&self) -> &CfgContext {
440                &self.config
441            }
442            fn set_ctx(&mut self, ctx: CfgContext) {
443                self.config = ctx;
444            }
445        }
446
447        let cfg: SharedCfg = SharedCfg::new("TEST")
448            .add(TestCfg::default())
449            .service("SVC")
450            .into();
451
452        assert_eq!(cfg.tag(), "TEST");
453        assert_eq!(cfg.service(), "SVC");
454        let t = cfg.get::<TestCfg>();
455        assert_eq!(t.tag(), "TEST");
456        assert_eq!(t.service(), "SVC");
457        assert_eq!(t.shared(), cfg);
458        let t: Cfg<TestCfg> = Cfg::default();
459        assert_eq!(t.tag(), "--");
460        assert_eq!(t.service(), "");
461        assert_eq!(t.ctx().id(), t.id());
462
463        let t: Cfg<TestCfg> = t.ctx().get();
464        assert_eq!(t.tag(), "--");
465        assert_eq!(t.ctx().id(), t.id());
466
467        let cfg = SharedCfg::new("TEST2").build();
468        let t = cfg.get::<TestCfg>();
469        assert_eq!(t.tag(), "TEST2");
470        assert_eq!(t.id(), cfg.id());
471        drop(cfg);
472
473        let cfg2 = t.shared();
474        let t2 = cfg2.get::<TestCfg>();
475        assert_eq!(t2.tag(), "TEST2");
476        assert_eq!(t2.id(), cfg2.id());
477        unsafe { t2.replace(SharedCfg::from(SharedCfg::new("TEST3")).get::<TestCfg>()) };
478
479        let cfg2 = t2.shared();
480        let t3 = cfg2.get::<TestCfg>();
481        assert_eq!(t3.tag(), "TEST3");
482        assert_eq!(t3.id(), cfg2.id());
483
484        let t = SharedCfg::from(SharedCfg::new("TEST4").add(TestCfg::default()))
485            .get::<TestCfg>();
486        let cfg = t.shared();
487        assert_eq!(t.id(), cfg.id());
488        let t2 = t.clone();
489        assert_eq!(t2.id(), cfg.id());
490        assert_eq!(t2.tag(), "TEST4");
491
492        let t3 = t.ctx().get::<TestCfg>();
493        assert_eq!(t3.id(), cfg.id());
494        assert_eq!(t3.tag(), "TEST4");
495    }
496}