omnillm 0.1.5

Production-grade LLM API gateway with multi-key load balancing, per-key rate limiting, circuit breaking, and cost tracking
Documentation
use std::collections::BTreeMap;
use std::time::Duration;

use serde_json::Value;

use crate::api::ResponseBody;
use crate::types::{PromptCacheUsage, TokenUsage};

use super::{
    PrimitiveAsyncJobStatus, PrimitiveProviderError, PrimitiveProviderKind, PrimitiveResponse,
    PrimitiveUsageTelemetry, ProviderPrimitiveWireFormat,
};

pub(crate) fn extract_usage(
    wire_format: ProviderPrimitiveWireFormat,
    body: &ResponseBody,
) -> Option<PrimitiveUsageTelemetry> {
    let ResponseBody::Json { value } = body else {
        return None;
    };

    let usage = match wire_format {
        ProviderPrimitiveWireFormat::OpenAiResponses
        | ProviderPrimitiveWireFormat::OpenAiChatCompletions
        | ProviderPrimitiveWireFormat::OpenAiCompatibleChatCompletions => value.get("usage"),
        ProviderPrimitiveWireFormat::OpenAiRealtime => value
            .get("usage")
            .or_else(|| value.pointer("/response/usage")),
        ProviderPrimitiveWireFormat::AnthropicMessages => value.get("usage"),
        ProviderPrimitiveWireFormat::GeminiGenerateContent
        | ProviderPrimitiveWireFormat::GeminiStreamGenerateContent
        | ProviderPrimitiveWireFormat::GeminiLive => value
            .get("usageMetadata")
            .or_else(|| value.pointer("/serverContent/usageMetadata")),
        _ => value.get("usage"),
    }?;

    let token_usage = token_usage_from_raw(wire_format, usage);
    Some(PrimitiveUsageTelemetry {
        raw_usage: usage.clone(),
        token_usage,
        billable_units: Vec::new(),
        vendor_extensions: BTreeMap::new(),
    })
}

pub(crate) fn primitive_error_from_body(
    provider: PrimitiveProviderKind,
    wire_format: ProviderPrimitiveWireFormat,
    status: Option<u16>,
    retry_after: Option<Duration>,
    raw_body: String,
) -> PrimitiveProviderError {
    let parsed = serde_json::from_str::<Value>(&raw_body).ok();
    let code = parsed.as_ref().and_then(extract_error_code);
    let message = parsed
        .as_ref()
        .and_then(extract_error_message)
        .filter(|message| !message.is_empty())
        .unwrap_or_else(|| raw_body.clone());

    PrimitiveProviderError {
        provider,
        wire_format,
        status,
        code,
        message,
        retry_after,
        raw_body: Some(raw_body),
        vendor_extensions: BTreeMap::new(),
    }
}

pub(crate) fn extract_async_job_id(response: &PrimitiveResponse) -> Option<String> {
    let ResponseBody::Json { value } = &response.body else {
        return None;
    };
    value
        .get("id")
        .or_else(|| value.get("name"))
        .or_else(|| value.get("batch_id"))
        .and_then(Value::as_str)
        .map(str::to_string)
}

pub(crate) fn extract_async_job_status(response: &PrimitiveResponse) -> PrimitiveAsyncJobStatus {
    let ResponseBody::Json { value } = &response.body else {
        return PrimitiveAsyncJobStatus::Unknown;
    };
    let status = value
        .get("status")
        .or_else(|| value.get("state"))
        .and_then(Value::as_str)
        .unwrap_or_default()
        .to_ascii_lowercase();
    match status.as_str() {
        "pending" | "queued" | "validating" => PrimitiveAsyncJobStatus::Pending,
        "running" | "in_progress" | "processing" => PrimitiveAsyncJobStatus::Running,
        "succeeded" | "completed" | "ended" | "done" => PrimitiveAsyncJobStatus::Succeeded,
        "failed" | "errored" | "expired" => PrimitiveAsyncJobStatus::Failed,
        "cancelled" | "canceled" | "cancelling" | "canceling" => PrimitiveAsyncJobStatus::Cancelled,
        _ => PrimitiveAsyncJobStatus::Unknown,
    }
}

fn token_usage_from_raw(
    wire_format: ProviderPrimitiveWireFormat,
    usage: &Value,
) -> Option<TokenUsage> {
    match wire_format {
        ProviderPrimitiveWireFormat::OpenAiResponses => Some(TokenUsage {
            prompt_tokens: usage_u32(usage, &["input_tokens"]),
            completion_tokens: usage_u32(usage, &["output_tokens"]),
            total_tokens: usage_u32_opt(usage, &["total_tokens"]),
            prompt_cache: openai_prompt_cache_usage(usage),
        }),
        ProviderPrimitiveWireFormat::OpenAiChatCompletions
        | ProviderPrimitiveWireFormat::OpenAiCompatibleChatCompletions => Some(TokenUsage {
            prompt_tokens: usage_u32(usage, &["prompt_tokens"]),
            completion_tokens: usage_u32(usage, &["completion_tokens"]),
            total_tokens: usage_u32_opt(usage, &["total_tokens"]),
            prompt_cache: openai_prompt_cache_usage(usage),
        }),
        ProviderPrimitiveWireFormat::AnthropicMessages => Some(TokenUsage {
            prompt_tokens: usage_u32(usage, &["input_tokens"]),
            completion_tokens: usage_u32(usage, &["output_tokens"]),
            total_tokens: None,
            prompt_cache: anthropic_prompt_cache_usage(usage),
        }),
        ProviderPrimitiveWireFormat::GeminiGenerateContent
        | ProviderPrimitiveWireFormat::GeminiStreamGenerateContent => Some(TokenUsage {
            prompt_tokens: usage_u32(usage, &["promptTokenCount"]),
            completion_tokens: usage_u32(usage, &["candidatesTokenCount"]),
            total_tokens: usage_u32_opt(usage, &["totalTokenCount"]),
            prompt_cache: None,
        }),
        _ => generic_token_usage(usage),
    }
}

fn generic_token_usage(usage: &Value) -> Option<TokenUsage> {
    let prompt_tokens = usage_u32(
        usage,
        &[
            "input_tokens",
            "prompt_tokens",
            "promptTokenCount",
            "total_tokens",
        ],
    );
    let completion_tokens = usage_u32(
        usage,
        &["output_tokens", "completion_tokens", "candidatesTokenCount"],
    );
    if prompt_tokens == 0 && completion_tokens == 0 {
        return None;
    }
    Some(TokenUsage {
        prompt_tokens,
        completion_tokens,
        total_tokens: usage_u32_opt(usage, &["total_tokens", "totalTokenCount"]),
        prompt_cache: None,
    })
}

fn openai_prompt_cache_usage(usage: &Value) -> Option<PromptCacheUsage> {
    let cached_input_tokens = usage
        .pointer("/input_tokens_details/cached_tokens")
        .or_else(|| usage.pointer("/prompt_tokens_details/cached_tokens"))
        .and_then(value_to_u32);
    cached_input_tokens.map(|cached_input_tokens| PromptCacheUsage {
        cached_input_tokens: Some(cached_input_tokens),
        ..Default::default()
    })
}

fn anthropic_prompt_cache_usage(usage: &Value) -> Option<PromptCacheUsage> {
    let prompt_cache = PromptCacheUsage {
        cache_read_input_tokens: usage.get("cache_read_input_tokens").and_then(value_to_u32),
        cache_creation_input_tokens: usage
            .get("cache_creation_input_tokens")
            .and_then(value_to_u32),
        cache_creation_short_input_tokens: usage
            .get("cache_creation_5m_input_tokens")
            .or_else(|| usage.pointer("/cache_creation/ephemeral_5m_input_tokens"))
            .and_then(value_to_u32),
        cache_creation_long_input_tokens: usage
            .get("cache_creation_1h_input_tokens")
            .or_else(|| usage.pointer("/cache_creation/ephemeral_1h_input_tokens"))
            .and_then(value_to_u32),
        ..Default::default()
    };
    if prompt_cache.cached_input_tokens.is_some()
        || prompt_cache.cache_read_input_tokens.is_some()
        || prompt_cache.cache_creation_input_tokens.is_some()
        || prompt_cache.cache_creation_short_input_tokens.is_some()
        || prompt_cache.cache_creation_long_input_tokens.is_some()
    {
        Some(prompt_cache)
    } else {
        None
    }
}

fn usage_u32(usage: &Value, fields: &[&str]) -> u32 {
    usage_u32_opt(usage, fields).unwrap_or(0)
}

fn usage_u32_opt(usage: &Value, fields: &[&str]) -> Option<u32> {
    fields
        .iter()
        .find_map(|field| usage.get(*field).and_then(value_to_u32))
}

fn value_to_u32(value: &Value) -> Option<u32> {
    value.as_u64().and_then(|value| u32::try_from(value).ok())
}

fn extract_error_code(value: &Value) -> Option<String> {
    value
        .pointer("/error/code")
        .or_else(|| value.pointer("/error/type"))
        .or_else(|| value.get("code"))
        .or_else(|| value.get("type"))
        .and_then(|value| match value {
            Value::String(value) => Some(value.clone()),
            Value::Number(value) => Some(value.to_string()),
            _ => None,
        })
}

fn extract_error_message(value: &Value) -> Option<String> {
    value
        .pointer("/error/message")
        .or_else(|| value.get("message"))
        .and_then(Value::as_str)
        .map(str::to_string)
}