1#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use std::{collections::HashMap, ops::Deref, sync::RwLock};
5
6use featureflag::{Context, Evaluator, context::ContextRef, fields::Fields};
7
8pub use featureflag_test_macros::*;
9
10pub struct TestEvaluator {
12 features: RwLock<HashMap<String, Box<dyn TestFeature>>>,
13}
14
15impl TestEvaluator {
16 pub fn new() -> TestEvaluator {
18 TestEvaluator {
19 features: RwLock::new(HashMap::new()),
20 }
21 }
22
23 pub fn set_feature<T: TestFeature>(&self, feature: &str, enabled: T) {
30 self.features
31 .write()
32 .unwrap()
33 .insert(feature.to_string(), Box::new(enabled));
34 }
35
36 pub fn clear_feature(&self, feature: &str) {
38 self.features.write().unwrap().remove(feature);
39 }
40}
41
42impl Default for TestEvaluator {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl Evaluator for TestEvaluator {
49 fn is_enabled(&self, feature: &str, _context: &crate::Context) -> Option<bool> {
50 self.features
51 .read()
52 .unwrap()
53 .get(feature)
54 .and_then(|f| f.is_enabled(_context))
55 }
56
57 fn on_new_context(&self, mut context: ContextRef<'_>, fields: Fields<'_>) {
58 let fields = TestFields::new(fields);
59 context.extensions_mut().insert(fields);
60 }
61}
62
63pub trait TestFeature: Send + Sync + 'static {
65 fn is_enabled(&self, context: &Context) -> Option<bool>;
67}
68
69impl<T: TestFeature> TestFeature for Option<T> {
70 fn is_enabled(&self, context: &Context) -> Option<bool> {
71 self.as_ref()?.is_enabled(context)
72 }
73}
74
75impl TestFeature for bool {
76 fn is_enabled(&self, _context: &Context) -> Option<bool> {
77 Some(*self)
78 }
79}
80
81impl<F, O> TestFeature for F
82where
83 F: Fn(&Context) -> O + Send + Sync + 'static,
84 O: TestFeature,
85{
86 fn is_enabled(&self, context: &Context) -> Option<bool> {
87 self(context).is_enabled(context)
88 }
89}
90
91struct TestFields {
97 fields: Fields<'static>,
98}
99
100impl TestFields {
101 fn new(fields: Fields<'_>) -> TestFields {
102 let fields = fields
105 .pairs()
106 .map(|(k, v)| (&*k.to_string().leak(), v.to_static()))
107 .collect::<Vec<_>>()
108 .leak();
109
110 TestFields {
111 fields: Fields::new(fields),
112 }
113 }
114}
115
116impl Deref for TestFields {
117 type Target = Fields<'static>;
118
119 fn deref(&self) -> &Self::Target {
120 &self.fields
121 }
122}
123
124pub trait TestContextExt {
127 fn test_fields(&self) -> Option<&Fields<'_>>;
132}
133
134impl TestContextExt for Context {
135 fn test_fields(&self) -> Option<&Fields<'_>> {
136 self.extensions()
137 .get::<TestFields>()
138 .map(|fields| fields.deref())
139 }
140}