use std::any::{Any, TypeId};
use std::cell::RefCell;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{LazyLock, Mutex};
use crate::{HashSet, PluginContext};
static DATA_PLUGINS: LazyLock<Mutex<RefCell<HashSet<TypeId>>>> =
LazyLock::new(|| Mutex::new(RefCell::new(HashSet::default())));
pub fn add_data_plugin_to_registry<T: DataPlugin>() {
DATA_PLUGINS
.lock()
.unwrap()
.borrow_mut()
.insert(TypeId::of::<T>());
}
pub fn get_data_plugin_ids() -> Vec<TypeId> {
DATA_PLUGINS
.lock()
.unwrap()
.borrow()
.iter()
.copied()
.collect()
}
pub fn get_data_plugin_count() -> usize {
DATA_PLUGINS.lock().unwrap().borrow().len()
}
static NEXT_DATA_PLUGIN_INDEX: Mutex<usize> = Mutex::new(0);
pub fn initialize_data_plugin_index(plugin_index: &AtomicUsize) -> usize {
let mut guard = NEXT_DATA_PLUGIN_INDEX.lock().unwrap();
let candidate = *guard;
match plugin_index.compare_exchange(usize::MAX, candidate, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {
*guard += 1;
candidate
}
Err(existing) => {
existing
}
}
}
pub trait DataPlugin: Any {
type DataContainer;
fn init<C: PluginContext>(context: &C) -> Self::DataContainer;
fn index_within_context() -> usize;
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Barrier};
use std::thread;
use super::*;
use crate::{define_data_plugin, Context};
#[test]
#[should_panic(
expected = "No data plugin found with index = 1000. You must use the `define_data_plugin!` macro to create a data plugin."
)]
fn test_wrong_data_plugin_impl_index_oob() {
struct MyDataPlugin;
impl DataPlugin for MyDataPlugin {
type DataContainer = Vec<u32>;
fn init<C: PluginContext>(_context: &C) -> Self::DataContainer {
vec![]
}
fn index_within_context() -> usize {
1000 }
}
let context = Context::new();
let container = context.get_data(MyDataPlugin);
println!("{}", container.len());
}
define_data_plugin!(LegitDataPlugin, Vec<u32>, vec![]);
#[should_panic(
expected = "TypeID does not match data plugin type. You must use the `define_data_plugin!` macro to create a data plugin."
)]
#[test]
fn test_wrong_data_plugin_impl_wrong_type() {
struct MyOtherDataPlugin;
impl DataPlugin for MyOtherDataPlugin {
type DataContainer = Vec<u8>;
fn init<C: PluginContext>(_context: &C) -> Self::DataContainer {
vec![]
}
fn index_within_context() -> usize {
LegitDataPlugin::index_within_context()
}
}
let context = Context::new();
let _ = context.get_data(LegitDataPlugin);
let container = context.get_data(MyOtherDataPlugin);
println!("{}", container.len());
}
#[test]
fn test_multithreaded_plugin_init() {
struct DataPluginContainerA;
define_data_plugin!(DataPluginA, DataPluginContainerA, DataPluginContainerA);
struct DataPluginContainerB;
define_data_plugin!(DataPluginB, DataPluginContainerB, DataPluginContainerB);
struct DataPluginContainerC;
define_data_plugin!(DataPluginC, DataPluginContainerC, DataPluginContainerC);
struct DataPluginContainerD;
define_data_plugin!(DataPluginD, DataPluginContainerD, DataPluginContainerD);
let accessors: Vec<&(dyn Fn(&Context) + Send + Sync)> = vec![
&|ctx: &Context| {
let _ = ctx.get_data(DataPluginA);
},
&|ctx: &Context| {
let _ = ctx.get_data(DataPluginB);
},
&|ctx: &Context| {
let _ = ctx.get_data(DataPluginC);
},
&|ctx: &Context| {
let _ = ctx.get_data(DataPluginD);
},
];
let num_threads = 20;
let barrier = Arc::new(Barrier::new(num_threads));
let mut handles = Vec::with_capacity(num_threads);
for i in 0..num_threads {
let barrier = Arc::clone(&barrier);
let accessor = accessors[i % accessors.len()];
let handle = thread::spawn(move || {
let context = Context::new();
barrier.wait();
accessor(&context);
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("Thread panicked");
}
}
}