use std::collections::HashMap;
use clust::messages::StopSequence;
use crate::{
models::ModelMetadata,
types::{
credentials::{ApiKeyCredentials, Credentials},
engine::{
AnthropicModelParams, BedrockModelParams, ClaudeModel, CompletionEngineParams,
EmbeddingsEngineParams, ExecutionOptions, GeminiModelParams,
ImageGenerationEngineParams, OpenAiModelParams,
},
gateway::{
ChatCompletionRequest, CreateEmbeddingRequest, CreateImageRequest,
ProviderSpecificRequest,
},
provider::{BedrockProvider, InferenceModelProvider},
},
};
use crate::error::GatewayError;
pub struct Provider {}
impl Provider {
pub fn get_completion_engine_for_model(
model: &ModelMetadata,
request: &ChatCompletionRequest,
credentials: Option<Credentials>,
provider_specific: Option<&ProviderSpecificRequest>,
execution_options: Option<ExecutionOptions>,
) -> Result<CompletionEngineParams, GatewayError> {
match model.inference_provider.provider {
InferenceModelProvider::OpenAI | InferenceModelProvider::Proxy(_) => {
let params = OpenAiModelParams {
model: Some(model.inference_provider.model_name.clone()),
frequency_penalty: request.frequency_penalty,
logit_bias: request.logit_bias.clone(),
logprobs: None,
top_logprobs: None,
max_tokens: request.max_tokens,
presence_penalty: request.presence_penalty,
seed: request.seed,
stop: request.stop.clone(),
temperature: request.temperature,
top_p: request.top_p,
user: request.user.clone(),
response_format: request.response_format.clone(),
prompt_cache_key: request.prompt_cache_key.clone(),
};
let mut custom_endpoint = None;
let api_key_credentials = credentials.and_then(|cred| match cred {
Credentials::ApiKey(key) => Some(key),
Credentials::ApiKeyWithEndpoint {
api_key: key,
endpoint,
} => {
custom_endpoint = Some(endpoint);
Some(ApiKeyCredentials { api_key: key })
}
_ => None,
});
match &model.inference_provider.provider {
InferenceModelProvider::OpenAI => Ok(CompletionEngineParams::OpenAi {
params,
execution_options: execution_options.unwrap_or_default(),
credentials: api_key_credentials,
endpoint: None,
}),
InferenceModelProvider::Proxy(proxy_provider) => {
if proxy_provider == "azure" {
Ok(CompletionEngineParams::OpenAi {
params,
execution_options: execution_options.unwrap_or_default(),
credentials: api_key_credentials,
endpoint: custom_endpoint,
})
} else {
Ok(CompletionEngineParams::Proxy {
params,
execution_options: execution_options.unwrap_or_default(),
credentials: api_key_credentials,
})
}
}
_ => unreachable!(),
}
}
InferenceModelProvider::Bedrock => {
let aws_creds = credentials.and_then(|cred| cred.to_bedrock_credentials());
let provider = match model.model_provider.as_str() {
"cohere" => BedrockProvider::Cohere,
"meta" => BedrockProvider::Meta,
"mistral" => BedrockProvider::Mistral,
p => BedrockProvider::Other(p.to_string()),
};
Ok(CompletionEngineParams::Bedrock {
credentials: aws_creds,
execution_options: execution_options.unwrap_or_default(),
params: BedrockModelParams {
model_id: Some(model.inference_provider.model_name.clone()),
max_tokens: request.max_tokens.map(|x| x as i32),
temperature: request.temperature,
top_p: request.top_p,
stop_sequences: request.stop.clone(),
additional_parameters: HashMap::new(),
},
provider,
})
}
InferenceModelProvider::Anthropic => {
let api_key_credentials = credentials.and_then(|cred| match cred {
Credentials::ApiKey(key) => Some(key),
_ => None,
});
let model_name = get_anthropic_model(&model.inference_provider.model_name);
let model = serde_json::from_str::<ClaudeModel>(&format!("\"{model_name}\""))?;
Ok(CompletionEngineParams::Anthropic {
credentials: api_key_credentials,
execution_options: execution_options.unwrap_or_default(),
params: AnthropicModelParams {
model: Some(model.clone()),
max_tokens: match request.max_tokens {
Some(x) => Some(clust::messages::MaxTokens::new(x, model.model)?),
None => None,
},
stop_sequences: request
.stop
.as_ref()
.map(|s| s.iter().map(StopSequence::new).collect()),
stream: None,
temperature: match request.temperature {
Some(t) => Some(clust::messages::Temperature::new(t)?),
None => None,
},
top_p: match request.top_p {
Some(p) => Some(clust::messages::TopP::new(p)?),
None => None,
},
top_k: provider_specific
.and_then(|ps| ps.top_k.map(clust::messages::TopK::new)),
thinking: provider_specific.and_then(|ps| {
ps.thinking
.as_ref()
.map(|thinking| clust::messages::Thinking {
r#type: thinking.r#type.clone(),
budget_tokens: thinking.budget_tokens,
})
}),
},
})
}
InferenceModelProvider::Gemini => {
let api_key_credentials = credentials.and_then(|cred| match cred {
Credentials::ApiKey(key) => Some(key),
_ => None,
});
Ok(CompletionEngineParams::Gemini {
credentials: api_key_credentials,
execution_options: execution_options.unwrap_or_default(),
params: GeminiModelParams {
model: Some(model.inference_provider.model_name.clone()),
max_output_tokens: request.max_tokens.map(|x| x as i32),
temperature: request.temperature,
top_p: request.top_p,
stop_sequences: request.stop.clone(),
candidate_count: request.n,
presence_penalty: request.presence_penalty,
frequency_penalty: request.frequency_penalty,
seed: request.seed,
response_logprobs: None,
logprobs: None,
top_k: None,
response_format: request.response_format.clone(),
},
})
}
InferenceModelProvider::VertexAI => {
unimplemented!()
}
}
}
pub fn get_image_engine_for_model(
model: &ModelMetadata,
request: &CreateImageRequest,
credentials: Option<&Credentials>,
) -> Result<ImageGenerationEngineParams, GatewayError> {
match model.inference_provider.provider {
InferenceModelProvider::OpenAI => {
let mut custom_endpoint = None;
Ok(ImageGenerationEngineParams::OpenAi {
credentials: credentials.and_then(|cred| match cred {
Credentials::ApiKey(key) => Some(key.clone()),
Credentials::ApiKeyWithEndpoint { api_key, endpoint } => {
custom_endpoint = Some(endpoint.clone());
Some(ApiKeyCredentials {
api_key: api_key.clone(),
})
}
_ => None,
}),
model_name: request.model.clone(),
endpoint: custom_endpoint,
})
}
InferenceModelProvider::Proxy(_) => Ok(ImageGenerationEngineParams::LangdbOpen {
credentials: credentials.and_then(|cred| match cred {
Credentials::ApiKey(key) => Some(key.clone()),
_ => None,
}),
model_name: request.model.clone(),
}),
InferenceModelProvider::VertexAI
| InferenceModelProvider::Anthropic
| InferenceModelProvider::Gemini
| InferenceModelProvider::Bedrock => Err(GatewayError::UnsupportedProvider(
model.inference_provider.provider.to_string(),
)),
}
}
pub fn get_embeddings_engine_for_model(
model: &ModelMetadata,
request: &CreateEmbeddingRequest,
credentials: Option<&Credentials>,
) -> Result<EmbeddingsEngineParams, GatewayError> {
match model.inference_provider.provider {
InferenceModelProvider::OpenAI | InferenceModelProvider::Proxy(_) => {
let mut custom_endpoint = None;
Ok(EmbeddingsEngineParams::OpenAi {
credentials: credentials.and_then(|cred| match cred {
Credentials::ApiKey(key) => Some(key.clone()),
Credentials::ApiKeyWithEndpoint { api_key, endpoint } => {
custom_endpoint = Some(endpoint.clone());
Some(ApiKeyCredentials {
api_key: api_key.clone(),
})
}
_ => None,
}),
model_name: request.model.clone(),
endpoint: custom_endpoint,
})
}
InferenceModelProvider::Gemini => Ok(EmbeddingsEngineParams::Gemini {
credentials: credentials.and_then(|cred| match cred {
Credentials::ApiKey(key) => Some(key.clone()),
_ => None,
}),
model_name: request.model.clone(),
}),
InferenceModelProvider::Bedrock => Ok(EmbeddingsEngineParams::Bedrock {
credentials: credentials.and_then(|cred| match cred {
Credentials::Aws(cred) => Some(cred.clone()),
_ => None,
}),
model_name: request.model.clone(),
}),
InferenceModelProvider::VertexAI | InferenceModelProvider::Anthropic => Err(
GatewayError::UnsupportedProvider(model.inference_provider.provider.to_string()),
),
}
}
}
fn get_anthropic_model(model_name: &str) -> &str {
match model_name {
"claude-3-opus" => "claude-3-opus-20240229",
"claude-3-sonnet" => "claude-3-sonnet-20240229",
"claude-3-haiku" => "claude-3-haiku-20240307",
"claude-3-5-sonnet" => "claude-3-5-sonnet-20240620",
n => n,
}
}