use crate::config::ProviderConfig;
use crate::error::LlmConnectorError;
use crate::protocols::{AliyunProtocol, AnthropicProtocol};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub trait ProtocolFactory: Send + Sync {
fn protocol_name(&self) -> &str;
fn supported_providers(&self) -> Vec<&str>;
fn create_adapter(
&self,
provider_name: &str,
config: &ProviderConfig,
) -> Result<Box<dyn std::any::Any + Send>, LlmConnectorError>;
fn supports_provider(&self, provider_name: &str) -> bool {
self.supported_providers().contains(&provider_name)
}
}
#[derive(Debug, Clone)]
pub struct OpenAIProtocolFactory;
impl ProtocolFactory for OpenAIProtocolFactory {
fn protocol_name(&self) -> &str {
"openai"
}
fn supported_providers(&self) -> Vec<&str> {
vec![
"deepseek",
"zhipu",
"moonshot",
"volcengine",
"tencent",
"minimax",
"stepfun",
"longcat",
]
}
fn create_adapter(
&self,
provider_name: &str,
_config: &ProviderConfig,
) -> Result<Box<dyn std::any::Any + Send>, LlmConnectorError> {
let adapter = match provider_name {
"deepseek" => crate::protocols::openai::deepseek(),
"zhipu" => crate::protocols::openai::zhipu(),
"moonshot" => crate::protocols::openai::moonshot(),
"volcengine" => crate::protocols::openai::volcengine(),
"tencent" => crate::protocols::openai::tencent(),
"minimax" => crate::protocols::openai::minimax(),
"stepfun" => crate::protocols::openai::stepfun(),
"longcat" => crate::protocols::openai::longcat(),
_ => {
return Err(LlmConnectorError::UnsupportedModel(format!(
"Unknown OpenAI-compatible provider: {}",
provider_name
)))
}
};
Ok(Box::new(adapter))
}
}
#[derive(Debug, Clone)]
pub struct AnthropicProtocolFactory;
impl ProtocolFactory for AnthropicProtocolFactory {
fn protocol_name(&self) -> &str {
"anthropic"
}
fn supported_providers(&self) -> Vec<&str> {
vec!["anthropic", "claude"]
}
fn create_adapter(
&self,
_provider_name: &str,
config: &ProviderConfig,
) -> Result<Box<dyn std::any::Any + Send>, LlmConnectorError> {
let adapter = AnthropicProtocol::new(config.base_url.as_deref());
Ok(Box::new(adapter))
}
}
#[derive(Debug, Clone)]
pub struct AliyunProtocolFactory;
impl ProtocolFactory for AliyunProtocolFactory {
fn protocol_name(&self) -> &str {
"aliyun"
}
fn supported_providers(&self) -> Vec<&str> {
vec!["aliyun", "dashscope", "qwen"]
}
fn create_adapter(
&self,
_provider_name: &str,
config: &ProviderConfig,
) -> Result<Box<dyn std::any::Any + Send>, LlmConnectorError> {
let adapter = AliyunProtocol::new(config.base_url.as_deref());
Ok(Box::new(adapter))
}
}
#[derive(Clone)]
pub struct ProtocolFactoryRegistry {
factories: Arc<RwLock<HashMap<String, Arc<dyn ProtocolFactory>>>>,
provider_to_protocol: Arc<RwLock<HashMap<String, String>>>,
}
impl ProtocolFactoryRegistry {
pub fn new() -> Self {
Self {
factories: Arc::new(RwLock::new(HashMap::new())),
provider_to_protocol: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn with_defaults() -> Self {
let registry = Self::new();
registry.register_default_factories();
registry
}
pub fn register_default_factories(&self) {
self.register(Arc::new(OpenAIProtocolFactory));
self.register(Arc::new(AnthropicProtocolFactory));
self.register(Arc::new(AliyunProtocolFactory));
}
pub fn register(&self, factory: Arc<dyn ProtocolFactory>) {
let protocol_name = factory.protocol_name().to_string();
self.factories
.write()
.unwrap()
.insert(protocol_name.clone(), factory.clone());
let mut provider_map = self.provider_to_protocol.write().unwrap();
for provider in factory.supported_providers() {
provider_map.insert(provider.to_string(), protocol_name.clone());
}
}
pub fn get_factory(&self, protocol_name: &str) -> Option<Arc<dyn ProtocolFactory>> {
self.factories.read().unwrap().get(protocol_name).cloned()
}
pub fn get_protocol_for_provider(&self, provider_name: &str) -> Option<String> {
self.provider_to_protocol
.read()
.unwrap()
.get(provider_name)
.cloned()
}
pub fn create_for_provider(
&self,
provider_name: &str,
config: &ProviderConfig,
) -> Result<Box<dyn std::any::Any + Send>, LlmConnectorError> {
let protocol_name = self
.get_protocol_for_provider(provider_name)
.ok_or_else(|| {
LlmConnectorError::UnsupportedModel(format!("Unknown provider: {}", provider_name))
})?;
let factory = self.get_factory(&protocol_name).ok_or_else(|| {
LlmConnectorError::ProviderError(format!(
"No factory registered for protocol: {}",
protocol_name
))
})?;
factory.create_adapter(provider_name, config)
}
pub fn list_protocols(&self) -> Vec<String> {
self.factories.read().unwrap().keys().cloned().collect()
}
pub fn list_providers(&self) -> Vec<String> {
self.provider_to_protocol
.read()
.unwrap()
.keys()
.cloned()
.collect()
}
pub fn get_providers_for_protocol(&self, protocol_name: &str) -> Vec<String> {
self.provider_to_protocol
.read()
.unwrap()
.iter()
.filter(|(_, proto)| proto.as_str() == protocol_name)
.map(|(provider, _)| provider.clone())
.collect()
}
}
impl Default for ProtocolFactoryRegistry {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_openai_factory() {
let factory = OpenAIProtocolFactory;
assert_eq!(factory.protocol_name(), "openai");
assert!(factory.supports_provider("deepseek"));
assert!(factory.supports_provider("zhipu"));
assert!(!factory.supports_provider("claude"));
}
#[test]
fn test_anthropic_factory() {
let factory = AnthropicProtocolFactory;
assert_eq!(factory.protocol_name(), "anthropic");
assert!(factory.supports_provider("anthropic"));
assert!(factory.supports_provider("claude"));
assert!(!factory.supports_provider("deepseek"));
}
#[test]
fn test_registry() {
let registry = ProtocolFactoryRegistry::with_defaults();
assert!(registry.get_factory("openai").is_some());
assert!(registry.get_factory("anthropic").is_some());
assert!(registry.get_factory("aliyun").is_some());
assert_eq!(
registry.get_protocol_for_provider("deepseek"),
Some("openai".to_string())
);
assert_eq!(
registry.get_protocol_for_provider("claude"),
Some("anthropic".to_string())
);
assert_eq!(
registry.get_protocol_for_provider("qwen"),
Some("aliyun".to_string())
);
}
#[test]
fn test_list_protocols() {
let registry = ProtocolFactoryRegistry::with_defaults();
let protocols = registry.list_protocols();
assert!(protocols.contains(&"openai".to_string()));
assert!(protocols.contains(&"anthropic".to_string()));
assert!(protocols.contains(&"aliyun".to_string()));
}
#[test]
fn test_list_providers() {
let registry = ProtocolFactoryRegistry::with_defaults();
let providers = registry.list_providers();
assert!(providers.contains(&"deepseek".to_string()));
assert!(providers.contains(&"claude".to_string()));
assert!(providers.contains(&"qwen".to_string()));
}
}