Skip to main content

modkit/plugins/
mod.rs

1use std::future::Future;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use tokio::sync::Mutex;
6
7use crate::gts::BaseModkitPluginV1;
8
9/// A resettable, allocation-friendly selector for GTS plugin instance IDs.
10///
11/// Uses a single-flight pattern to ensure that the resolve function is called
12/// at most once even under concurrent callers. The selected instance ID is
13/// cached as `Arc<str>` to avoid allocations on the happy path.
14pub struct GtsPluginSelector {
15    /// Cached selected instance ID (sync lock for fast access and sync reset).
16    cached: RwLock<Option<Arc<str>>>,
17    /// Mutex to ensure single-flight resolution.
18    resolve_lock: Mutex<()>,
19}
20
21impl Default for GtsPluginSelector {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl GtsPluginSelector {
28    #[must_use]
29    pub fn new() -> Self {
30        Self {
31            cached: RwLock::new(None),
32            resolve_lock: Mutex::new(()),
33        }
34    }
35
36    /// Create a selector with `value` already cached, skipping resolution entirely.
37    ///
38    /// Useful in tests to pre-warm the selector with a known instance ID or
39    /// an empty-string sentinel (meaning "no plugin configured").
40    #[must_use]
41    pub fn pre_cached(value: String) -> Self {
42        Self {
43            cached: RwLock::new(Some(Arc::from(value))),
44            resolve_lock: Mutex::new(()),
45        }
46    }
47
48    /// Returns the cached instance ID, or resolves it using the provided function.
49    ///
50    /// Uses a single-flight pattern: even under concurrent callers, the resolve
51    /// function is called at most once. Returns `Arc<str>` to avoid allocations
52    /// on the happy path.
53    /// # Errors
54    ///
55    /// Returns `Err(E)` if the provided `resolve` future fails.
56    pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
57    where
58        F: FnOnce() -> Fut,
59        Fut: Future<Output = Result<String, E>>,
60    {
61        // Fast path: check if already cached (sync lock, no await)
62        {
63            let guard = self.cached.read();
64            if let Some(ref id) = *guard {
65                return Ok(Arc::clone(id));
66            }
67        }
68
69        // Slow path: acquire resolve lock for single-flight
70        let _resolve_guard = self.resolve_lock.lock().await;
71
72        // Re-check after acquiring resolve lock (another caller may have resolved)
73        {
74            let guard = self.cached.read();
75            if let Some(ref id) = *guard {
76                return Ok(Arc::clone(id));
77            }
78        }
79
80        // Resolve and cache
81        let id_string = resolve().await?;
82        let id: Arc<str> = id_string.into();
83
84        {
85            let mut guard = self.cached.write();
86            *guard = Some(Arc::clone(&id));
87        }
88
89        Ok(id)
90    }
91
92    /// Clears the cached selected instance ID.
93    ///
94    /// Returns `true` if there was a cached value, `false` otherwise.
95    pub async fn reset(&self) -> bool {
96        let _resolve_guard = self.resolve_lock.lock().await;
97        let mut guard = self.cached.write();
98        guard.take().is_some()
99    }
100}
101
102/// Error returned by [`choose_plugin_instance`].
103#[derive(Debug, thiserror::Error)]
104pub enum ChoosePluginError {
105    /// Failed to deserialize a plugin instance's content.
106    #[error("invalid plugin instance content for '{gts_id}': {reason}")]
107    InvalidPluginInstance {
108        /// GTS ID of the malformed instance.
109        gts_id: String,
110        /// Human-readable reason.
111        reason: String,
112    },
113
114    /// No plugin instance matched the requested vendor.
115    #[error("no plugin instances found for schema '{schema_id}', vendor '{vendor}'")]
116    PluginNotFound {
117        /// GTS schema ID of the plugin type being resolved.
118        schema_id: String,
119        /// The vendor that was requested.
120        vendor: String,
121    },
122}
123
124/// Selects the best plugin instance for the given vendor.
125///
126/// Accepts an iterator of `(gts_id, content)` pairs — typically
127/// produced from `types_registry_sdk::GtsEntity`:
128///
129/// ```ignore
130/// choose_plugin_instance::<MyPluginSpecV1>(
131///     &self.vendor,
132///     instances.iter().map(|e| (e.gts_id.as_str(), &e.content)),
133/// )
134/// ```
135///
136/// Deserializes each entry as `BaseModkitPluginV1<P>`, filters by
137/// `vendor`, and returns the `gts_id` of the instance with the
138/// **lowest** priority value.
139///
140/// # Type Parameters
141///
142/// - `P` — The plugin-specific properties struct (e.g.
143///   `AuthNResolverPluginSpecV1`). Must be `DeserializeOwned`.
144///
145/// # Errors
146///
147/// - [`ChoosePluginError::InvalidPluginInstance`] if deserialization fails
148///   or the `content.id` doesn't match `gts_id`.
149/// - [`ChoosePluginError::PluginNotFound`] if no instance matches the vendor.
150pub fn choose_plugin_instance<'a, P>(
151    vendor: &str,
152    instances: impl IntoIterator<Item = (&'a str, &'a serde_json::Value)>,
153) -> Result<String, ChoosePluginError>
154where
155    P: for<'de> gts::GtsDeserialize<'de> + gts::GtsSchema,
156{
157    let mut best: Option<(&str, i16)> = None;
158    let mut count: usize = 0;
159
160    for (gts_id, content_val) in instances {
161        count += 1;
162        let content: BaseModkitPluginV1<P> =
163            serde_json::from_value(content_val.clone()).map_err(|e| {
164                tracing::error!(
165                    gts_id = %gts_id,
166                    error = %e,
167                    "Failed to deserialize plugin instance content"
168                );
169                ChoosePluginError::InvalidPluginInstance {
170                    gts_id: gts_id.to_owned(),
171                    reason: e.to_string(),
172                }
173            })?;
174
175        if content.id != gts_id {
176            return Err(ChoosePluginError::InvalidPluginInstance {
177                gts_id: gts_id.to_owned(),
178                reason: format!(
179                    "content.id mismatch: expected {:?}, got {:?}",
180                    gts_id, content.id
181                ),
182            });
183        }
184
185        if content.vendor != vendor {
186            continue;
187        }
188
189        match &best {
190            None => best = Some((gts_id, content.priority)),
191            Some((_, cur_priority)) => {
192                if content.priority < *cur_priority {
193                    best = Some((gts_id, content.priority));
194                }
195            }
196        }
197    }
198
199    tracing::debug!(vendor, instance_count = count, "choose_plugin_instance");
200
201    best.map(|(gts_id, _)| gts_id.to_owned())
202        .ok_or_else(|| ChoosePluginError::PluginNotFound {
203            schema_id: P::SCHEMA_ID.to_owned(),
204            vendor: vendor.to_owned(),
205        })
206}
207
208#[cfg(test)]
209#[cfg_attr(coverage_nightly, coverage(off))]
210mod tests {
211    use super::*;
212    use std::sync::Arc;
213    use std::sync::atomic::{AtomicUsize, Ordering};
214
215    #[tokio::test]
216    async fn resolve_called_once_returns_same_str() {
217        let selector = GtsPluginSelector::new();
218        let calls = Arc::new(AtomicUsize::new(0));
219
220        let calls_a = calls.clone();
221        let id_a = selector
222            .get_or_init(|| async move {
223                calls_a.fetch_add(1, Ordering::SeqCst);
224                Ok::<_, std::convert::Infallible>(
225                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
226                        .to_owned(),
227                )
228            })
229            .await
230            .unwrap();
231
232        let calls_b = calls.clone();
233        let id_b = selector
234            .get_or_init(|| async move {
235                calls_b.fetch_add(1, Ordering::SeqCst);
236                Ok::<_, std::convert::Infallible>(
237                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
238                        .to_owned(),
239                )
240            })
241            .await
242            .unwrap();
243
244        assert_eq!(id_a, id_b);
245        assert_eq!(calls.load(Ordering::SeqCst), 1);
246    }
247
248    #[tokio::test]
249    async fn reset_triggers_reselection() {
250        let selector = GtsPluginSelector::new();
251        let calls = Arc::new(AtomicUsize::new(0));
252
253        let calls_a = calls.clone();
254        let id_a = selector
255            .get_or_init(|| async move {
256                calls_a.fetch_add(1, Ordering::SeqCst);
257                Ok::<_, std::convert::Infallible>(
258                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
259                        .to_owned(),
260                )
261            })
262            .await;
263        assert_eq!(
264            &*id_a.unwrap(),
265            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
266        );
267        assert_eq!(calls.load(Ordering::SeqCst), 1);
268        assert!(selector.reset().await);
269
270        let calls_b = calls.clone();
271        let id_b = selector
272            .get_or_init(|| async move {
273                calls_b.fetch_add(1, Ordering::SeqCst);
274                Ok::<_, std::convert::Infallible>(
275                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
276                        .to_owned(),
277                )
278            })
279            .await;
280        assert_eq!(
281            &*id_b.unwrap(),
282            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
283        );
284        assert_eq!(calls.load(Ordering::SeqCst), 2);
285    }
286
287    #[tokio::test]
288    async fn concurrent_get_or_init_resolves_once() {
289        let selector = Arc::new(GtsPluginSelector::new());
290        let calls = Arc::new(AtomicUsize::new(0));
291
292        let mut handles = Vec::new();
293        for _ in 0..10 {
294            let selector = Arc::clone(&selector);
295            let calls = Arc::clone(&calls);
296            handles.push(tokio::spawn(async move {
297                selector
298                    .get_or_init(|| async {
299                        // Small delay to increase chance of concurrent access
300                        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
301                        calls.fetch_add(1, Ordering::SeqCst);
302                        Ok::<_, std::convert::Infallible>(
303                            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
304                                .to_owned(),
305                        )
306                    })
307                    .await
308            }));
309        }
310
311        // Await each handle in a loop (no futures_util dependency)
312        let mut results = Vec::new();
313        for handle in handles {
314            results.push(handle.await.unwrap().unwrap());
315        }
316
317        // All results should be the same
318        for id in &results {
319            assert_eq!(
320                &**id,
321                "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
322            );
323        }
324
325        // Resolve should have been called exactly once
326        assert_eq!(calls.load(Ordering::SeqCst), 1);
327    }
328}