use std::{
ffi::OsStr,
path::{Path, PathBuf},
sync::{Arc, Mutex},
};
use anyhow::Error;
use async_trait::async_trait;
use channel_plugin::plugin_actor::{PluginHandle, spawn_rpc_plugin};
use dashmap::DashMap;
use tracing::{error, info, warn};
use crate::watcher::{DirectoryWatcher, WatchedType};
#[async_trait]
pub trait PluginEventHandler: Send + Sync + 'static {
async fn plugin_added_or_reloaded(&self, name: &str, plugin: PluginHandle)
-> Result<(), Error>;
async fn plugin_removed(&self, name: &str) -> Result<(), Error>;
}
pub struct PluginWatcher {
dir: PathBuf,
pub plugins: DashMap<String, PluginHandle>,
subscribers: Mutex<Vec<Arc<dyn PluginEventHandler>>>,
path_to_name: DashMap<String, String>,
watcher: Option<DirectoryWatcher>,
}
impl PluginWatcher {
pub async fn new(dir: PathBuf) -> Self {
let plugins: DashMap<String, PluginHandle> = DashMap::new();
PluginWatcher {
dir,
plugins,
subscribers: Mutex::new(Vec::new()),
path_to_name: DashMap::new(),
watcher: None,
}
}
pub async fn watch(self: Arc<Self>) -> Result<DirectoryWatcher, Error> {
let dir = self.dir.clone();
let watch_me: Arc<dyn WatchedType> = self.clone();
DirectoryWatcher::new(dir, watch_me, &["exe", ""], true).await
}
pub fn set_watcher(&mut self, watcher: DirectoryWatcher) {
self.watcher = Some(watcher);
}
pub async fn shutdown(&self) {
if let Some(watcher) = self.watcher.clone() {
watcher.shutdown();
}
}
pub fn get(&self, name: &str) -> Option<PluginHandle> {
self.plugins
.get(name) .map(|entry| entry.value().clone())
}
pub async fn subscribe(&self, handler: Arc<dyn PluginEventHandler>, notify: bool) {
{
let mut subs = self.subscribers.lock().unwrap();
subs.push(handler.clone());
}
if notify {
for entry in self.plugins.iter() {
let name = entry.key(); let plugin = entry.value(); if let Err(err) = handler.plugin_added_or_reloaded(name, plugin.clone()).await {
warn!("Could not load plugin {}: {:?}", name, err);
}
}
}
}
async fn notify_add_or_reload(&self, name: &str, plugin: &PluginHandle) {
let subs = self.subscribers.lock().unwrap().clone();
for sub in subs {
let result = sub.plugin_added_or_reloaded(name, plugin.clone()).await;
if result.is_err() {
warn!("Could not reload plugin {}", name);
} else {
info!("Loaded plugin: {}", name.to_string());
}
}
}
async fn notify_removal(&self, name: &str) {
let subs = self.subscribers.lock().unwrap().clone();
for sub in subs {
let result = sub.plugin_removed(name).await;
if result.is_err() {
warn!("Could not remove plugin {}", name);
} else {
info!("Removed plugin: {}", name.to_string());
}
}
}
}
#[async_trait]
impl crate::watcher::WatchedType for PluginWatcher {
fn is_relevant(&self, path: &Path) -> bool {
path.parent().map(|d| d == self.dir).unwrap_or(false)
&& match path.extension().and_then(OsStr::to_str) {
Some(ext) => ["exe", "sh"].contains(&ext),
None => true, }
}
async fn on_create_or_modify(&self, path: &Path) -> anyhow::Result<()> {
match spawn_rpc_plugin(path).await {
Ok(handle) => {
let name = handle.name();
self.plugins.insert(name.to_string(), handle.clone());
let path_str = path.to_string_lossy().to_string();
self.path_to_name.insert(path_str, name.to_string());
self.notify_add_or_reload(&name, &handle).await;
}
Err(err) => {
error!("Could not load {:?} because {:?}", path, err);
}
}
Ok(())
}
async fn on_remove(&self, path: &Path) -> anyhow::Result<()> {
let path_str = path.to_string_lossy();
if let Some(name_ref) = self.path_to_name.get(&path_str.to_string()) {
let plugin_name = name_ref.value().clone();
self.plugins.remove(&plugin_name);
info!("Unloading plugin `{}`", plugin_name);
self.notify_removal(&plugin_name).await;
}
Ok(())
}
}
#[cfg(test)]
pub mod tests {
use crate::watcher::WatchedType;
use super::*;
use channel_plugin::plugin_test_util::spawn_mock_handle;
use std::{
fs::{self, File},
path::PathBuf,
};
use tempfile::TempDir;
#[tokio::test]
async fn test_mock_channel() {
let (_mock, handle) = spawn_mock_handle().await;
let name = handle.name();
assert_eq!(name, "mock");
let caps = handle.capabilities();
assert_eq!(caps.name, "mock");
}
#[tokio::test(flavor = "current_thread")]
async fn is_relevant_only_dylibs_in_dir() {
let tmp = TempDir::new().unwrap();
let dir = tmp.path().to_path_buf();
let watcher = PluginWatcher::new(dir.clone()).await;
let good = dir.join("plugin1");
let bad_ext = dir.join("not_a_plugin.txt");
let outside = PathBuf::from("/other/plugin2.dll");
assert!(watcher.is_relevant(&good));
assert!(!watcher.is_relevant(&bad_ext));
assert!(!watcher.is_relevant(&outside));
}
#[tokio::test(flavor = "current_thread")]
async fn new_skips_invalid_plugins() {
let tmp = TempDir::new().unwrap();
let dir = tmp.path().to_path_buf();
let so = dir.join("a.dll");
File::create(&so).unwrap();
let txt = dir.join("b.txt");
File::create(&txt).unwrap();
let watcher = PluginWatcher::new(dir).await;
assert!(watcher.plugins.is_empty());
}
#[tokio::test(flavor = "current_thread")]
async fn on_create_or_modify_loads_new_plugin() {
let tmp = TempDir::new().unwrap();
let dir = tmp.path().to_path_buf();
let watcher = PluginWatcher::new(dir.clone()).await;
let exe = dir.join("new.exe");
File::create(&exe).unwrap();
watcher.on_create_or_modify(&exe).await.unwrap();
assert!(watcher.plugins.is_empty());
}
#[tokio::test(flavor = "current_thread")]
async fn on_remove_unloads_plugin_safely() {
let tmp = TempDir::new().unwrap();
let dir = tmp.path().to_path_buf();
let watcher = PluginWatcher::new(dir.clone()).await;
let so = dir.join("dummy.exe");
File::create(&so).unwrap();
{
let (_mock, plugin_handle) = spawn_mock_handle().await;
watcher.plugins.insert("dummy".into(), plugin_handle);
watcher
.path_to_name
.insert(so.to_string_lossy().into_owned(), "dummy".to_string());
}
let bogus = dir.join("unknown.exe");
watcher.on_remove(&bogus).await.unwrap();
let p = dir.join("dummy.exe");
let _ = fs::File::create(&p);
watcher.on_remove(&p).await.unwrap();
assert!(watcher.plugins.is_empty());
}
}