pub mod ai21;
pub mod aleph_alpha;
pub mod anthropic;
pub mod anyscale;
pub mod aws_bedrock;
pub mod azure;
pub mod cohere;
pub mod deepinfra;
pub mod fireworks;
pub mod google;
pub mod google_vertex;
pub mod groq;
pub mod huggingface;
pub mod mistral;
pub mod ollama;
pub mod openai;
pub mod perplexity;
pub mod replicate;
pub mod stability_ai;
pub mod together_ai;
pub mod voyage;
pub mod writer;
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::{GatewayError, Result};
use async_trait::async_trait;
use std::fmt::Debug;
use std::sync::Arc;
pub use ai21::AI21Provider;
pub use aleph_alpha::AlephAlphaProvider;
pub use anthropic::AnthropicProvider;
pub use anyscale::AnyscaleProvider;
pub use aws_bedrock::AWSBedrockProvider;
pub use azure::AzureProvider;
pub use cohere::CohereProvider;
pub use deepinfra::DeepInfraProvider;
pub use fireworks::FireworksProvider;
pub use google::GoogleProvider;
pub use google_vertex::GoogleVertexProvider;
pub use groq::GroqProvider;
pub use huggingface::HuggingFaceProvider;
pub use mistral::MistralProvider;
pub use ollama::OllamaProvider;
pub use openai::OpenAIProvider;
pub use perplexity::PerplexityProvider;
pub use replicate::ReplicateProvider;
pub use stability_ai::StabilityAIProvider;
pub use together_ai::TogetherAIProvider;
pub use voyage::VoyageProvider;
pub use writer::WriterProvider;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProviderType {
OpenAI,
Anthropic,
Azure,
Google,
Cohere,
HuggingFace,
Ollama,
Custom(String),
}
impl From<&str> for ProviderType {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"openai" => ProviderType::OpenAI,
"anthropic" => ProviderType::Anthropic,
"azure" => ProviderType::Azure,
"google" => ProviderType::Google,
"cohere" => ProviderType::Cohere,
"huggingface" => ProviderType::HuggingFace,
"ollama" => ProviderType::Ollama,
"aws_bedrock" | "bedrock" => ProviderType::Custom("aws_bedrock".to_string()),
"google_vertex" | "vertex" => ProviderType::Custom("google_vertex".to_string()),
"mistral" => ProviderType::Custom("mistral".to_string()),
"together_ai" | "together" => ProviderType::Custom("together_ai".to_string()),
"perplexity" => ProviderType::Custom("perplexity".to_string()),
"replicate" => ProviderType::Custom("replicate".to_string()),
"fireworks" => ProviderType::Custom("fireworks".to_string()),
"groq" => ProviderType::Custom("groq".to_string()),
"anyscale" => ProviderType::Custom("anyscale".to_string()),
"deepinfra" => ProviderType::Custom("deepinfra".to_string()),
"ai21" => ProviderType::Custom("ai21".to_string()),
"aleph_alpha" => ProviderType::Custom("aleph_alpha".to_string()),
"voyage" => ProviderType::Custom("voyage".to_string()),
"writer" => ProviderType::Custom("writer".to_string()),
"stability_ai" | "stability" => ProviderType::Custom("stability_ai".to_string()),
_ => ProviderType::Custom(s.to_string()),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ProviderError {
#[error("Authentication failed: {0}")]
Authentication(String),
#[error("Rate limit exceeded: {0}")]
RateLimit(String),
#[error("Rate limited: {0}")]
RateLimited(String),
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Invalid request: {0}")]
InvalidRequest(String),
#[error("Provider unavailable: {0}")]
Unavailable(String),
#[error("Network error: {0}")]
Network(String),
#[error("Parsing error: {0}")]
Parsing(String),
#[error("Timeout: {0}")]
Timeout(String),
#[error("Other error: {0}")]
Other(String),
#[error("Unknown error: {0}")]
Unknown(String),
}
#[async_trait]
pub trait Provider: Send + Sync + Debug {
fn name(&self) -> &str;
fn provider_type(&self) -> ProviderType;
async fn supports_model(&self, model: &str) -> bool;
async fn supports_images(&self) -> bool;
async fn supports_embeddings(&self) -> bool;
async fn supports_streaming(&self) -> bool;
async fn list_models(&self) -> Result<Vec<Model>>;
async fn get_model(&self, model_id: &str) -> Result<Option<Model>> {
let models = self.list_models().await?;
Ok(models.into_iter().find(|m| m.id == model_id))
}
async fn health_check(&self) -> Result<()>;
async fn chat_completion(
&self,
request: ChatCompletionRequest,
context: RequestContext,
) -> Result<ChatCompletionResponse>;
async fn chat_completion_stream(
&self,
_request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin + 'static>> {
Err(crate::utils::error::GatewayError::invalid_request(
"Streaming not supported by this provider",
))
}
async fn completion(
&self,
request: CompletionRequest,
context: RequestContext,
) -> Result<CompletionResponse>;
async fn embedding(
&self,
request: EmbeddingRequest,
context: RequestContext,
) -> Result<EmbeddingResponse>;
async fn image_generation(
&self,
request: ImageGenerationRequest,
context: RequestContext,
) -> Result<ImageGenerationResponse>;
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing>;
async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64>;
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ModelPricing {
pub model: String,
pub input_cost_per_1k: f64,
pub output_cost_per_1k: f64,
pub currency: String,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone)]
pub struct BaseProvider {
pub name: String,
pub config: ProviderConfig,
pub client: reqwest::Client,
pub base_url: String,
pub api_key: String,
}
impl BaseProvider {
pub fn new(config: &ProviderConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(config.timeout))
.build()
.map_err(|e| GatewayError::internal(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self {
name: config.name.clone(),
config: config.clone(),
client,
base_url: config.base_url.clone().unwrap_or_default(),
api_key: config.api_key.clone(),
})
}
pub async fn make_request(
&self,
method: reqwest::Method,
endpoint: &str,
body: Option<serde_json::Value>,
) -> Result<reqwest::Response> {
let url = format!(
"{}/{}",
self.base_url.trim_end_matches('/'),
endpoint.trim_start_matches('/')
);
let mut request = self.client.request(method, &url);
if !self.api_key.is_empty() {
request = request.header("Authorization", format!("Bearer {}", self.api_key));
}
if let Some(body) = body {
request = request
.header("Content-Type", "application/json")
.json(&body);
}
let response = request
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 => ProviderError::Authentication(error_text),
429 => ProviderError::RateLimit(error_text),
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
503 => ProviderError::Unavailable(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
Ok(response)
}
pub async fn parse_json_response<T>(&self, response: reqwest::Response) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let text = response
.text()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
serde_json::from_str(&text)
.map_err(|e| ProviderError::Parsing(format!("Failed to parse JSON: {}", e)).into())
}
pub fn get_supported_models(&self) -> &[String] {
&self.config.models
}
pub fn is_model_supported(&self, model: &str) -> bool {
self.config
.models
.iter()
.any(|m| m == model || model.starts_with(m))
}
}
pub struct ProviderFactory;
impl ProviderFactory {
pub async fn create_provider(config: &ProviderConfig) -> Result<Box<dyn Provider>> {
match config.provider_type.as_str() {
"openai" => Ok(Box::new(OpenAIProvider::new(config).await?)),
"anthropic" => Ok(Box::new(AnthropicProvider::new(config).await?)),
"azure" => Ok(Box::new(AzureProvider::new(config).await?)),
"google" => Ok(Box::new(GoogleProvider::new(config).await?)),
"cohere" => Ok(Box::new(CohereProvider::new(config).await?)),
"huggingface" => Ok(Box::new(HuggingFaceProvider::new(config).await?)),
"ollama" => Ok(Box::new(OllamaProvider::new(config).await?)),
"aws_bedrock" | "bedrock" => Ok(Box::new(AWSBedrockProvider::new(config).await?)),
"google_vertex" | "vertex" => Ok(Box::new(GoogleVertexProvider::new(config).await?)),
"mistral" => Ok(Box::new(MistralProvider::new(config).await?)),
"together_ai" | "together" => Ok(Box::new(TogetherAIProvider::new(config).await?)),
"perplexity" => Ok(Box::new(PerplexityProvider::new(config).await?)),
"replicate" => Ok(Box::new(ReplicateProvider::new(config).await?)),
"fireworks" => Ok(Box::new(FireworksProvider::new(config).await?)),
"groq" => Ok(Box::new(GroqProvider::new(config).await?)),
"anyscale" => Ok(Box::new(AnyscaleProvider::new(config).await?)),
"deepinfra" => Ok(Box::new(DeepInfraProvider::new(config).await?)),
"ai21" => Ok(Box::new(AI21Provider::new(config).await?)),
"aleph_alpha" => Ok(Box::new(AlephAlphaProvider::new(config).await?)),
"voyage" => Ok(Box::new(VoyageProvider::new(config).await?)),
"writer" => Ok(Box::new(WriterProvider::new(config).await?)),
"stability_ai" | "stability" => Ok(Box::new(StabilityAIProvider::new(config).await?)),
_ => Err(GatewayError::bad_request(format!(
"Unsupported provider: {}",
config.provider_type
))),
}
}
pub fn supported_types() -> Vec<&'static str> {
vec![
"openai",
"anthropic",
"azure",
"google",
"cohere",
"huggingface",
"ollama",
"aws_bedrock",
"google_vertex",
"mistral",
"together_ai",
"perplexity",
"replicate",
"fireworks",
"groq",
"anyscale",
"deepinfra",
"ai21",
"aleph_alpha",
"voyage",
"writer",
"stability_ai",
]
}
}
pub mod utils {
use super::*;
pub fn provider_error_to_gateway_error(error: ProviderError) -> GatewayError {
match error {
ProviderError::Authentication(msg) => GatewayError::Unauthorized(msg),
ProviderError::RateLimit(msg) => GatewayError::RateLimit(msg),
ProviderError::ModelNotFound(msg) => GatewayError::NotFound(msg),
ProviderError::InvalidRequest(msg) => GatewayError::Validation(msg),
ProviderError::Unavailable(msg) => GatewayError::ServiceUnavailable(msg),
ProviderError::Network(msg) => GatewayError::Network(msg),
ProviderError::RateLimited(msg) => GatewayError::RateLimit(msg),
ProviderError::Timeout(msg) => GatewayError::Network(msg),
ProviderError::Other(msg) => GatewayError::Internal(msg),
ProviderError::Parsing(msg) => GatewayError::Parsing(msg),
ProviderError::Unknown(msg) => GatewayError::Internal(msg),
}
}
pub fn extract_model_name(full_model: &str) -> &str {
full_model.split('/').next_back().unwrap_or(full_model)
}
pub fn normalize_model_name(model: &str, provider_prefix: &str) -> String {
if model.starts_with(provider_prefix) {
model.to_string()
} else {
format!("{}/{}", provider_prefix, model)
}
}
pub fn estimate_token_count(text: &str) -> u32 {
(text.len() as f64 / 4.0).ceil() as u32
}
pub fn convert_messages_format(
messages: &[ChatMessage],
provider_type: &ProviderType,
) -> Result<serde_json::Value> {
match provider_type {
ProviderType::OpenAI | ProviderType::Azure => Ok(serde_json::to_value(messages)?),
ProviderType::Anthropic => {
let mut anthropic_messages = Vec::new();
let mut system_message = None;
for message in messages {
match message.role {
MessageRole::System => {
if let Some(MessageContent::Text(text)) = &message.content {
system_message = Some(text.clone());
}
}
MessageRole::User | MessageRole::Assistant => {
if let Some(content) = &message.content {
anthropic_messages.push(serde_json::json!({
"role": message.role.to_string().to_lowercase(),
"content": content
}));
}
}
_ => {} }
}
let mut result = serde_json::json!({
"messages": anthropic_messages
});
if let Some(system) = system_message {
result["system"] = serde_json::Value::String(system);
}
Ok(result)
}
_ => {
Ok(serde_json::to_value(messages)?)
}
}
}
}
pub async fn create_provider(
config: crate::config::ProviderConfig,
) -> crate::utils::error::Result<std::sync::Arc<dyn Provider>> {
let provider = ProviderFactory::create_provider(&config).await?;
Ok(std::sync::Arc::from(provider))
}
pub struct ProviderPool {
providers: Vec<Arc<dyn Provider>>,
config: Vec<crate::config::ProviderConfig>,
}
impl ProviderPool {
pub async fn new(
configs: &[crate::config::ProviderConfig],
) -> crate::utils::error::Result<Self> {
let mut providers = Vec::new();
for config in configs {
let provider = ProviderFactory::create_provider(config).await?;
providers.push(Arc::from(provider));
}
Ok(Self {
providers,
config: configs.to_vec(),
})
}
pub fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
for (i, config) in self.config.iter().enumerate() {
if config.name == name {
return Some(self.providers[i].clone());
}
}
None
}
pub fn get_all_providers(&self) -> &[Arc<dyn Provider>] {
&self.providers
}
pub async fn health_check(&self) -> ProviderPoolHealth {
let mut healthy_count = 0;
let total_count = self.providers.len();
for provider in &self.providers {
if provider.health_check().await.is_ok() {
healthy_count += 1;
}
}
ProviderPoolHealth {
healthy_count,
total_count,
}
}
}
pub struct ProviderPoolHealth {
pub healthy_count: usize,
pub total_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_type_from_str() {
assert_eq!(ProviderType::from("openai"), ProviderType::OpenAI);
assert_eq!(ProviderType::from("anthropic"), ProviderType::Anthropic);
assert_eq!(
ProviderType::from("custom"),
ProviderType::Custom("custom".to_string())
);
}
#[test]
fn test_extract_model_name() {
assert_eq!(utils::extract_model_name("openai/gpt-4"), "gpt-4");
assert_eq!(utils::extract_model_name("gpt-4"), "gpt-4");
assert_eq!(utils::extract_model_name("provider/sub/model"), "model");
}
#[test]
fn test_normalize_model_name() {
assert_eq!(
utils::normalize_model_name("gpt-4", "openai"),
"openai/gpt-4"
);
assert_eq!(
utils::normalize_model_name("openai/gpt-4", "openai"),
"openai/gpt-4"
);
}
#[test]
fn test_estimate_token_count() {
assert_eq!(utils::estimate_token_count("hello"), 2); assert_eq!(utils::estimate_token_count("hello world"), 3); }
}