Skip to main content

modkit/plugins/
mod.rs

1use std::future::Future;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use tokio::sync::Mutex;
6
7use crate::gts::BaseModkitPluginV1;
8
9/// A resettable, allocation-friendly selector for GTS plugin instance IDs.
10///
11/// Uses a single-flight pattern to ensure that the resolve function is called
12/// at most once even under concurrent callers. The selected instance ID is
13/// cached as `Arc<str>` to avoid allocations on the happy path.
14pub struct GtsPluginSelector {
15    /// Cached selected instance ID (sync lock for fast access and sync reset).
16    cached: RwLock<Option<Arc<str>>>,
17    /// Mutex to ensure single-flight resolution.
18    resolve_lock: Mutex<()>,
19}
20
21impl Default for GtsPluginSelector {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl GtsPluginSelector {
28    #[must_use]
29    pub fn new() -> Self {
30        Self {
31            cached: RwLock::new(None),
32            resolve_lock: Mutex::new(()),
33        }
34    }
35
36    /// Returns the cached instance ID, or resolves it using the provided function.
37    ///
38    /// Uses a single-flight pattern: even under concurrent callers, the resolve
39    /// function is called at most once. Returns `Arc<str>` to avoid allocations
40    /// on the happy path.
41    /// # Errors
42    ///
43    /// Returns `Err(E)` if the provided `resolve` future fails.
44    pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
45    where
46        F: FnOnce() -> Fut,
47        Fut: Future<Output = Result<String, E>>,
48    {
49        // Fast path: check if already cached (sync lock, no await)
50        {
51            let guard = self.cached.read();
52            if let Some(ref id) = *guard {
53                return Ok(Arc::clone(id));
54            }
55        }
56
57        // Slow path: acquire resolve lock for single-flight
58        let _resolve_guard = self.resolve_lock.lock().await;
59
60        // Re-check after acquiring resolve lock (another caller may have resolved)
61        {
62            let guard = self.cached.read();
63            if let Some(ref id) = *guard {
64                return Ok(Arc::clone(id));
65            }
66        }
67
68        // Resolve and cache
69        let id_string = resolve().await?;
70        let id: Arc<str> = id_string.into();
71
72        {
73            let mut guard = self.cached.write();
74            *guard = Some(Arc::clone(&id));
75        }
76
77        Ok(id)
78    }
79
80    /// Clears the cached selected instance ID.
81    ///
82    /// Returns `true` if there was a cached value, `false` otherwise.
83    pub async fn reset(&self) -> bool {
84        let _resolve_guard = self.resolve_lock.lock().await;
85        let mut guard = self.cached.write();
86        guard.take().is_some()
87    }
88}
89
90/// Error returned by [`choose_plugin_instance`].
91#[derive(Debug, thiserror::Error)]
92pub enum ChoosePluginError {
93    /// Failed to deserialize a plugin instance's content.
94    #[error("invalid plugin instance content for '{gts_id}': {reason}")]
95    InvalidPluginInstance {
96        /// GTS ID of the malformed instance.
97        gts_id: String,
98        /// Human-readable reason.
99        reason: String,
100    },
101
102    /// No plugin instance matched the requested vendor.
103    #[error("no plugin instances found for schema '{schema_id}', vendor '{vendor}'")]
104    PluginNotFound {
105        /// GTS schema ID of the plugin type being resolved.
106        schema_id: String,
107        /// The vendor that was requested.
108        vendor: String,
109    },
110}
111
112/// Selects the best plugin instance for the given vendor.
113///
114/// Accepts an iterator of `(gts_id, content)` pairs — typically
115/// produced from `types_registry_sdk::GtsEntity`:
116///
117/// ```ignore
118/// choose_plugin_instance::<MyPluginSpecV1>(
119///     &self.vendor,
120///     instances.iter().map(|e| (e.gts_id.as_str(), &e.content)),
121/// )
122/// ```
123///
124/// Deserializes each entry as `BaseModkitPluginV1<P>`, filters by
125/// `vendor`, and returns the `gts_id` of the instance with the
126/// **lowest** priority value.
127///
128/// # Type Parameters
129///
130/// - `P` — The plugin-specific properties struct (e.g.
131///   `AuthNResolverPluginSpecV1`). Must be `DeserializeOwned`.
132///
133/// # Errors
134///
135/// - [`ChoosePluginError::InvalidPluginInstance`] if deserialization fails
136///   or the `content.id` doesn't match `gts_id`.
137/// - [`ChoosePluginError::PluginNotFound`] if no instance matches the vendor.
138pub fn choose_plugin_instance<'a, P>(
139    vendor: &str,
140    instances: impl IntoIterator<Item = (&'a str, &'a serde_json::Value)>,
141) -> Result<String, ChoosePluginError>
142where
143    P: for<'de> gts::GtsDeserialize<'de> + gts::GtsSchema,
144{
145    let mut best: Option<(&str, i16)> = None;
146    let mut count: usize = 0;
147
148    for (gts_id, content_val) in instances {
149        count += 1;
150        let content: BaseModkitPluginV1<P> =
151            serde_json::from_value(content_val.clone()).map_err(|e| {
152                tracing::error!(
153                    gts_id = %gts_id,
154                    error = %e,
155                    "Failed to deserialize plugin instance content"
156                );
157                ChoosePluginError::InvalidPluginInstance {
158                    gts_id: gts_id.to_owned(),
159                    reason: e.to_string(),
160                }
161            })?;
162
163        if content.id != gts_id {
164            return Err(ChoosePluginError::InvalidPluginInstance {
165                gts_id: gts_id.to_owned(),
166                reason: format!(
167                    "content.id mismatch: expected {:?}, got {:?}",
168                    gts_id, content.id
169                ),
170            });
171        }
172
173        if content.vendor != vendor {
174            continue;
175        }
176
177        match &best {
178            None => best = Some((gts_id, content.priority)),
179            Some((_, cur_priority)) => {
180                if content.priority < *cur_priority {
181                    best = Some((gts_id, content.priority));
182                }
183            }
184        }
185    }
186
187    tracing::debug!(vendor, instance_count = count, "choose_plugin_instance");
188
189    best.map(|(gts_id, _)| gts_id.to_owned())
190        .ok_or_else(|| ChoosePluginError::PluginNotFound {
191            schema_id: P::SCHEMA_ID.to_owned(),
192            vendor: vendor.to_owned(),
193        })
194}
195
196#[cfg(test)]
197#[cfg_attr(coverage_nightly, coverage(off))]
198mod tests {
199    use super::*;
200    use std::sync::Arc;
201    use std::sync::atomic::{AtomicUsize, Ordering};
202
203    #[tokio::test]
204    async fn resolve_called_once_returns_same_str() {
205        let selector = GtsPluginSelector::new();
206        let calls = Arc::new(AtomicUsize::new(0));
207
208        let calls_a = calls.clone();
209        let id_a = selector
210            .get_or_init(|| async move {
211                calls_a.fetch_add(1, Ordering::SeqCst);
212                Ok::<_, std::convert::Infallible>(
213                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
214                        .to_owned(),
215                )
216            })
217            .await
218            .unwrap();
219
220        let calls_b = calls.clone();
221        let id_b = selector
222            .get_or_init(|| async move {
223                calls_b.fetch_add(1, Ordering::SeqCst);
224                Ok::<_, std::convert::Infallible>(
225                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
226                        .to_owned(),
227                )
228            })
229            .await
230            .unwrap();
231
232        assert_eq!(id_a, id_b);
233        assert_eq!(calls.load(Ordering::SeqCst), 1);
234    }
235
236    #[tokio::test]
237    async fn reset_triggers_reselection() {
238        let selector = GtsPluginSelector::new();
239        let calls = Arc::new(AtomicUsize::new(0));
240
241        let calls_a = calls.clone();
242        let id_a = selector
243            .get_or_init(|| async move {
244                calls_a.fetch_add(1, Ordering::SeqCst);
245                Ok::<_, std::convert::Infallible>(
246                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
247                        .to_owned(),
248                )
249            })
250            .await;
251        assert_eq!(
252            &*id_a.unwrap(),
253            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
254        );
255        assert_eq!(calls.load(Ordering::SeqCst), 1);
256        assert!(selector.reset().await);
257
258        let calls_b = calls.clone();
259        let id_b = selector
260            .get_or_init(|| async move {
261                calls_b.fetch_add(1, Ordering::SeqCst);
262                Ok::<_, std::convert::Infallible>(
263                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
264                        .to_owned(),
265                )
266            })
267            .await;
268        assert_eq!(
269            &*id_b.unwrap(),
270            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
271        );
272        assert_eq!(calls.load(Ordering::SeqCst), 2);
273    }
274
275    #[tokio::test]
276    async fn concurrent_get_or_init_resolves_once() {
277        let selector = Arc::new(GtsPluginSelector::new());
278        let calls = Arc::new(AtomicUsize::new(0));
279
280        let mut handles = Vec::new();
281        for _ in 0..10 {
282            let selector = Arc::clone(&selector);
283            let calls = Arc::clone(&calls);
284            handles.push(tokio::spawn(async move {
285                selector
286                    .get_or_init(|| async {
287                        // Small delay to increase chance of concurrent access
288                        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
289                        calls.fetch_add(1, Ordering::SeqCst);
290                        Ok::<_, std::convert::Infallible>(
291                            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
292                                .to_owned(),
293                        )
294                    })
295                    .await
296            }));
297        }
298
299        // Await each handle in a loop (no futures_util dependency)
300        let mut results = Vec::new();
301        for handle in handles {
302            results.push(handle.await.unwrap().unwrap());
303        }
304
305        // All results should be the same
306        for id in &results {
307            assert_eq!(
308                &**id,
309                "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
310            );
311        }
312
313        // Resolve should have been called exactly once
314        assert_eq!(calls.load(Ordering::SeqCst), 1);
315    }
316}