Skip to main content

modkit/plugins/
mod.rs

1use std::future::Future;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use tokio::sync::Mutex;
6
7use crate::gts::BaseModkitPluginV1;
8
9/// A resettable, allocation-friendly selector for GTS plugin instance IDs.
10///
11/// Uses a single-flight pattern to ensure that the resolve function is called
12/// at most once even under concurrent callers. The selected instance ID is
13/// cached as `Arc<str>` to avoid allocations on the happy path.
14pub struct GtsPluginSelector {
15    /// Cached selected instance ID (sync lock for fast access and sync reset).
16    cached: RwLock<Option<Arc<str>>>,
17    /// Mutex to ensure single-flight resolution.
18    resolve_lock: Mutex<()>,
19}
20
21impl Default for GtsPluginSelector {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl GtsPluginSelector {
28    #[must_use]
29    pub fn new() -> Self {
30        Self {
31            cached: RwLock::new(None),
32            resolve_lock: Mutex::new(()),
33        }
34    }
35
36    /// Returns the cached instance ID, or resolves it using the provided function.
37    ///
38    /// Uses a single-flight pattern: even under concurrent callers, the resolve
39    /// function is called at most once. Returns `Arc<str>` to avoid allocations
40    /// on the happy path.
41    /// # Errors
42    ///
43    /// Returns `Err(E)` if the provided `resolve` future fails.
44    pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
45    where
46        F: FnOnce() -> Fut,
47        Fut: Future<Output = Result<String, E>>,
48    {
49        // Fast path: check if already cached (sync lock, no await)
50        {
51            let guard = self.cached.read();
52            if let Some(ref id) = *guard {
53                return Ok(Arc::clone(id));
54            }
55        }
56
57        // Slow path: acquire resolve lock for single-flight
58        let _resolve_guard = self.resolve_lock.lock().await;
59
60        // Re-check after acquiring resolve lock (another caller may have resolved)
61        {
62            let guard = self.cached.read();
63            if let Some(ref id) = *guard {
64                return Ok(Arc::clone(id));
65            }
66        }
67
68        // Resolve and cache
69        let id_string = resolve().await?;
70        let id: Arc<str> = id_string.into();
71
72        {
73            let mut guard = self.cached.write();
74            *guard = Some(Arc::clone(&id));
75        }
76
77        Ok(id)
78    }
79
80    /// Clears the cached selected instance ID.
81    ///
82    /// Returns `true` if there was a cached value, `false` otherwise.
83    pub async fn reset(&self) -> bool {
84        let _resolve_guard = self.resolve_lock.lock().await;
85        let mut guard = self.cached.write();
86        guard.take().is_some()
87    }
88}
89
90/// Error returned by [`choose_plugin_instance`].
91#[derive(Debug, thiserror::Error)]
92pub enum ChoosePluginError {
93    /// Failed to deserialize a plugin instance's content.
94    #[error("invalid plugin instance content for '{gts_id}': {reason}")]
95    InvalidPluginInstance {
96        /// GTS ID of the malformed instance.
97        gts_id: String,
98        /// Human-readable reason.
99        reason: String,
100    },
101
102    /// No plugin instance matched the requested vendor.
103    #[error("no plugin instances found for vendor '{vendor}'")]
104    PluginNotFound {
105        /// The vendor that was requested.
106        vendor: String,
107    },
108}
109
110/// Selects the best plugin instance for the given vendor.
111///
112/// Accepts an iterator of `(gts_id, content)` pairs — typically
113/// produced from `types_registry_sdk::GtsEntity`:
114///
115/// ```ignore
116/// choose_plugin_instance::<MyPluginSpecV1>(
117///     &self.vendor,
118///     instances.iter().map(|e| (e.gts_id.as_str(), &e.content)),
119/// )
120/// ```
121///
122/// Deserializes each entry as `BaseModkitPluginV1<P>`, filters by
123/// `vendor`, and returns the `gts_id` of the instance with the
124/// **lowest** priority value.
125///
126/// # Type Parameters
127///
128/// - `P` — The plugin-specific properties struct (e.g.
129///   `AuthNResolverPluginSpecV1`). Must be `DeserializeOwned`.
130///
131/// # Errors
132///
133/// - [`ChoosePluginError::InvalidPluginInstance`] if deserialization fails
134///   or the `content.id` doesn't match `gts_id`.
135/// - [`ChoosePluginError::PluginNotFound`] if no instance matches the vendor.
136pub fn choose_plugin_instance<'a, P>(
137    vendor: &str,
138    instances: impl IntoIterator<Item = (&'a str, &'a serde_json::Value)>,
139) -> Result<String, ChoosePluginError>
140where
141    P: for<'de> gts::GtsDeserialize<'de> + gts::GtsSchema,
142{
143    let mut best: Option<(&str, i16)> = None;
144    let mut count: usize = 0;
145
146    for (gts_id, content_val) in instances {
147        count += 1;
148        let content: BaseModkitPluginV1<P> =
149            serde_json::from_value(content_val.clone()).map_err(|e| {
150                tracing::error!(
151                    gts_id = %gts_id,
152                    error = %e,
153                    "Failed to deserialize plugin instance content"
154                );
155                ChoosePluginError::InvalidPluginInstance {
156                    gts_id: gts_id.to_owned(),
157                    reason: e.to_string(),
158                }
159            })?;
160
161        if content.id != gts_id {
162            return Err(ChoosePluginError::InvalidPluginInstance {
163                gts_id: gts_id.to_owned(),
164                reason: format!(
165                    "content.id mismatch: expected {:?}, got {:?}",
166                    gts_id, content.id
167                ),
168            });
169        }
170
171        if content.vendor != vendor {
172            continue;
173        }
174
175        match &best {
176            None => best = Some((gts_id, content.priority)),
177            Some((_, cur_priority)) => {
178                if content.priority < *cur_priority {
179                    best = Some((gts_id, content.priority));
180                }
181            }
182        }
183    }
184
185    tracing::debug!(vendor, instance_count = count, "choose_plugin_instance");
186
187    best.map(|(gts_id, _)| gts_id.to_owned())
188        .ok_or_else(|| ChoosePluginError::PluginNotFound {
189            vendor: vendor.to_owned(),
190        })
191}
192
193#[cfg(test)]
194#[cfg_attr(coverage_nightly, coverage(off))]
195mod tests {
196    use super::*;
197    use std::sync::Arc;
198    use std::sync::atomic::{AtomicUsize, Ordering};
199
200    #[tokio::test]
201    async fn resolve_called_once_returns_same_str() {
202        let selector = GtsPluginSelector::new();
203        let calls = Arc::new(AtomicUsize::new(0));
204
205        let calls_a = calls.clone();
206        let id_a = selector
207            .get_or_init(|| async move {
208                calls_a.fetch_add(1, Ordering::SeqCst);
209                Ok::<_, std::convert::Infallible>(
210                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
211                        .to_owned(),
212                )
213            })
214            .await
215            .unwrap();
216
217        let calls_b = calls.clone();
218        let id_b = selector
219            .get_or_init(|| async move {
220                calls_b.fetch_add(1, Ordering::SeqCst);
221                Ok::<_, std::convert::Infallible>(
222                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
223                        .to_owned(),
224                )
225            })
226            .await
227            .unwrap();
228
229        assert_eq!(id_a, id_b);
230        assert_eq!(calls.load(Ordering::SeqCst), 1);
231    }
232
233    #[tokio::test]
234    async fn reset_triggers_reselection() {
235        let selector = GtsPluginSelector::new();
236        let calls = Arc::new(AtomicUsize::new(0));
237
238        let calls_a = calls.clone();
239        let id_a = selector
240            .get_or_init(|| async move {
241                calls_a.fetch_add(1, Ordering::SeqCst);
242                Ok::<_, std::convert::Infallible>(
243                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
244                        .to_owned(),
245                )
246            })
247            .await;
248        assert_eq!(
249            &*id_a.unwrap(),
250            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
251        );
252        assert_eq!(calls.load(Ordering::SeqCst), 1);
253        assert!(selector.reset().await);
254
255        let calls_b = calls.clone();
256        let id_b = selector
257            .get_or_init(|| async move {
258                calls_b.fetch_add(1, Ordering::SeqCst);
259                Ok::<_, std::convert::Infallible>(
260                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
261                        .to_owned(),
262                )
263            })
264            .await;
265        assert_eq!(
266            &*id_b.unwrap(),
267            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
268        );
269        assert_eq!(calls.load(Ordering::SeqCst), 2);
270    }
271
272    #[tokio::test]
273    async fn concurrent_get_or_init_resolves_once() {
274        let selector = Arc::new(GtsPluginSelector::new());
275        let calls = Arc::new(AtomicUsize::new(0));
276
277        let mut handles = Vec::new();
278        for _ in 0..10 {
279            let selector = Arc::clone(&selector);
280            let calls = Arc::clone(&calls);
281            handles.push(tokio::spawn(async move {
282                selector
283                    .get_or_init(|| async {
284                        // Small delay to increase chance of concurrent access
285                        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
286                        calls.fetch_add(1, Ordering::SeqCst);
287                        Ok::<_, std::convert::Infallible>(
288                            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
289                                .to_owned(),
290                        )
291                    })
292                    .await
293            }));
294        }
295
296        // Await each handle in a loop (no futures_util dependency)
297        let mut results = Vec::new();
298        for handle in handles {
299            results.push(handle.await.unwrap().unwrap());
300        }
301
302        // All results should be the same
303        for id in &results {
304            assert_eq!(
305                &**id,
306                "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
307            );
308        }
309
310        // Resolve should have been called exactly once
311        assert_eq!(calls.load(Ordering::SeqCst), 1);
312    }
313}