atrium_common/
resolver.rs

1mod cached;
2mod throttled;
3
4pub use self::cached::CachedResolver;
5pub use self::throttled::ThrottledResolver;
6use std::future::Future;
7
8#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
9pub trait Resolver {
10    type Input: ?Sized;
11    type Output;
12    type Error;
13
14    fn resolve(
15        &self,
16        input: &Self::Input,
17    ) -> impl Future<Output = core::result::Result<Self::Output, Self::Error>>;
18}
19
20#[cfg(test)]
21mod tests {
22    use super::*;
23    use crate::types::cached::r#impl::{Cache, CacheImpl};
24    use crate::types::cached::{CacheConfig, Cacheable};
25    use crate::types::throttled::Throttleable;
26    use std::collections::HashMap;
27    use std::sync::Arc;
28    use std::time::Duration;
29    use tokio::sync::RwLock;
30    #[cfg(target_arch = "wasm32")]
31    use wasm_bindgen_test::wasm_bindgen_test;
32
33    #[cfg(not(target_arch = "wasm32"))]
34    async fn sleep(duration: Duration) {
35        tokio::time::sleep(duration).await;
36    }
37
38    #[cfg(target_arch = "wasm32")]
39    async fn sleep(duration: Duration) {
40        gloo_timers::future::sleep(duration).await;
41    }
42
43    #[derive(Debug, PartialEq)]
44    struct Error;
45
46    type Result<T> = core::result::Result<T, Error>;
47
48    struct MockResolver {
49        data: HashMap<String, String>,
50        counts: Arc<RwLock<HashMap<String, usize>>>,
51    }
52
53    impl Resolver for MockResolver {
54        type Input = String;
55        type Output = String;
56        type Error = Error;
57
58        async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
59            sleep(Duration::from_millis(10)).await;
60            *self.counts.write().await.entry(input.clone()).or_default() += 1;
61            if let Some(value) = self.data.get(input) { Ok(value.clone()) } else { Err(Error) }
62        }
63    }
64
65    fn mock_resolver(counts: Arc<RwLock<HashMap<String, usize>>>) -> MockResolver {
66        MockResolver {
67            data: [
68                (String::from("k1"), String::from("v1")),
69                (String::from("k2"), String::from("v2")),
70            ]
71            .into_iter()
72            .collect(),
73            counts,
74        }
75    }
76
77    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
78    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
79    async fn test_no_cached() {
80        let counts = Arc::new(RwLock::new(HashMap::new()));
81        let resolver = mock_resolver(counts.clone());
82        for (input, expected) in [
83            ("k1", Some("v1")),
84            ("k2", Some("v2")),
85            ("k2", Some("v2")),
86            ("k1", Some("v1")),
87            ("k3", None),
88            ("k1", Some("v1")),
89            ("k3", None),
90        ] {
91            let result = resolver.resolve(&input.to_string()).await;
92            match expected {
93                Some(value) => assert_eq!(result.expect("failed to resolve"), value),
94                None => assert_eq!(result.expect_err("succesfully resolved"), Error),
95            }
96        }
97        assert_eq!(
98            *counts.read().await,
99            [(String::from("k1"), 3), (String::from("k2"), 2), (String::from("k3"), 2),]
100                .into_iter()
101                .collect()
102        );
103    }
104
105    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
106    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
107    async fn test_cached() {
108        let counts = Arc::new(RwLock::new(HashMap::new()));
109        let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig::default()));
110        for (input, expected) in [
111            ("k1", Some("v1")),
112            ("k2", Some("v2")),
113            ("k2", Some("v2")),
114            ("k1", Some("v1")),
115            ("k3", None),
116            ("k1", Some("v1")),
117            ("k3", None),
118        ] {
119            let result = resolver.resolve(&input.to_string()).await;
120            match expected {
121                Some(value) => assert_eq!(result.expect("failed to resolve"), value),
122                None => assert_eq!(result.expect_err("succesfully resolved"), Error),
123            }
124        }
125        assert_eq!(
126            *counts.read().await,
127            [(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 2),]
128                .into_iter()
129                .collect()
130        );
131    }
132
133    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
134    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
135    async fn test_cached_with_max_capacity() {
136        let counts = Arc::new(RwLock::new(HashMap::new()));
137        let resolver = mock_resolver(counts.clone())
138            .cached(CacheImpl::new(CacheConfig { max_capacity: Some(1), ..Default::default() }));
139        for (input, expected) in [
140            ("k1", Some("v1")),
141            ("k2", Some("v2")),
142            ("k2", Some("v2")),
143            ("k1", Some("v1")),
144            ("k3", None),
145            ("k1", Some("v1")),
146            ("k3", None),
147        ] {
148            let result = resolver.resolve(&input.to_string()).await;
149            match expected {
150                Some(value) => assert_eq!(result.expect("failed to resolve"), value),
151                None => assert_eq!(result.expect_err("succesfully resolved"), Error),
152            }
153        }
154        assert_eq!(
155            *counts.read().await,
156            [(String::from("k1"), 2), (String::from("k2"), 1), (String::from("k3"), 2),]
157                .into_iter()
158                .collect()
159        );
160    }
161
162    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
163    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
164    async fn test_cached_with_time_to_live() {
165        let counts = Arc::new(RwLock::new(HashMap::new()));
166        let resolver = mock_resolver(counts.clone()).cached(CacheImpl::new(CacheConfig {
167            time_to_live: Some(Duration::from_millis(10)),
168            ..Default::default()
169        }));
170        for _ in 0..10 {
171            let result = resolver.resolve(&String::from("k1")).await;
172            assert_eq!(result.expect("failed to resolve"), "v1");
173        }
174        sleep(Duration::from_millis(10)).await;
175        for _ in 0..10 {
176            let result = resolver.resolve(&String::from("k1")).await;
177            assert_eq!(result.expect("failed to resolve"), "v1");
178        }
179        assert_eq!(*counts.read().await, [(String::from("k1"), 2)].into_iter().collect());
180    }
181
182    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
183    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
184    async fn test_throttled() {
185        let counts = Arc::new(RwLock::new(HashMap::new()));
186        let resolver = Arc::new(mock_resolver(counts.clone()).throttled());
187
188        let mut handles = Vec::new();
189        for (input, expected) in [
190            ("k1", Some("v1")),
191            ("k2", Some("v2")),
192            ("k2", Some("v2")),
193            ("k1", Some("v1")),
194            ("k3", None),
195            ("k1", Some("v1")),
196            ("k3", None),
197        ] {
198            let resolver = resolver.clone();
199            handles.push(async move { (resolver.resolve(&input.to_string()).await, expected) });
200        }
201        for (result, expected) in futures::future::join_all(handles).await {
202            let result = result.and_then(|opt| opt.ok_or(Error));
203
204            match expected {
205                Some(value) => {
206                    assert_eq!(result.expect("failed to resolve"), value)
207                }
208                None => assert_eq!(result.expect_err("succesfully resolved"), Error),
209            }
210        }
211        assert_eq!(
212            *counts.read().await,
213            [(String::from("k1"), 1), (String::from("k2"), 1), (String::from("k3"), 1),]
214                .into_iter()
215                .collect()
216        );
217    }
218}