use super::adapter::ProviderAdapter;
use super::types::{FieldAccessor, FieldMappings, JsonFieldAccessor};
use crate::error::LlmError;
use crate::traits::ProviderCapabilities;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub id: String,
pub name: String,
pub base_url: String,
pub field_mappings: ProviderFieldMappings,
pub capabilities: Vec<String>,
pub default_model: Option<String>,
pub supports_reasoning: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderFieldMappings {
pub thinking_fields: Vec<String>,
pub content_field: String,
pub tool_calls_field: String,
pub role_field: String,
}
impl Default for ProviderFieldMappings {
fn default() -> Self {
Self {
thinking_fields: vec!["thinking".to_string()],
content_field: "content".to_string(),
tool_calls_field: "tool_calls".to_string(),
role_field: "role".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct ConfigurableAdapter {
config: ProviderConfig,
}
impl ConfigurableAdapter {
pub fn new(config: ProviderConfig) -> Self {
Self { config }
}
}
impl ProviderAdapter for ConfigurableAdapter {
fn provider_id(&self) -> &'static str {
Box::leak(self.config.id.clone().into_boxed_str())
}
fn transform_request_params(
&self,
_params: &mut serde_json::Value,
_model: &str,
_request_type: super::types::RequestType,
) -> Result<(), LlmError> {
Ok(())
}
fn get_field_mappings(&self, _model: &str) -> FieldMappings {
let config_mappings = &self.config.field_mappings;
FieldMappings {
thinking_fields: config_mappings
.thinking_fields
.iter()
.map(|s| Box::leak(s.clone().into_boxed_str()) as &'static str)
.collect(),
content_field: Box::leak(config_mappings.content_field.clone().into_boxed_str()),
tool_calls_field: Box::leak(config_mappings.tool_calls_field.clone().into_boxed_str()),
role_field: Box::leak(config_mappings.role_field.clone().into_boxed_str()),
}
}
fn get_model_config(&self, _model: &str) -> super::types::ModelConfig {
super::types::ModelConfig {
supports_thinking: self.config.supports_reasoning,
..Default::default()
}
}
fn get_field_accessor(&self) -> Box<dyn FieldAccessor> {
Box::new(JsonFieldAccessor)
}
fn capabilities(&self) -> ProviderCapabilities {
let mut caps = ProviderCapabilities::new().with_chat().with_streaming();
if self.config.capabilities.contains(&"tools".to_string()) {
caps = caps.with_tools();
}
if self.config.capabilities.contains(&"vision".to_string()) {
caps = caps.with_vision();
}
if self.config.capabilities.contains(&"embedding".to_string()) {
caps = caps.with_embedding();
}
if self.config.supports_reasoning {
caps = caps.with_custom_feature("reasoning", true);
}
caps
}
fn base_url(&self) -> &str {
&self.config.base_url
}
fn clone_adapter(&self) -> Box<dyn ProviderAdapter> {
Box::new(self.clone())
}
fn supports_image_generation(&self) -> bool {
self.config
.capabilities
.contains(&"image_generation".to_string())
}
fn transform_image_request(
&self,
_request: &mut crate::types::ImageGenerationRequest,
) -> Result<(), LlmError> {
Ok(())
}
fn get_supported_image_sizes(&self) -> Vec<String> {
vec![
"256x256".to_string(),
"512x512".to_string(),
"1024x1024".to_string(),
"1024x1792".to_string(),
"1792x1024".to_string(),
]
}
fn get_supported_image_formats(&self) -> Vec<String> {
vec!["url".to_string(), "b64_json".to_string()]
}
fn supports_image_editing(&self) -> bool {
self.supports_image_generation()
}
fn supports_image_variations(&self) -> bool {
self.supports_image_generation()
}
}
pub struct ProviderRegistry {
providers: HashMap<String, ProviderConfig>,
}
impl ProviderRegistry {
pub fn new() -> Self {
let mut registry = Self {
providers: HashMap::new(),
};
registry.register_builtin_providers();
registry
}
fn register_builtin_providers(&mut self) {
let builtin_providers =
crate::providers::openai_compatible::config::get_builtin_providers();
for (id, config) in builtin_providers {
self.providers.insert(id, config);
}
}
pub fn get_provider(&self, id: &str) -> Option<&ProviderConfig> {
self.providers.get(id)
}
pub fn create_adapter(&self, provider_id: &str) -> Result<Arc<dyn ProviderAdapter>, LlmError> {
let config = self.get_provider(provider_id).ok_or_else(|| {
LlmError::ConfigurationError(format!("Unknown provider: {}", provider_id))
})?;
Ok(Arc::new(ConfigurableAdapter::new(config.clone())))
}
pub fn register_provider(&mut self, config: ProviderConfig) {
self.providers.insert(config.id.clone(), config);
}
pub fn list_providers(&self) -> Vec<&str> {
self.providers.keys().map(|s| s.as_str()).collect()
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
lazy_static::lazy_static! {
pub static ref PROVIDER_REGISTRY: std::sync::Mutex<ProviderRegistry> =
std::sync::Mutex::new(ProviderRegistry::new());
}
pub fn get_provider_adapter(provider_id: &str) -> Result<Arc<dyn ProviderAdapter>, LlmError> {
PROVIDER_REGISTRY
.lock()
.map_err(|_| LlmError::ConfigurationError("Failed to lock provider registry".to_string()))?
.create_adapter(provider_id)
}