Skip to main content

open_feature/api/
api.rs

1use std::sync::OnceLock;
2
3use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
4
5use crate::{
6    provider::{FeatureProvider, ProviderMetadata},
7    Client, EvaluationContext, Hook, HookWrapper,
8};
9
10use super::{
11    global_evaluation_context::GlobalEvaluationContext, global_hooks::GlobalHooks,
12    provider_registry::ProviderRegistry,
13};
14
15/// The singleton instance of [`OpenFeature`] struct.
16/// The client should always use this instance to access OpenFeature APIs.
17static SINGLETON: OnceLock<RwLock<OpenFeature>> = OnceLock::new();
18
19fn get_singleton() -> &'static RwLock<OpenFeature> {
20    SINGLETON.get_or_init(|| RwLock::new(OpenFeature::default()))
21}
22
23/// THE struct of the OpenFeature API.
24/// Access it via [`OpenFeature::singleton()`] or [`OpenFeature::singleton_mut()`].
25#[derive(Default)]
26pub struct OpenFeature {
27    evaluation_context: GlobalEvaluationContext,
28    hooks: GlobalHooks,
29
30    provider_registry: ProviderRegistry,
31}
32
33impl OpenFeature {
34    /// Get the singleton of [`OpenFeature`].
35    pub async fn singleton() -> RwLockReadGuard<'static, Self> {
36        get_singleton().read().await
37    }
38
39    /// Get a mutable singleton of [`OpenFeature`].
40    pub async fn singleton_mut() -> RwLockWriteGuard<'static, Self> {
41        get_singleton().write().await
42    }
43
44    /// Set the global evaluation context.
45    pub async fn set_evaluation_context(&mut self, evaluation_context: EvaluationContext) {
46        let mut context = self.evaluation_context.get_mut().await;
47
48        context.targeting_key = evaluation_context.targeting_key;
49        context.custom_fields = evaluation_context.custom_fields;
50    }
51
52    /// Set the default provider.
53    pub async fn set_provider<T: FeatureProvider>(&mut self, provider: T) {
54        self.provider_registry.set_default(provider).await;
55    }
56
57    /// Bind the given `provider` to the corresponding `name`.
58    pub async fn set_named_provider<T: FeatureProvider>(&mut self, name: &str, provider: T) {
59        self.provider_registry.set_named(name, provider).await;
60    }
61
62    /// Add a new hook to the global list of hooks.
63    pub async fn add_hook<T: Hook>(&mut self, hook: T) {
64        let mut lock = self.hooks.get_mut().await;
65        lock.push(HookWrapper::new(hook));
66    }
67
68    /// Return the metadata of default (unnamed) provider.
69    pub async fn provider_metadata(&self) -> ProviderMetadata {
70        self.provider_registry
71            .get_default()
72            .await
73            .get()
74            .metadata()
75            .clone()
76    }
77
78    /// Return the metadata of named provider (a provider bound to clients with this name).
79    pub async fn named_provider_metadata(&self, name: &str) -> Option<ProviderMetadata> {
80        self.provider_registry
81            .get_named(name)
82            .await
83            .map(|provider| provider.get().metadata().clone())
84    }
85
86    /// Create a new client with default name.
87    pub fn create_client(&self) -> Client {
88        Client::new(
89            String::default(),
90            self.evaluation_context.clone(),
91            self.hooks.clone(),
92            self.provider_registry.clone(),
93        )
94    }
95
96    /// Create a new client with specific `name`.
97    /// It will use the provider bound to this name, if any.
98    pub fn create_named_client(&self, name: &str) -> Client {
99        Client::new(
100            name.to_string(),
101            self.evaluation_context.clone(),
102            self.hooks.clone(),
103            self.provider_registry.clone(),
104        )
105    }
106
107    /// Drops all the registered providers.
108    pub async fn shutdown(&mut self) {
109        self.provider_registry.clear().await;
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::sync::Arc;
116
117    use super::*;
118    use crate::{
119        provider::{MockFeatureProvider, NoOpProvider, ResolutionDetails},
120        EvaluationContextFieldValue,
121    };
122    use mockall::predicate;
123    use spec::spec;
124
125    #[spec(
126        number = "1.1.1",
127        text = "The API, and any state it maintains SHOULD exist as a global singleton, even in cases wherein multiple versions of the API are present at runtime."
128    )]
129    #[tokio::test]
130    async fn singleton_multi_thread() {
131        let reader1 = tokio::spawn(async move {
132            let _ = OpenFeature::singleton().await.provider_metadata();
133        });
134
135        let writer = tokio::spawn(async move {
136            OpenFeature::singleton_mut()
137                .await
138                .set_provider(NoOpProvider::default())
139                .await;
140        });
141
142        let reader2 = tokio::spawn(async move {
143            let _ = OpenFeature::singleton().await.provider_metadata();
144        });
145
146        let _ = (reader1.await, reader2.await, writer.await);
147
148        assert_eq!(
149            "No-op Provider",
150            OpenFeature::singleton()
151                .await
152                .provider_metadata()
153                .await
154                .name
155        );
156    }
157
158    #[spec(
159        number = "1.1.2.1",
160        text = "The API MUST define a provider mutator, a function to set the default provider, which accepts an API-conformant provider implementation."
161    )]
162    #[tokio::test]
163    async fn set_provider() {
164        let mut api = OpenFeature::default();
165        let client = api.create_client();
166
167        assert!(client.get_int_value("some-key", None, None).await.is_err());
168
169        // Set the new provider and ensure the value comes from it.
170        let mut provider = MockFeatureProvider::new();
171        provider.expect_initialize().returning(|_| {});
172        provider.expect_hooks().return_const(vec![]);
173        provider
174            .expect_metadata()
175            .return_const(ProviderMetadata::default());
176        provider
177            .expect_resolve_int_value()
178            .return_const(Ok(ResolutionDetails::new(200)));
179
180        api.set_provider(provider).await;
181
182        assert_eq!(
183            client.get_int_value("some-key", None, None).await.unwrap(),
184            200
185        );
186    }
187
188    #[spec(
189        number = "1.1.2.2",
190        text = "The provider mutator function MUST invoke the initialize function on the newly registered provider before using it to resolve flag values."
191    )]
192    #[tokio::test]
193    async fn set_provider_invoke_initialize() {
194        let mut provider = MockFeatureProvider::new();
195        provider.expect_initialize().returning(|_| {}).once();
196
197        let mut api = OpenFeature::default();
198        api.set_provider(provider).await;
199    }
200
201    #[spec(
202        number = "1.1.2.3",
203        text = "The provider mutator function MUST invoke the shutdown function on the previously registered provider once it's no longer being used to resolve flag values."
204    )]
205    #[test]
206    fn invoke_shutdown_on_old_provider_checked_by_type_system() {}
207
208    #[spec(
209        number = "1.1.3",
210        text = "The API MUST provide a function to bind a given provider to one or more client names. If the client-name already has a bound provider, it is overwritten with the new mapping."
211    )]
212    #[tokio::test]
213    async fn set_named_provider() {
214        let mut api = OpenFeature::default();
215
216        // Ensure the No-op provider is used.
217        let client = api.create_named_client("test");
218        assert!(client.get_int_value("", None, None).await.is_err());
219
220        // Bind provider to the same name.
221        let mut provider = MockFeatureProvider::new();
222        provider.expect_initialize().returning(|_| {});
223        provider.expect_hooks().return_const(vec![]);
224        provider
225            .expect_metadata()
226            .return_const(ProviderMetadata::default());
227        provider
228            .expect_resolve_int_value()
229            .return_const(Ok(ResolutionDetails::new(30)));
230        api.set_named_provider("test", provider).await;
231
232        // Ensure the new provider is used for existing clients.
233        assert_eq!(client.get_int_value("", None, None).await, Ok(30));
234
235        // Create a new client and ensure new provider is used.
236        let new_client = api.create_named_client("test");
237        assert_eq!(new_client.get_int_value("", None, None).await, Ok(30));
238    }
239
240    #[spec(
241        number = "1.1.5",
242        text = "The API MUST provide a function for retrieving the metadata field of the configured provider."
243    )]
244    #[tokio::test]
245    async fn provider_metadata() {
246        let mut api = OpenFeature::default();
247        api.set_provider(NoOpProvider::default()).await;
248        api.set_named_provider("test", NoOpProvider::default())
249            .await;
250
251        assert_eq!(api.provider_metadata().await.name, "No-op Provider");
252        assert_eq!(
253            api.named_provider_metadata("test").await.unwrap().name,
254            "No-op Provider"
255        );
256        assert!(api.named_provider_metadata("invalid").await.is_none());
257    }
258
259    #[spec(
260        number = "1.1.6",
261        text = "The API MUST provide a function for creating a client which accepts the following options:
262        * name (optional): A logical string identifier for the client."
263    )]
264    #[tokio::test]
265    async fn get_client() {
266        let mut api = OpenFeature::default();
267
268        let mut default_provider = MockFeatureProvider::new();
269        default_provider.expect_initialize().returning(|_| {});
270        default_provider.expect_hooks().return_const(vec![]);
271        default_provider
272            .expect_metadata()
273            .return_const(ProviderMetadata::default());
274        default_provider
275            .expect_resolve_int_value()
276            .return_const(Ok(ResolutionDetails::new(100)));
277
278        let mut named_provider = MockFeatureProvider::new();
279        named_provider.expect_initialize().returning(|_| {});
280        named_provider.expect_hooks().return_const(vec![]);
281        named_provider
282            .expect_metadata()
283            .return_const(ProviderMetadata::default());
284        named_provider
285            .expect_resolve_int_value()
286            .return_const(Ok(ResolutionDetails::new(200)));
287
288        api.set_provider(default_provider).await;
289        api.set_named_provider("test", named_provider).await;
290
291        let client = api.create_client();
292        assert_eq!(client.get_int_value("key", None, None).await.unwrap(), 100);
293
294        let client = api.create_named_client("test");
295        assert_eq!(client.get_int_value("key", None, None).await.unwrap(), 200);
296
297        let client = api.create_named_client("another");
298        assert_eq!(client.get_int_value("test", None, None).await.unwrap(), 100);
299    }
300
301    #[spec(
302        number = "1.1.7",
303        text = "The client creation function MUST NOT throw, or otherwise abnormally terminate."
304    )]
305    #[test]
306    fn get_client_not_throw_checked_by_type_system() {}
307
308    #[spec(
309        number = "1.1.8",
310        text = "The API SHOULD provide functions to set a provider and wait for the initialize function to return or throw."
311    )]
312    #[tokio::test]
313    async fn set_provider_should_block() {
314        let mut api = OpenFeature::default();
315        api.set_provider(NoOpProvider::default()).await;
316
317        api.set_named_provider("named", NoOpProvider::default())
318            .await;
319    }
320
321    #[spec(
322        number = "1.6.1",
323        text = "The API MUST define a shutdown function which, when called, must call the respective shutdown function on the active provider."
324    )]
325    #[tokio::test]
326    async fn shutdown() {
327        let mut api = OpenFeature::default();
328        api.set_provider(NoOpProvider::default()).await;
329
330        api.shutdown().await;
331    }
332
333    #[spec(
334        number = "3.2.1.1",
335        text = "The API, Client and invocation MUST have a method for supplying evaluation context."
336    )]
337    #[spec(
338        number = "3.2.3",
339        text = "Evaluation context MUST be merged in the order: API (global; lowest precedence) -> client -> invocation -> before hooks (highest precedence), with duplicate values being overwritten."
340    )]
341    #[tokio::test]
342    async fn evaluation_context() {
343        // Setup expectations for different evaluation contexts.
344        let mut provider = MockFeatureProvider::new();
345        provider.expect_initialize().returning(|_| {});
346        provider.expect_hooks().return_const(vec![]);
347        provider
348            .expect_metadata()
349            .return_const(ProviderMetadata::default());
350
351        provider
352            .expect_resolve_int_value()
353            .with(
354                predicate::eq("flag"),
355                predicate::eq(
356                    EvaluationContext::default()
357                        .with_targeting_key("global_targeting_key")
358                        .with_custom_field("key", "global_value"),
359                ),
360            )
361            .return_const(Ok(ResolutionDetails::new(100)));
362
363        provider
364            .expect_resolve_int_value()
365            .with(
366                predicate::eq("flag"),
367                predicate::eq(
368                    EvaluationContext::default()
369                        .with_targeting_key("client_targeting_key")
370                        .with_custom_field("key", "client_value"),
371                ),
372            )
373            .return_const(Ok(ResolutionDetails::new(200)));
374
375        provider
376            .expect_resolve_int_value()
377            .with(
378                predicate::eq("flag"),
379                predicate::eq(
380                    EvaluationContext::default()
381                        .with_targeting_key("invocation_targeting_key")
382                        .with_custom_field("key", "invocation_value"),
383                ),
384            )
385            .return_const(Ok(ResolutionDetails::new(300)));
386
387        // Register the provider.
388        let mut api = OpenFeature::default();
389        api.set_provider(provider).await;
390
391        // Set global client context and ensure its values are picked up.
392        let global_evaluation_context = EvaluationContext::default()
393            .with_targeting_key("global_targeting_key")
394            .with_custom_field("key", "global_value");
395
396        api.set_evaluation_context(global_evaluation_context).await;
397
398        let mut client = api.create_client();
399
400        assert_eq!(client.get_int_value("flag", None, None).await.unwrap(), 100);
401
402        // Set client evaluation context and ensure its values overwrite the global ones.
403        let client_evaluation_context = EvaluationContext::default()
404            .with_targeting_key("client_targeting_key")
405            .with_custom_field("key", "client_value");
406
407        client.set_evaluation_context(client_evaluation_context);
408
409        assert_eq!(client.get_int_value("flag", None, None).await.unwrap(), 200);
410
411        // Use invocation level evaluation context and ensure its values are used.
412        let invocation_evaluation_context = EvaluationContext::default()
413            .with_targeting_key("invocation_targeting_key")
414            .with_custom_field("key", "invocation_value");
415
416        assert_eq!(
417            client
418                .get_int_value("flag", Some(&invocation_evaluation_context), None)
419                .await
420                .unwrap(),
421            300
422        );
423    }
424
425    #[spec(
426        number = "3.2.2.1",
427        text = "The API MUST have a method for setting the global evaluation context."
428    )]
429    #[spec(
430        number = "3.2.2.2",
431        text = "The Client and invocation MUST NOT have a method for supplying evaluation context."
432    )]
433    #[spec(
434        number = "3.2.4.1",
435        text = "When the global evaluation context is set, the on context changed handler MUST run."
436    )]
437    #[test]
438    fn static_context_not_applicable() {}
439
440    #[derive(Clone, Default, Debug)]
441    struct MyStruct {}
442
443    #[tokio::test]
444    async fn extended_example() {
445        // Acquire an OpenFeature API instance.
446        let mut api = OpenFeature::singleton_mut().await;
447
448        // Set the default (unnamed) provider.
449        api.set_provider(NoOpProvider::default()).await;
450
451        // Create an unnamed client.
452        let client = api.create_client();
453
454        // Create an evaluation context.
455        // It supports types mentioned in the specification.
456        let evaluation_context = EvaluationContext::default()
457            .with_targeting_key("Targeting")
458            .with_custom_field("bool_key", true)
459            .with_custom_field("int_key", 100)
460            .with_custom_field("float_key", 3.14)
461            .with_custom_field("string_key", "Hello".to_string())
462            .with_custom_field("datetime_key", time::OffsetDateTime::now_utc())
463            .with_custom_field(
464                "struct_key",
465                EvaluationContextFieldValue::Struct(Arc::new(MyStruct::default())),
466            )
467            .with_custom_field("another_struct_key", Arc::new(MyStruct::default()))
468            .with_custom_field(
469                "yet_another_struct_key",
470                EvaluationContextFieldValue::new_struct(MyStruct::default()),
471            );
472
473        // This function returns a `Result`.
474        // You can process it with functions provided by std.
475        let is_feature_enabled = client
476            .get_bool_value("SomeFlagEnabled", Some(&evaluation_context), None)
477            .await
478            .unwrap_or(false);
479
480        if is_feature_enabled {
481            // Let's get evaluation details.
482            let _result = client
483                .get_int_details("key", Some(&evaluation_context), None)
484                .await;
485        }
486    }
487}