1use std::future::Future;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use tokio::sync::Mutex;
6
7use crate::gts::BaseModkitPluginV1;
8
9pub struct GtsPluginSelector {
15 cached: RwLock<Option<Arc<str>>>,
17 resolve_lock: Mutex<()>,
19}
20
21impl Default for GtsPluginSelector {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl GtsPluginSelector {
28 #[must_use]
29 pub fn new() -> Self {
30 Self {
31 cached: RwLock::new(None),
32 resolve_lock: Mutex::new(()),
33 }
34 }
35
36 pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
45 where
46 F: FnOnce() -> Fut,
47 Fut: Future<Output = Result<String, E>>,
48 {
49 {
51 let guard = self.cached.read();
52 if let Some(ref id) = *guard {
53 return Ok(Arc::clone(id));
54 }
55 }
56
57 let _resolve_guard = self.resolve_lock.lock().await;
59
60 {
62 let guard = self.cached.read();
63 if let Some(ref id) = *guard {
64 return Ok(Arc::clone(id));
65 }
66 }
67
68 let id_string = resolve().await?;
70 let id: Arc<str> = id_string.into();
71
72 {
73 let mut guard = self.cached.write();
74 *guard = Some(Arc::clone(&id));
75 }
76
77 Ok(id)
78 }
79
80 pub async fn reset(&self) -> bool {
84 let _resolve_guard = self.resolve_lock.lock().await;
85 let mut guard = self.cached.write();
86 guard.take().is_some()
87 }
88}
89
90#[derive(Debug, thiserror::Error)]
92pub enum ChoosePluginError {
93 #[error("invalid plugin instance content for '{gts_id}': {reason}")]
95 InvalidPluginInstance {
96 gts_id: String,
98 reason: String,
100 },
101
102 #[error("no plugin instances found for vendor '{vendor}'")]
104 PluginNotFound {
105 vendor: String,
107 },
108}
109
110pub fn choose_plugin_instance<'a, P>(
137 vendor: &str,
138 instances: impl IntoIterator<Item = (&'a str, &'a serde_json::Value)>,
139) -> Result<String, ChoosePluginError>
140where
141 P: for<'de> gts::GtsDeserialize<'de> + gts::GtsSchema,
142{
143 let mut best: Option<(&str, i16)> = None;
144 let mut count: usize = 0;
145
146 for (gts_id, content_val) in instances {
147 count += 1;
148 let content: BaseModkitPluginV1<P> =
149 serde_json::from_value(content_val.clone()).map_err(|e| {
150 tracing::error!(
151 gts_id = %gts_id,
152 error = %e,
153 "Failed to deserialize plugin instance content"
154 );
155 ChoosePluginError::InvalidPluginInstance {
156 gts_id: gts_id.to_owned(),
157 reason: e.to_string(),
158 }
159 })?;
160
161 if content.id != gts_id {
162 return Err(ChoosePluginError::InvalidPluginInstance {
163 gts_id: gts_id.to_owned(),
164 reason: format!(
165 "content.id mismatch: expected {:?}, got {:?}",
166 gts_id, content.id
167 ),
168 });
169 }
170
171 if content.vendor != vendor {
172 continue;
173 }
174
175 match &best {
176 None => best = Some((gts_id, content.priority)),
177 Some((_, cur_priority)) => {
178 if content.priority < *cur_priority {
179 best = Some((gts_id, content.priority));
180 }
181 }
182 }
183 }
184
185 tracing::debug!(vendor, instance_count = count, "choose_plugin_instance");
186
187 best.map(|(gts_id, _)| gts_id.to_owned())
188 .ok_or_else(|| ChoosePluginError::PluginNotFound {
189 vendor: vendor.to_owned(),
190 })
191}
192
193#[cfg(test)]
194#[cfg_attr(coverage_nightly, coverage(off))]
195mod tests {
196 use super::*;
197 use std::sync::Arc;
198 use std::sync::atomic::{AtomicUsize, Ordering};
199
200 #[tokio::test]
201 async fn resolve_called_once_returns_same_str() {
202 let selector = GtsPluginSelector::new();
203 let calls = Arc::new(AtomicUsize::new(0));
204
205 let calls_a = calls.clone();
206 let id_a = selector
207 .get_or_init(|| async move {
208 calls_a.fetch_add(1, Ordering::SeqCst);
209 Ok::<_, std::convert::Infallible>(
210 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
211 .to_owned(),
212 )
213 })
214 .await
215 .unwrap();
216
217 let calls_b = calls.clone();
218 let id_b = selector
219 .get_or_init(|| async move {
220 calls_b.fetch_add(1, Ordering::SeqCst);
221 Ok::<_, std::convert::Infallible>(
222 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
223 .to_owned(),
224 )
225 })
226 .await
227 .unwrap();
228
229 assert_eq!(id_a, id_b);
230 assert_eq!(calls.load(Ordering::SeqCst), 1);
231 }
232
233 #[tokio::test]
234 async fn reset_triggers_reselection() {
235 let selector = GtsPluginSelector::new();
236 let calls = Arc::new(AtomicUsize::new(0));
237
238 let calls_a = calls.clone();
239 let id_a = selector
240 .get_or_init(|| async move {
241 calls_a.fetch_add(1, Ordering::SeqCst);
242 Ok::<_, std::convert::Infallible>(
243 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
244 .to_owned(),
245 )
246 })
247 .await;
248 assert_eq!(
249 &*id_a.unwrap(),
250 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
251 );
252 assert_eq!(calls.load(Ordering::SeqCst), 1);
253 assert!(selector.reset().await);
254
255 let calls_b = calls.clone();
256 let id_b = selector
257 .get_or_init(|| async move {
258 calls_b.fetch_add(1, Ordering::SeqCst);
259 Ok::<_, std::convert::Infallible>(
260 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
261 .to_owned(),
262 )
263 })
264 .await;
265 assert_eq!(
266 &*id_b.unwrap(),
267 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
268 );
269 assert_eq!(calls.load(Ordering::SeqCst), 2);
270 }
271
272 #[tokio::test]
273 async fn concurrent_get_or_init_resolves_once() {
274 let selector = Arc::new(GtsPluginSelector::new());
275 let calls = Arc::new(AtomicUsize::new(0));
276
277 let mut handles = Vec::new();
278 for _ in 0..10 {
279 let selector = Arc::clone(&selector);
280 let calls = Arc::clone(&calls);
281 handles.push(tokio::spawn(async move {
282 selector
283 .get_or_init(|| async {
284 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
286 calls.fetch_add(1, Ordering::SeqCst);
287 Ok::<_, std::convert::Infallible>(
288 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
289 .to_owned(),
290 )
291 })
292 .await
293 }));
294 }
295
296 let mut results = Vec::new();
298 for handle in handles {
299 results.push(handle.await.unwrap().unwrap());
300 }
301
302 for id in &results {
304 assert_eq!(
305 &**id,
306 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
307 );
308 }
309
310 assert_eq!(calls.load(Ordering::SeqCst), 1);
312 }
313}