1use lazy_static::lazy_static;
2use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
3
4use crate::{
5 provider::{FeatureProvider, ProviderMetadata},
6 Client, EvaluationContext, Hook, HookWrapper,
7};
8
9use super::{
10 global_evaluation_context::GlobalEvaluationContext, global_hooks::GlobalHooks,
11 provider_registry::ProviderRegistry,
12};
13
14lazy_static! {
15 static ref SINGLETON: RwLock<OpenFeature> = RwLock::new(OpenFeature::default());
18}
19
20#[derive(Default)]
23pub struct OpenFeature {
24 evaluation_context: GlobalEvaluationContext,
25 hooks: GlobalHooks,
26
27 provider_registry: ProviderRegistry,
28}
29
30impl OpenFeature {
31 pub async fn singleton() -> RwLockReadGuard<'static, Self> {
33 SINGLETON.read().await
34 }
35
36 pub async fn singleton_mut() -> RwLockWriteGuard<'static, Self> {
38 SINGLETON.write().await
39 }
40
41 pub async fn set_evaluation_context(&mut self, evaluation_context: EvaluationContext) {
43 let mut context = self.evaluation_context.get_mut().await;
44
45 context.targeting_key = evaluation_context.targeting_key;
46 context.custom_fields = evaluation_context.custom_fields;
47 }
48
49 pub async fn set_provider<T: FeatureProvider>(&mut self, provider: T) {
51 self.provider_registry.set_default(provider).await;
52 }
53
54 pub async fn set_named_provider<T: FeatureProvider>(&mut self, name: &str, provider: T) {
56 self.provider_registry.set_named(name, provider).await;
57 }
58
59 pub async fn add_hook<T: Hook>(&mut self, hook: T) {
61 let mut lock = self.hooks.get_mut().await;
62 lock.push(HookWrapper::new(hook));
63 }
64
65 pub async fn provider_metadata(&self) -> ProviderMetadata {
67 self.provider_registry
68 .get_default()
69 .await
70 .get()
71 .metadata()
72 .clone()
73 }
74
75 pub async fn named_provider_metadata(&self, name: &str) -> Option<ProviderMetadata> {
77 self.provider_registry
78 .get_named(name)
79 .await
80 .map(|provider| provider.get().metadata().clone())
81 }
82
83 pub fn create_client(&self) -> Client {
85 Client::new(
86 String::default(),
87 self.evaluation_context.clone(),
88 self.hooks.clone(),
89 self.provider_registry.clone(),
90 )
91 }
92
93 pub fn create_named_client(&self, name: &str) -> Client {
96 Client::new(
97 name.to_string(),
98 self.evaluation_context.clone(),
99 self.hooks.clone(),
100 self.provider_registry.clone(),
101 )
102 }
103
104 pub async fn shutdown(&mut self) {
106 self.provider_registry.clear().await;
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use std::sync::Arc;
113
114 use super::*;
115 use crate::{
116 provider::{MockFeatureProvider, NoOpProvider, ResolutionDetails},
117 EvaluationContextFieldValue,
118 };
119 use mockall::predicate;
120 use spec::spec;
121
122 #[spec(
123 number = "1.1.1",
124 text = "The API, and any state it maintains SHOULD exist as a global singleton, even in cases wherein multiple versions of the API are present at runtime."
125 )]
126 #[tokio::test]
127 async fn singleton_multi_thread() {
128 let reader1 = tokio::spawn(async move {
129 let _ = OpenFeature::singleton().await.provider_metadata();
130 });
131
132 let writer = tokio::spawn(async move {
133 OpenFeature::singleton_mut()
134 .await
135 .set_provider(NoOpProvider::default())
136 .await;
137 });
138
139 let reader2 = tokio::spawn(async move {
140 let _ = OpenFeature::singleton().await.provider_metadata();
141 });
142
143 let _ = (reader1.await, reader2.await, writer.await);
144
145 assert_eq!(
146 "No-op Provider",
147 OpenFeature::singleton()
148 .await
149 .provider_metadata()
150 .await
151 .name
152 );
153 }
154
155 #[spec(
156 number = "1.1.2.1",
157 text = "The API MUST define a provider mutator, a function to set the default provider, which accepts an API-conformant provider implementation."
158 )]
159 #[tokio::test]
160 async fn set_provider() {
161 let mut api = OpenFeature::default();
162 let client = api.create_client();
163
164 assert!(client.get_int_value("some-key", None, None).await.is_err());
165
166 let mut provider = MockFeatureProvider::new();
168 provider.expect_initialize().returning(|_| {});
169 provider.expect_hooks().return_const(vec![]);
170 provider
171 .expect_metadata()
172 .return_const(ProviderMetadata::default());
173 provider
174 .expect_resolve_int_value()
175 .return_const(Ok(ResolutionDetails::new(200)));
176
177 api.set_provider(provider).await;
178
179 assert_eq!(
180 client.get_int_value("some-key", None, None).await.unwrap(),
181 200
182 );
183 }
184
185 #[spec(
186 number = "1.1.2.2",
187 text = "The provider mutator function MUST invoke the initialize function on the newly registered provider before using it to resolve flag values."
188 )]
189 #[tokio::test]
190 async fn set_provider_invoke_initialize() {
191 let mut provider = MockFeatureProvider::new();
192 provider.expect_initialize().returning(|_| {}).once();
193
194 let mut api = OpenFeature::default();
195 api.set_provider(provider).await;
196 }
197
198 #[spec(
199 number = "1.1.2.3",
200 text = "The provider mutator function MUST invoke the shutdown function on the previously registered provider once it's no longer being used to resolve flag values."
201 )]
202 #[test]
203 fn invoke_shutdown_on_old_provider_checked_by_type_system() {}
204
205 #[spec(
206 number = "1.1.3",
207 text = "The API MUST provide a function to bind a given provider to one or more client names. If the client-name already has a bound provider, it is overwritten with the new mapping."
208 )]
209 #[tokio::test]
210 async fn set_named_provider() {
211 let mut api = OpenFeature::default();
212
213 let client = api.create_named_client("test");
215 assert!(client.get_int_value("", None, None).await.is_err());
216
217 let mut provider = MockFeatureProvider::new();
219 provider.expect_initialize().returning(|_| {});
220 provider.expect_hooks().return_const(vec![]);
221 provider
222 .expect_metadata()
223 .return_const(ProviderMetadata::default());
224 provider
225 .expect_resolve_int_value()
226 .return_const(Ok(ResolutionDetails::new(30)));
227 api.set_named_provider("test", provider).await;
228
229 assert_eq!(client.get_int_value("", None, None).await, Ok(30));
231
232 let new_client = api.create_named_client("test");
234 assert_eq!(new_client.get_int_value("", None, None).await, Ok(30));
235 }
236
237 #[spec(
238 number = "1.1.5",
239 text = "The API MUST provide a function for retrieving the metadata field of the configured provider."
240 )]
241 #[tokio::test]
242 async fn provider_metadata() {
243 let mut api = OpenFeature::default();
244 api.set_provider(NoOpProvider::default()).await;
245 api.set_named_provider("test", NoOpProvider::default())
246 .await;
247
248 assert_eq!(api.provider_metadata().await.name, "No-op Provider");
249 assert_eq!(
250 api.named_provider_metadata("test").await.unwrap().name,
251 "No-op Provider"
252 );
253 assert!(api.named_provider_metadata("invalid").await.is_none());
254 }
255
256 #[spec(
257 number = "1.1.6",
258 text = "The API MUST provide a function for creating a client which accepts the following options:
259 * name (optional): A logical string identifier for the client."
260 )]
261 #[tokio::test]
262 async fn get_client() {
263 let mut api = OpenFeature::default();
264
265 let mut default_provider = MockFeatureProvider::new();
266 default_provider.expect_initialize().returning(|_| {});
267 default_provider.expect_hooks().return_const(vec![]);
268 default_provider
269 .expect_metadata()
270 .return_const(ProviderMetadata::default());
271 default_provider
272 .expect_resolve_int_value()
273 .return_const(Ok(ResolutionDetails::new(100)));
274
275 let mut named_provider = MockFeatureProvider::new();
276 named_provider.expect_initialize().returning(|_| {});
277 named_provider.expect_hooks().return_const(vec![]);
278 named_provider
279 .expect_metadata()
280 .return_const(ProviderMetadata::default());
281 named_provider
282 .expect_resolve_int_value()
283 .return_const(Ok(ResolutionDetails::new(200)));
284
285 api.set_provider(default_provider).await;
286 api.set_named_provider("test", named_provider).await;
287
288 let client = api.create_client();
289 assert_eq!(client.get_int_value("key", None, None).await.unwrap(), 100);
290
291 let client = api.create_named_client("test");
292 assert_eq!(client.get_int_value("key", None, None).await.unwrap(), 200);
293
294 let client = api.create_named_client("another");
295 assert_eq!(client.get_int_value("test", None, None).await.unwrap(), 100);
296 }
297
298 #[spec(
299 number = "1.1.7",
300 text = "The client creation function MUST NOT throw, or otherwise abnormally terminate."
301 )]
302 #[test]
303 fn get_client_not_throw_checked_by_type_system() {}
304
305 #[spec(
306 number = "1.1.8",
307 text = "The API SHOULD provide functions to set a provider and wait for the initialize function to return or throw."
308 )]
309 #[tokio::test]
310 async fn set_provider_should_block() {
311 let mut api = OpenFeature::default();
312 api.set_provider(NoOpProvider::default()).await;
313
314 api.set_named_provider("named", NoOpProvider::default())
315 .await;
316 }
317
318 #[spec(
319 number = "1.6.1",
320 text = "The API MUST define a shutdown function which, when called, must call the respective shutdown function on the active provider."
321 )]
322 #[tokio::test]
323 async fn shutdown() {
324 let mut api = OpenFeature::default();
325 api.set_provider(NoOpProvider::default()).await;
326
327 api.shutdown().await;
328 }
329
330 #[spec(
331 number = "3.2.1.1",
332 text = "The API, Client and invocation MUST have a method for supplying evaluation context."
333 )]
334 #[spec(
335 number = "3.2.3",
336 text = "Evaluation context MUST be merged in the order: API (global; lowest precedence) -> client -> invocation -> before hooks (highest precedence), with duplicate values being overwritten."
337 )]
338 #[tokio::test]
339 async fn evaluation_context() {
340 let mut provider = MockFeatureProvider::new();
342 provider.expect_initialize().returning(|_| {});
343 provider.expect_hooks().return_const(vec![]);
344 provider
345 .expect_metadata()
346 .return_const(ProviderMetadata::default());
347
348 provider
349 .expect_resolve_int_value()
350 .with(
351 predicate::eq("flag"),
352 predicate::eq(
353 EvaluationContext::default()
354 .with_targeting_key("global_targeting_key")
355 .with_custom_field("key", "global_value"),
356 ),
357 )
358 .return_const(Ok(ResolutionDetails::new(100)));
359
360 provider
361 .expect_resolve_int_value()
362 .with(
363 predicate::eq("flag"),
364 predicate::eq(
365 EvaluationContext::default()
366 .with_targeting_key("client_targeting_key")
367 .with_custom_field("key", "client_value"),
368 ),
369 )
370 .return_const(Ok(ResolutionDetails::new(200)));
371
372 provider
373 .expect_resolve_int_value()
374 .with(
375 predicate::eq("flag"),
376 predicate::eq(
377 EvaluationContext::default()
378 .with_targeting_key("invocation_targeting_key")
379 .with_custom_field("key", "invocation_value"),
380 ),
381 )
382 .return_const(Ok(ResolutionDetails::new(300)));
383
384 let mut api = OpenFeature::default();
386 api.set_provider(provider).await;
387
388 let global_evaluation_context = EvaluationContext::default()
390 .with_targeting_key("global_targeting_key")
391 .with_custom_field("key", "global_value");
392
393 api.set_evaluation_context(global_evaluation_context).await;
394
395 let mut client = api.create_client();
396
397 assert_eq!(client.get_int_value("flag", None, None).await.unwrap(), 100);
398
399 let client_evaluation_context = EvaluationContext::default()
401 .with_targeting_key("client_targeting_key")
402 .with_custom_field("key", "client_value");
403
404 client.set_evaluation_context(client_evaluation_context);
405
406 assert_eq!(client.get_int_value("flag", None, None).await.unwrap(), 200);
407
408 let invocation_evaluation_context = EvaluationContext::default()
410 .with_targeting_key("invocation_targeting_key")
411 .with_custom_field("key", "invocation_value");
412
413 assert_eq!(
414 client
415 .get_int_value("flag", Some(&invocation_evaluation_context), None)
416 .await
417 .unwrap(),
418 300
419 );
420 }
421
422 #[spec(
423 number = "3.2.2.1",
424 text = "The API MUST have a method for setting the global evaluation context."
425 )]
426 #[spec(
427 number = "3.2.2.2",
428 text = "The Client and invocation MUST NOT have a method for supplying evaluation context."
429 )]
430 #[spec(
431 number = "3.2.4.1",
432 text = "When the global evaluation context is set, the on context changed handler MUST run."
433 )]
434 #[test]
435 fn static_context_not_applicable() {}
436
437 #[derive(Clone, Default, Debug)]
438 struct MyStruct {}
439
440 #[tokio::test]
441 async fn extended_example() {
442 let mut api = OpenFeature::singleton_mut().await;
444
445 api.set_provider(NoOpProvider::default()).await;
447
448 let client = api.create_client();
450
451 let evaluation_context = EvaluationContext::default()
454 .with_targeting_key("Targeting")
455 .with_custom_field("bool_key", true)
456 .with_custom_field("int_key", 100)
457 .with_custom_field("float_key", 3.14)
458 .with_custom_field("string_key", "Hello".to_string())
459 .with_custom_field("datetime_key", time::OffsetDateTime::now_utc())
460 .with_custom_field(
461 "struct_key",
462 EvaluationContextFieldValue::Struct(Arc::new(MyStruct::default())),
463 )
464 .with_custom_field("another_struct_key", Arc::new(MyStruct::default()))
465 .with_custom_field(
466 "yet_another_struct_key",
467 EvaluationContextFieldValue::new_struct(MyStruct::default()),
468 );
469
470 let is_feature_enabled = client
473 .get_bool_value("SomeFlagEnabled", Some(&evaluation_context), None)
474 .await
475 .unwrap_or(false);
476
477 if is_feature_enabled {
478 let _result = client
480 .get_int_details("key", Some(&evaluation_context), None)
481 .await;
482 }
483 }
484}