Skip to main content

gproxy_protocol/transform/gemini/websocket/from_http/
response.rs

1use crate::gemini::count_tokens::types::GeminiContent;
2use crate::gemini::generate_content::response::GeminiGenerateContentResponse;
3use crate::gemini::generate_content::response::ResponseBody as GeminiGenerateContentResponseBody;
4use crate::gemini::generate_content::types::{GeminiCandidate, GeminiUsageMetadata};
5use crate::gemini::live::response::GeminiLiveMessageResponse;
6use crate::gemini::live::types::{
7    GeminiBidiGenerateContentServerContent, GeminiBidiGenerateContentServerMessage,
8    GeminiBidiGenerateContentServerMessageType, GeminiBidiGenerateContentToolCall,
9    GeminiFunctionCall, GeminiLiveUsageMetadata,
10};
11use crate::gemini::stream_generate_content::response::GeminiStreamGenerateContentResponse;
12use crate::transform::gemini::websocket::context::GeminiWebsocketTransformContext;
13use crate::transform::utils::TransformError;
14
15pub fn usage_generate_to_live(
16    usage: Option<GeminiUsageMetadata>,
17) -> Option<GeminiLiveUsageMetadata> {
18    usage.map(|usage| GeminiLiveUsageMetadata {
19        prompt_token_count: usage.prompt_token_count,
20        cached_content_token_count: usage.cached_content_token_count,
21        response_token_count: usage.candidates_token_count,
22        tool_use_prompt_token_count: usage.tool_use_prompt_token_count,
23        thoughts_token_count: usage.thoughts_token_count,
24        total_token_count: usage.total_token_count,
25        prompt_tokens_details: usage.prompt_tokens_details,
26        cache_tokens_details: usage.cache_tokens_details,
27        response_tokens_details: usage.candidates_tokens_details,
28        tool_use_prompt_tokens_details: usage.tool_use_prompt_tokens_details,
29    })
30}
31
32pub fn candidate_to_server_message(
33    candidate: GeminiCandidate,
34    usage_metadata: Option<GeminiLiveUsageMetadata>,
35) -> Option<GeminiLiveMessageResponse> {
36    let generation_complete = candidate.finish_reason.is_some();
37
38    let has_payload = candidate.content.is_some()
39        || candidate.finish_reason.is_some()
40        || candidate.grounding_metadata.is_some()
41        || candidate.url_context_metadata.is_some()
42        || usage_metadata.is_some();
43
44    if !has_payload {
45        return None;
46    }
47
48    let as_tool_calls = candidate
49        .content
50        .as_ref()
51        .and_then(content_as_pure_function_calls);
52
53    Some(GeminiLiveMessageResponse::Message(match as_tool_calls {
54        Some(function_calls) => GeminiBidiGenerateContentServerMessage {
55            usage_metadata,
56            message_type: GeminiBidiGenerateContentServerMessageType::ToolCall {
57                tool_call: GeminiBidiGenerateContentToolCall {
58                    function_calls: Some(function_calls),
59                },
60            },
61        },
62        None => GeminiBidiGenerateContentServerMessage {
63            usage_metadata,
64            message_type: GeminiBidiGenerateContentServerMessageType::ServerContent {
65                server_content: GeminiBidiGenerateContentServerContent {
66                    generation_complete: generation_complete.then_some(true),
67                    turn_complete: generation_complete.then_some(true),
68                    interrupted: None,
69                    grounding_metadata: candidate.grounding_metadata,
70                    input_transcription: None,
71                    output_transcription: None,
72                    url_context_metadata: candidate.url_context_metadata,
73                    model_turn: candidate.content,
74                },
75            },
76        },
77    }))
78}
79
80fn content_as_pure_function_calls(content: &GeminiContent) -> Option<Vec<GeminiFunctionCall>> {
81    let mut calls = Vec::new();
82    for part in &content.parts {
83        let call = part.function_call.clone()?;
84        let has_non_call_fields = part.text.is_some()
85            || part.inline_data.is_some()
86            || part.function_response.is_some()
87            || part.file_data.is_some()
88            || part.executable_code.is_some()
89            || part.code_execution_result.is_some();
90        if has_non_call_fields {
91            return None;
92        }
93        calls.push(call);
94    }
95
96    if calls.is_empty() { None } else { Some(calls) }
97}
98
99fn chunk_to_live_messages(
100    chunk: GeminiGenerateContentResponseBody,
101    ctx: &mut GeminiWebsocketTransformContext,
102) -> Vec<GeminiLiveMessageResponse> {
103    if chunk.prompt_feedback.is_some() {
104        ctx.push_warning("gemini websocket from_http response: dropped promptFeedback".to_string());
105    }
106    if chunk.model_version.is_some() {
107        ctx.push_warning("gemini websocket from_http response: dropped modelVersion".to_string());
108    }
109    if chunk.response_id.is_some() {
110        ctx.push_warning("gemini websocket from_http response: dropped responseId".to_string());
111    }
112    if chunk.model_status.is_some() {
113        ctx.push_warning("gemini websocket from_http response: dropped modelStatus".to_string());
114    }
115
116    let usage_metadata = usage_generate_to_live(chunk.usage_metadata);
117
118    let mut messages = Vec::new();
119    if let Some(candidates) = chunk.candidates {
120        for candidate in candidates {
121            if let Some(message) = candidate_to_server_message(candidate, usage_metadata.clone()) {
122                messages.push(message);
123            }
124        }
125    }
126
127    if messages.is_empty() && usage_metadata.is_some() {
128        messages.push(GeminiLiveMessageResponse::Message(
129            GeminiBidiGenerateContentServerMessage {
130                usage_metadata,
131                message_type: GeminiBidiGenerateContentServerMessageType::ServerContent {
132                    server_content: GeminiBidiGenerateContentServerContent::default(),
133                },
134            },
135        ));
136    }
137
138    messages
139}
140
141impl TryFrom<GeminiStreamGenerateContentResponse> for Vec<GeminiLiveMessageResponse> {
142    type Error = TransformError;
143
144    fn try_from(value: GeminiStreamGenerateContentResponse) -> Result<Self, TransformError> {
145        Ok(gemini_stream_response_to_live_messages_with_context(value)?.0)
146    }
147}
148
149impl TryFrom<GeminiGenerateContentResponse> for Vec<GeminiLiveMessageResponse> {
150    type Error = TransformError;
151
152    fn try_from(value: GeminiGenerateContentResponse) -> Result<Self, TransformError> {
153        Ok(gemini_nonstream_response_to_live_messages_with_context(value)?.0)
154    }
155}
156
157pub fn gemini_nonstream_response_to_live_messages_with_context(
158    value: GeminiGenerateContentResponse,
159) -> Result<
160    (
161        Vec<GeminiLiveMessageResponse>,
162        GeminiWebsocketTransformContext,
163    ),
164    TransformError,
165> {
166    let mut ctx = GeminiWebsocketTransformContext::default();
167    let mut out = Vec::new();
168    match value {
169        GeminiGenerateContentResponse::Success { body, .. } => {
170            out.extend(chunk_to_live_messages(body, &mut ctx));
171        }
172        GeminiGenerateContentResponse::Error { body, .. } => {
173            out.push(GeminiLiveMessageResponse::Error(body));
174        }
175    }
176    Ok((out, ctx))
177}
178
179pub fn gemini_stream_response_to_live_messages_with_context(
180    value: GeminiStreamGenerateContentResponse,
181) -> Result<
182    (
183        Vec<GeminiLiveMessageResponse>,
184        GeminiWebsocketTransformContext,
185    ),
186    TransformError,
187> {
188    let ctx = GeminiWebsocketTransformContext::default();
189    let mut out = Vec::new();
190    match value {
191        GeminiStreamGenerateContentResponse::Success { .. } => {
192            // Stream envelope has no body data; chunks are handled by the
193            // transport layer, so nothing to convert here.
194        }
195        GeminiStreamGenerateContentResponse::Error { body, .. } => {
196            out.push(GeminiLiveMessageResponse::Error(body));
197        }
198    }
199
200    Ok((out, ctx))
201}