1use std::sync::OnceLock;
2
3use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
4
5use crate::{
6 provider::{FeatureProvider, ProviderMetadata},
7 Client, EvaluationContext, Hook, HookWrapper,
8};
9
10use super::{
11 global_evaluation_context::GlobalEvaluationContext, global_hooks::GlobalHooks,
12 provider_registry::ProviderRegistry,
13};
14
15static SINGLETON: OnceLock<RwLock<OpenFeature>> = OnceLock::new();
18
19fn get_singleton() -> &'static RwLock<OpenFeature> {
20 SINGLETON.get_or_init(|| RwLock::new(OpenFeature::default()))
21}
22
23#[derive(Default)]
26pub struct OpenFeature {
27 evaluation_context: GlobalEvaluationContext,
28 hooks: GlobalHooks,
29
30 provider_registry: ProviderRegistry,
31}
32
33impl OpenFeature {
34 pub async fn singleton() -> RwLockReadGuard<'static, Self> {
36 get_singleton().read().await
37 }
38
39 pub async fn singleton_mut() -> RwLockWriteGuard<'static, Self> {
41 get_singleton().write().await
42 }
43
44 pub async fn set_evaluation_context(&mut self, evaluation_context: EvaluationContext) {
46 let mut context = self.evaluation_context.get_mut().await;
47
48 context.targeting_key = evaluation_context.targeting_key;
49 context.custom_fields = evaluation_context.custom_fields;
50 }
51
52 pub async fn set_provider<T: FeatureProvider>(&mut self, provider: T) {
54 self.provider_registry.set_default(provider).await;
55 }
56
57 pub async fn set_named_provider<T: FeatureProvider>(&mut self, name: &str, provider: T) {
59 self.provider_registry.set_named(name, provider).await;
60 }
61
62 pub async fn add_hook<T: Hook>(&mut self, hook: T) {
64 let mut lock = self.hooks.get_mut().await;
65 lock.push(HookWrapper::new(hook));
66 }
67
68 pub async fn provider_metadata(&self) -> ProviderMetadata {
70 self.provider_registry
71 .get_default()
72 .await
73 .get()
74 .metadata()
75 .clone()
76 }
77
78 pub async fn named_provider_metadata(&self, name: &str) -> Option<ProviderMetadata> {
80 self.provider_registry
81 .get_named(name)
82 .await
83 .map(|provider| provider.get().metadata().clone())
84 }
85
86 pub fn create_client(&self) -> Client {
88 Client::new(
89 String::default(),
90 self.evaluation_context.clone(),
91 self.hooks.clone(),
92 self.provider_registry.clone(),
93 )
94 }
95
96 pub fn create_named_client(&self, name: &str) -> Client {
99 Client::new(
100 name.to_string(),
101 self.evaluation_context.clone(),
102 self.hooks.clone(),
103 self.provider_registry.clone(),
104 )
105 }
106
107 pub async fn shutdown(&mut self) {
109 self.provider_registry.clear().await;
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use std::sync::Arc;
116
117 use super::*;
118 use crate::{
119 provider::{MockFeatureProvider, NoOpProvider, ResolutionDetails},
120 EvaluationContextFieldValue,
121 };
122 use mockall::predicate;
123 use spec::spec;
124
125 #[spec(
126 number = "1.1.1",
127 text = "The API, and any state it maintains SHOULD exist as a global singleton, even in cases wherein multiple versions of the API are present at runtime."
128 )]
129 #[tokio::test]
130 async fn singleton_multi_thread() {
131 let reader1 = tokio::spawn(async move {
132 let _ = OpenFeature::singleton().await.provider_metadata();
133 });
134
135 let writer = tokio::spawn(async move {
136 OpenFeature::singleton_mut()
137 .await
138 .set_provider(NoOpProvider::default())
139 .await;
140 });
141
142 let reader2 = tokio::spawn(async move {
143 let _ = OpenFeature::singleton().await.provider_metadata();
144 });
145
146 let _ = (reader1.await, reader2.await, writer.await);
147
148 assert_eq!(
149 "No-op Provider",
150 OpenFeature::singleton()
151 .await
152 .provider_metadata()
153 .await
154 .name
155 );
156 }
157
158 #[spec(
159 number = "1.1.2.1",
160 text = "The API MUST define a provider mutator, a function to set the default provider, which accepts an API-conformant provider implementation."
161 )]
162 #[tokio::test]
163 async fn set_provider() {
164 let mut api = OpenFeature::default();
165 let client = api.create_client();
166
167 assert!(client.get_int_value("some-key", None, None).await.is_err());
168
169 let mut provider = MockFeatureProvider::new();
171 provider.expect_initialize().returning(|_| {});
172 provider.expect_hooks().return_const(vec![]);
173 provider
174 .expect_metadata()
175 .return_const(ProviderMetadata::default());
176 provider
177 .expect_resolve_int_value()
178 .return_const(Ok(ResolutionDetails::new(200)));
179
180 api.set_provider(provider).await;
181
182 assert_eq!(
183 client.get_int_value("some-key", None, None).await.unwrap(),
184 200
185 );
186 }
187
188 #[spec(
189 number = "1.1.2.2",
190 text = "The provider mutator function MUST invoke the initialize function on the newly registered provider before using it to resolve flag values."
191 )]
192 #[tokio::test]
193 async fn set_provider_invoke_initialize() {
194 let mut provider = MockFeatureProvider::new();
195 provider.expect_initialize().returning(|_| {}).once();
196
197 let mut api = OpenFeature::default();
198 api.set_provider(provider).await;
199 }
200
201 #[spec(
202 number = "1.1.2.3",
203 text = "The provider mutator function MUST invoke the shutdown function on the previously registered provider once it's no longer being used to resolve flag values."
204 )]
205 #[test]
206 fn invoke_shutdown_on_old_provider_checked_by_type_system() {}
207
208 #[spec(
209 number = "1.1.3",
210 text = "The API MUST provide a function to bind a given provider to one or more client names. If the client-name already has a bound provider, it is overwritten with the new mapping."
211 )]
212 #[tokio::test]
213 async fn set_named_provider() {
214 let mut api = OpenFeature::default();
215
216 let client = api.create_named_client("test");
218 assert!(client.get_int_value("", None, None).await.is_err());
219
220 let mut provider = MockFeatureProvider::new();
222 provider.expect_initialize().returning(|_| {});
223 provider.expect_hooks().return_const(vec![]);
224 provider
225 .expect_metadata()
226 .return_const(ProviderMetadata::default());
227 provider
228 .expect_resolve_int_value()
229 .return_const(Ok(ResolutionDetails::new(30)));
230 api.set_named_provider("test", provider).await;
231
232 assert_eq!(client.get_int_value("", None, None).await, Ok(30));
234
235 let new_client = api.create_named_client("test");
237 assert_eq!(new_client.get_int_value("", None, None).await, Ok(30));
238 }
239
240 #[spec(
241 number = "1.1.5",
242 text = "The API MUST provide a function for retrieving the metadata field of the configured provider."
243 )]
244 #[tokio::test]
245 async fn provider_metadata() {
246 let mut api = OpenFeature::default();
247 api.set_provider(NoOpProvider::default()).await;
248 api.set_named_provider("test", NoOpProvider::default())
249 .await;
250
251 assert_eq!(api.provider_metadata().await.name, "No-op Provider");
252 assert_eq!(
253 api.named_provider_metadata("test").await.unwrap().name,
254 "No-op Provider"
255 );
256 assert!(api.named_provider_metadata("invalid").await.is_none());
257 }
258
259 #[spec(
260 number = "1.1.6",
261 text = "The API MUST provide a function for creating a client which accepts the following options:
262 * name (optional): A logical string identifier for the client."
263 )]
264 #[tokio::test]
265 async fn get_client() {
266 let mut api = OpenFeature::default();
267
268 let mut default_provider = MockFeatureProvider::new();
269 default_provider.expect_initialize().returning(|_| {});
270 default_provider.expect_hooks().return_const(vec![]);
271 default_provider
272 .expect_metadata()
273 .return_const(ProviderMetadata::default());
274 default_provider
275 .expect_resolve_int_value()
276 .return_const(Ok(ResolutionDetails::new(100)));
277
278 let mut named_provider = MockFeatureProvider::new();
279 named_provider.expect_initialize().returning(|_| {});
280 named_provider.expect_hooks().return_const(vec![]);
281 named_provider
282 .expect_metadata()
283 .return_const(ProviderMetadata::default());
284 named_provider
285 .expect_resolve_int_value()
286 .return_const(Ok(ResolutionDetails::new(200)));
287
288 api.set_provider(default_provider).await;
289 api.set_named_provider("test", named_provider).await;
290
291 let client = api.create_client();
292 assert_eq!(client.get_int_value("key", None, None).await.unwrap(), 100);
293
294 let client = api.create_named_client("test");
295 assert_eq!(client.get_int_value("key", None, None).await.unwrap(), 200);
296
297 let client = api.create_named_client("another");
298 assert_eq!(client.get_int_value("test", None, None).await.unwrap(), 100);
299 }
300
301 #[spec(
302 number = "1.1.7",
303 text = "The client creation function MUST NOT throw, or otherwise abnormally terminate."
304 )]
305 #[test]
306 fn get_client_not_throw_checked_by_type_system() {}
307
308 #[spec(
309 number = "1.1.8",
310 text = "The API SHOULD provide functions to set a provider and wait for the initialize function to return or throw."
311 )]
312 #[tokio::test]
313 async fn set_provider_should_block() {
314 let mut api = OpenFeature::default();
315 api.set_provider(NoOpProvider::default()).await;
316
317 api.set_named_provider("named", NoOpProvider::default())
318 .await;
319 }
320
321 #[spec(
322 number = "1.6.1",
323 text = "The API MUST define a shutdown function which, when called, must call the respective shutdown function on the active provider."
324 )]
325 #[tokio::test]
326 async fn shutdown() {
327 let mut api = OpenFeature::default();
328 api.set_provider(NoOpProvider::default()).await;
329
330 api.shutdown().await;
331 }
332
333 #[spec(
334 number = "3.2.1.1",
335 text = "The API, Client and invocation MUST have a method for supplying evaluation context."
336 )]
337 #[spec(
338 number = "3.2.3",
339 text = "Evaluation context MUST be merged in the order: API (global; lowest precedence) -> client -> invocation -> before hooks (highest precedence), with duplicate values being overwritten."
340 )]
341 #[tokio::test]
342 async fn evaluation_context() {
343 let mut provider = MockFeatureProvider::new();
345 provider.expect_initialize().returning(|_| {});
346 provider.expect_hooks().return_const(vec![]);
347 provider
348 .expect_metadata()
349 .return_const(ProviderMetadata::default());
350
351 provider
352 .expect_resolve_int_value()
353 .with(
354 predicate::eq("flag"),
355 predicate::eq(
356 EvaluationContext::default()
357 .with_targeting_key("global_targeting_key")
358 .with_custom_field("key", "global_value"),
359 ),
360 )
361 .return_const(Ok(ResolutionDetails::new(100)));
362
363 provider
364 .expect_resolve_int_value()
365 .with(
366 predicate::eq("flag"),
367 predicate::eq(
368 EvaluationContext::default()
369 .with_targeting_key("client_targeting_key")
370 .with_custom_field("key", "client_value"),
371 ),
372 )
373 .return_const(Ok(ResolutionDetails::new(200)));
374
375 provider
376 .expect_resolve_int_value()
377 .with(
378 predicate::eq("flag"),
379 predicate::eq(
380 EvaluationContext::default()
381 .with_targeting_key("invocation_targeting_key")
382 .with_custom_field("key", "invocation_value"),
383 ),
384 )
385 .return_const(Ok(ResolutionDetails::new(300)));
386
387 let mut api = OpenFeature::default();
389 api.set_provider(provider).await;
390
391 let global_evaluation_context = EvaluationContext::default()
393 .with_targeting_key("global_targeting_key")
394 .with_custom_field("key", "global_value");
395
396 api.set_evaluation_context(global_evaluation_context).await;
397
398 let mut client = api.create_client();
399
400 assert_eq!(client.get_int_value("flag", None, None).await.unwrap(), 100);
401
402 let client_evaluation_context = EvaluationContext::default()
404 .with_targeting_key("client_targeting_key")
405 .with_custom_field("key", "client_value");
406
407 client.set_evaluation_context(client_evaluation_context);
408
409 assert_eq!(client.get_int_value("flag", None, None).await.unwrap(), 200);
410
411 let invocation_evaluation_context = EvaluationContext::default()
413 .with_targeting_key("invocation_targeting_key")
414 .with_custom_field("key", "invocation_value");
415
416 assert_eq!(
417 client
418 .get_int_value("flag", Some(&invocation_evaluation_context), None)
419 .await
420 .unwrap(),
421 300
422 );
423 }
424
425 #[spec(
426 number = "3.2.2.1",
427 text = "The API MUST have a method for setting the global evaluation context."
428 )]
429 #[spec(
430 number = "3.2.2.2",
431 text = "The Client and invocation MUST NOT have a method for supplying evaluation context."
432 )]
433 #[spec(
434 number = "3.2.4.1",
435 text = "When the global evaluation context is set, the on context changed handler MUST run."
436 )]
437 #[test]
438 fn static_context_not_applicable() {}
439
440 #[derive(Clone, Default, Debug)]
441 struct MyStruct {}
442
443 #[tokio::test]
444 async fn extended_example() {
445 let mut api = OpenFeature::singleton_mut().await;
447
448 api.set_provider(NoOpProvider::default()).await;
450
451 let client = api.create_client();
453
454 let evaluation_context = EvaluationContext::default()
457 .with_targeting_key("Targeting")
458 .with_custom_field("bool_key", true)
459 .with_custom_field("int_key", 100)
460 .with_custom_field("float_key", 3.14)
461 .with_custom_field("string_key", "Hello".to_string())
462 .with_custom_field("datetime_key", time::OffsetDateTime::now_utc())
463 .with_custom_field(
464 "struct_key",
465 EvaluationContextFieldValue::Struct(Arc::new(MyStruct::default())),
466 )
467 .with_custom_field("another_struct_key", Arc::new(MyStruct::default()))
468 .with_custom_field(
469 "yet_another_struct_key",
470 EvaluationContextFieldValue::new_struct(MyStruct::default()),
471 );
472
473 let is_feature_enabled = client
476 .get_bool_value("SomeFlagEnabled", Some(&evaluation_context), None)
477 .await
478 .unwrap_or(false);
479
480 if is_feature_enabled {
481 let _result = client
483 .get_int_details("key", Some(&evaluation_context), None)
484 .await;
485 }
486 }
487}