gproxy-protocol 1.0.20

Wire-format types and cross-protocol transforms for Claude, OpenAI, and Gemini LLM APIs.
Documentation
use std::collections::BTreeMap;

use http::StatusCode;

use crate::claude::create_message::response::ClaudeCreateMessageResponse;
use crate::claude::create_message::stream::{BetaRawContentBlockDelta, ClaudeStreamEvent};
use crate::claude::create_message::types::{
    BetaContentBlock, BetaErrorResponse, BetaErrorResponseType, BetaMessage, BetaTextBlock,
    BetaThinkingBlock, BetaToolUseBlock, JsonObject,
};
use crate::claude::types::{BetaError, ClaudeResponseHeaders};
use crate::transform::utils::TransformError;

#[derive(Debug, Clone)]
enum PendingBlock {
    Text(BetaTextBlock),
    Thinking(BetaThinkingBlock),
    ToolUse {
        block: BetaToolUseBlock,
        input_json_buf: String,
    },
    Other(BetaContentBlock),
}

impl PendingBlock {
    fn apply_delta(&mut self, delta: BetaRawContentBlockDelta) {
        match (self, delta) {
            (Self::Text(block), BetaRawContentBlockDelta::Text { text }) => {
                block.text.push_str(&text);
            }
            (Self::Text(block), BetaRawContentBlockDelta::Citations { citation }) => {
                if let Some(citations) = block.citations.as_mut() {
                    citations.push(citation);
                } else {
                    block.citations = Some(vec![citation]);
                }
            }
            (Self::Thinking(block), BetaRawContentBlockDelta::Thinking { thinking }) => {
                block.thinking.push_str(&thinking);
            }
            (Self::Thinking(block), BetaRawContentBlockDelta::Signature { signature }) => {
                block.signature = signature;
            }
            (
                Self::ToolUse { input_json_buf, .. },
                BetaRawContentBlockDelta::InputJson { partial_json },
            ) => {
                input_json_buf.push_str(&partial_json);
            }
            (
                Self::Other(BetaContentBlock::Compaction(block)),
                BetaRawContentBlockDelta::Compaction { content },
            ) => {
                block.content = content;
            }
            _ => {}
        }
    }

    fn into_content_block(self) -> BetaContentBlock {
        match self {
            Self::Text(block) => BetaContentBlock::Text(block),
            Self::Thinking(block) => BetaContentBlock::Thinking(block),
            Self::ToolUse {
                mut block,
                input_json_buf,
            } => {
                if !input_json_buf.is_empty() {
                    block.input =
                        serde_json::from_str::<JsonObject>(&input_json_buf).unwrap_or_default();
                }
                BetaContentBlock::ToolUse(block)
            }
            Self::Other(block) => block,
        }
    }
}

fn pending_from_content_block(content_block: BetaContentBlock) -> PendingBlock {
    match content_block {
        BetaContentBlock::Text(block) => PendingBlock::Text(block),
        BetaContentBlock::Thinking(block) => PendingBlock::Thinking(block),
        BetaContentBlock::ToolUse(block) => PendingBlock::ToolUse {
            block,
            input_json_buf: String::new(),
        },
        other => PendingBlock::Other(other),
    }
}

fn status_code_from_stream_error(error: &BetaError) -> StatusCode {
    match error {
        BetaError::InvalidRequest(_) => StatusCode::BAD_REQUEST,
        BetaError::Authentication(_) => StatusCode::UNAUTHORIZED,
        BetaError::Billing(_) => StatusCode::PAYMENT_REQUIRED,
        BetaError::Permission(_) => StatusCode::FORBIDDEN,
        BetaError::NotFound(_) => StatusCode::NOT_FOUND,
        BetaError::RateLimit(_) => StatusCode::TOO_MANY_REQUESTS,
        BetaError::GatewayTimeout(_) => StatusCode::GATEWAY_TIMEOUT,
        BetaError::Api(_) => StatusCode::INTERNAL_SERVER_ERROR,
        BetaError::Overloaded(_) => {
            StatusCode::from_u16(529).unwrap_or(StatusCode::SERVICE_UNAVAILABLE)
        }
    }
}

impl TryFrom<Vec<ClaudeStreamEvent>> for ClaudeCreateMessageResponse {
    type Error = TransformError;

    fn try_from(value: Vec<ClaudeStreamEvent>) -> Result<Self, TransformError> {
        let mut message: Option<BetaMessage> = None;
        let mut open_blocks: BTreeMap<u64, PendingBlock> = BTreeMap::new();
        let mut closed_blocks: BTreeMap<u64, BetaContentBlock> = BTreeMap::new();

        for event in value {
            match event {
                ClaudeStreamEvent::MessageStart { message: msg } => {
                    if message.is_some() {
                        return Err(TransformError::not_implemented(
                            "multiple message_start events are not supported",
                        ));
                    }
                    message = Some(msg);
                }
                ClaudeStreamEvent::ContentBlockStart {
                    content_block,
                    index,
                } => {
                    open_blocks.insert(index, pending_from_content_block(content_block));
                }
                ClaudeStreamEvent::ContentBlockDelta { delta, index } => {
                    let Some(block) = open_blocks.get_mut(&index) else {
                        return Err(TransformError::not_implemented(
                            "content_block_delta received before content_block_start",
                        ));
                    };
                    block.apply_delta(delta);
                }
                ClaudeStreamEvent::ContentBlockStop { index } => {
                    let Some(block) = open_blocks.remove(&index) else {
                        return Err(TransformError::not_implemented(
                            "content_block_stop received before content_block_start",
                        ));
                    };
                    closed_blocks.insert(index, block.into_content_block());
                }
                ClaudeStreamEvent::MessageDelta {
                    context_management,
                    delta,
                    usage,
                } => {
                    let Some(message) = message.as_mut() else {
                        return Err(TransformError::not_implemented(
                            "message_delta received before message_start",
                        ));
                    };

                    if let Some(context_management) = context_management {
                        message.context_management = Some(context_management);
                    }

                    message.stop_reason = delta.stop_reason;
                    message.stop_sequence = delta.stop_sequence;
                    if let Some(container) = delta.container {
                        message.container = Some(container);
                    }

                    if let Some(input_tokens) = usage.input_tokens {
                        message.usage.input_tokens = input_tokens;
                    }
                    if let Some(cache_read_input_tokens) = usage.cache_read_input_tokens {
                        message.usage.cache_read_input_tokens = cache_read_input_tokens;
                    }
                    if let Some(cache_creation_input_tokens) = usage.cache_creation_input_tokens {
                        message.usage.cache_creation_input_tokens = cache_creation_input_tokens;
                    }
                    if let Some(iterations) = usage.iterations {
                        message.usage.iterations = iterations;
                    }
                    if let Some(server_tool_use) = usage.server_tool_use {
                        message.usage.server_tool_use = server_tool_use;
                    }
                    message.usage.output_tokens = usage.output_tokens;
                }
                ClaudeStreamEvent::MessageStop {} => {}
                ClaudeStreamEvent::Ping {} => {}
                ClaudeStreamEvent::Error { error } => {
                    return Ok(ClaudeCreateMessageResponse::Error {
                        stats_code: status_code_from_stream_error(&error),
                        headers: ClaudeResponseHeaders {
                            extra: BTreeMap::new(),
                        },
                        body: BetaErrorResponse {
                            error,
                            request_id: String::new(),
                            type_: BetaErrorResponseType::Error,
                        },
                    });
                }
            }
        }

        let Some(mut message) = message else {
            return Err(TransformError::not_implemented(
                "message_start event is required for stream_to_nonstream conversion",
            ));
        };

        for (index, block) in open_blocks {
            closed_blocks.insert(index, block.into_content_block());
        }

        if !closed_blocks.is_empty() {
            message.content = closed_blocks.into_values().collect();
        }

        Ok(ClaudeCreateMessageResponse::Success {
            stats_code: StatusCode::OK,
            headers: ClaudeResponseHeaders {
                extra: BTreeMap::new(),
            },
            body: message,
        })
    }
}