gproxy-protocol 1.0.20

Wire-format types and cross-protocol transforms for Claude, OpenAI, and Gemini LLM APIs.
Documentation
use std::collections::BTreeMap;

use crate::claude::create_message::stream::{BetaRawContentBlockDelta, ClaudeStreamEvent};
use crate::claude::create_message::types::{BetaContentBlock, BetaStopReason};
use crate::claude::types::BetaError;
use crate::gemini::count_tokens::types::{GeminiContentRole, GeminiFunctionCall, GeminiPart};
use crate::gemini::generate_content::response::ResponseBody as GeminiGenerateContentResponseBody;
use crate::gemini::generate_content::types::{
    GeminiBlockReason, GeminiCandidate, GeminiContent, GeminiFinishReason, GeminiPromptFeedback,
    GeminiUsageMetadata,
};
use crate::transform::claude::utils::claude_model_to_string;
use crate::transform::gemini::stream_generate_content::utils::parse_json_object_or_empty;
use crate::transform::utils::TransformError;

#[derive(Debug, Clone)]
enum ClaudeBlockState {
    Thinking {
        signature: String,
    },
    ToolUse {
        id: String,
        name: String,
        partial_json: String,
    },
    Other,
}

#[derive(Debug, Default, Clone)]
pub struct ClaudeToGeminiStream {
    response_id: Option<String>,
    model_version: Option<String>,
    input_tokens: u64,
    cache_creation_input_tokens: u64,
    cached_input_tokens: u64,
    output_tokens: u64,
    usage_metadata: Option<GeminiUsageMetadata>,
    blocks: BTreeMap<u64, ClaudeBlockState>,
    finished: bool,
}

impl ClaudeToGeminiStream {
    pub fn is_finished(&self) -> bool {
        self.finished
    }

    fn usage_from_counts(
        input_tokens: u64,
        cache_creation_tokens: u64,
        cached_tokens: u64,
        output_tokens: u64,
    ) -> GeminiUsageMetadata {
        let prompt_tokens = input_tokens.saturating_add(cache_creation_tokens);
        GeminiUsageMetadata {
            prompt_token_count: Some(prompt_tokens),
            cached_content_token_count: Some(cached_tokens),
            candidates_token_count: Some(output_tokens),
            total_token_count: Some(
                prompt_tokens
                    .saturating_add(cached_tokens)
                    .saturating_add(output_tokens),
            ),
            ..GeminiUsageMetadata::default()
        }
    }

    fn sync_usage_metadata(&mut self) {
        self.usage_metadata = Some(Self::usage_from_counts(
            self.input_tokens,
            self.cache_creation_input_tokens,
            self.cached_input_tokens,
            self.output_tokens,
        ));
    }

    fn finish_reason_from_stop_reason(stop_reason: Option<BetaStopReason>) -> GeminiFinishReason {
        match stop_reason {
            Some(BetaStopReason::MaxTokens) | Some(BetaStopReason::ModelContextWindowExceeded) => {
                GeminiFinishReason::MaxTokens
            }
            Some(BetaStopReason::ToolUse) => GeminiFinishReason::UnexpectedToolCall,
            Some(BetaStopReason::Refusal) => GeminiFinishReason::Safety,
            Some(BetaStopReason::Compaction) | Some(BetaStopReason::PauseTurn) => {
                GeminiFinishReason::Other
            }
            Some(BetaStopReason::EndTurn) | Some(BetaStopReason::StopSequence) | None => {
                GeminiFinishReason::Stop
            }
        }
    }

    fn error_message(error: BetaError) -> String {
        match error {
            BetaError::InvalidRequest(error) => error.message,
            BetaError::Authentication(error) => error.message,
            BetaError::Billing(error) => error.message,
            BetaError::Permission(error) => error.message,
            BetaError::NotFound(error) => error.message,
            BetaError::RateLimit(error) => error.message,
            BetaError::GatewayTimeout(error) => error.message,
            BetaError::Api(error) => error.message,
            BetaError::Overloaded(error) => error.message,
        }
    }

    fn chunk_from_parts(
        &self,
        parts: Vec<GeminiPart>,
        finish_reason: Option<GeminiFinishReason>,
        prompt_feedback: Option<GeminiPromptFeedback>,
    ) -> GeminiGenerateContentResponseBody {
        GeminiGenerateContentResponseBody {
            candidates: Some(vec![GeminiCandidate {
                content: Some(GeminiContent {
                    parts,
                    role: Some(GeminiContentRole::Model),
                }),
                finish_reason,
                index: Some(0),
                ..GeminiCandidate::default()
            }]),
            prompt_feedback,
            usage_metadata: self.usage_metadata.clone(),
            model_version: self.model_version.clone(),
            response_id: self.response_id.clone(),
            model_status: None,
        }
    }

    fn text_chunk(&self, text: String) -> Option<GeminiGenerateContentResponseBody> {
        if text.is_empty() {
            None
        } else {
            Some(self.chunk_from_parts(
                vec![GeminiPart {
                    text: Some(text),
                    ..GeminiPart::default()
                }],
                None,
                None,
            ))
        }
    }

    fn thinking_chunk(
        &self,
        signature: String,
        thinking: String,
    ) -> Option<GeminiGenerateContentResponseBody> {
        if thinking.is_empty() {
            None
        } else {
            Some(self.chunk_from_parts(
                vec![GeminiPart {
                    thought: Some(true),
                    thought_signature: Some(signature),
                    text: Some(thinking),
                    ..GeminiPart::default()
                }],
                None,
                None,
            ))
        }
    }

    fn function_call_chunk(
        &self,
        id: String,
        name: String,
        arguments: String,
    ) -> GeminiGenerateContentResponseBody {
        self.chunk_from_parts(
            vec![GeminiPart {
                function_call: Some(GeminiFunctionCall {
                    id: Some(id),
                    name,
                    args: Some(parse_json_object_or_empty(&arguments)),
                }),
                ..GeminiPart::default()
            }],
            None,
            None,
        )
    }

    pub fn on_event(
        &mut self,
        event: ClaudeStreamEvent,
        out: &mut Vec<GeminiGenerateContentResponseBody>,
    ) -> Result<(), TransformError> {
        if self.finished {
            return Ok(());
        }

        match event {
            ClaudeStreamEvent::MessageStart { message } => {
                self.response_id = Some(message.id);
                self.model_version = Some(claude_model_to_string(&message.model));
                self.input_tokens = message.usage.input_tokens;
                self.cache_creation_input_tokens = message.usage.cache_creation_input_tokens;
                self.cached_input_tokens = message.usage.cache_read_input_tokens;
                self.output_tokens = message.usage.output_tokens;
                self.sync_usage_metadata();
            }
            ClaudeStreamEvent::ContentBlockStart {
                content_block,
                index,
            } => {
                let state = match content_block {
                    BetaContentBlock::Thinking(block) => ClaudeBlockState::Thinking {
                        signature: block.signature,
                    },
                    BetaContentBlock::ToolUse(block) => ClaudeBlockState::ToolUse {
                        id: block.id,
                        name: block.name,
                        partial_json: String::new(),
                    },
                    _ => ClaudeBlockState::Other,
                };
                self.blocks.insert(index, state);
            }
            ClaudeStreamEvent::ContentBlockDelta { delta, index } => match delta {
                BetaRawContentBlockDelta::Text { text } => {
                    if let Some(chunk) = self.text_chunk(text) {
                        out.push(chunk);
                    }
                }
                BetaRawContentBlockDelta::Thinking { thinking } => {
                    let signature = match self.blocks.get(&index) {
                        Some(ClaudeBlockState::Thinking { signature }) => signature.clone(),
                        _ => format!("thought_{index}"),
                    };
                    if let Some(chunk) = self.thinking_chunk(signature, thinking) {
                        out.push(chunk);
                    }
                }
                BetaRawContentBlockDelta::InputJson { partial_json } => {
                    let mut tool_snapshot = None;
                    if let Some(ClaudeBlockState::ToolUse {
                        id,
                        name,
                        partial_json: accumulated,
                    }) = self.blocks.get_mut(&index)
                    {
                        accumulated.push_str(&partial_json);
                        tool_snapshot = Some((id.clone(), name.clone(), accumulated.clone()));
                    }
                    if let Some((id, name, arguments)) = tool_snapshot {
                        out.push(self.function_call_chunk(id, name, arguments));
                    }
                }
                BetaRawContentBlockDelta::Signature { signature } => {
                    if let Some(ClaudeBlockState::Thinking { signature: sig }) =
                        self.blocks.get_mut(&index)
                    {
                        *sig = signature;
                    }
                }
                BetaRawContentBlockDelta::Compaction { content } => {
                    if let Some(content) = content
                        && let Some(chunk) = self.text_chunk(content)
                    {
                        out.push(chunk);
                    }
                }
                BetaRawContentBlockDelta::Citations { .. } => {}
            },
            ClaudeStreamEvent::ContentBlockStop { index } => {
                self.blocks.remove(&index);
            }
            ClaudeStreamEvent::MessageDelta {
                delta,
                usage,
                context_management: _,
            } => {
                if let Some(input_tokens) = usage.input_tokens {
                    self.input_tokens = input_tokens;
                }
                if let Some(cache_creation_input_tokens) = usage.cache_creation_input_tokens {
                    self.cache_creation_input_tokens = cache_creation_input_tokens;
                }
                if let Some(cached_input_tokens) = usage.cache_read_input_tokens {
                    self.cached_input_tokens = cached_input_tokens;
                }
                self.output_tokens = usage.output_tokens;
                self.sync_usage_metadata();

                let finish_reason = Self::finish_reason_from_stop_reason(delta.stop_reason);
                let prompt_feedback = if matches!(finish_reason, GeminiFinishReason::Safety) {
                    Some(GeminiPromptFeedback {
                        block_reason: Some(GeminiBlockReason::Safety),
                        safety_ratings: None,
                    })
                } else {
                    None
                };

                out.push(self.chunk_from_parts(Vec::new(), Some(finish_reason), prompt_feedback));
            }
            ClaudeStreamEvent::MessageStop {} => {
                self.finished = true;
            }
            ClaudeStreamEvent::Error { error } => {
                let message = Self::error_message(error);
                if let Some(chunk) = self.text_chunk(message) {
                    out.push(chunk);
                }
                self.finished = true;
            }
            ClaudeStreamEvent::Ping {} => {}
        }

        Ok(())
    }
}