#![warn(missing_docs)]
#![warn(clippy::missing_errors_doc)]
#![warn(clippy::missing_panics_doc)]
#![warn(clippy::missing_safety_doc)]
#![warn(clippy::panic)]
#![warn(clippy::todo)]
#![warn(clippy::pedantic)]
#![warn(clippy::all)]
#![warn(clippy::empty_docs)]
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub mod error;
pub mod hook;
pub mod macros;
use tracing::{error, warn};
use self::error::{PluginError, PluginResult};
use self::hook::{ExtensionPoint, HookRegistry};
pub type PluginID = &'static str;
#[cfg_attr(feature = "serde", doc = concat!(
r"# Serialization
When the `serde` feature is enabled, this type implements [`Serialize`]
and [`Deserialize`]. Note that deserialization involves a memory leak,
as the string is converted to a `&'static str` by leaking memory (to make sure it is always
existing in memory).
The leaking is using safe rust with [`String::leak`]."
))]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize), serde(transparent))]
pub struct PluginIDOwned {
inner: &'static str,
}
impl PluginIDOwned {
#[inline]
#[must_use]
pub fn id(&self) -> PluginID {
self.inner
}
}
impl From<PluginID> for PluginIDOwned {
fn from(value: PluginID) -> Self {
Self { inner: value }
}
}
impl From<&PluginIDOwned> for PluginID {
fn from(value: &PluginIDOwned) -> Self {
value.id()
}
}
impl From<PluginIDOwned> for PluginID {
fn from(value: PluginIDOwned) -> Self {
value.id()
}
}
impl std::fmt::Display for PluginIDOwned {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.inner, f)
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for PluginIDOwned {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct PluginIDVisitor;
impl serde::de::Visitor<'_> for PluginIDVisitor {
type Value = PluginIDOwned;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(PluginIDOwned {
inner: value.to_string().leak(),
})
}
}
deserializer.deserialize_str(PluginIDVisitor)
}
}
pub trait Plugin: Any + Send + Sync + Debug {
fn id(&self) -> PluginID;
fn description(&self) -> &str;
fn is_enabled(&self) -> bool;
fn enable(&mut self);
fn disable(&mut self);
fn register_hooks(&self, registry: &mut HookRegistry) -> PluginResult<()>;
fn on_load(&mut self) -> PluginResult<()> {
Ok(())
}
fn on_unload(&mut self) -> PluginResult<()> {
Ok(())
}
}
#[derive(Debug, Default)]
pub struct PluginManager {
plugins: HashMap<PluginID, Box<dyn Plugin>>,
hook_registry: HookRegistry,
}
impl PluginManager {
#[must_use]
pub fn new() -> Self {
Self {
plugins: HashMap::new(),
hook_registry: HookRegistry::new(),
}
}
#[must_use]
pub fn with_registry(hook_registry: HookRegistry) -> Self {
Self {
plugins: HashMap::new(),
hook_registry,
}
}
#[must_use]
pub fn hook_registry(&self) -> &HookRegistry {
&self.hook_registry
}
#[must_use]
pub fn hook_registry_mut(&mut self) -> &mut HookRegistry {
&mut self.hook_registry
}
pub fn load_plugin(&mut self, mut plugin: Box<dyn Plugin>) -> PluginResult<()> {
let id = plugin.id();
if self.plugins.contains_key(id) {
return Err(error::PluginError::AlreadyLoaded(id.into()));
}
if let Err(e) = plugin.register_hooks(self.hook_registry_mut()) {
self.handle_error_during_load(&e, id);
return Err(e);
}
if let Err(e) = plugin.on_load() {
self.handle_error_during_load(&e, id);
return Err(e);
}
self.plugins.insert(id, plugin);
Ok(())
}
fn handle_error_during_load(&mut self, e: &PluginError, plugin_id: PluginID) {
error!("Could not register hooks of plugin {plugin_id}: {e}");
warn!("Trying to unload the plugin again... Will crash if this fails");
self.unload_plugin(plugin_id)
.expect("Could not unload bad plugin again");
}
pub fn unload_plugin(&mut self, id: PluginID) -> PluginResult<()> {
if let Some(mut plugin) = self.plugins.remove(id) {
plugin.on_unload()?;
self.hook_registry.deregister_hooks_for_plugin(id);
}
Ok(())
}
#[must_use]
pub fn get_plugin(&self, id: PluginID) -> Option<&dyn Plugin> {
self.plugins.get(id).map(std::convert::AsRef::as_ref)
}
#[must_use]
pub fn get_plugin_mut(&mut self, id: PluginID) -> Option<&mut dyn Plugin> {
self.plugins.get_mut(id).map(std::convert::AsMut::as_mut)
}
#[must_use]
pub fn plugin_ids(&self) -> Vec<PluginID> {
self.plugins.keys().copied().collect()
}
#[must_use]
pub fn plugins(&self) -> Vec<&dyn Plugin> {
self.plugins
.values()
.map(std::convert::AsRef::as_ref)
.collect()
}
#[must_use]
pub fn enabled_plugins(&self) -> Vec<&dyn Plugin> {
self.plugins
.values()
.filter(|p| p.is_enabled())
.map(std::convert::AsRef::as_ref)
.collect()
}
#[inline]
#[must_use]
pub fn plugin_is_enabled(&self, id: PluginID) -> Option<bool> {
Some(self.plugins.get(id)?.is_enabled())
}
pub fn enable_plugin(&mut self, id: PluginID) -> PluginResult<()> {
match self.plugins.get_mut(id) {
Some(plugin) => {
plugin.enable();
Ok(())
}
None => Err(error::PluginError::NotFound(id.into())),
}
}
pub fn disable_plugin(&mut self, id: PluginID) -> PluginResult<()> {
match self.plugins.get_mut(id) {
Some(plugin) => {
plugin.disable();
Ok(())
}
None => Err(error::PluginError::NotFound(id.into())),
}
}
#[must_use]
pub fn get_enabled_hooks_by_ep<E: ExtensionPoint>(
&self,
) -> Vec<(&hook::HookID, &hook::Hook<E>)> {
self.hook_registry()
.get_by_extension_point()
.into_iter()
.filter(|(id, _hook)| {
if let Some(plugin) = self.plugins.get(id.plugin_id) {
plugin.is_enabled()
} else {
false
}
})
.collect()
}
#[must_use]
pub fn get_enabled_hooks_by_ep_mut<E: ExtensionPoint>(
&mut self,
) -> Vec<(&hook::HookID, &mut hook::Hook<E>)> {
let enabled_ids: Vec<PluginID> = self
.plugins
.iter()
.filter_map(|(id, plug)| if plug.is_enabled() { Some(*id) } else { None })
.collect();
self.hook_registry_mut()
.get_by_extension_point_mut()
.into_iter()
.filter(|(id, _hook)| enabled_ids.contains(&id.plugin_id))
.collect()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_ser_dser_pluginid() {
let some_id: PluginID = "foo";
let oid = PluginIDOwned::from(some_id);
let serial = serde_json::to_string(&oid).unwrap();
assert_eq!(serial, r#""foo""#);
let raw = r#""myid""#;
let oid: PluginIDOwned = serde_json::from_str(raw).unwrap();
let id = oid.id();
let serial: String = serde_json::to_string(&id).unwrap();
assert_eq!(raw, format!(r#""{id}""#));
assert_eq!(serial, raw);
}
}