omnillm 0.1.5

Production-grade LLM API gateway with multi-key load balancing, per-key rate limiting, circuit breaking, and cost tracking
Documentation
use serde_json::{json, Map, Value};

use crate::api::{
    EmbeddingInput, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage, EmbeddingVector,
};

use super::common::*;
use super::ApiProtocolError;

pub(super) fn emit_openai_embeddings_request(
    request: &EmbeddingRequest,
) -> Result<Value, ApiProtocolError> {
    let mut map = Map::new();
    map.insert("model".into(), Value::String(request.model.clone()));
    map.insert("input".into(), embedding_inputs_to_value(&request.input));
    if let Some(dimensions) = request.dimensions {
        map.insert("dimensions".into(), Value::from(dimensions));
    }
    if let Some(format) = &request.encoding_format {
        map.insert("encoding_format".into(), Value::String(format.clone()));
    }
    if let Some(user) = &request.user {
        map.insert("user".into(), Value::String(user.clone()));
    }
    extend_with_vendor_extensions(&mut map, &request.vendor_extensions);
    Ok(Value::Object(map))
}

pub(super) fn parse_openai_embeddings_request(
    body: &Value,
) -> Result<EmbeddingRequest, ApiProtocolError> {
    let model = required_str(body, "model")?.to_string();
    let input = parse_embedding_inputs(
        body.get("input")
            .ok_or_else(|| ApiProtocolError::MissingField("input".into()))?,
    )?;

    Ok(EmbeddingRequest {
        model,
        input,
        dimensions: body
            .get("dimensions")
            .and_then(Value::as_u64)
            .map(|value| value as u32),
        encoding_format: body
            .get("encoding_format")
            .and_then(Value::as_str)
            .map(str::to_owned),
        user: body.get("user").and_then(Value::as_str).map(str::to_owned),
        vendor_extensions: collect_vendor_extensions(
            body,
            &["model", "input", "dimensions", "encoding_format", "user"],
        ),
    })
}

pub(super) fn emit_openai_embeddings_response(response: &EmbeddingResponse) -> Value {
    let mut map = Map::new();
    map.insert("object".into(), Value::String("list".into()));
    map.insert("model".into(), Value::String(response.model.clone()));
    map.insert(
        "data".into(),
        Value::Array(
            response
                .data
                .iter()
                .map(|vector| {
                    json!({
                        "object": "embedding",
                        "index": vector.index,
                        "embedding": vector.embedding,
                    })
                })
                .collect(),
        ),
    );
    if let Some(usage) = &response.usage {
        let mut usage_map = Map::new();
        usage_map.insert("prompt_tokens".into(), Value::from(usage.prompt_tokens));
        if let Some(total_tokens) = usage.total_tokens {
            usage_map.insert("total_tokens".into(), Value::from(total_tokens));
        }
        map.insert("usage".into(), Value::Object(usage_map));
    }
    extend_with_vendor_extensions(&mut map, &response.vendor_extensions);
    Value::Object(map)
}

pub(super) fn parse_openai_embeddings_response(
    body: &Value,
) -> Result<EmbeddingResponse, ApiProtocolError> {
    let data = body
        .get("data")
        .and_then(Value::as_array)
        .ok_or_else(|| ApiProtocolError::MissingField("data".into()))?
        .iter()
        .map(|item| {
            Ok(EmbeddingVector {
                index: item.get("index").and_then(Value::as_u64).unwrap_or(0) as usize,
                embedding: item
                    .get("embedding")
                    .and_then(Value::as_array)
                    .ok_or_else(|| ApiProtocolError::MissingField("embedding".into()))?
                    .iter()
                    .map(|value| {
                        value.as_f64().map(|number| number as f32).ok_or_else(|| {
                            ApiProtocolError::InvalidShape("embedding vector entry".into())
                        })
                    })
                    .collect::<Result<Vec<_>, _>>()?,
            })
        })
        .collect::<Result<Vec<_>, ApiProtocolError>>()?;

    let usage = body.get("usage").map(|usage| EmbeddingUsage {
        prompt_tokens: usage
            .get("prompt_tokens")
            .and_then(Value::as_u64)
            .unwrap_or(0) as u32,
        total_tokens: usage
            .get("total_tokens")
            .and_then(Value::as_u64)
            .map(|value| value as u32),
    });

    Ok(EmbeddingResponse {
        model: body
            .get("model")
            .and_then(Value::as_str)
            .unwrap_or_default()
            .to_string(),
        data,
        usage,
        vendor_extensions: collect_vendor_extensions(body, &["object", "model", "data", "usage"]),
    })
}

fn embedding_inputs_to_value(inputs: &[EmbeddingInput]) -> Value {
    match inputs {
        [EmbeddingInput::Text { text }] => Value::String(text.clone()),
        [EmbeddingInput::Tokens { tokens }] => {
            Value::Array(tokens.iter().copied().map(Value::from).collect::<Vec<_>>())
        }
        many => Value::Array(
            many.iter()
                .map(|item| match item {
                    EmbeddingInput::Text { text } => Value::String(text.clone()),
                    EmbeddingInput::Tokens { tokens } => {
                        Value::Array(tokens.iter().copied().map(Value::from).collect::<Vec<_>>())
                    }
                })
                .collect(),
        ),
    }
}

fn parse_embedding_inputs(value: &Value) -> Result<Vec<EmbeddingInput>, ApiProtocolError> {
    match value {
        Value::String(text) => Ok(vec![EmbeddingInput::Text { text: text.clone() }]),
        Value::Array(items) => {
            if items.is_empty() {
                return Ok(Vec::new());
            }
            if items.iter().all(Value::is_number) {
                return Ok(vec![EmbeddingInput::Tokens {
                    tokens: items
                        .iter()
                        .map(value_as_i32)
                        .collect::<Result<Vec<_>, _>>()?,
                }]);
            }
            items
                .iter()
                .map(|item| match item {
                    Value::String(text) => Ok(EmbeddingInput::Text { text: text.clone() }),
                    Value::Array(tokens) if tokens.iter().all(Value::is_number) => {
                        Ok(EmbeddingInput::Tokens {
                            tokens: tokens
                                .iter()
                                .map(value_as_i32)
                                .collect::<Result<Vec<_>, _>>()?,
                        })
                    }
                    _ => Err(ApiProtocolError::InvalidShape(
                        "embedding input item".into(),
                    )),
                })
                .collect()
        }
        _ => Err(ApiProtocolError::InvalidShape("embedding input".into())),
    }
}