1use std::future::Future;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use tokio::sync::Mutex;
6
7use crate::gts::BaseModkitPluginV1;
8
9pub struct GtsPluginSelector {
15 cached: RwLock<Option<Arc<str>>>,
17 resolve_lock: Mutex<()>,
19}
20
21impl Default for GtsPluginSelector {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl GtsPluginSelector {
28 #[must_use]
29 pub fn new() -> Self {
30 Self {
31 cached: RwLock::new(None),
32 resolve_lock: Mutex::new(()),
33 }
34 }
35
36 #[must_use]
41 pub fn pre_cached(value: String) -> Self {
42 Self {
43 cached: RwLock::new(Some(Arc::from(value))),
44 resolve_lock: Mutex::new(()),
45 }
46 }
47
48 pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
57 where
58 F: FnOnce() -> Fut,
59 Fut: Future<Output = Result<String, E>>,
60 {
61 {
63 let guard = self.cached.read();
64 if let Some(ref id) = *guard {
65 return Ok(Arc::clone(id));
66 }
67 }
68
69 let _resolve_guard = self.resolve_lock.lock().await;
71
72 {
74 let guard = self.cached.read();
75 if let Some(ref id) = *guard {
76 return Ok(Arc::clone(id));
77 }
78 }
79
80 let id_string = resolve().await?;
82 let id: Arc<str> = id_string.into();
83
84 {
85 let mut guard = self.cached.write();
86 *guard = Some(Arc::clone(&id));
87 }
88
89 Ok(id)
90 }
91
92 pub async fn reset(&self) -> bool {
96 let _resolve_guard = self.resolve_lock.lock().await;
97 let mut guard = self.cached.write();
98 guard.take().is_some()
99 }
100}
101
102#[derive(Debug, thiserror::Error)]
104pub enum ChoosePluginError {
105 #[error("invalid plugin instance content for '{gts_id}': {reason}")]
107 InvalidPluginInstance {
108 gts_id: String,
110 reason: String,
112 },
113
114 #[error("no plugin instances found for schema '{schema_id}', vendor '{vendor}'")]
116 PluginNotFound {
117 schema_id: String,
119 vendor: String,
121 },
122}
123
124pub fn choose_plugin_instance<'a, P>(
151 vendor: &str,
152 instances: impl IntoIterator<Item = (&'a str, &'a serde_json::Value)>,
153) -> Result<String, ChoosePluginError>
154where
155 P: for<'de> gts::GtsDeserialize<'de> + gts::GtsSchema,
156{
157 let mut best: Option<(&str, i16)> = None;
158 let mut count: usize = 0;
159
160 for (gts_id, content_val) in instances {
161 count += 1;
162 let content: BaseModkitPluginV1<P> =
163 serde_json::from_value(content_val.clone()).map_err(|e| {
164 tracing::error!(
165 gts_id = %gts_id,
166 error = %e,
167 "Failed to deserialize plugin instance content"
168 );
169 ChoosePluginError::InvalidPluginInstance {
170 gts_id: gts_id.to_owned(),
171 reason: e.to_string(),
172 }
173 })?;
174
175 if content.id != gts_id {
176 return Err(ChoosePluginError::InvalidPluginInstance {
177 gts_id: gts_id.to_owned(),
178 reason: format!(
179 "content.id mismatch: expected {:?}, got {:?}",
180 gts_id, content.id
181 ),
182 });
183 }
184
185 if content.vendor != vendor {
186 continue;
187 }
188
189 match &best {
190 None => best = Some((gts_id, content.priority)),
191 Some((_, cur_priority)) => {
192 if content.priority < *cur_priority {
193 best = Some((gts_id, content.priority));
194 }
195 }
196 }
197 }
198
199 tracing::debug!(vendor, instance_count = count, "choose_plugin_instance");
200
201 best.map(|(gts_id, _)| gts_id.to_owned())
202 .ok_or_else(|| ChoosePluginError::PluginNotFound {
203 schema_id: P::SCHEMA_ID.to_owned(),
204 vendor: vendor.to_owned(),
205 })
206}
207
208#[cfg(test)]
209#[cfg_attr(coverage_nightly, coverage(off))]
210mod tests {
211 use super::*;
212 use std::sync::Arc;
213 use std::sync::atomic::{AtomicUsize, Ordering};
214
215 #[tokio::test]
216 async fn resolve_called_once_returns_same_str() {
217 let selector = GtsPluginSelector::new();
218 let calls = Arc::new(AtomicUsize::new(0));
219
220 let calls_a = calls.clone();
221 let id_a = selector
222 .get_or_init(|| async move {
223 calls_a.fetch_add(1, Ordering::SeqCst);
224 Ok::<_, std::convert::Infallible>(
225 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
226 .to_owned(),
227 )
228 })
229 .await
230 .unwrap();
231
232 let calls_b = calls.clone();
233 let id_b = selector
234 .get_or_init(|| async move {
235 calls_b.fetch_add(1, Ordering::SeqCst);
236 Ok::<_, std::convert::Infallible>(
237 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
238 .to_owned(),
239 )
240 })
241 .await
242 .unwrap();
243
244 assert_eq!(id_a, id_b);
245 assert_eq!(calls.load(Ordering::SeqCst), 1);
246 }
247
248 #[tokio::test]
249 async fn reset_triggers_reselection() {
250 let selector = GtsPluginSelector::new();
251 let calls = Arc::new(AtomicUsize::new(0));
252
253 let calls_a = calls.clone();
254 let id_a = selector
255 .get_or_init(|| async move {
256 calls_a.fetch_add(1, Ordering::SeqCst);
257 Ok::<_, std::convert::Infallible>(
258 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
259 .to_owned(),
260 )
261 })
262 .await;
263 assert_eq!(
264 &*id_a.unwrap(),
265 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
266 );
267 assert_eq!(calls.load(Ordering::SeqCst), 1);
268 assert!(selector.reset().await);
269
270 let calls_b = calls.clone();
271 let id_b = selector
272 .get_or_init(|| async move {
273 calls_b.fetch_add(1, Ordering::SeqCst);
274 Ok::<_, std::convert::Infallible>(
275 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
276 .to_owned(),
277 )
278 })
279 .await;
280 assert_eq!(
281 &*id_b.unwrap(),
282 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
283 );
284 assert_eq!(calls.load(Ordering::SeqCst), 2);
285 }
286
287 #[tokio::test]
288 async fn concurrent_get_or_init_resolves_once() {
289 let selector = Arc::new(GtsPluginSelector::new());
290 let calls = Arc::new(AtomicUsize::new(0));
291
292 let mut handles = Vec::new();
293 for _ in 0..10 {
294 let selector = Arc::clone(&selector);
295 let calls = Arc::clone(&calls);
296 handles.push(tokio::spawn(async move {
297 selector
298 .get_or_init(|| async {
299 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
301 calls.fetch_add(1, Ordering::SeqCst);
302 Ok::<_, std::convert::Infallible>(
303 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
304 .to_owned(),
305 )
306 })
307 .await
308 }));
309 }
310
311 let mut results = Vec::new();
313 for handle in handles {
314 results.push(handle.await.unwrap().unwrap());
315 }
316
317 for id in &results {
319 assert_eq!(
320 &**id,
321 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
322 );
323 }
324
325 assert_eq!(calls.load(Ordering::SeqCst), 1);
327 }
328}