pub mod base;
#[cfg(feature = "providers-extended")]
pub mod ai21;
#[cfg(feature = "providers-extended")]
pub mod amazon_nova;
pub mod anthropic;
#[cfg(feature = "providers-extra")]
pub mod azure;
#[cfg(feature = "providers-extra")]
pub mod azure_ai;
#[cfg(feature = "providers-extended")]
pub mod baseten;
#[cfg(feature = "providers-extra")]
pub mod bedrock;
#[cfg(feature = "providers-extended")]
pub mod clarifai;
pub mod cloudflare;
#[cfg(feature = "providers-extended")]
pub mod codestral;
#[cfg(feature = "providers-extended")]
pub mod cohere;
#[cfg(feature = "providers-extended")]
pub mod custom_api;
#[cfg(feature = "providers-extended")]
pub mod databricks;
#[cfg(feature = "providers-extended")]
pub mod datarobot;
#[cfg(feature = "providers-extended")]
pub mod deepgram;
#[cfg(feature = "providers-extended")]
pub mod deepl;
#[cfg(feature = "providers-extended")]
pub mod elevenlabs;
#[cfg(feature = "providers-extended")]
pub mod empower;
#[cfg(feature = "providers-extended")]
pub mod exa_ai;
#[cfg(feature = "providers-extended")]
pub mod fal_ai;
#[cfg(feature = "providers-extended")]
pub mod firecrawl;
#[cfg(feature = "providers-extended")]
pub mod gemini;
#[cfg(feature = "providers-extended")]
pub mod gigachat;
#[cfg(feature = "providers-extended")]
pub mod github;
#[cfg(feature = "providers-extended")]
pub mod github_copilot;
#[cfg(feature = "providers-extended")]
pub mod google_pse;
#[cfg(feature = "providers-extended")]
pub mod gradient_ai;
#[cfg(feature = "providers-extended")]
pub mod huggingface;
#[cfg(feature = "providers-extended")]
pub mod jina;
#[cfg(feature = "providers-extended")]
pub mod langgraph;
#[cfg(feature = "providers-extended")]
pub mod manus;
#[cfg(feature = "providers-extra")]
pub mod meta_llama;
#[cfg(feature = "providers-extended")]
pub mod milvus;
pub mod mistral;
#[cfg(feature = "providers-extended")]
pub mod morph;
#[cfg(feature = "providers-extended")]
pub mod nlp_cloud;
#[cfg(feature = "providers-extended")]
pub mod oci;
#[cfg(feature = "providers-extended")]
pub mod ollama;
pub mod openai;
pub mod openai_like;
#[cfg(feature = "providers-extended")]
pub mod petals;
#[cfg(feature = "providers-extended")]
pub mod pg_vector;
#[cfg(feature = "providers-extended")]
pub mod predibase;
#[cfg(feature = "providers-extended")]
pub mod ragflow;
#[cfg(feature = "providers-extended")]
pub mod recraft;
#[cfg(feature = "providers-extended")]
pub mod replicate;
#[cfg(feature = "providers-extended")]
pub mod runwayml;
#[cfg(feature = "providers-extended")]
pub mod sagemaker;
#[cfg(feature = "providers-extended")]
pub mod sap_ai;
#[cfg(feature = "providers-extended")]
pub mod searxng;
#[cfg(feature = "providers-extended")]
pub mod snowflake;
#[cfg(feature = "providers-extended")]
pub mod spark;
#[cfg(feature = "providers-extended")]
pub mod stability;
#[cfg(feature = "providers-extended")]
pub mod tavily;
#[cfg(feature = "providers-extended")]
pub mod topaz;
#[cfg(feature = "providers-extended")]
pub mod triton;
#[cfg(feature = "providers-extra")]
pub mod v0;
#[cfg(feature = "providers-extended")]
pub mod vercel_ai;
#[cfg(feature = "providers-extra")]
pub mod vertex_ai;
#[cfg(feature = "providers-extended")]
pub mod voyage;
#[cfg(feature = "providers-extended")]
pub mod watsonx;
pub mod macros; pub mod shared; pub mod thinking; pub mod transform;
pub mod provider_type;
pub use provider_type::ProviderType;
pub mod factory;
pub use factory::{create_provider, is_provider_selector_supported};
pub mod contextual_error;
pub mod provider_error_conversions;
pub mod provider_registry;
pub mod registry; pub mod unified_provider;
#[cfg(test)]
mod unified_provider_tests;
pub use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
use crate::core::types::responses::{
ChatChunk, ChatResponse, EmbeddingResponse, ImageGenerationResponse,
};
use crate::core::types::{
chat::ChatRequest, embedding::EmbeddingRequest, image::ImageGenerationRequest,
};
use crate::core::types::{context::RequestContext, model::ProviderCapability};
use chrono::{DateTime, Utc};
pub use contextual_error::ContextualError;
pub use provider_registry::ProviderRegistry;
pub use unified_provider::ProviderError;
#[derive(Debug, Clone)]
pub struct ModelPricing {
pub model: String,
pub input_cost_per_1k: f64,
pub output_cost_per_1k: f64,
pub currency: String,
pub updated_at: DateTime<Utc>,
}
macro_rules! dispatch_provider {
(sync, $self:expr, $method:ident) => {
dispatch_provider!(@expand sync, $self, $method,)
};
(sync, $self:expr, $method:ident, $($arg:expr),+ $(,)?) => {
dispatch_provider!(@expand sync, $self, $method, $($arg),+)
};
(async_err, $self:expr, $method:ident $(, $arg:expr)* $(,)?) => {
dispatch_provider!(@expand async_err, $self, $method, $($arg),*)
};
(value, $self:expr, $method:ident) => {
dispatch_provider!(@expand value, $self, $method,)
};
(value, $self:expr, $method:ident, $($arg:expr),+ $(,)?) => {
dispatch_provider!(@expand value, $self, $method, $($arg),+)
};
(async_direct, $self:expr, $method:ident) => {
dispatch_provider!(@expand async_direct, $self, $method,)
};
(@expand sync, $self:expr, $method:ident, $($arg:expr),*) => {
match $self {
Provider::OpenAI(p) => p.$method($($arg),*),
Provider::Anthropic(p) => p.$method($($arg),*),
Provider::Mistral(p) => p.$method($($arg),*),
Provider::Cloudflare(p) => p.$method($($arg),*),
Provider::OpenAILike(p) => p.$method($($arg),*),
}
};
(@expand async_err, $self:expr, $method:ident, $($arg:expr),*) => {
match $self {
Provider::OpenAI(p) => LLMProvider::$method(p, $($arg),*).await.map_err(ProviderError::from),
Provider::Anthropic(p) => LLMProvider::$method(p, $($arg),*).await.map_err(ProviderError::from),
Provider::Mistral(p) => LLMProvider::$method(p, $($arg),*).await.map_err(ProviderError::from),
Provider::Cloudflare(p) => LLMProvider::$method(p, $($arg),*).await.map_err(ProviderError::from),
Provider::OpenAILike(p) => LLMProvider::$method(p, $($arg),*).await.map_err(ProviderError::from),
}
};
(@expand value, $self:expr, $method:ident, $($arg:expr),*) => {
match $self {
Provider::OpenAI(p) => LLMProvider::$method(p, $($arg),*),
Provider::Anthropic(p) => LLMProvider::$method(p, $($arg),*),
Provider::Mistral(p) => LLMProvider::$method(p, $($arg),*),
Provider::Cloudflare(p) => LLMProvider::$method(p, $($arg),*),
Provider::OpenAILike(p) => LLMProvider::$method(p, $($arg),*),
}
};
(@expand async_direct, $self:expr, $method:ident, $($arg:expr),*) => {
match $self {
Provider::OpenAI(p) => LLMProvider::$method(p).await,
Provider::Anthropic(p) => LLMProvider::$method(p).await,
Provider::Mistral(p) => LLMProvider::$method(p).await,
Provider::Cloudflare(p) => LLMProvider::$method(p).await,
Provider::OpenAILike(p) => LLMProvider::$method(p).await,
}
};
}
#[allow(unused_macros)]
macro_rules! dispatch_provider_selective {
($self:expr, $method:ident, { $($provider:ident),+ }, $default:expr) => {
match $self {
$(Provider::$provider(p) => p.$method()),+,
_ => $default,
}
};
($self:expr, $method:ident($($arg:expr),+), { $($provider:ident),+ }, $default:expr) => {
match $self {
$(Provider::$provider(p) => p.$method($($arg),+)),+,
_ => $default,
}
};
}
#[derive(Debug, Clone)]
pub enum Provider {
OpenAI(openai::OpenAIProvider),
Anthropic(anthropic::AnthropicProvider),
Mistral(mistral::MistralProvider),
Cloudflare(cloudflare::CloudflareProvider),
OpenAILike(openai_like::OpenAILikeProvider),
}
impl Provider {
pub fn name(&self) -> &'static str {
match self {
Provider::OpenAI(_) => "openai",
Provider::Anthropic(_) => "anthropic",
Provider::Mistral(_) => "mistral",
Provider::Cloudflare(_) => "cloudflare",
Provider::OpenAILike(p) => {
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
p.name()
}
}
}
pub fn provider_type(&self) -> ProviderType {
match self {
Provider::OpenAI(_) => ProviderType::OpenAI,
Provider::Anthropic(_) => ProviderType::Anthropic,
Provider::Mistral(_) => ProviderType::Mistral,
Provider::Cloudflare(_) => ProviderType::Cloudflare,
Provider::OpenAILike(_) => ProviderType::OpenAICompatible,
}
}
pub fn factory_supported_provider_types() -> &'static [ProviderType] {
static SUPPORTED: &[ProviderType] = &[
ProviderType::OpenAI,
ProviderType::Anthropic,
ProviderType::Mistral,
ProviderType::Cloudflare,
ProviderType::OpenAICompatible,
ProviderType::MetaLlama,
ProviderType::V0,
ProviderType::AzureAI,
ProviderType::AmazonNova,
ProviderType::FalAI,
ProviderType::Azure,
ProviderType::Bedrock,
ProviderType::VertexAI,
ProviderType::Replicate,
ProviderType::GitHub,
ProviderType::GitHubCopilot,
ProviderType::Groq,
ProviderType::OpenRouter,
ProviderType::DeepSeek,
ProviderType::DeepInfra,
ProviderType::Moonshot,
ProviderType::Minimax,
ProviderType::Dashscope,
ProviderType::XAI,
ProviderType::Perplexity,
ProviderType::Hyperbolic,
ProviderType::Infinity,
ProviderType::Novita,
ProviderType::Volcengine,
ProviderType::Nebius,
ProviderType::Nscale,
];
SUPPORTED
}
pub fn supports_model(&self, model: &str) -> bool {
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
dispatch_provider!(value, self, supports_model, model)
}
pub fn capabilities(&self) -> &'static [ProviderCapability] {
dispatch_provider!(sync, self, capabilities)
}
pub async fn chat_completion(
&self,
request: ChatRequest,
context: RequestContext,
) -> Result<ChatResponse, ProviderError> {
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
dispatch_provider!(async_err, self, chat_completion, request, context)
}
pub async fn health_check(&self) -> crate::core::types::health::HealthStatus {
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
dispatch_provider!(async_direct, self, health_check)
}
pub fn list_models(&self) -> &[crate::core::types::model::ModelInfo] {
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
dispatch_provider!(value, self, models)
}
pub async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64, ProviderError> {
let usage = crate::core::providers::base::pricing::Usage {
prompt_tokens: input_tokens,
completion_tokens: output_tokens,
total_tokens: input_tokens + output_tokens,
reasoning_tokens: None,
};
Ok(crate::core::providers::base::get_pricing_db().calculate(model, &usage))
}
pub async fn chat_completion_stream(
&self,
request: ChatRequest,
context: RequestContext,
) -> Result<
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<ChatChunk, ProviderError>> + Send + 'static>,
>,
ProviderError,
> {
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
dispatch_provider!(async_err, self, chat_completion_stream, request, context)
}
pub async fn create_embeddings(
&self,
request: EmbeddingRequest,
context: RequestContext,
) -> Result<EmbeddingResponse, ProviderError> {
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
match self {
Provider::OpenAI(p) => LLMProvider::embeddings(p, request, context).await,
_ => Err(ProviderError::not_implemented(
"unknown",
format!("Embeddings not supported by {}", self.name()),
)),
}
}
pub async fn create_images(
&self,
request: ImageGenerationRequest,
context: RequestContext,
) -> Result<ImageGenerationResponse, ProviderError> {
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
match self {
Provider::OpenAI(p) => LLMProvider::image_generation(p, request, context).await,
_ => Err(ProviderError::not_implemented(
"unknown",
format!("Image generation not supported by {}", self.name()),
)),
}
}
pub async fn get_model(
&self,
model_id: &str,
) -> Result<Option<crate::core::types::model::ModelInfo>, ProviderError> {
let models = self.list_models();
for model in models {
if model.id == model_id || model.name == model_id {
return Ok(Some(model.clone()));
}
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_pricing_creation() {
let pricing = ModelPricing {
model: "gpt-4".to_string(),
input_cost_per_1k: 0.03,
output_cost_per_1k: 0.06,
currency: "USD".to_string(),
updated_at: Utc::now(),
};
assert_eq!(pricing.model, "gpt-4");
assert_eq!(pricing.input_cost_per_1k, 0.03);
assert_eq!(pricing.output_cost_per_1k, 0.06);
assert_eq!(pricing.currency, "USD");
}
#[test]
fn test_model_pricing_clone() {
let pricing = ModelPricing {
model: "claude-3-opus".to_string(),
input_cost_per_1k: 0.015,
output_cost_per_1k: 0.075,
currency: "USD".to_string(),
updated_at: Utc::now(),
};
let cloned = pricing.clone();
assert_eq!(cloned.model, pricing.model);
assert_eq!(cloned.input_cost_per_1k, pricing.input_cost_per_1k);
assert_eq!(cloned.output_cost_per_1k, pricing.output_cost_per_1k);
}
#[test]
fn test_model_pricing_zero_cost() {
let pricing = ModelPricing {
model: "free-model".to_string(),
input_cost_per_1k: 0.0,
output_cost_per_1k: 0.0,
currency: "USD".to_string(),
updated_at: Utc::now(),
};
assert_eq!(pricing.input_cost_per_1k, 0.0);
assert_eq!(pricing.output_cost_per_1k, 0.0);
}
#[test]
fn test_model_pricing_debug() {
let pricing = ModelPricing {
model: "gpt-4".to_string(),
input_cost_per_1k: 0.03,
output_cost_per_1k: 0.06,
currency: "USD".to_string(),
updated_at: Utc::now(),
};
let debug_str = format!("{:?}", pricing);
assert!(debug_str.contains("gpt-4"));
assert!(debug_str.contains("0.03"));
}
#[test]
fn test_provider_enum_is_send_sync() {
assert!(matches!(ProviderType::from("openai"), ProviderType::OpenAI));
}
}