featureflag_test/
lib.rs

1//! Test utilities for the [`featureflag`] crate.
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use std::{collections::HashMap, ops::Deref, sync::RwLock};
5
6use featureflag::{Context, Evaluator, context::ContextRef, fields::Fields};
7
8pub use featureflag_test_macros::*;
9
10/// A test evaluator that allows setting features for testing purposes.
11pub struct TestEvaluator {
12    features: RwLock<HashMap<String, Box<dyn TestFeature>>>,
13}
14
15impl TestEvaluator {
16    /// Create a new `TestEvaluator`.
17    pub fn new() -> TestEvaluator {
18        TestEvaluator {
19            features: RwLock::new(HashMap::new()),
20        }
21    }
22
23    /// Set the state of a feature.
24    ///
25    /// The feature can be set to any value that implements `TestFeature`, which
26    /// allows for complex logic to determine if a feature is enabled. `TestFeature`
27    /// is automatically implemented for `bool`, `Option<bool>` and
28    /// `Fn(&Context) -> impl TestFeature`.
29    pub fn set_feature<T: TestFeature>(&self, feature: &str, enabled: T) {
30        self.features
31            .write()
32            .unwrap()
33            .insert(feature.to_string(), Box::new(enabled));
34    }
35
36    /// Unset a feature.
37    pub fn clear_feature(&self, feature: &str) {
38        self.features.write().unwrap().remove(feature);
39    }
40}
41
42impl Default for TestEvaluator {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl Evaluator for TestEvaluator {
49    fn is_enabled(&self, feature: &str, _context: &crate::Context) -> Option<bool> {
50        self.features
51            .read()
52            .unwrap()
53            .get(feature)
54            .and_then(|f| f.is_enabled(_context))
55    }
56
57    fn on_new_context(&self, mut context: ContextRef<'_>, fields: Fields<'_>) {
58        let fields = TestFields::new(fields);
59        context.extensions_mut().insert(fields);
60    }
61}
62
63/// A trait for types that can determine if a feature is enabled.
64pub trait TestFeature: Send + Sync + 'static {
65    /// Check if the feature is enabled.
66    fn is_enabled(&self, context: &Context) -> Option<bool>;
67}
68
69impl<T: TestFeature> TestFeature for Option<T> {
70    fn is_enabled(&self, context: &Context) -> Option<bool> {
71        self.as_ref()?.is_enabled(context)
72    }
73}
74
75impl TestFeature for bool {
76    fn is_enabled(&self, _context: &Context) -> Option<bool> {
77        Some(*self)
78    }
79}
80
81impl<F, O> TestFeature for F
82where
83    F: Fn(&Context) -> O + Send + Sync + 'static,
84    O: TestFeature,
85{
86    fn is_enabled(&self, context: &Context) -> Option<bool> {
87        self(context).is_enabled(context)
88    }
89}
90
91/// Extension type for [`Context`] that allows access to fields set on a context
92/// when using the [`TestEvaluator`].
93///
94/// This type is not intended to be used directly. Instead, use [`TestContextExt::test_fields`]
95/// to access the fields.
96struct TestFields {
97    fields: Fields<'static>,
98}
99
100impl TestFields {
101    fn new(fields: Fields<'_>) -> TestFields {
102        // very leaky!
103
104        let fields = fields
105            .pairs()
106            .map(|(k, v)| (&*k.to_string().leak(), v.to_static()))
107            .collect::<Vec<_>>()
108            .leak();
109
110        TestFields {
111            fields: Fields::new(fields),
112        }
113    }
114}
115
116impl Deref for TestFields {
117    type Target = Fields<'static>;
118
119    fn deref(&self) -> &Self::Target {
120        &self.fields
121    }
122}
123
124/// Extension trait for [`Context`] that provides access to the fields set on
125/// the context when using the [`TestEvaluator`].
126pub trait TestContextExt {
127    /// Get the fields set on the context.
128    ///
129    /// This method will only work with contexts that have been created when
130    /// using [`TestEvaluator`].
131    fn test_fields(&self) -> Option<&Fields<'_>>;
132}
133
134impl TestContextExt for Context {
135    fn test_fields(&self) -> Option<&Fields<'_>> {
136        self.extensions()
137            .get::<TestFields>()
138            .map(|fields| fields.deref())
139    }
140}