omnillm 0.1.5

Production-grade LLM API gateway with multi-key load balancing, per-key rate limiting, circuit breaking, and cost tracking
Documentation
use std::collections::BTreeMap;

use serde_json::Value;

use crate::api::RequestBody;
use crate::protocol::AuthScheme;

use super::{PrimitiveProviderKind, PrimitiveRequest, ProviderPrimitiveWireFormat};

pub(super) fn default_headers_for(provider: PrimitiveProviderKind) -> BTreeMap<String, String> {
    let mut headers = BTreeMap::new();
    if matches!(provider, PrimitiveProviderKind::Anthropic) {
        headers.insert("anthropic-version".into(), "2023-06-01".into());
    }
    headers
}

pub(super) fn default_auth_for(provider: PrimitiveProviderKind) -> AuthScheme {
    match provider {
        PrimitiveProviderKind::OpenAi
        | PrimitiveProviderKind::AzureOpenAi
        | PrimitiveProviderKind::OpenAiCompatible
        | PrimitiveProviderKind::Bedrock
        | PrimitiveProviderKind::Custom => AuthScheme::Bearer,
        PrimitiveProviderKind::Anthropic => AuthScheme::Header {
            name: "x-api-key".into(),
        },
        PrimitiveProviderKind::Gemini | PrimitiveProviderKind::VertexAi => AuthScheme::Header {
            name: "x-goog-api-key".into(),
        },
    }
}

pub(super) fn default_path(request: &PrimitiveRequest) -> Option<String> {
    match request.wire_format {
        ProviderPrimitiveWireFormat::OpenAiResponses => Some("/responses".into()),
        ProviderPrimitiveWireFormat::OpenAiChatCompletions
        | ProviderPrimitiveWireFormat::OpenAiCompatibleChatCompletions => {
            Some("/chat/completions".into())
        }
        ProviderPrimitiveWireFormat::OpenAiImages => Some("/images/generations".into()),
        ProviderPrimitiveWireFormat::OpenAiImageEdits => Some("/images/edits".into()),
        ProviderPrimitiveWireFormat::OpenAiImageVariations => Some("/images/variations".into()),
        ProviderPrimitiveWireFormat::OpenAiRealtime => Some("/realtime/sessions".into()),
        ProviderPrimitiveWireFormat::OpenAiAudioTranscriptions => {
            Some("/audio/transcriptions".into())
        }
        ProviderPrimitiveWireFormat::OpenAiAudioTranslations => Some("/audio/translations".into()),
        ProviderPrimitiveWireFormat::OpenAiAudioSpeech => Some("/audio/speech".into()),
        ProviderPrimitiveWireFormat::OpenAiEmbeddings => Some("/embeddings".into()),
        ProviderPrimitiveWireFormat::OpenAiFiles => Some("/files".into()),
        ProviderPrimitiveWireFormat::OpenAiUploads => Some("/uploads".into()),
        ProviderPrimitiveWireFormat::OpenAiModels => Some("/models".into()),
        ProviderPrimitiveWireFormat::OpenAiBatches => Some("/batches".into()),
        ProviderPrimitiveWireFormat::AnthropicMessages => Some("/messages".into()),
        ProviderPrimitiveWireFormat::AnthropicCountTokens => Some("/messages/count_tokens".into()),
        ProviderPrimitiveWireFormat::AnthropicMessageBatches => Some("/messages/batches".into()),
        ProviderPrimitiveWireFormat::AnthropicFiles => Some("/files".into()),
        ProviderPrimitiveWireFormat::AnthropicModels => Some("/models".into()),
        ProviderPrimitiveWireFormat::GeminiGenerateContent => {
            model_path(request, "generateContent")
        }
        ProviderPrimitiveWireFormat::GeminiStreamGenerateContent => {
            model_path(request, "streamGenerateContent")
        }
        ProviderPrimitiveWireFormat::GeminiCountTokens => model_path(request, "countTokens"),
        ProviderPrimitiveWireFormat::GeminiEmbedContent => model_path(request, "embedContent"),
        ProviderPrimitiveWireFormat::GeminiLive => None,
        ProviderPrimitiveWireFormat::GeminiFiles => Some("/files".into()),
        ProviderPrimitiveWireFormat::GeminiCaches => Some("/cachedContents".into()),
        ProviderPrimitiveWireFormat::GeminiModels => Some("/models".into()),
        ProviderPrimitiveWireFormat::GeminiOperations => Some("/operations".into()),
        ProviderPrimitiveWireFormat::GeminiBatches => Some("/batches".into()),
        ProviderPrimitiveWireFormat::BedrockConverse
        | ProviderPrimitiveWireFormat::BedrockInvokeModel
        | ProviderPrimitiveWireFormat::CustomHttp => None,
    }
}

fn model_path(request: &PrimitiveRequest, action: &str) -> Option<String> {
    let model = request.model.as_ref()?;
    Some(format!("/models/{model}:{action}"))
}

pub(super) fn estimate_body_tokens(body: &RequestBody) -> u32 {
    let chars = match body {
        RequestBody::Json { value } => value.to_string().len(),
        RequestBody::Multipart { fields } => fields
            .iter()
            .map(|field| match &field.value {
                crate::api::MultipartValue::Text { value } => value.len(),
                crate::api::MultipartValue::File { data_base64, .. } => data_base64.len() / 8,
            })
            .sum(),
        RequestBody::Text { text } => text.len(),
        RequestBody::Binary { data_base64, .. } => data_base64.len() / 8,
    };
    ((chars / 4).max(1)) as u32
}

pub(super) fn known_output_token_limit(value: &Value) -> Option<u32> {
    value
        .get("max_output_tokens")
        .or_else(|| value.get("max_tokens"))
        .or_else(|| value.get("maxOutputTokens"))
        .or_else(|| value.pointer("/generationConfig/maxOutputTokens"))
        .and_then(Value::as_u64)
        .and_then(|value| u32::try_from(value).ok())
}