open_feature/api/
api.rs

1use lazy_static::lazy_static;
2use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
3
4use crate::{
5    provider::{FeatureProvider, ProviderMetadata},
6    Client, EvaluationContext, Hook, HookWrapper,
7};
8
9use super::{
10    global_evaluation_context::GlobalEvaluationContext, global_hooks::GlobalHooks,
11    provider_registry::ProviderRegistry,
12};
13
14lazy_static! {
15    /// The singleton instance of [`OpenFeature`] struct.
16    /// The client should always use this instance to access OpenFeature APIs.
17    static ref SINGLETON: RwLock<OpenFeature> = RwLock::new(OpenFeature::default());
18}
19
20/// THE struct of the OpenFeature API.
21/// Access it via the [`SINGLETON`] instance.
22#[derive(Default)]
23pub struct OpenFeature {
24    evaluation_context: GlobalEvaluationContext,
25    hooks: GlobalHooks,
26
27    provider_registry: ProviderRegistry,
28}
29
30impl OpenFeature {
31    /// Get the singleton of [`OpenFeature`].
32    pub async fn singleton() -> RwLockReadGuard<'static, Self> {
33        SINGLETON.read().await
34    }
35
36    /// Get a mutable singleton of [`OpenFeature`].
37    pub async fn singleton_mut() -> RwLockWriteGuard<'static, Self> {
38        SINGLETON.write().await
39    }
40
41    /// Set the global evaluation context.
42    pub async fn set_evaluation_context(&mut self, evaluation_context: EvaluationContext) {
43        let mut context = self.evaluation_context.get_mut().await;
44
45        context.targeting_key = evaluation_context.targeting_key;
46        context.custom_fields = evaluation_context.custom_fields;
47    }
48
49    /// Set the default provider.
50    pub async fn set_provider<T: FeatureProvider>(&mut self, provider: T) {
51        self.provider_registry.set_default(provider).await;
52    }
53
54    /// Bind the given `provider` to the corresponding `name`.
55    pub async fn set_named_provider<T: FeatureProvider>(&mut self, name: &str, provider: T) {
56        self.provider_registry.set_named(name, provider).await;
57    }
58
59    /// Add a new hook to the global list of hooks.
60    pub async fn add_hook<T: Hook>(&mut self, hook: T) {
61        let mut lock = self.hooks.get_mut().await;
62        lock.push(HookWrapper::new(hook));
63    }
64
65    /// Return the metadata of default (unnamed) provider.
66    pub async fn provider_metadata(&self) -> ProviderMetadata {
67        self.provider_registry
68            .get_default()
69            .await
70            .get()
71            .metadata()
72            .clone()
73    }
74
75    /// Return the metadata of named provider (a provider bound to clients with this name).
76    pub async fn named_provider_metadata(&self, name: &str) -> Option<ProviderMetadata> {
77        self.provider_registry
78            .get_named(name)
79            .await
80            .map(|provider| provider.get().metadata().clone())
81    }
82
83    /// Create a new client with default name.
84    pub fn create_client(&self) -> Client {
85        Client::new(
86            String::default(),
87            self.evaluation_context.clone(),
88            self.hooks.clone(),
89            self.provider_registry.clone(),
90        )
91    }
92
93    /// Create a new client with specific `name`.
94    /// It will use the provider bound to this name, if any.
95    pub fn create_named_client(&self, name: &str) -> Client {
96        Client::new(
97            name.to_string(),
98            self.evaluation_context.clone(),
99            self.hooks.clone(),
100            self.provider_registry.clone(),
101        )
102    }
103
104    /// Drops all the registered providers.
105    pub async fn shutdown(&mut self) {
106        self.provider_registry.clear().await;
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use std::sync::Arc;
113
114    use super::*;
115    use crate::{
116        provider::{MockFeatureProvider, NoOpProvider, ResolutionDetails},
117        EvaluationContextFieldValue,
118    };
119    use mockall::predicate;
120    use spec::spec;
121
122    #[spec(
123        number = "1.1.1",
124        text = "The API, and any state it maintains SHOULD exist as a global singleton, even in cases wherein multiple versions of the API are present at runtime."
125    )]
126    #[tokio::test]
127    async fn singleton_multi_thread() {
128        let reader1 = tokio::spawn(async move {
129            let _ = OpenFeature::singleton().await.provider_metadata();
130        });
131
132        let writer = tokio::spawn(async move {
133            OpenFeature::singleton_mut()
134                .await
135                .set_provider(NoOpProvider::default())
136                .await;
137        });
138
139        let reader2 = tokio::spawn(async move {
140            let _ = OpenFeature::singleton().await.provider_metadata();
141        });
142
143        let _ = (reader1.await, reader2.await, writer.await);
144
145        assert_eq!(
146            "No-op Provider",
147            OpenFeature::singleton()
148                .await
149                .provider_metadata()
150                .await
151                .name
152        );
153    }
154
155    #[spec(
156        number = "1.1.2.1",
157        text = "The API MUST define a provider mutator, a function to set the default provider, which accepts an API-conformant provider implementation."
158    )]
159    #[tokio::test]
160    async fn set_provider() {
161        let mut api = OpenFeature::default();
162        let client = api.create_client();
163
164        assert!(client.get_int_value("some-key", None, None).await.is_err());
165
166        // Set the new provider and ensure the value comes from it.
167        let mut provider = MockFeatureProvider::new();
168        provider.expect_initialize().returning(|_| {});
169        provider.expect_hooks().return_const(vec![]);
170        provider
171            .expect_metadata()
172            .return_const(ProviderMetadata::default());
173        provider
174            .expect_resolve_int_value()
175            .return_const(Ok(ResolutionDetails::new(200)));
176
177        api.set_provider(provider).await;
178
179        assert_eq!(
180            client.get_int_value("some-key", None, None).await.unwrap(),
181            200
182        );
183    }
184
185    #[spec(
186        number = "1.1.2.2",
187        text = "The provider mutator function MUST invoke the initialize function on the newly registered provider before using it to resolve flag values."
188    )]
189    #[tokio::test]
190    async fn set_provider_invoke_initialize() {
191        let mut provider = MockFeatureProvider::new();
192        provider.expect_initialize().returning(|_| {}).once();
193
194        let mut api = OpenFeature::default();
195        api.set_provider(provider).await;
196    }
197
198    #[spec(
199        number = "1.1.2.3",
200        text = "The provider mutator function MUST invoke the shutdown function on the previously registered provider once it's no longer being used to resolve flag values."
201    )]
202    #[test]
203    fn invoke_shutdown_on_old_provider_checked_by_type_system() {}
204
205    #[spec(
206        number = "1.1.3",
207        text = "The API MUST provide a function to bind a given provider to one or more client names. If the client-name already has a bound provider, it is overwritten with the new mapping."
208    )]
209    #[tokio::test]
210    async fn set_named_provider() {
211        let mut api = OpenFeature::default();
212
213        // Ensure the No-op provider is used.
214        let client = api.create_named_client("test");
215        assert!(client.get_int_value("", None, None).await.is_err());
216
217        // Bind provider to the same name.
218        let mut provider = MockFeatureProvider::new();
219        provider.expect_initialize().returning(|_| {});
220        provider.expect_hooks().return_const(vec![]);
221        provider
222            .expect_metadata()
223            .return_const(ProviderMetadata::default());
224        provider
225            .expect_resolve_int_value()
226            .return_const(Ok(ResolutionDetails::new(30)));
227        api.set_named_provider("test", provider).await;
228
229        // Ensure the new provider is used for existing clients.
230        assert_eq!(client.get_int_value("", None, None).await, Ok(30));
231
232        // Create a new client and ensure new provider is used.
233        let new_client = api.create_named_client("test");
234        assert_eq!(new_client.get_int_value("", None, None).await, Ok(30));
235    }
236
237    #[spec(
238        number = "1.1.5",
239        text = "The API MUST provide a function for retrieving the metadata field of the configured provider."
240    )]
241    #[tokio::test]
242    async fn provider_metadata() {
243        let mut api = OpenFeature::default();
244        api.set_provider(NoOpProvider::default()).await;
245        api.set_named_provider("test", NoOpProvider::default())
246            .await;
247
248        assert_eq!(api.provider_metadata().await.name, "No-op Provider");
249        assert_eq!(
250            api.named_provider_metadata("test").await.unwrap().name,
251            "No-op Provider"
252        );
253        assert!(api.named_provider_metadata("invalid").await.is_none());
254    }
255
256    #[spec(
257        number = "1.1.6",
258        text = "The API MUST provide a function for creating a client which accepts the following options:
259        * name (optional): A logical string identifier for the client."
260    )]
261    #[tokio::test]
262    async fn get_client() {
263        let mut api = OpenFeature::default();
264
265        let mut default_provider = MockFeatureProvider::new();
266        default_provider.expect_initialize().returning(|_| {});
267        default_provider.expect_hooks().return_const(vec![]);
268        default_provider
269            .expect_metadata()
270            .return_const(ProviderMetadata::default());
271        default_provider
272            .expect_resolve_int_value()
273            .return_const(Ok(ResolutionDetails::new(100)));
274
275        let mut named_provider = MockFeatureProvider::new();
276        named_provider.expect_initialize().returning(|_| {});
277        named_provider.expect_hooks().return_const(vec![]);
278        named_provider
279            .expect_metadata()
280            .return_const(ProviderMetadata::default());
281        named_provider
282            .expect_resolve_int_value()
283            .return_const(Ok(ResolutionDetails::new(200)));
284
285        api.set_provider(default_provider).await;
286        api.set_named_provider("test", named_provider).await;
287
288        let client = api.create_client();
289        assert_eq!(client.get_int_value("key", None, None).await.unwrap(), 100);
290
291        let client = api.create_named_client("test");
292        assert_eq!(client.get_int_value("key", None, None).await.unwrap(), 200);
293
294        let client = api.create_named_client("another");
295        assert_eq!(client.get_int_value("test", None, None).await.unwrap(), 100);
296    }
297
298    #[spec(
299        number = "1.1.7",
300        text = "The client creation function MUST NOT throw, or otherwise abnormally terminate."
301    )]
302    #[test]
303    fn get_client_not_throw_checked_by_type_system() {}
304
305    #[spec(
306        number = "1.1.8",
307        text = "The API SHOULD provide functions to set a provider and wait for the initialize function to return or throw."
308    )]
309    #[tokio::test]
310    async fn set_provider_should_block() {
311        let mut api = OpenFeature::default();
312        api.set_provider(NoOpProvider::default()).await;
313
314        api.set_named_provider("named", NoOpProvider::default())
315            .await;
316    }
317
318    #[spec(
319        number = "1.6.1",
320        text = "The API MUST define a shutdown function which, when called, must call the respective shutdown function on the active provider."
321    )]
322    #[tokio::test]
323    async fn shutdown() {
324        let mut api = OpenFeature::default();
325        api.set_provider(NoOpProvider::default()).await;
326
327        api.shutdown().await;
328    }
329
330    #[spec(
331        number = "3.2.1.1",
332        text = "The API, Client and invocation MUST have a method for supplying evaluation context."
333    )]
334    #[spec(
335        number = "3.2.3",
336        text = "Evaluation context MUST be merged in the order: API (global; lowest precedence) -> client -> invocation -> before hooks (highest precedence), with duplicate values being overwritten."
337    )]
338    #[tokio::test]
339    async fn evaluation_context() {
340        // Setup expectations for different evaluation contexts.
341        let mut provider = MockFeatureProvider::new();
342        provider.expect_initialize().returning(|_| {});
343        provider.expect_hooks().return_const(vec![]);
344        provider
345            .expect_metadata()
346            .return_const(ProviderMetadata::default());
347
348        provider
349            .expect_resolve_int_value()
350            .with(
351                predicate::eq("flag"),
352                predicate::eq(
353                    EvaluationContext::default()
354                        .with_targeting_key("global_targeting_key")
355                        .with_custom_field("key", "global_value"),
356                ),
357            )
358            .return_const(Ok(ResolutionDetails::new(100)));
359
360        provider
361            .expect_resolve_int_value()
362            .with(
363                predicate::eq("flag"),
364                predicate::eq(
365                    EvaluationContext::default()
366                        .with_targeting_key("client_targeting_key")
367                        .with_custom_field("key", "client_value"),
368                ),
369            )
370            .return_const(Ok(ResolutionDetails::new(200)));
371
372        provider
373            .expect_resolve_int_value()
374            .with(
375                predicate::eq("flag"),
376                predicate::eq(
377                    EvaluationContext::default()
378                        .with_targeting_key("invocation_targeting_key")
379                        .with_custom_field("key", "invocation_value"),
380                ),
381            )
382            .return_const(Ok(ResolutionDetails::new(300)));
383
384        // Register the provider.
385        let mut api = OpenFeature::default();
386        api.set_provider(provider).await;
387
388        // Set global client context and ensure its values are picked up.
389        let global_evaluation_context = EvaluationContext::default()
390            .with_targeting_key("global_targeting_key")
391            .with_custom_field("key", "global_value");
392
393        api.set_evaluation_context(global_evaluation_context).await;
394
395        let mut client = api.create_client();
396
397        assert_eq!(client.get_int_value("flag", None, None).await.unwrap(), 100);
398
399        // Set client evaluation context and ensure its values overwrite the global ones.
400        let client_evaluation_context = EvaluationContext::default()
401            .with_targeting_key("client_targeting_key")
402            .with_custom_field("key", "client_value");
403
404        client.set_evaluation_context(client_evaluation_context);
405
406        assert_eq!(client.get_int_value("flag", None, None).await.unwrap(), 200);
407
408        // Use invocation level evaluation context and ensure its values are used.
409        let invocation_evaluation_context = EvaluationContext::default()
410            .with_targeting_key("invocation_targeting_key")
411            .with_custom_field("key", "invocation_value");
412
413        assert_eq!(
414            client
415                .get_int_value("flag", Some(&invocation_evaluation_context), None)
416                .await
417                .unwrap(),
418            300
419        );
420    }
421
422    #[spec(
423        number = "3.2.2.1",
424        text = "The API MUST have a method for setting the global evaluation context."
425    )]
426    #[spec(
427        number = "3.2.2.2",
428        text = "The Client and invocation MUST NOT have a method for supplying evaluation context."
429    )]
430    #[spec(
431        number = "3.2.4.1",
432        text = "When the global evaluation context is set, the on context changed handler MUST run."
433    )]
434    #[test]
435    fn static_context_not_applicable() {}
436
437    #[derive(Clone, Default, Debug)]
438    struct MyStruct {}
439
440    #[tokio::test]
441    async fn extended_example() {
442        // Acquire an OpenFeature API instance.
443        let mut api = OpenFeature::singleton_mut().await;
444
445        // Set the default (unnamed) provider.
446        api.set_provider(NoOpProvider::default()).await;
447
448        // Create an unnamed client.
449        let client = api.create_client();
450
451        // Create an evaluation context.
452        // It supports types mentioned in the specification.
453        let evaluation_context = EvaluationContext::default()
454            .with_targeting_key("Targeting")
455            .with_custom_field("bool_key", true)
456            .with_custom_field("int_key", 100)
457            .with_custom_field("float_key", 3.14)
458            .with_custom_field("string_key", "Hello".to_string())
459            .with_custom_field("datetime_key", time::OffsetDateTime::now_utc())
460            .with_custom_field(
461                "struct_key",
462                EvaluationContextFieldValue::Struct(Arc::new(MyStruct::default())),
463            )
464            .with_custom_field("another_struct_key", Arc::new(MyStruct::default()))
465            .with_custom_field(
466                "yet_another_struct_key",
467                EvaluationContextFieldValue::new_struct(MyStruct::default()),
468            );
469
470        // This function returns a `Result`.
471        // You can process it with functions provided by std.
472        let is_feature_enabled = client
473            .get_bool_value("SomeFlagEnabled", Some(&evaluation_context), None)
474            .await
475            .unwrap_or(false);
476
477        if is_feature_enabled {
478            // Let's get evaluation details.
479            let _result = client
480                .get_int_details("key", Some(&evaluation_context), None)
481                .await;
482        }
483    }
484}