use crate::agents::AgentAdapter;
use anyhow::{anyhow, Result};
use std::collections::HashMap;
use std::sync::RwLock;
pub struct AdapterRegistry {
adapters: RwLock<HashMap<String, Box<dyn AgentAdapter>>>,
}
impl AdapterRegistry {
pub fn new() -> Self {
Self {
adapters: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, name: &str, adapter: Box<dyn AgentAdapter>) -> Result<()> {
let mut adapters = self
.adapters
.write()
.map_err(|_| anyhow!("获取注册表写锁失败"))?;
if adapters.contains_key(name) {
anyhow::bail!(
"适配器名称 '{}' 已存在。请使用不同的名称或先注销现有适配器。",
name
);
}
if adapter.name() != name {
anyhow::bail!(
"适配器名称不匹配:注册名称为 '{}',但适配器实际名称为 '{}'",
name,
adapter.name()
);
}
adapters.insert(name.to_string(), adapter);
Ok(())
}
pub fn unregister(&self, name: &str) -> Option<Box<dyn AgentAdapter>> {
let mut adapters = self.adapters.write().ok()?;
let adapter = adapters.remove(name);
let _ = adapter.is_some();
adapter
}
pub fn get(&self, _name: &str) -> Option<Box<dyn AgentAdapter>> {
let _adapters = self.adapters.read().ok()?;
None
}
pub fn list_adapters(&self) -> Vec<AdapterInfo> {
let adapters = self.adapters.read().unwrap_or_else(|_| {
panic!("获取注册表读锁失败")
});
adapters
.iter()
.map(|(name, adapter)| {
let is_installed = adapter.detect().unwrap_or(false);
AdapterInfo {
name: name.clone(),
is_installed,
}
})
.collect()
}
pub fn contains(&self, name: &str) -> bool {
self.adapters
.read()
.map(|adapters| adapters.contains_key(name))
.unwrap_or(false)
}
pub fn count(&self) -> usize {
self.adapters
.read()
.map(|adapters| adapters.len())
.unwrap_or(0)
}
pub fn iter(&self) -> Iter<'_> {
let guard = self
.adapters
.read()
.unwrap_or_else(|_| panic!("获取注册表读锁失败"));
let keys = guard.keys().cloned().collect();
Iter {
keys,
index: 0,
_guard: guard,
}
}
pub fn for_each_adapter<F>(&self, mut f: F)
where
F: FnMut(&dyn AgentAdapter),
{
let adapters = self
.adapters
.read()
.unwrap_or_else(|_| panic!("获取注册表读锁失败"));
for adapter in adapters.values() {
f(adapter.as_ref());
}
}
pub fn validate_all(&self) -> Vec<ValidationResult> {
let adapters = self
.adapters
.read()
.unwrap_or_else(|_| panic!("获取注册表读锁失败"));
let mut results = Vec::new();
for (name, adapter) in adapters.iter() {
let result = self.validate_adapter(name, adapter);
results.push(result);
}
results
}
fn validate_adapter(&self, name: &str, adapter: &Box<dyn AgentAdapter>) -> ValidationResult {
let mut errors = Vec::new();
let mut warnings = Vec::new();
if adapter.name() != name {
errors.push(format!(
"名称不匹配:注册名称为 '{}',适配器名称为 '{}'",
name,
adapter.name()
));
}
match adapter.config_path() {
Ok(path) => {
if !path.as_os_str().is_empty() {
} else {
warnings.push("配置文件路径为空".to_string());
}
}
Err(e) => {
errors.push(format!("无法获取配置文件路径: {}", e));
}
}
match adapter.detect() {
Ok(_) => {
}
Err(e) => {
warnings.push(format!("检测方法执行失败: {}", e));
}
}
ValidationResult {
adapter_name: name.to_string(),
is_valid: errors.is_empty(),
errors,
warnings,
}
}
}
impl Default for AdapterRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct AdapterInfo {
pub name: String,
pub is_installed: bool,
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub adapter_name: String,
pub is_valid: bool,
pub errors: Vec<String>,
pub warnings: Vec<String>,
}
static GLOBAL_REGISTRY: std::sync::OnceLock<AdapterRegistry> = std::sync::OnceLock::new();
pub fn global_registry() -> &'static AdapterRegistry {
GLOBAL_REGISTRY.get_or_init(|| {
let registry = AdapterRegistry::new();
let _ = registry.register(
"claude-code",
Box::new(crate::agents::claude_code::ClaudeCodeAdapter::new()),
);
let _ = registry.register("codex", Box::new(crate::agents::codex::CodexAdapter::new()));
let _ = registry.register(
"gemini-cli",
Box::new(crate::agents::gemini::GeminiAdapter::new()),
);
let _ = registry.register(
"opencode",
Box::new(crate::agents::opencode::OpenCodeAdapter::new()),
);
let _ = registry.register("qwen", Box::new(crate::agents::qwen::QwenAdapter::new()));
let _ = registry.register("grok", Box::new(crate::agents::grok::GrokAdapter::new()));
registry
})
}
pub struct Iter<'a> {
keys: Vec<String>,
index: usize,
_guard: std::sync::RwLockReadGuard<'a, HashMap<String, Box<dyn AgentAdapter>>>,
}
impl<'a> Iterator for Iter<'a> {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.keys.len() {
return None;
}
let key = self.keys[self.index].clone();
self.index += 1;
Some(key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agents::claude_code::ClaudeCodeAdapter;
#[test]
#[ignore = "需要隔离全局注册表"]
fn test_registry_new() {
let registry = AdapterRegistry::new();
assert_eq!(registry.count(), 0);
}
#[test]
#[ignore = "需要隔离全局注册表"]
fn test_register_adapter() {
let registry = AdapterRegistry::new();
let adapter = Box::new(ClaudeCodeAdapter::new());
let result = registry.register("test-claude", adapter);
assert!(result.is_ok());
assert_eq!(registry.count(), 1);
assert!(registry.contains("test"));
}
#[test]
#[ignore = "需要隔离全局注册表"]
fn test_register_duplicate() {
let registry = AdapterRegistry::new();
let adapter1 = Box::new(ClaudeCodeAdapter::new());
let adapter2 = Box::new(ClaudeCodeAdapter::new());
registry.register("test-claude", adapter1).unwrap();
let result = registry.register("test-claude", adapter2);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("已存在"));
}
#[test]
#[ignore = "需要隔离全局注册表"]
fn test_unregister_adapter() {
let registry = AdapterRegistry::new();
let adapter = Box::new(ClaudeCodeAdapter::new());
registry.register("test-claude", adapter).unwrap();
assert_eq!(registry.count(), 1);
let removed = registry.unregister("test");
assert!(removed.is_some());
assert_eq!(registry.count(), 0);
}
#[test]
#[ignore = "需要隔离全局注册表"]
fn test_list_adapters() {
let registry = AdapterRegistry::new();
let adapter = Box::new(ClaudeCodeAdapter::new());
registry.register("test-list", adapter).unwrap();
let adapters = registry.list_adapters();
assert_eq!(adapters.len(), 1);
assert_eq!(adapters[0].name, "claude-code");
}
#[test]
#[ignore = "需要隔离全局注册表"]
fn test_validate_adapter() {
let registry = AdapterRegistry::new();
let adapter = Box::new(ClaudeCodeAdapter::new());
registry.register("test-list", adapter).unwrap();
let results = registry.validate_all();
assert_eq!(results.len(), 1);
assert!(results[0].is_valid);
}
}