atrium_common/
resolver.rs

1mod cached;
2mod throttled;
3
4pub use self::cached::CachedResolver;
5pub use self::throttled::ThrottledResolver;
6use std::future::Future;
7
8#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
9pub trait Resolver {
10    type Input: ?Sized;
11    type Output;
12    type Error;
13
14    fn resolve(
15        &self,
16        input: &Self::Input,
17    ) -> impl Future<Output = core::result::Result<Self::Output, Self::Error>>;
18}
19
20#[cfg(test)]
21mod tests {
22    use super::*;
23    use crate::types::cached::r#impl::{Cache, CacheImpl};
24    use crate::types::cached::{CacheConfig, Cacheable};
25    use crate::types::throttled::Throttleable;
26    use std::collections::HashMap;
27    use std::sync::Arc;
28    use std::time::Duration;
29    use tokio::sync::RwLock;
30    #[cfg(target_arch = "wasm32")]
31    use wasm_bindgen_test::wasm_bindgen_test;
32
33    #[cfg(not(target_arch = "wasm32"))]
34    async fn sleep(duration: Duration) {
35        tokio::time::sleep(duration).await;
36    }
37
38    #[cfg(target_arch = "wasm32")]
39    async fn sleep(duration: Duration) {
40        gloo_timers::future::sleep(duration).await;
41    }
42
43    #[derive(Debug, PartialEq)]
44    struct Error;
45
46    type Result<T> = core::result::Result<T, Error>;
47
48    struct MockResolver {
49        data: HashMap<String, String>,
50        counts: Arc<RwLock<HashMap<String, usize>>>,
51    }
52
53    impl Resolver for MockResolver {
54        type Input = String;
55        type Output = String;
56        type Error = Error;
57
58        async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
59            sleep(Duration::from_millis(10)).await;
60            *self.counts.write().await.entry(input.clone()).or_default() += 1;
61            if let Some(value) = self.data.get(input) {
62                Ok(value.clone())
63            } else {
64                Err(Error)
65            }
66        }
67    }
68
69    fn mock_resolver(counts: Arc<RwLock<HashMap<String, usize>>>) -> MockResolver {
70        MockResolver {
71            data: [
72                (String::from("k1"), String::from("v1")),
73                (String::from("k2"), String::from("v2")),
74            ]
75            .into_iter()
76            .collect(),
77            counts,
78        }
79    }
80
81    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
82    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
83    async fn test_no_cached() {
84        let counts = Arc::new(RwLock::new(HashMap::new()));
85        let resolver = mock_resolver(counts.clone());
86        for (input, expected) in [
87            ("k1", Some("v1")),
88            ("k2", Some("v2")),
89            ("k2", Some("v2")),
90            ("k1", Some("v1")),
91            ("k3", None),
92            ("k1", Some("v1")),
93            ("k3", None),
94        ] {
95            let result = resolver.resolve(&input.to_string()).await;
96            match expected {
97                Some(value) => assert_eq!(result.expect("failed to resolve"), value),
98                None => assert_eq!(result.expect_err("succesfully resolved"), Error),
99            }
100        }
101        assert_eq!(
102            *counts.read().await,
103            [(String::from("k1"), 3), (String::from("k2"), 2), (String::from("k3"), 2),]
104                .into_iter()
105                .collect()
106        );
107    }
108
109    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
110    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
111    async fn test_cached() {
112        let counts = Arc::new(RwLock::new(HashMap::new()));
113        let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig::default()));
114        for (input, expected) in [
115            ("k1", Some("v1")),
116            ("k2", Some("v2")),
117            ("k2", Some("v2")),
118            ("k1", Some("v1")),
119            ("k3", None),
120            ("k1", Some("v1")),
121            ("k3", None),
122        ] {
123            let result = resolver.resolve(&input.to_string()).await;
124            match expected {
125                Some(value) => assert_eq!(result.expect("failed to resolve"), value),
126                None => assert_eq!(result.expect_err("succesfully resolved"), Error),
127            }
128        }
129        assert_eq!(
130            *counts.read().await,
131            [(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 2),]
132                .into_iter()
133                .collect()
134        );
135    }
136
137    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
138    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
139    async fn test_cached_with_max_capacity() {
140        let counts = Arc::new(RwLock::new(HashMap::new()));
141        let resolver = mock_resolver(counts.clone())
142            .cached(CacheImpl::new(CacheConfig { max_capacity: Some(1), ..Default::default() }));
143        for (input, expected) in [
144            ("k1", Some("v1")),
145            ("k2", Some("v2")),
146            ("k2", Some("v2")),
147            ("k1", Some("v1")),
148            ("k3", None),
149            ("k1", Some("v1")),
150            ("k3", None),
151        ] {
152            let result = resolver.resolve(&input.to_string()).await;
153            match expected {
154                Some(value) => assert_eq!(result.expect("failed to resolve"), value),
155                None => assert_eq!(result.expect_err("succesfully resolved"), Error),
156            }
157        }
158        assert_eq!(
159            *counts.read().await,
160            [(String::from("k1"), 2), (String::from("k2"), 1), (String::from("k3"), 2),]
161                .into_iter()
162                .collect()
163        );
164    }
165
166    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
167    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
168    async fn test_cached_with_time_to_live() {
169        let counts = Arc::new(RwLock::new(HashMap::new()));
170        let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig {
171            time_to_live: Some(Duration::from_millis(10)),
172            ..Default::default()
173        }));
174        for _ in 0..10 {
175            let result = resolver.resolve(&String::from("k1")).await;
176            assert_eq!(result.expect("failed to resolve"), "v1");
177        }
178        sleep(Duration::from_millis(10)).await;
179        for _ in 0..10 {
180            let result = resolver.resolve(&String::from("k1")).await;
181            assert_eq!(result.expect("failed to resolve"), "v1");
182        }
183        assert_eq!(*counts.read().await, [(String::from("k1"), 2)].into_iter().collect());
184    }
185
186    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
187    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
188    async fn test_throttled() {
189        let counts = Arc::new(RwLock::new(HashMap::new()));
190        let resolver = Arc::new(mock_resolver(counts.clone()).throttled());
191
192        let mut handles = Vec::new();
193        for (input, expected) in [
194            ("k1", Some("v1")),
195            ("k2", Some("v2")),
196            ("k2", Some("v2")),
197            ("k1", Some("v1")),
198            ("k3", None),
199            ("k1", Some("v1")),
200            ("k3", None),
201        ] {
202            let resolver = resolver.clone();
203            handles.push(async move { (resolver.resolve(&input.to_string()).await, expected) });
204        }
205        for (result, expected) in futures::future::join_all(handles).await {
206            let result = result.and_then(|opt| opt.ok_or(Error));
207
208            match expected {
209                Some(value) => {
210                    assert_eq!(result.expect("failed to resolve"), value)
211                }
212                None => assert_eq!(result.expect_err("succesfully resolved"), Error),
213            }
214        }
215        assert_eq!(
216            *counts.read().await,
217            [(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 1),]
218                .into_iter()
219                .collect()
220        );
221    }
222}