gproxy-protocol 1.0.20

Wire-format types and cross-protocol transforms for Claude, OpenAI, and Gemini LLM APIs.
Documentation
use std::collections::BTreeMap;

use crate::gemini::generate_content::request::GeminiGenerateContentRequest;
use crate::gemini::generate_content::response::ResponseBody as GeminiGenerateContentResponseBody;
use crate::gemini::generate_content::types::GeminiCandidate;
use crate::gemini::stream_generate_content::request::GeminiStreamGenerateContentRequest;
use crate::transform::utils::TransformError;

fn merge_candidate(target: &mut GeminiCandidate, incoming: GeminiCandidate, index: u32) {
    target.index = Some(index);

    if let Some(content) = incoming.content {
        if let Some(target_content) = target.content.as_mut() {
            if target_content.role.is_none() {
                target_content.role = content.role;
            }
            target_content.parts.extend(content.parts);
        } else {
            target.content = Some(content);
        }
    }

    if incoming.finish_reason.is_some() {
        target.finish_reason = incoming.finish_reason;
    }
    if incoming.safety_ratings.is_some() {
        target.safety_ratings = incoming.safety_ratings;
    }
    if incoming.citation_metadata.is_some() {
        target.citation_metadata = incoming.citation_metadata;
    }
    if incoming.token_count.is_some() {
        target.token_count = incoming.token_count;
    }
    if let Some(grounding_attributions) = incoming.grounding_attributions {
        if let Some(existing) = target.grounding_attributions.as_mut() {
            existing.extend(grounding_attributions);
        } else {
            target.grounding_attributions = Some(grounding_attributions);
        }
    }
    if incoming.grounding_metadata.is_some() {
        target.grounding_metadata = incoming.grounding_metadata;
    }
    if incoming.avg_logprobs.is_some() {
        target.avg_logprobs = incoming.avg_logprobs;
    }
    if incoming.logprobs_result.is_some() {
        target.logprobs_result = incoming.logprobs_result;
    }
    if incoming.url_context_metadata.is_some() {
        target.url_context_metadata = incoming.url_context_metadata;
    }
    if incoming.finish_message.is_some() {
        target.finish_message = incoming.finish_message;
    }
}

pub fn merge_chunk(
    merged: &mut GeminiGenerateContentResponseBody,
    candidate_map: &mut BTreeMap<u32, GeminiCandidate>,
    chunk: GeminiGenerateContentResponseBody,
) {
    if let Some(candidates) = chunk.candidates {
        for (pos, candidate) in candidates.into_iter().enumerate() {
            let index = candidate.index.unwrap_or(pos as u32);
            let entry = candidate_map
                .entry(index)
                .or_insert_with(|| GeminiCandidate {
                    index: Some(index),
                    ..GeminiCandidate::default()
                });
            merge_candidate(entry, candidate, index);
        }
    }

    if chunk.prompt_feedback.is_some() {
        merged.prompt_feedback = chunk.prompt_feedback;
    }
    if chunk.usage_metadata.is_some() {
        merged.usage_metadata = chunk.usage_metadata;
    }
    if chunk.model_version.is_some() {
        merged.model_version = chunk.model_version;
    }
    if chunk.response_id.is_some() {
        merged.response_id = chunk.response_id;
    }
    if chunk.model_status.is_some() {
        merged.model_status = chunk.model_status;
    }
}

pub fn finalize_body(
    mut merged: GeminiGenerateContentResponseBody,
    candidate_map: BTreeMap<u32, GeminiCandidate>,
) -> GeminiGenerateContentResponseBody {
    if candidate_map.is_empty() {
        merged.candidates = None;
    } else {
        merged.candidates = Some(candidate_map.into_values().collect());
    }
    merged
}

impl TryFrom<&GeminiStreamGenerateContentRequest> for GeminiGenerateContentRequest {
    type Error = TransformError;

    fn try_from(value: &GeminiStreamGenerateContentRequest) -> Result<Self, TransformError> {
        Ok(GeminiGenerateContentRequest {
            method: value.method,
            path: crate::gemini::generate_content::request::PathParameters {
                model: value.path.model.clone(),
            },
            query: crate::gemini::generate_content::request::QueryParameters::default(),
            headers: crate::gemini::generate_content::request::RequestHeaders::default(),
            body: value.body.clone(),
        })
    }
}

impl TryFrom<GeminiStreamGenerateContentRequest> for GeminiGenerateContentRequest {
    type Error = TransformError;

    fn try_from(value: GeminiStreamGenerateContentRequest) -> Result<Self, TransformError> {
        GeminiGenerateContentRequest::try_from(&value)
    }
}