use std::future::Future;
use std::sync::Arc;
use parking_lot::RwLock;
use tokio::sync::Mutex;
use crate::gts::BaseModkitPluginV1;
pub struct GtsPluginSelector {
cached: RwLock<Option<Arc<str>>>,
resolve_lock: Mutex<()>,
}
impl Default for GtsPluginSelector {
fn default() -> Self {
Self::new()
}
}
impl GtsPluginSelector {
#[must_use]
pub fn new() -> Self {
Self {
cached: RwLock::new(None),
resolve_lock: Mutex::new(()),
}
}
#[must_use]
pub fn pre_cached(value: String) -> Self {
Self {
cached: RwLock::new(Some(Arc::from(value))),
resolve_lock: Mutex::new(()),
}
}
pub async fn get_or_init<F, Fut, E>(&self, resolve: F) -> Result<Arc<str>, E>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<String, E>>,
{
{
let guard = self.cached.read();
if let Some(ref id) = *guard {
return Ok(Arc::clone(id));
}
}
let _resolve_guard = self.resolve_lock.lock().await;
{
let guard = self.cached.read();
if let Some(ref id) = *guard {
return Ok(Arc::clone(id));
}
}
let id_string = resolve().await?;
let id: Arc<str> = id_string.into();
{
let mut guard = self.cached.write();
*guard = Some(Arc::clone(&id));
}
Ok(id)
}
pub async fn reset(&self) -> bool {
let _resolve_guard = self.resolve_lock.lock().await;
let mut guard = self.cached.write();
guard.take().is_some()
}
}
#[derive(Debug, thiserror::Error)]
pub enum ChoosePluginError {
#[error("invalid plugin instance content for '{gts_id}': {reason}")]
InvalidPluginInstance {
gts_id: String,
reason: String,
},
#[error("no plugin instances found for schema '{schema_id}', vendor '{vendor}'")]
PluginNotFound {
schema_id: String,
vendor: String,
},
}
pub fn choose_plugin_instance<'a, P>(
vendor: &str,
instances: impl IntoIterator<Item = (&'a str, &'a serde_json::Value)>,
) -> Result<String, ChoosePluginError>
where
P: for<'de> gts::GtsDeserialize<'de> + gts::GtsSchema,
{
let mut best: Option<(&str, i16)> = None;
let mut count: usize = 0;
for (gts_id, content_val) in instances {
count += 1;
let content: BaseModkitPluginV1<P> =
serde_json::from_value(content_val.clone()).map_err(|e| {
tracing::error!(
gts_id = %gts_id,
error = %e,
"Failed to deserialize plugin instance content"
);
ChoosePluginError::InvalidPluginInstance {
gts_id: gts_id.to_owned(),
reason: e.to_string(),
}
})?;
if content.id != gts_id {
return Err(ChoosePluginError::InvalidPluginInstance {
gts_id: gts_id.to_owned(),
reason: format!(
"content.id mismatch: expected {:?}, got {:?}",
gts_id, content.id
),
});
}
if content.vendor != vendor {
continue;
}
match &best {
None => best = Some((gts_id, content.priority)),
Some((_, cur_priority)) => {
if content.priority < *cur_priority {
best = Some((gts_id, content.priority));
}
}
}
}
tracing::debug!(vendor, instance_count = count, "choose_plugin_instance");
best.map(|(gts_id, _)| gts_id.to_owned())
.ok_or_else(|| ChoosePluginError::PluginNotFound {
schema_id: P::SCHEMA_ID.to_owned(),
vendor: vendor.to_owned(),
})
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn resolve_called_once_returns_same_str() {
let selector = GtsPluginSelector::new();
let calls = Arc::new(AtomicUsize::new(0));
let calls_a = calls.clone();
let id_a = selector
.get_or_init(|| async move {
calls_a.fetch_add(1, Ordering::SeqCst);
Ok::<_, std::convert::Infallible>(
"gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
.to_owned(),
)
})
.await
.unwrap();
let calls_b = calls.clone();
let id_b = selector
.get_or_init(|| async move {
calls_b.fetch_add(1, Ordering::SeqCst);
Ok::<_, std::convert::Infallible>(
"gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
.to_owned(),
)
})
.await
.unwrap();
assert_eq!(id_a, id_b);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn reset_triggers_reselection() {
let selector = GtsPluginSelector::new();
let calls = Arc::new(AtomicUsize::new(0));
let calls_a = calls.clone();
let id_a = selector
.get_or_init(|| async move {
calls_a.fetch_add(1, Ordering::SeqCst);
Ok::<_, std::convert::Infallible>(
"gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
.to_owned(),
)
})
.await;
assert_eq!(
&*id_a.unwrap(),
"gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~a.test._.plugin.v1"
);
assert_eq!(calls.load(Ordering::SeqCst), 1);
assert!(selector.reset().await);
let calls_b = calls.clone();
let id_b = selector
.get_or_init(|| async move {
calls_b.fetch_add(1, Ordering::SeqCst);
Ok::<_, std::convert::Infallible>(
"gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
.to_owned(),
)
})
.await;
assert_eq!(
&*id_b.unwrap(),
"gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~b.test._.plugin.v1"
);
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn concurrent_get_or_init_resolves_once() {
let selector = Arc::new(GtsPluginSelector::new());
let calls = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..10 {
let selector = Arc::clone(&selector);
let calls = Arc::clone(&calls);
handles.push(tokio::spawn(async move {
selector
.get_or_init(|| async {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
calls.fetch_add(1, Ordering::SeqCst);
Ok::<_, std::convert::Infallible>(
"gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
.to_owned(),
)
})
.await
}));
}
let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap().unwrap());
}
for id in &results {
assert_eq!(
&**id,
"gts.x.core.modkit.plugin.v1~x.core.test.plugin.v1~concurrent.test._.plugin.v1"
);
}
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
}