omnillm 0.1.5

Production-grade LLM API gateway with multi-key load balancing, per-key rate limiting, circuit breaking, and cost tracking
Documentation
use serde_json::{Map, Value};

use crate::api::{RerankDocument, RerankRequest, RerankResponse, RerankResult, RerankUsage};

use super::common::*;
use super::ApiProtocolError;

pub(super) fn emit_openai_rerank_request(request: &RerankRequest) -> Value {
    let mut map = Map::new();
    map.insert("model".into(), Value::String(request.model.clone()));
    map.insert("query".into(), Value::String(request.query.clone()));
    map.insert(
        "documents".into(),
        Value::Array(
            request
                .documents
                .iter()
                .map(|document| match document {
                    RerankDocument::Text { text } => Value::String(text.clone()),
                    RerankDocument::Json { value } => value.clone(),
                })
                .collect(),
        ),
    );
    if let Some(top_n) = request.top_n {
        map.insert("top_n".into(), Value::from(top_n));
    }
    if let Some(return_documents) = request.return_documents {
        map.insert("return_documents".into(), Value::Bool(return_documents));
    }
    extend_with_vendor_extensions(&mut map, &request.vendor_extensions);
    Value::Object(map)
}

pub(super) fn parse_openai_rerank_request(body: &Value) -> Result<RerankRequest, ApiProtocolError> {
    let documents = body
        .get("documents")
        .and_then(Value::as_array)
        .ok_or_else(|| ApiProtocolError::MissingField("documents".into()))?
        .iter()
        .map(|document| match document {
            Value::String(text) => RerankDocument::Text { text: text.clone() },
            value => RerankDocument::Json {
                value: value.clone(),
            },
        })
        .collect();

    Ok(RerankRequest {
        model: required_str(body, "model")?.to_string(),
        query: required_str(body, "query")?.to_string(),
        documents,
        top_n: body
            .get("top_n")
            .and_then(Value::as_u64)
            .map(|value| value as u32),
        return_documents: body.get("return_documents").and_then(Value::as_bool),
        vendor_extensions: collect_vendor_extensions(
            body,
            &["model", "query", "documents", "top_n", "return_documents"],
        ),
    })
}

pub(super) fn emit_openai_rerank_response(response: &RerankResponse) -> Value {
    let mut map = Map::new();
    map.insert("model".into(), Value::String(response.model.clone()));
    map.insert(
        "results".into(),
        Value::Array(
            response
                .results
                .iter()
                .map(|result| {
                    let mut result_map = Map::new();
                    result_map.insert("index".into(), Value::from(result.index));
                    result_map.insert(
                        "relevance_score".into(),
                        Value::from(result.relevance_score),
                    );
                    if let Some(document) = &result.document {
                        result_map.insert("document".into(), document.clone());
                    }
                    Value::Object(result_map)
                })
                .collect(),
        ),
    );
    if let Some(usage) = &response.usage {
        let mut usage_map = Map::new();
        if let Some(total_tokens) = usage.total_tokens {
            usage_map.insert("total_tokens".into(), Value::from(total_tokens));
        }
        map.insert("usage".into(), Value::Object(usage_map));
    }
    extend_with_vendor_extensions(&mut map, &response.vendor_extensions);
    Value::Object(map)
}

pub(super) fn parse_openai_rerank_response(
    body: &Value,
) -> Result<RerankResponse, ApiProtocolError> {
    let results = body
        .get("results")
        .and_then(Value::as_array)
        .ok_or_else(|| ApiProtocolError::MissingField("results".into()))?
        .iter()
        .map(|result| {
            Ok(RerankResult {
                index: result.get("index").and_then(Value::as_u64).unwrap_or(0) as u32,
                relevance_score: result
                    .get("relevance_score")
                    .and_then(Value::as_f64)
                    .unwrap_or_default() as f32,
                document: result.get("document").cloned(),
            })
        })
        .collect::<Result<Vec<_>, ApiProtocolError>>()?;

    let usage = body.get("usage").map(|usage| RerankUsage {
        total_tokens: usage
            .get("total_tokens")
            .and_then(Value::as_u64)
            .map(|value| value as u32),
    });

    Ok(RerankResponse {
        model: body
            .get("model")
            .and_then(Value::as_str)
            .unwrap_or_default()
            .to_string(),
        results,
        usage,
        vendor_extensions: collect_vendor_extensions(body, &["model", "results", "usage"]),
    })
}