1use std::future::Future;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use tokio::sync::Mutex;
6
7use crate::gts::BaseModkitPluginV1;
8
9pub struct GtsPluginSelector {
15 cached: RwLock<Option<Arc<str>>>,
17 resolve_lock: Mutex<()>,
19}
20
21impl Default for GtsPluginSelector {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl GtsPluginSelector {
28 #[must_use]
29 pub fn new() -> Self {
30 Self {
31 cached: RwLock::new(None),
32 resolve_lock: Mutex::new(()),
33 }
34 }
35
36 pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
45 where
46 F: FnOnce() -> Fut,
47 Fut: Future<Output = Result<String, E>>,
48 {
49 {
51 let guard = self.cached.read();
52 if let Some(ref id) = *guard {
53 return Ok(Arc::clone(id));
54 }
55 }
56
57 let _resolve_guard = self.resolve_lock.lock().await;
59
60 {
62 let guard = self.cached.read();
63 if let Some(ref id) = *guard {
64 return Ok(Arc::clone(id));
65 }
66 }
67
68 let id_string = resolve().await?;
70 let id: Arc<str> = id_string.into();
71
72 {
73 let mut guard = self.cached.write();
74 *guard = Some(Arc::clone(&id));
75 }
76
77 Ok(id)
78 }
79
80 pub async fn reset(&self) -> bool {
84 let _resolve_guard = self.resolve_lock.lock().await;
85 let mut guard = self.cached.write();
86 guard.take().is_some()
87 }
88}
89
90#[derive(Debug, thiserror::Error)]
92pub enum ChoosePluginError {
93 #[error("invalid plugin instance content for '{gts_id}': {reason}")]
95 InvalidPluginInstance {
96 gts_id: String,
98 reason: String,
100 },
101
102 #[error("no plugin instances found for schema '{schema_id}', vendor '{vendor}'")]
104 PluginNotFound {
105 schema_id: String,
107 vendor: String,
109 },
110}
111
112pub fn choose_plugin_instance<'a, P>(
139 vendor: &str,
140 instances: impl IntoIterator<Item = (&'a str, &'a serde_json::Value)>,
141) -> Result<String, ChoosePluginError>
142where
143 P: for<'de> gts::GtsDeserialize<'de> + gts::GtsSchema,
144{
145 let mut best: Option<(&str, i16)> = None;
146 let mut count: usize = 0;
147
148 for (gts_id, content_val) in instances {
149 count += 1;
150 let content: BaseModkitPluginV1<P> =
151 serde_json::from_value(content_val.clone()).map_err(|e| {
152 tracing::error!(
153 gts_id = %gts_id,
154 error = %e,
155 "Failed to deserialize plugin instance content"
156 );
157 ChoosePluginError::InvalidPluginInstance {
158 gts_id: gts_id.to_owned(),
159 reason: e.to_string(),
160 }
161 })?;
162
163 if content.id != gts_id {
164 return Err(ChoosePluginError::InvalidPluginInstance {
165 gts_id: gts_id.to_owned(),
166 reason: format!(
167 "content.id mismatch: expected {:?}, got {:?}",
168 gts_id, content.id
169 ),
170 });
171 }
172
173 if content.vendor != vendor {
174 continue;
175 }
176
177 match &best {
178 None => best = Some((gts_id, content.priority)),
179 Some((_, cur_priority)) => {
180 if content.priority < *cur_priority {
181 best = Some((gts_id, content.priority));
182 }
183 }
184 }
185 }
186
187 tracing::debug!(vendor, instance_count = count, "choose_plugin_instance");
188
189 best.map(|(gts_id, _)| gts_id.to_owned())
190 .ok_or_else(|| ChoosePluginError::PluginNotFound {
191 schema_id: P::SCHEMA_ID.to_owned(),
192 vendor: vendor.to_owned(),
193 })
194}
195
196#[cfg(test)]
197#[cfg_attr(coverage_nightly, coverage(off))]
198mod tests {
199 use super::*;
200 use std::sync::Arc;
201 use std::sync::atomic::{AtomicUsize, Ordering};
202
203 #[tokio::test]
204 async fn resolve_called_once_returns_same_str() {
205 let selector = GtsPluginSelector::new();
206 let calls = Arc::new(AtomicUsize::new(0));
207
208 let calls_a = calls.clone();
209 let id_a = selector
210 .get_or_init(|| async move {
211 calls_a.fetch_add(1, Ordering::SeqCst);
212 Ok::<_, std::convert::Infallible>(
213 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
214 .to_owned(),
215 )
216 })
217 .await
218 .unwrap();
219
220 let calls_b = calls.clone();
221 let id_b = selector
222 .get_or_init(|| async move {
223 calls_b.fetch_add(1, Ordering::SeqCst);
224 Ok::<_, std::convert::Infallible>(
225 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
226 .to_owned(),
227 )
228 })
229 .await
230 .unwrap();
231
232 assert_eq!(id_a, id_b);
233 assert_eq!(calls.load(Ordering::SeqCst), 1);
234 }
235
236 #[tokio::test]
237 async fn reset_triggers_reselection() {
238 let selector = GtsPluginSelector::new();
239 let calls = Arc::new(AtomicUsize::new(0));
240
241 let calls_a = calls.clone();
242 let id_a = selector
243 .get_or_init(|| async move {
244 calls_a.fetch_add(1, Ordering::SeqCst);
245 Ok::<_, std::convert::Infallible>(
246 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
247 .to_owned(),
248 )
249 })
250 .await;
251 assert_eq!(
252 &*id_a.unwrap(),
253 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
254 );
255 assert_eq!(calls.load(Ordering::SeqCst), 1);
256 assert!(selector.reset().await);
257
258 let calls_b = calls.clone();
259 let id_b = selector
260 .get_or_init(|| async move {
261 calls_b.fetch_add(1, Ordering::SeqCst);
262 Ok::<_, std::convert::Infallible>(
263 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
264 .to_owned(),
265 )
266 })
267 .await;
268 assert_eq!(
269 &*id_b.unwrap(),
270 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
271 );
272 assert_eq!(calls.load(Ordering::SeqCst), 2);
273 }
274
275 #[tokio::test]
276 async fn concurrent_get_or_init_resolves_once() {
277 let selector = Arc::new(GtsPluginSelector::new());
278 let calls = Arc::new(AtomicUsize::new(0));
279
280 let mut handles = Vec::new();
281 for _ in 0..10 {
282 let selector = Arc::clone(&selector);
283 let calls = Arc::clone(&calls);
284 handles.push(tokio::spawn(async move {
285 selector
286 .get_or_init(|| async {
287 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
289 calls.fetch_add(1, Ordering::SeqCst);
290 Ok::<_, std::convert::Infallible>(
291 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
292 .to_owned(),
293 )
294 })
295 .await
296 }));
297 }
298
299 let mut results = Vec::new();
301 for handle in handles {
302 results.push(handle.await.unwrap().unwrap());
303 }
304
305 for id in &results {
307 assert_eq!(
308 &**id,
309 "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
310 );
311 }
312
313 assert_eq!(calls.load(Ordering::SeqCst), 1);
315 }
316}