omnillm 0.1.5

Production-grade LLM API gateway with multi-key load balancing, per-key rate limiting, circuit breaking, and cost tracking
Documentation
use serde_json::{Map, Value};

use crate::api::{
    AudioInput, AudioSegment, AudioSpeechRequest, AudioTranscriptionRequest,
    AudioTranscriptionResponse, HttpMethod, MultipartField, MultipartValue, RequestBody,
    TranscribedWord, TransportRequest, WireFormat,
};

use super::common::*;
use super::generation::wire_path;
use super::ApiProtocolError;

pub(super) fn emit_openai_audio_transcription_transport(
    request: &AudioTranscriptionRequest,
) -> Result<TransportRequest, ApiProtocolError> {
    let mut fields = vec![MultipartField {
        name: "model".into(),
        value: MultipartValue::Text {
            value: request.model.clone(),
        },
    }];

    match &request.audio {
        AudioInput::File {
            filename,
            data_base64,
            media_type,
        } => fields.push(MultipartField {
            name: "file".into(),
            value: MultipartValue::File {
                filename: filename.clone(),
                data_base64: data_base64.clone(),
                media_type: media_type.clone(),
            },
        }),
        AudioInput::Url { .. } => {
            return Err(ApiProtocolError::UnsupportedFeature {
                wire_format: WireFormat::OpenAiAudioTranscriptions,
                message: "audio URL inputs are not supported by multipart transcription requests"
                    .into(),
            })
        }
    }

    if let Some(prompt) = &request.prompt {
        fields.push(MultipartField {
            name: "prompt".into(),
            value: MultipartValue::Text {
                value: prompt.clone(),
            },
        });
    }
    if let Some(response_format) = &request.response_format {
        fields.push(MultipartField {
            name: "response_format".into(),
            value: MultipartValue::Text {
                value: response_format.clone(),
            },
        });
    }
    if let Some(language) = &request.language {
        fields.push(MultipartField {
            name: "language".into(),
            value: MultipartValue::Text {
                value: language.clone(),
            },
        });
    }
    if let Some(temperature) = request.temperature {
        fields.push(MultipartField {
            name: "temperature".into(),
            value: MultipartValue::Text {
                value: temperature.to_string(),
            },
        });
    }
    for granularity in &request.timestamp_granularities {
        fields.push(MultipartField {
            name: "timestamp_granularities[]".into(),
            value: MultipartValue::Text {
                value: granularity.clone(),
            },
        });
    }

    Ok(TransportRequest {
        method: HttpMethod::Post,
        path: wire_path(WireFormat::OpenAiAudioTranscriptions, &request.model),
        headers: Default::default(),
        accept: Some("application/json".into()),
        body: RequestBody::Multipart { fields },
    })
}

pub(super) fn emit_openai_audio_speech_request(request: &AudioSpeechRequest) -> Value {
    let mut map = Map::new();
    map.insert("model".into(), Value::String(request.model.clone()));
    map.insert("input".into(), Value::String(request.input.clone()));
    map.insert("voice".into(), Value::String(request.voice.clone()));
    if let Some(response_format) = &request.response_format {
        map.insert(
            "response_format".into(),
            Value::String(response_format.clone()),
        );
    }
    if let Some(speed) = request.speed {
        map.insert("speed".into(), Value::from(speed));
    }
    extend_with_vendor_extensions(&mut map, &request.vendor_extensions);
    Value::Object(map)
}

pub(super) fn parse_openai_audio_speech_request(
    body: &Value,
) -> Result<AudioSpeechRequest, ApiProtocolError> {
    Ok(AudioSpeechRequest {
        model: required_str(body, "model")?.to_string(),
        input: required_str(body, "input")?.to_string(),
        voice: required_str(body, "voice")?.to_string(),
        response_format: body
            .get("response_format")
            .and_then(Value::as_str)
            .map(str::to_owned),
        speed: body
            .get("speed")
            .and_then(Value::as_f64)
            .map(|value| value as f32),
        vendor_extensions: collect_vendor_extensions(
            body,
            &["model", "input", "voice", "response_format", "speed"],
        ),
    })
}

pub(super) fn emit_openai_audio_transcription_response(
    response: &AudioTranscriptionResponse,
) -> Value {
    let mut map = Map::new();
    map.insert("text".into(), Value::String(response.text.clone()));
    if let Some(language) = &response.language {
        map.insert("language".into(), Value::String(language.clone()));
    }
    if let Some(duration) = response.duration_seconds {
        map.insert("duration".into(), Value::from(duration));
    }
    if !response.segments.is_empty() {
        map.insert(
            "segments".into(),
            Value::Array(
                response
                    .segments
                    .iter()
                    .map(|segment| {
                        let mut segment_map = Map::new();
                        if let Some(id) = segment.id {
                            segment_map.insert("id".into(), Value::from(id));
                        }
                        if let Some(start) = segment.start {
                            segment_map.insert("start".into(), Value::from(start));
                        }
                        if let Some(end) = segment.end {
                            segment_map.insert("end".into(), Value::from(end));
                        }
                        segment_map.insert("text".into(), Value::String(segment.text.clone()));
                        Value::Object(segment_map)
                    })
                    .collect(),
            ),
        );
    }
    if !response.words.is_empty() {
        map.insert(
            "words".into(),
            Value::Array(
                response
                    .words
                    .iter()
                    .map(|word| {
                        let mut word_map = Map::new();
                        word_map.insert("word".into(), Value::String(word.word.clone()));
                        if let Some(start) = word.start {
                            word_map.insert("start".into(), Value::from(start));
                        }
                        if let Some(end) = word.end {
                            word_map.insert("end".into(), Value::from(end));
                        }
                        Value::Object(word_map)
                    })
                    .collect(),
            ),
        );
    }
    extend_with_vendor_extensions(&mut map, &response.vendor_extensions);
    Value::Object(map)
}

pub(super) fn parse_openai_audio_transcription_response(
    body: &Value,
) -> Result<AudioTranscriptionResponse, ApiProtocolError> {
    let segments = body
        .get("segments")
        .and_then(Value::as_array)
        .map(|segments| {
            segments
                .iter()
                .map(|segment| AudioSegment {
                    id: segment
                        .get("id")
                        .and_then(Value::as_u64)
                        .map(|value| value as u32),
                    start: segment
                        .get("start")
                        .and_then(Value::as_f64)
                        .map(|value| value as f32),
                    end: segment
                        .get("end")
                        .and_then(Value::as_f64)
                        .map(|value| value as f32),
                    text: segment
                        .get("text")
                        .and_then(Value::as_str)
                        .unwrap_or_default()
                        .to_string(),
                })
                .collect::<Vec<_>>()
        })
        .unwrap_or_default();
    let words = body
        .get("words")
        .and_then(Value::as_array)
        .map(|words| {
            words
                .iter()
                .map(|word| TranscribedWord {
                    word: word
                        .get("word")
                        .or_else(|| word.get("text"))
                        .and_then(Value::as_str)
                        .unwrap_or_default()
                        .to_string(),
                    start: word
                        .get("start")
                        .and_then(Value::as_f64)
                        .map(|value| value as f32),
                    end: word
                        .get("end")
                        .and_then(Value::as_f64)
                        .map(|value| value as f32),
                })
                .collect::<Vec<_>>()
        })
        .unwrap_or_default();

    Ok(AudioTranscriptionResponse {
        text: required_str(body, "text")?.to_string(),
        language: body
            .get("language")
            .and_then(Value::as_str)
            .map(str::to_owned),
        duration_seconds: body
            .get("duration")
            .or_else(|| body.get("duration_seconds"))
            .and_then(Value::as_f64)
            .map(|value| value as f32),
        segments,
        words,
        vendor_extensions: collect_vendor_extensions(
            body,
            &[
                "text",
                "language",
                "duration",
                "duration_seconds",
                "segments",
                "words",
            ],
        ),
    })
}