Skip to main content

modkit/plugins/
mod.rs

1use std::future::Future;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use serde::de::DeserializeOwned;
6use tokio::sync::Mutex;
7
8use crate::gts::BaseModkitPluginV1;
9
10/// A resettable, allocation-friendly selector for GTS plugin instance IDs.
11///
12/// Uses a single-flight pattern to ensure that the resolve function is called
13/// at most once even under concurrent callers. The selected instance ID is
14/// cached as `Arc<str>` to avoid allocations on the happy path.
15pub struct GtsPluginSelector {
16    /// Cached selected instance ID (sync lock for fast access and sync reset).
17    cached: RwLock<Option<Arc<str>>>,
18    /// Mutex to ensure single-flight resolution.
19    resolve_lock: Mutex<()>,
20}
21
22impl Default for GtsPluginSelector {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl GtsPluginSelector {
29    #[must_use]
30    pub fn new() -> Self {
31        Self {
32            cached: RwLock::new(None),
33            resolve_lock: Mutex::new(()),
34        }
35    }
36
37    /// Returns the cached instance ID, or resolves it using the provided function.
38    ///
39    /// Uses a single-flight pattern: even under concurrent callers, the resolve
40    /// function is called at most once. Returns `Arc<str>` to avoid allocations
41    /// on the happy path.
42    /// # Errors
43    ///
44    /// Returns `Err(E)` if the provided `resolve` future fails.
45    pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
46    where
47        F: FnOnce() -> Fut,
48        Fut: Future<Output = Result<String, E>>,
49    {
50        // Fast path: check if already cached (sync lock, no await)
51        {
52            let guard = self.cached.read();
53            if let Some(ref id) = *guard {
54                return Ok(Arc::clone(id));
55            }
56        }
57
58        // Slow path: acquire resolve lock for single-flight
59        let _resolve_guard = self.resolve_lock.lock().await;
60
61        // Re-check after acquiring resolve lock (another caller may have resolved)
62        {
63            let guard = self.cached.read();
64            if let Some(ref id) = *guard {
65                return Ok(Arc::clone(id));
66            }
67        }
68
69        // Resolve and cache
70        let id_string = resolve().await?;
71        let id: Arc<str> = id_string.into();
72
73        {
74            let mut guard = self.cached.write();
75            *guard = Some(Arc::clone(&id));
76        }
77
78        Ok(id)
79    }
80
81    /// Clears the cached selected instance ID.
82    ///
83    /// Returns `true` if there was a cached value, `false` otherwise.
84    pub async fn reset(&self) -> bool {
85        let _resolve_guard = self.resolve_lock.lock().await;
86        let mut guard = self.cached.write();
87        guard.take().is_some()
88    }
89}
90
91/// Error returned by [`choose_plugin_instance`].
92#[derive(Debug, thiserror::Error)]
93pub enum ChoosePluginError {
94    /// Failed to deserialize a plugin instance's content.
95    #[error("invalid plugin instance content for '{gts_id}': {reason}")]
96    InvalidPluginInstance {
97        /// GTS ID of the malformed instance.
98        gts_id: String,
99        /// Human-readable reason.
100        reason: String,
101    },
102
103    /// No plugin instance matched the requested vendor.
104    #[error("no plugin instances found for vendor '{vendor}'")]
105    PluginNotFound {
106        /// The vendor that was requested.
107        vendor: String,
108    },
109}
110
111/// Selects the best plugin instance for the given vendor.
112///
113/// Accepts an iterator of `(gts_id, content)` pairs — typically
114/// produced from `types_registry_sdk::GtsEntity`:
115///
116/// ```ignore
117/// choose_plugin_instance::<MyPluginSpecV1>(
118///     &self.vendor,
119///     instances.iter().map(|e| (e.gts_id.as_str(), &e.content)),
120/// )
121/// ```
122///
123/// Deserializes each entry as `BaseModkitPluginV1<P>`, filters by
124/// `vendor`, and returns the `gts_id` of the instance with the
125/// **lowest** priority value.
126///
127/// # Type Parameters
128///
129/// - `P` — The plugin-specific properties struct (e.g.
130///   `AuthNResolverPluginSpecV1`). Must be `DeserializeOwned`.
131///
132/// # Errors
133///
134/// - [`ChoosePluginError::InvalidPluginInstance`] if deserialization fails
135///   or the `content.id` doesn't match `gts_id`.
136/// - [`ChoosePluginError::PluginNotFound`] if no instance matches the vendor.
137pub fn choose_plugin_instance<'a, P>(
138    vendor: &str,
139    instances: impl IntoIterator<Item = (&'a str, &'a serde_json::Value)>,
140) -> Result<String, ChoosePluginError>
141where
142    P: DeserializeOwned + gts::GtsSchema,
143{
144    let mut best: Option<(&str, i16)> = None;
145    let mut count: usize = 0;
146
147    for (gts_id, content_val) in instances {
148        count += 1;
149        let content: BaseModkitPluginV1<P> =
150            serde_json::from_value(content_val.clone()).map_err(|e| {
151                tracing::error!(
152                    gts_id = %gts_id,
153                    error = %e,
154                    "Failed to deserialize plugin instance content"
155                );
156                ChoosePluginError::InvalidPluginInstance {
157                    gts_id: gts_id.to_owned(),
158                    reason: e.to_string(),
159                }
160            })?;
161
162        if content.id != gts_id {
163            return Err(ChoosePluginError::InvalidPluginInstance {
164                gts_id: gts_id.to_owned(),
165                reason: format!(
166                    "content.id mismatch: expected {:?}, got {:?}",
167                    gts_id, content.id
168                ),
169            });
170        }
171
172        if content.vendor != vendor {
173            continue;
174        }
175
176        match &best {
177            None => best = Some((gts_id, content.priority)),
178            Some((_, cur_priority)) => {
179                if content.priority < *cur_priority {
180                    best = Some((gts_id, content.priority));
181                }
182            }
183        }
184    }
185
186    tracing::debug!(vendor, instance_count = count, "choose_plugin_instance");
187
188    best.map(|(gts_id, _)| gts_id.to_owned())
189        .ok_or_else(|| ChoosePluginError::PluginNotFound {
190            vendor: vendor.to_owned(),
191        })
192}
193
194#[cfg(test)]
195#[cfg_attr(coverage_nightly, coverage(off))]
196mod tests {
197    use super::*;
198    use std::sync::Arc;
199    use std::sync::atomic::{AtomicUsize, Ordering};
200
201    #[tokio::test]
202    async fn resolve_called_once_returns_same_str() {
203        let selector = GtsPluginSelector::new();
204        let calls = Arc::new(AtomicUsize::new(0));
205
206        let calls_a = calls.clone();
207        let id_a = selector
208            .get_or_init(|| async move {
209                calls_a.fetch_add(1, Ordering::SeqCst);
210                Ok::<_, std::convert::Infallible>(
211                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
212                        .to_owned(),
213                )
214            })
215            .await
216            .unwrap();
217
218        let calls_b = calls.clone();
219        let id_b = selector
220            .get_or_init(|| async move {
221                calls_b.fetch_add(1, Ordering::SeqCst);
222                Ok::<_, std::convert::Infallible>(
223                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
224                        .to_owned(),
225                )
226            })
227            .await
228            .unwrap();
229
230        assert_eq!(id_a, id_b);
231        assert_eq!(calls.load(Ordering::SeqCst), 1);
232    }
233
234    #[tokio::test]
235    async fn reset_triggers_reselection() {
236        let selector = GtsPluginSelector::new();
237        let calls = Arc::new(AtomicUsize::new(0));
238
239        let calls_a = calls.clone();
240        let id_a = selector
241            .get_or_init(|| async move {
242                calls_a.fetch_add(1, Ordering::SeqCst);
243                Ok::<_, std::convert::Infallible>(
244                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
245                        .to_owned(),
246                )
247            })
248            .await;
249        assert_eq!(
250            &*id_a.unwrap(),
251            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
252        );
253        assert_eq!(calls.load(Ordering::SeqCst), 1);
254        assert!(selector.reset().await);
255
256        let calls_b = calls.clone();
257        let id_b = selector
258            .get_or_init(|| async move {
259                calls_b.fetch_add(1, Ordering::SeqCst);
260                Ok::<_, std::convert::Infallible>(
261                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
262                        .to_owned(),
263                )
264            })
265            .await;
266        assert_eq!(
267            &*id_b.unwrap(),
268            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
269        );
270        assert_eq!(calls.load(Ordering::SeqCst), 2);
271    }
272
273    #[tokio::test]
274    async fn concurrent_get_or_init_resolves_once() {
275        let selector = Arc::new(GtsPluginSelector::new());
276        let calls = Arc::new(AtomicUsize::new(0));
277
278        let mut handles = Vec::new();
279        for _ in 0..10 {
280            let selector = Arc::clone(&selector);
281            let calls = Arc::clone(&calls);
282            handles.push(tokio::spawn(async move {
283                selector
284                    .get_or_init(|| async {
285                        // Small delay to increase chance of concurrent access
286                        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
287                        calls.fetch_add(1, Ordering::SeqCst);
288                        Ok::<_, std::convert::Infallible>(
289                            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
290                                .to_owned(),
291                        )
292                    })
293                    .await
294            }));
295        }
296
297        // Await each handle in a loop (no futures_util dependency)
298        let mut results = Vec::new();
299        for handle in handles {
300            results.push(handle.await.unwrap().unwrap());
301        }
302
303        // All results should be the same
304        for id in &results {
305            assert_eq!(
306                &**id,
307                "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
308            );
309        }
310
311        // Resolve should have been called exactly once
312        assert_eq!(calls.load(Ordering::SeqCst), 1);
313    }
314}