use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::sync::LazyLock;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::Deserialize;
use crate::error::{LiterLlmError, Result};
pub(crate) fn unix_timestamp_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamFormat {
Sse,
AwsEventStream,
}
const PROVIDERS_JSON: &str = include_str!("../../schemas/providers.json");
static REGISTRY: LazyLock<std::result::Result<ProviderRegistry, String>> =
LazyLock::new(|| serde_json::from_str(PROVIDERS_JSON).map_err(|e| e.to_string()));
fn registry() -> Result<&'static ProviderRegistry> {
REGISTRY.as_ref().map_err(|e| LiterLlmError::ServerError {
message: format!("embedded schemas/providers.json is invalid: {e}"),
})
}
#[derive(Debug, Deserialize)]
struct ProviderRegistry {
providers: Vec<ProviderConfig>,
#[serde(default, deserialize_with = "deserialize_hashset")]
complex_providers: HashSet<String>,
}
fn deserialize_hashset<'de, D>(deserializer: D) -> std::result::Result<HashSet<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
let vec = Vec::<String>::deserialize(deserializer)?;
Ok(vec.into_iter().collect())
}
#[derive(Debug, Clone, Deserialize)]
pub struct ProviderConfig {
pub name: String,
pub display_name: Option<String>,
pub base_url: Option<String>,
pub auth: Option<AuthConfig>,
pub endpoints: Option<Vec<String>>,
pub model_prefixes: Option<Vec<String>>,
pub(crate) param_mappings: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum AuthType {
Bearer,
#[serde(alias = "header", alias = "x-api-key")]
ApiKey,
None,
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AuthConfig {
#[serde(rename = "type")]
pub auth_type: AuthType,
pub env_var: Option<String>,
}
pub trait Provider: Send + Sync {
fn validate(&self) -> Result<()> {
Ok(())
}
fn name(&self) -> &str;
fn base_url(&self) -> &str;
fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)>;
fn extra_headers(&self) -> &'static [(&'static str, &'static str)] {
&[]
}
fn dynamic_headers(&self, _body: &serde_json::Value) -> Vec<(String, String)> {
vec![]
}
fn matches_model(&self, model: &str) -> bool;
fn strip_model_prefix<'m>(&self, model: &'m str) -> &'m str {
if let Some(rest) = model.strip_prefix(self.name())
&& let Some(stripped) = rest.strip_prefix('/')
{
return stripped;
}
model
}
fn chat_completions_path(&self) -> &str {
"/chat/completions"
}
fn embeddings_path(&self) -> &str {
"/embeddings"
}
fn models_path(&self) -> &str {
"/models"
}
fn image_generations_path(&self) -> &str {
"/images/generations"
}
fn audio_speech_path(&self) -> &str {
"/audio/speech"
}
fn audio_transcriptions_path(&self) -> &str {
"/audio/transcriptions"
}
fn moderations_path(&self) -> &str {
"/moderations"
}
fn rerank_path(&self) -> &str {
"/rerank"
}
fn files_path(&self) -> &str {
"/files"
}
fn batches_path(&self) -> &str {
"/batches"
}
fn responses_path(&self) -> &str {
"/responses"
}
fn search_path(&self) -> &str {
"/search"
}
fn ocr_path(&self) -> &str {
"/ocr"
}
#[allow(dead_code)] fn supports_streaming(&self) -> bool {
true
}
fn transform_request(&self, body: &mut serde_json::Value) -> Result<()> {
let _ = body;
Ok(())
}
fn transform_response(&self, _body: &mut serde_json::Value) -> Result<()> {
Ok(())
}
fn build_url(&self, endpoint_path: &str, _model: &str) -> String {
format!("{}{}", self.base_url(), endpoint_path)
}
fn parse_stream_event(&self, event_data: &str) -> Result<Option<crate::types::ChatCompletionChunk>> {
serde_json::from_str::<crate::types::ChatCompletionChunk>(event_data)
.map(Some)
.map_err(|e| LiterLlmError::Streaming {
message: format!("failed to parse SSE data: {e}"),
})
}
fn stream_format(&self) -> StreamFormat {
StreamFormat::Sse
}
fn build_stream_url(&self, endpoint_path: &str, model: &str) -> String {
self.build_url(endpoint_path, model)
}
fn signing_headers(&self, method: &str, url: &str, body: &[u8]) -> Vec<(String, String)> {
let _ = (method, url, body);
vec![]
}
}
pub mod anthropic;
pub mod azure;
pub mod bedrock;
pub mod cohere;
pub mod custom;
pub mod google_ai;
pub mod mistral;
pub mod vertex;
pub struct OpenAiProvider;
impl Provider for OpenAiProvider {
fn name(&self) -> &str {
"openai"
}
fn base_url(&self) -> &str {
"https://api.openai.com/v1"
}
fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
Some((Cow::Borrowed("Authorization"), Cow::Owned(format!("Bearer {api_key}"))))
}
fn matches_model(&self, model: &str) -> bool {
model.starts_with("gpt-")
|| model.starts_with("o1-")
|| model.starts_with("o3-")
|| model.starts_with("o4-")
|| model == "o1"
|| model == "o3"
|| model == "o4"
|| model.starts_with("dall-e-")
|| model.starts_with("whisper-")
|| model.starts_with("tts-")
|| model.starts_with("text-embedding-")
|| model.starts_with("chatgpt-")
|| model.starts_with("openai/")
}
fn strip_model_prefix<'m>(&self, model: &'m str) -> &'m str {
model.strip_prefix("openai/").unwrap_or(model)
}
}
pub struct OpenAiCompatibleProvider {
pub name: String,
pub base_url: String,
#[allow(dead_code)] pub env_var: Option<&'static str>,
pub model_prefixes: Vec<String>,
}
impl Provider for OpenAiCompatibleProvider {
fn name(&self) -> &str {
&self.name
}
fn base_url(&self) -> &str {
&self.base_url
}
fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
Some((Cow::Borrowed("Authorization"), Cow::Owned(format!("Bearer {api_key}"))))
}
fn matches_model(&self, model: &str) -> bool {
self.model_prefixes
.iter()
.any(|prefix| model.starts_with(prefix.as_str()))
}
}
pub struct ConfigDrivenProvider {
config: &'static ProviderConfig,
}
impl ConfigDrivenProvider {
#[must_use]
pub(crate) fn new(config: &'static ProviderConfig) -> Self {
Self { config }
}
}
impl Provider for ConfigDrivenProvider {
fn name(&self) -> &str {
&self.config.name
}
fn base_url(&self) -> &str {
self.config.base_url.as_deref().unwrap_or("")
}
fn transform_request(&self, body: &mut serde_json::Value) -> Result<()> {
if let Some(mappings) = &self.config.param_mappings
&& let Some(obj) = body.as_object_mut()
{
for (from, to) in mappings {
if let Some(val) = obj.remove(from.as_str()) {
obj.insert(to.clone(), val);
}
}
}
Ok(())
}
fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
let auth_type = self
.config
.auth
.as_ref()
.map(|a| &a.auth_type)
.unwrap_or(&AuthType::Bearer);
match auth_type {
AuthType::None => None,
AuthType::ApiKey => Some((Cow::Borrowed("x-api-key"), Cow::Borrowed(api_key))),
AuthType::Bearer | AuthType::Unknown => {
Some((Cow::Borrowed("Authorization"), Cow::Owned(format!("Bearer {api_key}"))))
}
}
}
fn matches_model(&self, model: &str) -> bool {
if let Some(prefixes) = &self.config.model_prefixes {
prefixes.iter().any(|p| model.starts_with(p.as_str()))
} else {
false
}
}
}
pub fn detect_provider(model: &str) -> Option<Box<dyn Provider>> {
if let Some(provider) = custom::detect_custom_provider(model) {
return Some(provider);
}
let openai = OpenAiProvider;
if openai.matches_model(model) {
return Some(Box::new(openai));
}
let anthropic = anthropic::AnthropicProvider;
if anthropic.matches_model(model) {
return Some(Box::new(anthropic));
}
if model.starts_with("azure/") {
return Some(Box::new(azure::AzureProvider::new()));
}
if model.starts_with("gemini/") || model.starts_with("google_ai/") {
return Some(Box::new(google_ai::GoogleAiProvider));
}
if model.starts_with("vertex_ai/") {
return Some(Box::new(vertex::VertexAiProvider::from_env()));
}
if model.starts_with("bedrock/") {
return Some(Box::new(bedrock::BedrockProvider::from_env()));
}
if model.starts_with("command-") || model.starts_with("cohere/") {
return Some(Box::new(cohere::CohereProvider));
}
if model.starts_with("mistral-")
|| model.starts_with("codestral-")
|| model.starts_with("pixtral-")
|| model.starts_with("mistral/")
{
return Some(Box::new(mistral::MistralProvider));
}
let reg = match REGISTRY.as_ref() {
Ok(r) => r,
Err(_) => return None,
};
if let Some((prefix, _)) = model.split_once('/')
&& let Some(cfg) = reg.providers.iter().find(|p| p.name == prefix)
&& cfg.base_url.is_some()
&& !reg.complex_providers.contains(&cfg.name)
{
return Some(Box::new(ConfigDrivenProvider::new(cfg)));
}
for cfg in ®.providers {
if reg.complex_providers.contains(&cfg.name) {
continue;
}
if let Some(prefixes) = &cfg.model_prefixes {
let matches = prefixes
.iter()
.any(|p| model.starts_with(p.as_str()) && !p.ends_with('/'));
if matches && cfg.base_url.is_some() {
return Some(Box::new(ConfigDrivenProvider::new(cfg)));
}
}
}
None
}
pub fn all_providers() -> Result<&'static [ProviderConfig]> {
Ok(®istry()?.providers)
}
pub fn complex_provider_names() -> Result<&'static HashSet<String>> {
Ok(®istry()?.complex_providers)
}