use super::types::{FieldAccessor, FieldMappings, JsonFieldAccessor, ModelConfig, RequestType};
use crate::error::LlmError;
use crate::traits::ProviderCapabilities;
use crate::types::HttpConfig;
use std::collections::HashMap;
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct ProviderCompatibility {
pub supports_array_content: bool,
pub supports_stream_options: bool,
pub supports_developer_role: bool,
pub supports_enable_thinking: bool,
pub supports_service_tier: bool,
pub force_streaming_models: Vec<String>,
pub custom_flags: HashMap<String, bool>,
}
impl ProviderCompatibility {
pub fn openai_standard() -> Self {
Self {
supports_array_content: true,
supports_stream_options: true,
supports_developer_role: true,
supports_enable_thinking: true,
supports_service_tier: true,
force_streaming_models: vec![],
custom_flags: HashMap::new(),
}
}
pub fn deepseek() -> Self {
Self {
supports_array_content: false, supports_stream_options: true,
supports_developer_role: true,
supports_enable_thinking: false, supports_service_tier: false,
force_streaming_models: vec!["deepseek-reasoner".to_string()],
custom_flags: HashMap::new(),
}
}
pub fn limited_compatibility() -> Self {
Self {
supports_array_content: false,
supports_stream_options: false,
supports_developer_role: false,
supports_enable_thinking: false,
supports_service_tier: false,
force_streaming_models: vec![],
custom_flags: HashMap::new(),
}
}
}
pub trait ProviderAdapter: Send + Sync + std::fmt::Debug {
fn provider_id(&self) -> &'static str;
fn transform_request_params(
&self,
params: &mut serde_json::Value,
model: &str,
request_type: RequestType,
) -> Result<(), LlmError>;
fn get_field_mappings(&self, model: &str) -> FieldMappings;
fn get_field_accessor(&self) -> Box<dyn FieldAccessor> {
Box::new(JsonFieldAccessor)
}
fn get_model_config(&self, model: &str) -> ModelConfig;
fn custom_headers(&self) -> reqwest::header::HeaderMap {
reqwest::header::HeaderMap::new()
}
fn capabilities(&self) -> ProviderCapabilities;
fn compatibility(&self) -> ProviderCompatibility {
ProviderCompatibility::openai_standard()
}
fn apply_http_config(&self, http_config: HttpConfig) -> HttpConfig {
http_config
}
fn validate_model(&self, model: &str) -> Result<(), LlmError> {
let _ = model;
Ok(())
}
fn base_url(&self) -> &str;
fn clone_adapter(&self) -> Box<dyn ProviderAdapter>;
fn supports_image_generation(&self) -> bool {
false
}
fn transform_image_request(
&self,
_request: &mut crate::types::ImageGenerationRequest,
) -> Result<(), LlmError> {
Ok(())
}
fn get_supported_image_sizes(&self) -> Vec<String> {
vec!["1024x1024".to_string()]
}
fn get_supported_image_formats(&self) -> Vec<String> {
vec!["url".to_string()]
}
fn supports_image_editing(&self) -> bool {
false
}
fn supports_image_variations(&self) -> bool {
false
}
}
impl Clone for Box<dyn ProviderAdapter> {
fn clone(&self) -> Self {
self.clone_adapter()
}
}
#[derive(Debug, Default)]
pub struct AdapterRegistry {
adapters: std::collections::HashMap<String, Box<dyn ProviderAdapter>>,
}
impl AdapterRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, adapter: Box<dyn ProviderAdapter>) {
let provider_id = adapter.provider_id().to_string();
self.adapters.insert(provider_id, adapter);
}
pub fn get_adapter(&self, provider_id: &str) -> Option<&dyn ProviderAdapter> {
self.adapters.get(provider_id).map(|a| a.as_ref())
}
pub fn list_providers(&self) -> Vec<String> {
self.adapters.keys().cloned().collect()
}
pub fn has_provider(&self, provider_id: &str) -> bool {
self.adapters.contains_key(provider_id)
}
}
impl Clone for AdapterRegistry {
fn clone(&self) -> Self {
let mut registry = Self::new();
for (id, adapter) in &self.adapters {
registry.adapters.insert(id.clone(), adapter.clone());
}
registry
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::openai_compatible::types::*;
#[derive(Debug, Clone)]
struct TestAdapter;
impl ProviderAdapter for TestAdapter {
fn provider_id(&self) -> &'static str {
"test"
}
fn transform_request_params(
&self,
_params: &mut serde_json::Value,
_model: &str,
_request_type: RequestType,
) -> Result<(), LlmError> {
Ok(())
}
fn get_field_mappings(&self, _model: &str) -> FieldMappings {
FieldMappings::default()
}
fn get_model_config(&self, _model: &str) -> ModelConfig {
ModelConfig::default()
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::default()
}
fn base_url(&self) -> &str {
"https://api.test.com/v1"
}
fn clone_adapter(&self) -> Box<dyn ProviderAdapter> {
Box::new(self.clone())
}
}
#[test]
fn test_adapter_registry() {
let mut registry = AdapterRegistry::new();
assert_eq!(registry.list_providers().len(), 0);
registry.register(Box::new(TestAdapter));
assert_eq!(registry.list_providers().len(), 1);
assert!(registry.has_provider("test"));
let adapter = registry.get_adapter("test").unwrap();
assert_eq!(adapter.provider_id(), "test");
}
#[test]
fn test_adapter_clone() {
let adapter: Box<dyn ProviderAdapter> = Box::new(TestAdapter);
let cloned = adapter.clone();
assert_eq!(adapter.provider_id(), cloned.provider_id());
}
}