use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use anyhow::Result;
use super::dependency::DependencyManager;
use super::plugin::Plugin;
#[derive(Debug, Clone)]
pub struct PluginManager {
plugins: Arc<HashMap<String, Arc<Plugin>>>,
sorted_plugins: Arc<Vec<Arc<Plugin>>>,
initialized: Arc<AtomicBool>,
}
pub struct PluginManagerBuilder {
plugins: HashMap<String, Arc<Plugin>>,
dependency_manager: DependencyManager,
}
impl PluginManagerBuilder {
pub fn new() -> Self {
Self {
plugins: HashMap::new(),
dependency_manager: DependencyManager::new(),
}
}
#[cfg_attr(feature = "dev-tracing", tracing::instrument(skip(self, plugin), fields(
crate_name = "state",
plugin_name = %plugin.spec.tr.metadata().name
)))]
pub fn register_plugin(
&mut self,
plugin: Arc<Plugin>,
) -> Result<()> {
let plugin_name = plugin.spec.tr.metadata().name.clone();
if self.plugins.contains_key(&plugin_name) {
return Err(anyhow::anyhow!("插件 '{}' 已存在", plugin_name));
}
let metadata = plugin.spec.tr.metadata();
self.dependency_manager.add_plugin(&metadata.name);
for dep in &metadata.dependencies {
self.dependency_manager.add_dependency(&metadata.name, dep)?;
}
self.plugins.insert(plugin_name.clone(), plugin);
tracing::debug!("插件 '{}' 注册成功", plugin_name);
Ok(())
}
#[cfg_attr(feature = "dev-tracing", tracing::instrument(skip(self), fields(
crate_name = "state",
plugin_count = self.plugins.len()
)))]
pub fn build(self) -> Result<PluginManager> {
if self.dependency_manager.has_circular_dependencies() {
let report =
self.dependency_manager.get_circular_dependency_report();
return Err(anyhow::anyhow!(
"检测到循环依赖: {}",
report.to_string()
));
}
let missing_report =
self.dependency_manager.check_missing_dependencies();
if missing_report.has_missing_dependencies {
return Err(anyhow::anyhow!(
"检测到缺失依赖: {}",
missing_report.to_string()
));
}
let available_plugins: HashSet<String> =
self.plugins.keys().cloned().collect();
for (name, plugin) in &self.plugins {
let metadata = plugin.spec.tr.metadata();
for conflict in &metadata.conflicts {
if available_plugins.contains(conflict) {
return Err(anyhow::anyhow!(
"插件 '{}' 与插件 '{}' 冲突",
name,
conflict
));
}
}
}
let plugin_order = self.dependency_manager.get_topological_order()?;
let sorted_plugins: Vec<Arc<Plugin>> = plugin_order
.iter()
.filter_map(|name| self.plugins.get(name).cloned())
.collect();
tracing::info!(
"插件管理器构建完成,共注册 {} 个插件",
self.plugins.len()
);
Ok(PluginManager {
plugins: Arc::new(self.plugins),
sorted_plugins: Arc::new(sorted_plugins),
initialized: Arc::new(AtomicBool::new(true)),
})
}
}
impl Default for PluginManagerBuilder {
fn default() -> Self {
Self::new()
}
}
impl PluginManager {
pub fn new() -> Self {
Self {
plugins: Arc::new(HashMap::new()),
sorted_plugins: Arc::new(Vec::new()),
initialized: Arc::new(AtomicBool::new(true)),
}
}
#[inline]
pub async fn get_sorted_plugins(&self) -> Vec<Arc<Plugin>> {
self.sorted_plugins.as_ref().clone()
}
#[inline]
pub fn get_sorted_plugins_sync(&self) -> &[Arc<Plugin>] {
self.sorted_plugins.as_ref()
}
#[inline]
pub async fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Acquire)
}
#[inline]
pub fn is_initialized_sync(&self) -> bool {
self.initialized.load(Ordering::Acquire)
}
#[inline]
pub fn plugin_count(&self) -> usize {
self.plugins.len()
}
#[inline]
pub fn get_plugin(
&self,
name: &str,
) -> Option<&Arc<Plugin>> {
self.plugins.get(name)
}
#[inline]
pub fn has_plugin(
&self,
name: &str,
) -> bool {
self.plugins.contains_key(name)
}
}
impl Default for PluginManager {
fn default() -> Self {
Self::new()
}
}