omnillm 0.1.5

Production-grade LLM API gateway with multi-key load balancing, per-key rate limiting, circuit breaking, and cost tracking
Documentation
use crate::api::{
    ApiRequest, ApiResponse, ConversionReport, EndpointKind, TransportRequest, WireFormat,
};
use crate::protocol::ProviderProtocol;
use crate::types::{BuiltinTool, CacheBreakpoint, LlmRequest, MessageRole, PromptCachePolicy};

use super::common::*;
use super::ApiProtocolError;

pub(super) fn generation_protocol(
    wire_format: WireFormat,
) -> Result<ProviderProtocol, ApiProtocolError> {
    ProviderProtocol::try_from(wire_format).map_err(|_| ApiProtocolError::UnsupportedWireFormat {
        endpoint: EndpointKind::Responses,
        wire_format,
    })
}

pub(super) fn ensure_matching_endpoint(
    wire_format: WireFormat,
    endpoint: EndpointKind,
) -> Result<(), ApiProtocolError> {
    let expected = wire_format.canonical_endpoint_kind();
    if expected == endpoint {
        Ok(())
    } else {
        Err(ApiProtocolError::EndpointMismatch {
            expected,
            actual: endpoint,
        })
    }
}

pub(super) fn generation_request_report(
    wire_format: WireFormat,
    value: ApiRequest,
    loss_reasons: Vec<String>,
) -> ConversionReport<ApiRequest> {
    if wire_format == WireFormat::OpenAiResponses && loss_reasons.is_empty() {
        ConversionReport::native(value, EndpointKind::Responses, wire_format)
    } else {
        ConversionReport::bridged(value, EndpointKind::Responses, wire_format, loss_reasons)
    }
}

pub(super) fn generation_response_report(
    wire_format: WireFormat,
    value: ApiResponse,
    loss_reasons: Vec<String>,
) -> ConversionReport<ApiResponse> {
    if wire_format == WireFormat::OpenAiResponses && loss_reasons.is_empty() {
        ConversionReport::native(value, EndpointKind::Responses, wire_format)
    } else {
        ConversionReport::bridged(value, EndpointKind::Responses, wire_format, loss_reasons)
    }
}

pub(super) fn generation_string_report(
    wire_format: WireFormat,
    value: String,
    loss_reasons: Vec<String>,
) -> ConversionReport<String> {
    if wire_format == WireFormat::OpenAiResponses && loss_reasons.is_empty() {
        ConversionReport::native(value, EndpointKind::Responses, wire_format)
    } else {
        ConversionReport::bridged(value, EndpointKind::Responses, wire_format, loss_reasons)
    }
}

pub(super) fn generation_transport_report(
    wire_format: WireFormat,
    value: TransportRequest,
    loss_reasons: Vec<String>,
) -> ConversionReport<TransportRequest> {
    if wire_format == WireFormat::OpenAiResponses && loss_reasons.is_empty() {
        ConversionReport::native(value, EndpointKind::Responses, wire_format)
    } else {
        ConversionReport::bridged(value, EndpointKind::Responses, wire_format, loss_reasons)
    }
}

pub(super) fn sanitize_generation_request(
    wire_format: WireFormat,
    request: &LlmRequest,
) -> Result<(LlmRequest, Vec<String>), ApiProtocolError> {
    let mut sanitized = request.clone();
    let mut loss_reasons = Vec::new();

    match wire_format {
        WireFormat::OpenAiResponses => {}
        WireFormat::OpenAiChatCompletions => {
            if !sanitized.capabilities.builtin_tools.is_empty() {
                sanitized.capabilities.builtin_tools.clear();
                loss_reasons.push(
                    "builtin tools are dropped when bridging to open_ai_chat_completions".into(),
                );
            }
            if sanitized.capabilities.reasoning.take().is_some() {
                loss_reasons.push(
                    "reasoning settings are dropped when bridging to open_ai_chat_completions"
                        .into(),
                );
            }
        }
        WireFormat::AnthropicMessages => {
            if !sanitized.capabilities.builtin_tools.is_empty() {
                sanitized.capabilities.builtin_tools.clear();
                loss_reasons
                    .push("builtin tools are dropped when bridging to anthropic_messages".into());
            }
            if sanitized.capabilities.structured_output.take().is_some() {
                loss_reasons.push(
                    "structured output is dropped when bridging to anthropic_messages".into(),
                );
            }
            if sanitized.capabilities.reasoning.take().is_some() {
                loss_reasons.push(
                    "reasoning settings are dropped when bridging to anthropic_messages".into(),
                );
            }
            if sanitized.generation.top_k.take().is_some() {
                loss_reasons.push("top_k is dropped when bridging to anthropic_messages".into());
            }
            if sanitized.generation.presence_penalty.take().is_some() {
                loss_reasons
                    .push("presence_penalty is dropped when bridging to anthropic_messages".into());
            }
            if sanitized.generation.frequency_penalty.take().is_some() {
                loss_reasons.push(
                    "frequency_penalty is dropped when bridging to anthropic_messages".into(),
                );
            }
            if sanitized.generation.seed.take().is_some() {
                loss_reasons.push("seed is dropped when bridging to anthropic_messages".into());
            }
        }
        WireFormat::GeminiGenerateContent => {
            let before = sanitized.capabilities.builtin_tools.len();
            sanitized
                .capabilities
                .builtin_tools
                .retain(|tool| matches!(tool, BuiltinTool::CodeExecution));
            if sanitized.capabilities.builtin_tools.len() != before {
                loss_reasons.push(
                    "only code_execution builtin tools are preserved for gemini_generate_content"
                        .into(),
                );
            }
            if sanitized.capabilities.reasoning.take().is_some() {
                loss_reasons.push(
                    "reasoning settings are dropped when bridging to gemini_generate_content"
                        .into(),
                );
            }
            if sanitized.generation.presence_penalty.take().is_some() {
                loss_reasons.push(
                    "presence_penalty is dropped when bridging to gemini_generate_content".into(),
                );
            }
            if sanitized.generation.frequency_penalty.take().is_some() {
                loss_reasons.push(
                    "frequency_penalty is dropped when bridging to gemini_generate_content".into(),
                );
            }
        }
        _ => {}
    }

    if wire_format != WireFormat::OpenAiResponses && !sanitized.metadata.is_empty() {
        sanitized.metadata.clear();
        loss_reasons.push(format!(
            "metadata is dropped when bridging to {}",
            wire_format_name(wire_format)
        ));
    }

    sanitize_prompt_cache_policy(wire_format, &mut sanitized, &mut loss_reasons)?;

    if request_has_unemitted_vendor_extensions(wire_format, request) {
        loss_reasons.push(format!(
            "some vendor_extensions and raw_message fields are not emitted to {}",
            wire_format_name(wire_format)
        ));
    }

    Ok((sanitized, dedupe_loss_reasons(loss_reasons)))
}

pub(super) fn sanitize_prompt_cache_policy(
    wire_format: WireFormat,
    request: &mut LlmRequest,
    loss_reasons: &mut Vec<String>,
) -> Result<(), ApiProtocolError> {
    let Some(policy) = request.capabilities.effective_prompt_cache() else {
        return Ok(());
    };
    if policy.is_disabled() {
        return Ok(());
    }

    match wire_format {
        WireFormat::OpenAiResponses | WireFormat::OpenAiChatCompletions => {
            if !policy.breakpoint().is_auto() {
                if policy.is_required() {
                    return unsupported_prompt_cache(
                        wire_format,
                        "OpenAI prompt cache does not support explicit breakpoints",
                    );
                }
                loss_reasons.push(format!(
                    "prompt cache breakpoint is not emitted when bridging to {}",
                    wire_format_name(wire_format)
                ));
            }
        }
        WireFormat::AnthropicMessages => {
            if policy.key().is_some() {
                if policy.is_required() {
                    return unsupported_prompt_cache(
                        wire_format,
                        "Claude prompt cache does not support explicit cache keys",
                    );
                }
                loss_reasons
                    .push("prompt cache key is dropped when bridging to anthropic_messages".into());
            }
            if !claude_prompt_cache_placement_available(&policy, request) {
                if policy.is_required() {
                    return unsupported_prompt_cache(
                        wire_format,
                        "Claude prompt cache breakpoint cannot be represented for this request",
                    );
                }
                clear_prompt_cache_policy(request);
                loss_reasons.push(
                    "prompt cache policy is dropped when bridging to anthropic_messages".into(),
                );
            }
        }
        WireFormat::GeminiGenerateContent => {
            if policy.is_required() {
                return unsupported_prompt_cache(
                    wire_format,
                    "prompt cache is not supported by gemini_generate_content",
                );
            }
            clear_prompt_cache_policy(request);
            loss_reasons.push(
                "prompt cache policy is dropped when bridging to gemini_generate_content".into(),
            );
        }
        _ => {
            if policy.is_required() {
                return unsupported_prompt_cache(
                    wire_format,
                    "prompt cache is only supported for generation wire formats",
                );
            }
            clear_prompt_cache_policy(request);
            loss_reasons.push(format!(
                "prompt cache policy is dropped when bridging to {}",
                wire_format_name(wire_format)
            ));
        }
    }

    Ok(())
}

pub(super) fn unsupported_prompt_cache<T>(
    wire_format: WireFormat,
    message: impl Into<String>,
) -> Result<T, ApiProtocolError> {
    Err(ApiProtocolError::UnsupportedFeature {
        wire_format,
        message: message.into(),
    })
}

pub(super) fn clear_prompt_cache_policy(request: &mut LlmRequest) {
    request.capabilities.prompt_cache = None;
    request.capabilities.cache = None;
}

pub(super) fn claude_prompt_cache_placement_available(
    policy: &PromptCachePolicy,
    request: &LlmRequest,
) -> bool {
    match policy.breakpoint() {
        CacheBreakpoint::Auto => {
            !request.capabilities.tools.is_empty() || request.normalized_instructions().is_some()
        }
        CacheBreakpoint::EndOfTools => !request.capabilities.tools.is_empty(),
        CacheBreakpoint::EndOfInstructions => request.normalized_instructions().is_some(),
        CacheBreakpoint::EndOfMessage { index } => request
            .normalized_messages()
            .into_iter()
            .filter(|message| !matches!(message.role, MessageRole::System | MessageRole::Developer))
            .nth(index)
            .is_some_and(|message| !message.parts.is_empty()),
        CacheBreakpoint::EndOfContentBlock {
            message_index,
            part_index,
        } => request
            .normalized_messages()
            .into_iter()
            .filter(|message| !matches!(message.role, MessageRole::System | MessageRole::Developer))
            .nth(message_index)
            .is_some_and(|message| part_index < message.parts.len()),
    }
}

pub(super) fn request_has_unemitted_vendor_extensions(
    wire_format: WireFormat,
    request: &LlmRequest,
) -> bool {
    let top_level_request_vendor_extensions_are_emitted = matches!(
        wire_format,
        WireFormat::OpenAiResponses | WireFormat::OpenAiChatCompletions
    );

    if (!top_level_request_vendor_extensions_are_emitted && !request.vendor_extensions.is_empty())
        || !request.capabilities.vendor_extensions.is_empty()
        || !request.generation.vendor_extensions.is_empty()
    {
        return true;
    }

    if request.normalized_input().iter().any(|item| match item {
        crate::types::RequestItem::Message { message } => {
            message.raw_message.is_some() || !message.vendor_extensions.is_empty()
        }
        crate::types::RequestItem::ToolResult { .. } => false,
    }) {
        return true;
    }

    request
        .capabilities
        .reasoning
        .as_ref()
        .is_some_and(|reasoning| !reasoning.vendor_extensions.is_empty())
        || request
            .capabilities
            .tools
            .iter()
            .any(|tool| !tool.vendor_extensions.is_empty())
}

pub(super) fn wire_path(wire_format: WireFormat, model: &str) -> String {
    match wire_format {
        WireFormat::OpenAiResponses => "/responses".into(),
        WireFormat::OpenAiChatCompletions => "/chat/completions".into(),
        WireFormat::AnthropicMessages => "/messages".into(),
        WireFormat::GeminiGenerateContent => format!("/models/{model}:generateContent"),
        WireFormat::OpenAiEmbeddings => "/embeddings".into(),
        WireFormat::OpenAiImageGenerations => "/images/generations".into(),
        WireFormat::OpenAiAudioTranscriptions => "/audio/transcriptions".into(),
        WireFormat::OpenAiAudioSpeech => "/audio/speech".into(),
        WireFormat::OpenAiRerank => "/rerank".into(),
    }
}