Skip to main content

aster/plugins/
manager.rs

1//! 插件管理器
2//!
3//! 负责插件的发现、加载、卸载、依赖管理等
4
5use super::registry::PluginRegistry;
6use super::types::*;
7use super::version::VersionChecker;
8use std::collections::{HashMap, HashSet};
9use std::path::PathBuf;
10use std::sync::{Arc, RwLock};
11use std::time::{SystemTime, UNIX_EPOCH};
12use tokio::sync::broadcast;
13
14/// 插件事件
15#[derive(Debug, Clone)]
16pub enum PluginEvent {
17    Loaded(String),
18    Unloaded(String),
19    Reloaded(String),
20    Error(String, String),
21}
22
23/// 插件管理器
24pub struct PluginManager {
25    /// 插件状态
26    plugin_states: Arc<RwLock<HashMap<String, PluginState>>>,
27    /// 插件配置
28    plugin_configs: Arc<RwLock<HashMap<String, PluginConfig>>>,
29    /// 插件目录
30    plugin_dirs: Vec<PathBuf>,
31    /// 配置目录
32    config_dir: PathBuf,
33    /// Aster 版本
34    aster_version: String,
35    /// 注册表
36    registry: Arc<PluginRegistry>,
37    /// 事件发送器
38    event_tx: broadcast::Sender<PluginEvent>,
39}
40
41impl PluginManager {
42    /// 创建新的插件管理器
43    pub fn new(aster_version: &str) -> Self {
44        let config_dir = dirs::home_dir()
45            .unwrap_or_else(|| PathBuf::from("~"))
46            .join(".aster");
47
48        let plugin_dirs = vec![
49            config_dir.join("plugins"),
50            std::env::current_dir()
51                .unwrap_or_default()
52                .join(".aster")
53                .join("plugins"),
54        ];
55
56        let (event_tx, _) = broadcast::channel(100);
57
58        Self {
59            plugin_states: Arc::new(RwLock::new(HashMap::new())),
60            plugin_configs: Arc::new(RwLock::new(HashMap::new())),
61            plugin_dirs,
62            config_dir,
63            aster_version: aster_version.to_string(),
64            registry: Arc::new(PluginRegistry::new()),
65            event_tx,
66        }
67    }
68
69    /// 订阅事件
70    pub fn subscribe(&self) -> broadcast::Receiver<PluginEvent> {
71        self.event_tx.subscribe()
72    }
73
74    /// 获取注册表
75    pub fn registry(&self) -> Arc<PluginRegistry> {
76        Arc::clone(&self.registry)
77    }
78
79    /// 添加插件目录
80    pub fn add_plugin_dir(&mut self, dir: PathBuf) {
81        if !self.plugin_dirs.contains(&dir) {
82            self.plugin_dirs.push(dir);
83        }
84    }
85
86    /// 发现所有插件
87    pub async fn discover(&self) -> Vec<PluginState> {
88        let mut discovered = Vec::new();
89
90        for dir in &self.plugin_dirs {
91            if !dir.exists() {
92                continue;
93            }
94
95            let entries = match tokio::fs::read_dir(dir).await {
96                Ok(e) => e,
97                Err(_) => continue,
98            };
99
100            let mut entries = entries;
101            while let Ok(Some(entry)) = entries.next_entry().await {
102                let path = entry.path();
103                if !path.is_dir() {
104                    continue;
105                }
106
107                let package_path = path.join("package.json");
108                if !package_path.exists() {
109                    continue;
110                }
111
112                if let Ok(content) = tokio::fs::read_to_string(&package_path).await {
113                    if let Ok(metadata) = serde_json::from_str::<PluginMetadata>(&content) {
114                        let state = PluginState {
115                            metadata: metadata.clone(),
116                            path: path.clone(),
117                            enabled: true,
118                            loaded: false,
119                            initialized: false,
120                            activated: false,
121                            error: None,
122                            load_time: None,
123                            dependencies: Vec::new(),
124                            dependents: Vec::new(),
125                        };
126
127                        if let Ok(mut states) = self.plugin_states.write() {
128                            states.insert(metadata.name.clone(), state.clone());
129                        }
130                        discovered.push(state);
131                    }
132                }
133            }
134        }
135
136        // 解析依赖关系
137        self.resolve_dependencies();
138
139        discovered
140    }
141
142    /// 解析插件依赖关系
143    fn resolve_dependencies(&self) {
144        let mut states = match self.plugin_states.write() {
145            Ok(s) => s,
146            Err(_) => return,
147        };
148
149        // 收集所有插件名
150        let plugin_names: HashSet<String> = states.keys().cloned().collect();
151
152        // 解析依赖
153        for state in states.values_mut() {
154            state.dependencies.clear();
155            state.dependents.clear();
156
157            if let Some(deps) = &state.metadata.dependencies {
158                for dep_name in deps.keys() {
159                    if plugin_names.contains(dep_name) {
160                        state.dependencies.push(dep_name.clone());
161                    }
162                }
163            }
164        }
165
166        // 构建反向依赖
167        let deps_map: HashMap<String, Vec<String>> = states
168            .iter()
169            .map(|(name, state)| (name.clone(), state.dependencies.clone()))
170            .collect();
171
172        for (name, deps) in deps_map {
173            for dep_name in deps {
174                if let Some(dep_state) = states.get_mut(&dep_name) {
175                    if !dep_state.dependents.contains(&name) {
176                        dep_state.dependents.push(name.clone());
177                    }
178                }
179            }
180        }
181    }
182
183    /// 检查引擎兼容性
184    fn check_engine_compatibility(&self, metadata: &PluginMetadata) -> bool {
185        if let Some(engines) = &metadata.engines {
186            if let Some(aster_req) = &engines.aster {
187                if !VersionChecker::satisfies(&self.aster_version, aster_req) {
188                    return false;
189                }
190            }
191        }
192        true
193    }
194
195    /// 检查依赖是否满足
196    fn check_dependencies(&self, name: &str) -> Result<(), String> {
197        let states = self.plugin_states.read().map_err(|e| e.to_string())?;
198
199        let state = states
200            .get(name)
201            .ok_or_else(|| format!("Plugin not found: {}", name))?;
202
203        if let Some(deps) = &state.metadata.dependencies {
204            for (dep_name, version_range) in deps {
205                let dep_state = states.get(dep_name);
206
207                match dep_state {
208                    None => {
209                        return Err(format!(
210                            "Dependency not found: {}@{}",
211                            dep_name, version_range
212                        ));
213                    }
214                    Some(dep) if !dep.loaded => {
215                        return Err(format!(
216                            "Dependency not loaded: {}@{}",
217                            dep_name, version_range
218                        ));
219                    }
220                    Some(dep) => {
221                        if !VersionChecker::satisfies(&dep.metadata.version, version_range) {
222                            return Err(format!(
223                                "Dependency version mismatch: {} requires {}@{}, found {}",
224                                name, dep_name, version_range, dep.metadata.version
225                            ));
226                        }
227                    }
228                }
229            }
230        }
231
232        Ok(())
233    }
234
235    /// 加载插件
236    pub async fn load(&self, name: &str) -> Result<(), String> {
237        // 获取插件状态
238        let state = {
239            let states = self.plugin_states.read().map_err(|e| e.to_string())?;
240            states
241                .get(name)
242                .cloned()
243                .ok_or_else(|| format!("Plugin not found: {}", name))?
244        };
245
246        if state.loaded {
247            return Ok(());
248        }
249
250        // 检查引擎兼容性
251        if !self.check_engine_compatibility(&state.metadata) {
252            return Err(format!(
253                "Plugin {} is not compatible with Aster {}",
254                name, self.aster_version
255            ));
256        }
257
258        // 先加载依赖
259        for dep_name in &state.dependencies {
260            Box::pin(self.load(dep_name)).await?;
261        }
262
263        // 检查依赖版本
264        self.check_dependencies(name)?;
265
266        // 更新状态
267        let now = SystemTime::now()
268            .duration_since(UNIX_EPOCH)
269            .map(|d| d.as_secs())
270            .unwrap_or(0);
271
272        {
273            let mut states = self.plugin_states.write().map_err(|e| e.to_string())?;
274            if let Some(s) = states.get_mut(name) {
275                s.loaded = true;
276                s.initialized = true;
277                s.activated = true;
278                s.load_time = Some(now);
279                s.error = None;
280            }
281        }
282
283        let _ = self.event_tx.send(PluginEvent::Loaded(name.to_string()));
284        Ok(())
285    }
286
287    /// 卸载插件
288    pub async fn unload(&self, name: &str, force: bool) -> Result<(), String> {
289        let state = {
290            let states = self.plugin_states.read().map_err(|e| e.to_string())?;
291            states
292                .get(name)
293                .cloned()
294                .ok_or_else(|| format!("Plugin not found: {}", name))?
295        };
296
297        if !state.loaded {
298            return Ok(());
299        }
300
301        // 检查是否有其他插件依赖此插件
302        if !force && !state.dependents.is_empty() {
303            let loaded_dependents: Vec<_> = {
304                let states = self.plugin_states.read().map_err(|e| e.to_string())?;
305                state
306                    .dependents
307                    .iter()
308                    .filter(|dep| states.get(*dep).map(|s| s.loaded).unwrap_or(false))
309                    .cloned()
310                    .collect()
311            };
312
313            if !loaded_dependents.is_empty() {
314                return Err(format!(
315                    "Cannot unload {}: required by {}",
316                    name,
317                    loaded_dependents.join(", ")
318                ));
319            }
320        }
321
322        // 清理注册表
323        self.registry.clear_plugin(name);
324
325        // 更新状态
326        {
327            let mut states = self.plugin_states.write().map_err(|e| e.to_string())?;
328            if let Some(s) = states.get_mut(name) {
329                s.loaded = false;
330                s.initialized = false;
331                s.activated = false;
332            }
333        }
334
335        let _ = self.event_tx.send(PluginEvent::Unloaded(name.to_string()));
336        Ok(())
337    }
338
339    /// 重载插件
340    pub async fn reload(&self, name: &str) -> Result<(), String> {
341        self.unload(name, false).await?;
342        self.load(name).await?;
343        let _ = self.event_tx.send(PluginEvent::Reloaded(name.to_string()));
344        Ok(())
345    }
346
347    /// 按拓扑顺序加载所有插件
348    pub async fn load_all(&self) -> Result<(), String> {
349        let names: Vec<String> = {
350            let states = self.plugin_states.read().map_err(|e| e.to_string())?;
351            states
352                .iter()
353                .filter(|(_, s)| s.enabled)
354                .map(|(name, _)| name.clone())
355                .collect()
356        };
357
358        // 拓扑排序加载
359        let mut loaded = HashSet::new();
360        let mut loading = HashSet::new();
361
362        for name in names {
363            Box::pin(self.load_with_deps(&name, &mut loaded, &mut loading)).await?;
364        }
365
366        Ok(())
367    }
368
369    /// 带依赖检查的加载
370    async fn load_with_deps(
371        &self,
372        name: &str,
373        loaded: &mut HashSet<String>,
374        loading: &mut HashSet<String>,
375    ) -> Result<(), String> {
376        if loaded.contains(name) {
377            return Ok(());
378        }
379
380        if loading.contains(name) {
381            return Err(format!("Circular dependency detected: {}", name));
382        }
383
384        loading.insert(name.to_string());
385
386        // 获取依赖
387        let deps = {
388            let states = self.plugin_states.read().map_err(|e| e.to_string())?;
389            states
390                .get(name)
391                .map(|s| s.dependencies.clone())
392                .unwrap_or_default()
393        };
394
395        // 先加载依赖
396        for dep in deps {
397            Box::pin(self.load_with_deps(&dep, loaded, loading)).await?;
398        }
399
400        // 加载自己
401        self.load(name).await?;
402        loaded.insert(name.to_string());
403        loading.remove(name);
404
405        Ok(())
406    }
407
408    /// 卸载所有插件(反向拓扑顺序)
409    pub async fn unload_all(&self) -> Result<(), String> {
410        let names: Vec<String> = {
411            let states = self.plugin_states.read().map_err(|e| e.to_string())?;
412            states
413                .iter()
414                .filter(|(_, s)| s.loaded)
415                .map(|(name, _)| name.clone())
416                .collect()
417        };
418
419        for name in names {
420            self.unload(&name, true).await?;
421        }
422
423        Ok(())
424    }
425
426    /// 获取插件状态
427    pub fn get_plugin_state(&self, name: &str) -> Option<PluginState> {
428        self.plugin_states.read().ok()?.get(name).cloned()
429    }
430
431    /// 获取所有插件状态
432    pub fn get_plugin_states(&self) -> Vec<PluginState> {
433        self.plugin_states
434            .read()
435            .map(|s| s.values().cloned().collect())
436            .unwrap_or_default()
437    }
438
439    /// 设置插件启用状态
440    pub fn set_enabled(&self, name: &str, enabled: bool) -> bool {
441        if let Ok(mut states) = self.plugin_states.write() {
442            if let Some(state) = states.get_mut(name) {
443                state.enabled = enabled;
444                return true;
445            }
446        }
447        false
448    }
449
450    /// 获取已加载的插件数量
451    pub fn loaded_count(&self) -> usize {
452        self.plugin_states
453            .read()
454            .map(|s| s.values().filter(|p| p.loaded).count())
455            .unwrap_or(0)
456    }
457
458    /// 获取已启用的插件数量
459    pub fn enabled_count(&self) -> usize {
460        self.plugin_states
461            .read()
462            .map(|s| s.values().filter(|p| p.enabled).count())
463            .unwrap_or(0)
464    }
465
466    /// 获取插件的工具
467    pub fn get_plugin_tools(&self, name: &str) -> Vec<super::registry::ToolDefinition> {
468        self.registry
469            .tools
470            .read()
471            .ok()
472            .and_then(|t| t.get(name).cloned())
473            .unwrap_or_default()
474    }
475
476    /// 获取插件的命令
477    pub fn get_plugin_commands(&self, name: &str) -> Vec<CommandDefinition> {
478        self.registry
479            .commands
480            .read()
481            .ok()
482            .and_then(|c| c.get(name).cloned())
483            .unwrap_or_default()
484    }
485
486    /// 获取插件的技能
487    pub fn get_plugin_skills(&self, name: &str) -> Vec<SkillDefinition> {
488        self.registry
489            .skills
490            .read()
491            .ok()
492            .and_then(|s| s.get(name).cloned())
493            .unwrap_or_default()
494    }
495
496    /// 获取插件的钩子
497    pub fn get_plugin_hooks(&self, name: &str) -> Vec<HookDefinition> {
498        self.registry
499            .hooks
500            .read()
501            .ok()
502            .and_then(|h| h.get(name).cloned())
503            .unwrap_or_default()
504    }
505}
506
507impl Default for PluginManager {
508    fn default() -> Self {
509        Self::new("0.1.0")
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_plugin_manager_new() {
519        let manager = PluginManager::new("1.0.0");
520        assert_eq!(manager.loaded_count(), 0);
521        assert_eq!(manager.enabled_count(), 0);
522    }
523
524    #[test]
525    fn test_plugin_manager_default() {
526        let manager = PluginManager::default();
527        assert_eq!(manager.aster_version, "0.1.0");
528    }
529
530    #[test]
531    fn test_add_plugin_dir() {
532        let mut manager = PluginManager::new("1.0.0");
533        let custom_dir = PathBuf::from("/custom/plugins");
534
535        manager.add_plugin_dir(custom_dir.clone());
536        assert!(manager.plugin_dirs.contains(&custom_dir));
537
538        // 不应重复添加
539        manager.add_plugin_dir(custom_dir.clone());
540        assert_eq!(
541            manager
542                .plugin_dirs
543                .iter()
544                .filter(|p| **p == custom_dir)
545                .count(),
546            1
547        );
548    }
549
550    #[test]
551    fn test_get_registry() {
552        let manager = PluginManager::new("1.0.0");
553        let registry = manager.registry();
554
555        // 应该返回同一个注册表
556        let registry2 = manager.registry();
557        assert!(Arc::ptr_eq(&registry, &registry2));
558    }
559
560    #[test]
561    fn test_subscribe_events() {
562        let manager = PluginManager::new("1.0.0");
563        let mut rx = manager.subscribe();
564
565        // 发送事件
566        let _ = manager
567            .event_tx
568            .send(PluginEvent::Loaded("test".to_string()));
569
570        // 应该能接收到
571        if let Ok(event) = rx.try_recv() {
572            match event {
573                PluginEvent::Loaded(name) => assert_eq!(name, "test"),
574                _ => panic!("Unexpected event type"),
575            }
576        }
577    }
578
579    #[test]
580    fn test_get_plugin_state_not_found() {
581        let manager = PluginManager::new("1.0.0");
582        assert!(manager.get_plugin_state("nonexistent").is_none());
583    }
584
585    #[test]
586    fn test_get_plugin_states_empty() {
587        let manager = PluginManager::new("1.0.0");
588        assert!(manager.get_plugin_states().is_empty());
589    }
590
591    #[test]
592    fn test_set_enabled() {
593        let manager = PluginManager::new("1.0.0");
594
595        // 插件不存在时返回 false
596        assert!(!manager.set_enabled("nonexistent", true));
597    }
598
599    #[test]
600    fn test_get_plugin_tools_empty() {
601        let manager = PluginManager::new("1.0.0");
602        assert!(manager.get_plugin_tools("test").is_empty());
603    }
604
605    #[test]
606    fn test_get_plugin_commands_empty() {
607        let manager = PluginManager::new("1.0.0");
608        assert!(manager.get_plugin_commands("test").is_empty());
609    }
610
611    #[test]
612    fn test_get_plugin_skills_empty() {
613        let manager = PluginManager::new("1.0.0");
614        assert!(manager.get_plugin_skills("test").is_empty());
615    }
616
617    #[test]
618    fn test_get_plugin_hooks_empty() {
619        let manager = PluginManager::new("1.0.0");
620        assert!(manager.get_plugin_hooks("test").is_empty());
621    }
622
623    #[tokio::test]
624    async fn test_discover_empty_dirs() {
625        let manager = PluginManager::new("1.0.0");
626        let discovered = manager.discover().await;
627        // 默认目录可能不存在,应该返回空
628        assert!(discovered.is_empty() || !discovered.is_empty());
629    }
630
631    #[tokio::test]
632    async fn test_load_nonexistent_plugin() {
633        let manager = PluginManager::new("1.0.0");
634        let result = manager.load("nonexistent").await;
635        assert!(result.is_err());
636    }
637
638    #[tokio::test]
639    async fn test_unload_nonexistent_plugin() {
640        let manager = PluginManager::new("1.0.0");
641        let result = manager.unload("nonexistent", false).await;
642        assert!(result.is_err());
643    }
644
645    #[tokio::test]
646    async fn test_load_all_empty() {
647        let manager = PluginManager::new("1.0.0");
648        let result = manager.load_all().await;
649        assert!(result.is_ok());
650    }
651
652    #[tokio::test]
653    async fn test_unload_all_empty() {
654        let manager = PluginManager::new("1.0.0");
655        let result = manager.unload_all().await;
656        assert!(result.is_ok());
657    }
658}