Skip to main content

modkit/plugins/
mod.rs

1use std::future::Future;
2use std::sync::Arc;
3
4use parking_lot::RwLock;
5use tokio::sync::Mutex;
6
7/// A resettable, allocation-friendly selector for GTS plugin instance IDs.
8///
9/// Uses a single-flight pattern to ensure that the resolve function is called
10/// at most once even under concurrent callers. The selected instance ID is
11/// cached as `Arc<str>` to avoid allocations on the happy path.
12pub struct GtsPluginSelector {
13    /// Cached selected instance ID (sync lock for fast access and sync reset).
14    cached: RwLock<Option<Arc<str>>>,
15    /// Mutex to ensure single-flight resolution.
16    resolve_lock: Mutex<()>,
17}
18
19impl Default for GtsPluginSelector {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl GtsPluginSelector {
26    #[must_use]
27    pub fn new() -> Self {
28        Self {
29            cached: RwLock::new(None),
30            resolve_lock: Mutex::new(()),
31        }
32    }
33
34    /// Returns the cached instance ID, or resolves it using the provided function.
35    ///
36    /// Uses a single-flight pattern: even under concurrent callers, the resolve
37    /// function is called at most once. Returns `Arc<str>` to avoid allocations
38    /// on the happy path.
39    /// # Errors
40    ///
41    /// Returns `Err(E)` if the provided `resolve` future fails.
42    pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
43    where
44        F: FnOnce() -> Fut,
45        Fut: Future<Output = Result<String, E>>,
46    {
47        // Fast path: check if already cached (sync lock, no await)
48        {
49            let guard = self.cached.read();
50            if let Some(ref id) = *guard {
51                return Ok(Arc::clone(id));
52            }
53        }
54
55        // Slow path: acquire resolve lock for single-flight
56        let _resolve_guard = self.resolve_lock.lock().await;
57
58        // Re-check after acquiring resolve lock (another caller may have resolved)
59        {
60            let guard = self.cached.read();
61            if let Some(ref id) = *guard {
62                return Ok(Arc::clone(id));
63            }
64        }
65
66        // Resolve and cache
67        let id_string = resolve().await?;
68        let id: Arc<str> = id_string.into();
69
70        {
71            let mut guard = self.cached.write();
72            *guard = Some(Arc::clone(&id));
73        }
74
75        Ok(id)
76    }
77
78    /// Clears the cached selected instance ID.
79    ///
80    /// Returns `true` if there was a cached value, `false` otherwise.
81    pub async fn reset(&self) -> bool {
82        let _resolve_guard = self.resolve_lock.lock().await;
83        let mut guard = self.cached.write();
84        guard.take().is_some()
85    }
86}
87
88#[cfg(test)]
89#[cfg_attr(coverage_nightly, coverage(off))]
90mod tests {
91    use super::*;
92    use std::sync::Arc;
93    use std::sync::atomic::{AtomicUsize, Ordering};
94
95    #[tokio::test]
96    async fn resolve_called_once_returns_same_str() {
97        let selector = GtsPluginSelector::new();
98        let calls = Arc::new(AtomicUsize::new(0));
99
100        let calls_a = calls.clone();
101        let id_a = selector
102            .get_or_init(|| async move {
103                calls_a.fetch_add(1, Ordering::SeqCst);
104                Ok::<_, std::convert::Infallible>(
105                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
106                        .to_owned(),
107                )
108            })
109            .await
110            .unwrap();
111
112        let calls_b = calls.clone();
113        let id_b = selector
114            .get_or_init(|| async move {
115                calls_b.fetch_add(1, Ordering::SeqCst);
116                Ok::<_, std::convert::Infallible>(
117                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
118                        .to_owned(),
119                )
120            })
121            .await
122            .unwrap();
123
124        assert_eq!(id_a, id_b);
125        assert_eq!(calls.load(Ordering::SeqCst), 1);
126    }
127
128    #[tokio::test]
129    async fn reset_triggers_reselection() {
130        let selector = GtsPluginSelector::new();
131        let calls = Arc::new(AtomicUsize::new(0));
132
133        let calls_a = calls.clone();
134        let id_a = selector
135            .get_or_init(|| async move {
136                calls_a.fetch_add(1, Ordering::SeqCst);
137                Ok::<_, std::convert::Infallible>(
138                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
139                        .to_owned(),
140                )
141            })
142            .await;
143        assert_eq!(
144            &*id_a.unwrap(),
145            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
146        );
147        assert_eq!(calls.load(Ordering::SeqCst), 1);
148        assert!(selector.reset().await);
149
150        let calls_b = calls.clone();
151        let id_b = selector
152            .get_or_init(|| async move {
153                calls_b.fetch_add(1, Ordering::SeqCst);
154                Ok::<_, std::convert::Infallible>(
155                    "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
156                        .to_owned(),
157                )
158            })
159            .await;
160        assert_eq!(
161            &*id_b.unwrap(),
162            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
163        );
164        assert_eq!(calls.load(Ordering::SeqCst), 2);
165    }
166
167    #[tokio::test]
168    async fn concurrent_get_or_init_resolves_once() {
169        let selector = Arc::new(GtsPluginSelector::new());
170        let calls = Arc::new(AtomicUsize::new(0));
171
172        let mut handles = Vec::new();
173        for _ in 0..10 {
174            let selector = Arc::clone(&selector);
175            let calls = Arc::clone(&calls);
176            handles.push(tokio::spawn(async move {
177                selector
178                    .get_or_init(|| async {
179                        // Small delay to increase chance of concurrent access
180                        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
181                        calls.fetch_add(1, Ordering::SeqCst);
182                        Ok::<_, std::convert::Infallible>(
183                            "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
184                                .to_owned(),
185                        )
186                    })
187                    .await
188            }));
189        }
190
191        // Await each handle in a loop (no futures_util dependency)
192        let mut results = Vec::new();
193        for handle in handles {
194            results.push(handle.await.unwrap().unwrap());
195        }
196
197        // All results should be the same
198        for id in &results {
199            assert_eq!(
200                &**id,
201                "gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
202            );
203        }
204
205        // Resolve should have been called exactly once
206        assert_eq!(calls.load(Ordering::SeqCst), 1);
207    }
208}