use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use super::types::*;
use crate::core::providers::ProviderType;
use crate::core::providers::unified_provider::ProviderError;
pub struct DefaultTransformEngine {
pipelines: HashMap<ProviderType, TransformPipeline>,
model_mappings: HashMap<String, ModelMapping>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMapping {
pub provider_model: String,
pub openai_equivalent: String,
pub capabilities: Vec<String>,
pub parameter_mappings: HashMap<String, String>,
}
impl Default for DefaultTransformEngine {
fn default() -> Self {
Self::new()
}
}
impl DefaultTransformEngine {
pub fn new() -> Self {
let mut engine = Self {
pipelines: HashMap::new(),
model_mappings: HashMap::new(),
};
engine.init_default_mappings();
engine.init_default_pipelines();
engine
}
fn init_default_mappings(&mut self) {
self.model_mappings.insert(
"claude-3-sonnet".to_string(),
ModelMapping {
provider_model: "claude-3-sonnet-20240229".to_string(),
openai_equivalent: "gpt-4".to_string(),
capabilities: vec!["chat".to_string(), "vision".to_string()],
parameter_mappings: HashMap::from([
("max_tokens".to_string(), "max_tokens".to_string()),
("temperature".to_string(), "temperature".to_string()),
]),
},
);
self.model_mappings.insert(
"gemini-pro".to_string(),
ModelMapping {
provider_model: "gemini-1.0-pro".to_string(),
openai_equivalent: "gpt-3.5-turbo".to_string(),
capabilities: vec!["chat".to_string()],
parameter_mappings: HashMap::from([
("max_tokens".to_string(), "maxOutputTokens".to_string()),
("temperature".to_string(), "temperature".to_string()),
]),
},
);
}
fn init_default_pipelines(&mut self) {
let anthropic_pipeline = TransformPipeline {
transforms: vec![
Box::new(AnthropicMessageTransform::new()),
Box::new(AnthropicParameterTransform::new()),
],
};
self.pipelines
.insert(ProviderType::Anthropic, anthropic_pipeline);
let vertexai_pipeline = TransformPipeline {
transforms: vec![
Box::new(GoogleMessageTransform::new()),
Box::new(GoogleParameterTransform::new()),
],
};
self.pipelines
.insert(ProviderType::VertexAI, vertexai_pipeline);
}
pub(crate) fn map_model_name(&self, model: &str, provider_type: &ProviderType) -> String {
match provider_type {
ProviderType::Anthropic => {
if model.starts_with("claude") {
model.to_string()
} else {
"claude-3-sonnet-20240229".to_string() }
}
ProviderType::VertexAI => {
if model.starts_with("gemini") {
model.to_string()
} else {
"gemini-1.0-pro".to_string() }
}
_ => model.to_string(),
}
}
}
#[async_trait]
impl TransformEngine for DefaultTransformEngine {
async fn transform_chat_request(
&self,
request: &TransformChatRequest,
provider_type: &ProviderType,
provider_config: &HashMap<String, Value>,
) -> ProviderResult<TransformResult<ProviderRequest>> {
let context = TransformContext {
provider_type: provider_type.clone(),
original_model: request.model.clone(),
target_model: self.map_model_name(&request.model, provider_type),
config: provider_config.clone(),
metadata: HashMap::new(),
};
let mut transformations = Vec::new();
let warnings = Vec::new();
let mut request_json =
serde_json::to_value(request).map_err(|e| ProviderError::Serialization {
provider: "transform",
message: format!("Serialization error: {}", e),
})?;
if let Some(pipeline) = self.pipelines.get(provider_type) {
for transform in &pipeline.transforms {
transformations.push(transform.name().to_string());
request_json = transform.transform_request(request_json, &context).await?;
}
}
let provider_request = match provider_type {
ProviderType::Anthropic => self.build_anthropic_request(request_json, &context).await?,
ProviderType::VertexAI => self.build_vertexai_request(request_json, &context).await?,
_ => {
self.build_openai_compatible_request(request_json, &context)
.await?
}
};
Ok(TransformResult {
data: provider_request,
metadata: TransformMetadata {
provider_type: provider_type.clone(),
original_model: request.model.clone(),
transformed_model: context.target_model,
transformations_applied: transformations,
warnings,
cost_estimate: None,
},
})
}
async fn transform_chat_response(
&self,
response: &ProviderResponse,
provider_type: &ProviderType,
original_request: &TransformChatRequest,
) -> ProviderResult<TransformResult<ChatResponse>> {
let context = TransformContext {
provider_type: provider_type.clone(),
original_model: original_request.model.clone(),
target_model: self.map_model_name(&original_request.model, provider_type),
config: HashMap::new(),
metadata: HashMap::new(),
};
let mut transformations = Vec::new();
let mut response_json = response.body.clone();
if let Some(pipeline) = self.pipelines.get(provider_type) {
for transform in pipeline.transforms.iter().rev() {
transformations.push(format!("reverse_{}", transform.name()));
response_json = transform
.transform_response(response_json, &context)
.await?;
}
}
let chat_response: ChatResponse =
serde_json::from_value(response_json).map_err(|e| ProviderError::Serialization {
provider: "transform",
message: format!("Deserialization error: {}", e),
})?;
Ok(TransformResult {
data: chat_response,
metadata: TransformMetadata {
provider_type: provider_type.clone(),
original_model: original_request.model.clone(),
transformed_model: context.target_model,
transformations_applied: transformations,
warnings: Vec::new(),
cost_estimate: None,
},
})
}
async fn transform_embedding_request(
&self,
request: &EmbeddingRequest,
provider_type: &ProviderType,
provider_config: &HashMap<String, Value>,
) -> ProviderResult<TransformResult<ProviderRequest>> {
let context = TransformContext {
provider_type: provider_type.clone(),
original_model: request.model.clone(),
target_model: self.map_model_name(&request.model, provider_type),
config: provider_config.clone(),
metadata: HashMap::new(),
};
let request_json =
serde_json::to_value(request).map_err(|e| ProviderError::Serialization {
provider: "transform",
message: format!("Serialization error: {}", e),
})?;
let provider_request = self
.build_openai_compatible_request(request_json, &context)
.await?;
Ok(TransformResult {
data: provider_request,
metadata: TransformMetadata {
provider_type: provider_type.clone(),
original_model: request.model.clone(),
transformed_model: context.target_model,
transformations_applied: vec!["embedding_transform".to_string()],
warnings: Vec::new(),
cost_estimate: None,
},
})
}
async fn transform_embedding_response(
&self,
response: &ProviderResponse,
provider_type: &ProviderType,
original_request: &EmbeddingRequest,
) -> ProviderResult<TransformResult<EmbeddingResponse>> {
let embedding_response: EmbeddingResponse = serde_json::from_value(response.body.clone())
.map_err(|e| ProviderError::Serialization {
provider: "transform",
message: format!("Deserialization error: {}", e),
})?;
Ok(TransformResult {
data: embedding_response,
metadata: TransformMetadata {
provider_type: provider_type.clone(),
original_model: original_request.model.clone(),
transformed_model: self.map_model_name(&original_request.model, provider_type),
transformations_applied: vec!["embedding_response_transform".to_string()],
warnings: Vec::new(),
cost_estimate: None,
},
})
}
fn get_supported_transformations(&self, provider_type: &ProviderType) -> Vec<String> {
self.pipelines
.get(provider_type)
.map(|pipeline| {
pipeline
.transforms
.iter()
.map(|t| t.name().to_string())
.collect()
})
.unwrap_or_default()
}
async fn validate_request_compatibility(
&self,
request: &TransformChatRequest,
provider_type: &ProviderType,
) -> ProviderResult<Vec<String>> {
let mut issues = Vec::new();
match provider_type {
ProviderType::Anthropic => {
if request.functions.is_some() {
issues.push(
"Functions are not supported by Anthropic, use tools instead".to_string(),
);
}
if request.logit_bias.is_some() {
issues.push("Logit bias is not supported by Anthropic".to_string());
}
}
ProviderType::VertexAI => {
if request.functions.is_some() || request.tools.is_some() {
issues.push("Function calling support limited in Vertex AI models".to_string());
}
}
_ => {}
}
Ok(issues)
}
}
impl DefaultTransformEngine {
async fn build_anthropic_request(
&self,
_request: Value,
_context: &TransformContext,
) -> ProviderResult<ProviderRequest> {
Ok(ProviderRequest {
endpoint: "/v1/messages".to_string(),
method: "POST".to_string(),
headers: HashMap::from([
("Content-Type".to_string(), "application/json".to_string()),
("anthropic-version".to_string(), "2023-06-01".to_string()),
]),
body: serde_json::json!({}), query_params: HashMap::new(),
})
}
async fn build_vertexai_request(
&self,
_request: Value,
context: &TransformContext,
) -> ProviderResult<ProviderRequest> {
Ok(ProviderRequest {
endpoint: format!("/v1/models/{}:generateContent", context.target_model),
method: "POST".to_string(),
headers: HashMap::from([("Content-Type".to_string(), "application/json".to_string())]),
body: serde_json::json!({}), query_params: HashMap::new(),
})
}
async fn build_openai_compatible_request(
&self,
request: Value,
_context: &TransformContext,
) -> ProviderResult<ProviderRequest> {
Ok(ProviderRequest {
endpoint: "/v1/chat/completions".to_string(),
method: "POST".to_string(),
headers: HashMap::from([("Content-Type".to_string(), "application/json".to_string())]),
body: request,
query_params: HashMap::new(),
})
}
}
#[derive(Default)]
pub struct AnthropicMessageTransform;
#[derive(Default)]
pub struct AnthropicParameterTransform;
#[derive(Default)]
pub struct GoogleMessageTransform;
#[derive(Default)]
pub struct GoogleParameterTransform;
impl AnthropicMessageTransform {
pub fn new() -> Self {
Self
}
}
impl AnthropicParameterTransform {
pub fn new() -> Self {
Self
}
}
impl GoogleMessageTransform {
pub fn new() -> Self {
Self
}
}
impl GoogleParameterTransform {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl Transform for AnthropicMessageTransform {
async fn transform_request(
&self,
request: Value,
_context: &TransformContext,
) -> ProviderResult<Value> {
Ok(request)
}
async fn transform_response(
&self,
response: Value,
_context: &TransformContext,
) -> ProviderResult<Value> {
Ok(response)
}
fn name(&self) -> &str {
"anthropic_message_transform"
}
}
#[async_trait]
impl Transform for AnthropicParameterTransform {
async fn transform_request(
&self,
request: Value,
_context: &TransformContext,
) -> ProviderResult<Value> {
Ok(request)
}
async fn transform_response(
&self,
response: Value,
_context: &TransformContext,
) -> ProviderResult<Value> {
Ok(response)
}
fn name(&self) -> &str {
"anthropic_parameter_transform"
}
}
#[async_trait]
impl Transform for GoogleMessageTransform {
async fn transform_request(
&self,
request: Value,
_context: &TransformContext,
) -> ProviderResult<Value> {
Ok(request)
}
async fn transform_response(
&self,
response: Value,
_context: &TransformContext,
) -> ProviderResult<Value> {
Ok(response)
}
fn name(&self) -> &str {
"google_message_transform"
}
}
#[async_trait]
impl Transform for GoogleParameterTransform {
async fn transform_request(
&self,
request: Value,
_context: &TransformContext,
) -> ProviderResult<Value> {
Ok(request)
}
async fn transform_response(
&self,
response: Value,
_context: &TransformContext,
) -> ProviderResult<Value> {
Ok(response)
}
fn name(&self) -> &str {
"google_parameter_transform"
}
}