1use std::future::Future;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use tokio::sync::Mutex;
6
7pub struct GtsPluginSelector {
13 cached: RwLock<Option<Arc<str>>>,
15 resolve_lock: Mutex<()>,
17}
18
19impl Default for GtsPluginSelector {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl GtsPluginSelector {
26 #[must_use]
27 pub fn new() -> Self {
28 Self {
29 cached: RwLock::new(None),
30 resolve_lock: Mutex::new(()),
31 }
32 }
33
34 pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
43 where
44 F: FnOnce() -> Fut,
45 Fut: Future<Output = Result<String, E>>,
46 {
47 {
49 let guard = self.cached.read();
50 if let Some(ref id) = *guard {
51 return Ok(Arc::clone(id));
52 }
53 }
54
55 let _resolve_guard = self.resolve_lock.lock().await;
57
58 {
60 let guard = self.cached.read();
61 if let Some(ref id) = *guard {
62 return Ok(Arc::clone(id));
63 }
64 }
65
66 let id_string = resolve().await?;
68 let id: Arc<str> = id_string.into();
69
70 {
71 let mut guard = self.cached.write();
72 *guard = Some(Arc::clone(&id));
73 }
74
75 Ok(id)
76 }
77
78 pub async fn reset(&self) -> bool {
82 let _resolve_guard = self.resolve_lock.lock().await;
83 let mut guard = self.cached.write();
84 guard.take().is_some()
85 }
86}
87
88#[cfg(test)]
89#[cfg_attr(coverage_nightly, coverage(off))]
90mod tests {
91 use super::*;
92 use std::sync::Arc;
93 use std::sync::atomic::{AtomicUsize, Ordering};
94
95 #[tokio::test]
96 async fn resolve_called_once_returns_same_str() {
97 let selector = GtsPluginSelector::new();
98 let calls = Arc::new(AtomicUsize::new(0));
99
100 let calls_a = calls.clone();
101 let id_a = selector
102 .get_or_init(|| async move {
103 calls_a.fetch_add(1, Ordering::SeqCst);
104 Ok::<_, std::convert::Infallible>(
105 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
106 .to_owned(),
107 )
108 })
109 .await
110 .unwrap();
111
112 let calls_b = calls.clone();
113 let id_b = selector
114 .get_or_init(|| async move {
115 calls_b.fetch_add(1, Ordering::SeqCst);
116 Ok::<_, std::convert::Infallible>(
117 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
118 .to_owned(),
119 )
120 })
121 .await
122 .unwrap();
123
124 assert_eq!(id_a, id_b);
125 assert_eq!(calls.load(Ordering::SeqCst), 1);
126 }
127
128 #[tokio::test]
129 async fn reset_triggers_reselection() {
130 let selector = GtsPluginSelector::new();
131 let calls = Arc::new(AtomicUsize::new(0));
132
133 let calls_a = calls.clone();
134 let id_a = selector
135 .get_or_init(|| async move {
136 calls_a.fetch_add(1, Ordering::SeqCst);
137 Ok::<_, std::convert::Infallible>(
138 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
139 .to_owned(),
140 )
141 })
142 .await;
143 assert_eq!(
144 &*id_a.unwrap(),
145 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
146 );
147 assert_eq!(calls.load(Ordering::SeqCst), 1);
148 assert!(selector.reset().await);
149
150 let calls_b = calls.clone();
151 let id_b = selector
152 .get_or_init(|| async move {
153 calls_b.fetch_add(1, Ordering::SeqCst);
154 Ok::<_, std::convert::Infallible>(
155 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
156 .to_owned(),
157 )
158 })
159 .await;
160 assert_eq!(
161 &*id_b.unwrap(),
162 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
163 );
164 assert_eq!(calls.load(Ordering::SeqCst), 2);
165 }
166
167 #[tokio::test]
168 async fn concurrent_get_or_init_resolves_once() {
169 let selector = Arc::new(GtsPluginSelector::new());
170 let calls = Arc::new(AtomicUsize::new(0));
171
172 let mut handles = Vec::new();
173 for _ in 0..10 {
174 let selector = Arc::clone(&selector);
175 let calls = Arc::clone(&calls);
176 handles.push(tokio::spawn(async move {
177 selector
178 .get_or_init(|| async {
179 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
181 calls.fetch_add(1, Ordering::SeqCst);
182 Ok::<_, std::convert::Infallible>(
183 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
184 .to_owned(),
185 )
186 })
187 .await
188 }));
189 }
190
191 let mut results = Vec::new();
193 for handle in handles {
194 results.push(handle.await.unwrap().unwrap());
195 }
196
197 for id in &results {
199 assert_eq!(
200 &**id,
201 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
202 );
203 }
204
205 assert_eq!(calls.load(Ordering::SeqCst), 1);
207 }
208}