1use std::future::Future;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use serde::de::DeserializeOwned;
6use tokio::sync::Mutex;
7
8use crate::gts::BaseModkitPluginV1;
9
10pub struct GtsPluginSelector {
16 cached: RwLock<Option<Arc<str>>>,
18 resolve_lock: Mutex<()>,
20}
21
22impl Default for GtsPluginSelector {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl GtsPluginSelector {
29 #[must_use]
30 pub fn new() -> Self {
31 Self {
32 cached: RwLock::new(None),
33 resolve_lock: Mutex::new(()),
34 }
35 }
36
37 pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
46 where
47 F: FnOnce() -> Fut,
48 Fut: Future<Output = Result<String, E>>,
49 {
50 {
52 let guard = self.cached.read();
53 if let Some(ref id) = *guard {
54 return Ok(Arc::clone(id));
55 }
56 }
57
58 let _resolve_guard = self.resolve_lock.lock().await;
60
61 {
63 let guard = self.cached.read();
64 if let Some(ref id) = *guard {
65 return Ok(Arc::clone(id));
66 }
67 }
68
69 let id_string = resolve().await?;
71 let id: Arc<str> = id_string.into();
72
73 {
74 let mut guard = self.cached.write();
75 *guard = Some(Arc::clone(&id));
76 }
77
78 Ok(id)
79 }
80
81 pub async fn reset(&self) -> bool {
85 let _resolve_guard = self.resolve_lock.lock().await;
86 let mut guard = self.cached.write();
87 guard.take().is_some()
88 }
89}
90
91#[derive(Debug, thiserror::Error)]
93pub enum ChoosePluginError {
94 #[error("invalid plugin instance content for '{gts_id}': {reason}")]
96 InvalidPluginInstance {
97 gts_id: String,
99 reason: String,
101 },
102
103 #[error("no plugin instances found for vendor '{vendor}'")]
105 PluginNotFound {
106 vendor: String,
108 },
109}
110
111pub fn choose_plugin_instance<'a, P>(
138 vendor: &str,
139 instances: impl IntoIterator<Item = (&'a str, &'a serde_json::Value)>,
140) -> Result<String, ChoosePluginError>
141where
142 P: DeserializeOwned + gts::GtsSchema,
143{
144 let mut best: Option<(&str, i16)> = None;
145 let mut count: usize = 0;
146
147 for (gts_id, content_val) in instances {
148 count += 1;
149 let content: BaseModkitPluginV1<P> =
150 serde_json::from_value(content_val.clone()).map_err(|e| {
151 tracing::error!(
152 gts_id = %gts_id,
153 error = %e,
154 "Failed to deserialize plugin instance content"
155 );
156 ChoosePluginError::InvalidPluginInstance {
157 gts_id: gts_id.to_owned(),
158 reason: e.to_string(),
159 }
160 })?;
161
162 if content.id != gts_id {
163 return Err(ChoosePluginError::InvalidPluginInstance {
164 gts_id: gts_id.to_owned(),
165 reason: format!(
166 "content.id mismatch: expected {:?}, got {:?}",
167 gts_id, content.id
168 ),
169 });
170 }
171
172 if content.vendor != vendor {
173 continue;
174 }
175
176 match &best {
177 None => best = Some((gts_id, content.priority)),
178 Some((_, cur_priority)) => {
179 if content.priority < *cur_priority {
180 best = Some((gts_id, content.priority));
181 }
182 }
183 }
184 }
185
186 tracing::debug!(vendor, instance_count = count, "choose_plugin_instance");
187
188 best.map(|(gts_id, _)| gts_id.to_owned())
189 .ok_or_else(|| ChoosePluginError::PluginNotFound {
190 vendor: vendor.to_owned(),
191 })
192}
193
194#[cfg(test)]
195#[cfg_attr(coverage_nightly, coverage(off))]
196mod tests {
197 use super::*;
198 use std::sync::Arc;
199 use std::sync::atomic::{AtomicUsize, Ordering};
200
201 #[tokio::test]
202 async fn resolve_called_once_returns_same_str() {
203 let selector = GtsPluginSelector::new();
204 let calls = Arc::new(AtomicUsize::new(0));
205
206 let calls_a = calls.clone();
207 let id_a = selector
208 .get_or_init(|| async move {
209 calls_a.fetch_add(1, Ordering::SeqCst);
210 Ok::<_, std::convert::Infallible>(
211 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
212 .to_owned(),
213 )
214 })
215 .await
216 .unwrap();
217
218 let calls_b = calls.clone();
219 let id_b = selector
220 .get_or_init(|| async move {
221 calls_b.fetch_add(1, Ordering::SeqCst);
222 Ok::<_, std::convert::Infallible>(
223 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
224 .to_owned(),
225 )
226 })
227 .await
228 .unwrap();
229
230 assert_eq!(id_a, id_b);
231 assert_eq!(calls.load(Ordering::SeqCst), 1);
232 }
233
234 #[tokio::test]
235 async fn reset_triggers_reselection() {
236 let selector = GtsPluginSelector::new();
237 let calls = Arc::new(AtomicUsize::new(0));
238
239 let calls_a = calls.clone();
240 let id_a = selector
241 .get_or_init(|| async move {
242 calls_a.fetch_add(1, Ordering::SeqCst);
243 Ok::<_, std::convert::Infallible>(
244 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
245 .to_owned(),
246 )
247 })
248 .await;
249 assert_eq!(
250 &*id_a.unwrap(),
251 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
252 );
253 assert_eq!(calls.load(Ordering::SeqCst), 1);
254 assert!(selector.reset().await);
255
256 let calls_b = calls.clone();
257 let id_b = selector
258 .get_or_init(|| async move {
259 calls_b.fetch_add(1, Ordering::SeqCst);
260 Ok::<_, std::convert::Infallible>(
261 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
262 .to_owned(),
263 )
264 })
265 .await;
266 assert_eq!(
267 &*id_b.unwrap(),
268 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
269 );
270 assert_eq!(calls.load(Ordering::SeqCst), 2);
271 }
272
273 #[tokio::test]
274 async fn concurrent_get_or_init_resolves_once() {
275 let selector = Arc::new(GtsPluginSelector::new());
276 let calls = Arc::new(AtomicUsize::new(0));
277
278 let mut handles = Vec::new();
279 for _ in 0..10 {
280 let selector = Arc::clone(&selector);
281 let calls = Arc::clone(&calls);
282 handles.push(tokio::spawn(async move {
283 selector
284 .get_or_init(|| async {
285 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
287 calls.fetch_add(1, Ordering::SeqCst);
288 Ok::<_, std::convert::Infallible>(
289 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
290 .to_owned(),
291 )
292 })
293 .await
294 }));
295 }
296
297 let mut results = Vec::new();
299 for handle in handles {
300 results.push(handle.await.unwrap().unwrap());
301 }
302
303 for id in &results {
305 assert_eq!(
306 &**id,
307 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
308 );
309 }
310
311 assert_eq!(calls.load(Ordering::SeqCst), 1);
313 }
314}