use async_trait::async_trait;
use crate::error::Result;
#[async_trait]
pub trait Plugin: Send + Sync {
fn name(&self) -> &str;
async fn start(&self) -> Result<()>;
async fn stop(&self) -> Result<()>;
fn depends_on(&self) -> &[&str] {
&[]
}
}
use std::collections::{HashMap, HashSet, VecDeque};
use tracing::{info, error};
pub struct PluginManager {
plugins: Vec<Box<dyn Plugin>>,
}
impl PluginManager {
pub fn new() -> Self {
Self { plugins: Vec::new() }
}
pub fn add<P: Plugin + 'static>(mut self, plugin: P) -> Self {
self.plugins.push(Box::new(plugin));
self
}
pub fn add_discovered(mut self, plugins: Vec<Box<dyn Plugin>>) -> Self {
self.plugins.extend(plugins);
self
}
pub async fn start_all(&self) -> Result<()> {
let ordered = self.topological_sort()?;
for plugin in ordered {
info!("启动插件: {}", plugin.name());
if let Err(e) = plugin.start().await {
error!("插件 {} 启动失败: {}", plugin.name(), e);
return Err(e);
}
}
Ok(())
}
pub async fn stop_all(&self) {
for plugin in self.plugins.iter().rev() {
info!("停止插件: {}", plugin.name());
if let Err(e) = plugin.stop().await {
error!("插件 {} 停止异常: {}", plugin.name(), e);
}
}
}
fn topological_sort(&self) -> Result<Vec<&Box<dyn Plugin>>> {
let name_to_idx: HashMap<&str, usize> = self.plugins
.iter()
.enumerate()
.map(|(i, p)| (p.name(), i))
.collect();
let count = self.plugins.len();
let mut in_degree = vec![0usize; count];
let mut adj = vec![Vec::new(); count];
for (i, plugin) in self.plugins.iter().enumerate() {
for dep in plugin.depends_on() {
if let Some(&j) = name_to_idx.get(dep) {
adj[j].push(i);
in_degree[i] += 1;
}
}
}
let mut queue: VecDeque<usize> = (0..count)
.filter(|&i| in_degree[i] == 0)
.collect();
let mut sorted = Vec::with_capacity(count);
while let Some(u) = queue.pop_front() {
sorted.push(&self.plugins[u]);
for &v in &adj[u] {
in_degree[v] -= 1;
if in_degree[v] == 0 {
queue.push_back(v);
}
}
}
if sorted.len() != count {
let cycle: Vec<&str> = self.plugins.iter()
.enumerate()
.filter(|(i, _)| in_degree[*i] > 0)
.map(|(_, p)| p.name())
.collect();
return Err(crate::error::Error::Config(
format!("插件循环依赖: {:?}", cycle)
));
}
Ok(sorted)
}
pub fn check_duplicate_names(&self) -> std::result::Result<(), String> {
let mut seen = HashSet::new();
for p in &self.plugins {
if !seen.insert(p.name()) {
return Err(format!("插件名重复: {}", p.name()));
}
}
Ok(())
}
}
impl Default for PluginManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use parking_lot::Mutex;
use std::sync::Arc;
struct TestPlugin {
name: &'static str,
deps: &'static [&'static str],
order: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl Plugin for TestPlugin {
fn name(&self) -> &str { self.name }
async fn start(&self) -> Result<()> {
self.order.lock().push(format!("start:{}", self.name));
Ok(())
}
async fn stop(&self) -> Result<()> {
self.order.lock().push(format!("stop:{}", self.name));
Ok(())
}
fn depends_on(&self) -> &[&str] { self.deps }
}
#[tokio::test]
async fn test_topological_start() {
let order = Arc::new(Mutex::new(Vec::new()));
let mgr = PluginManager::new()
.add(TestPlugin {
name: "c", deps: &["a", "b"],
order: order.clone(),
})
.add(TestPlugin {
name: "b", deps: &["a"],
order: order.clone(),
})
.add(TestPlugin {
name: "a", deps: &[],
order: order.clone(),
});
mgr.start_all().await.unwrap();
let log = order.lock();
assert_eq!(log[0], "start:a");
assert_eq!(log[1], "start:b");
assert_eq!(log[2], "start:c");
}
#[tokio::test]
async fn test_cycle_detection() {
let order = Arc::new(Mutex::new(Vec::new()));
let mgr = PluginManager::new()
.add(TestPlugin {
name: "x", deps: &["y"],
order: order.clone(),
})
.add(TestPlugin {
name: "y", deps: &["x"],
order: order.clone(),
});
let result = mgr.start_all().await;
assert!(result.is_err());
}
}