Skip to main content

gproxy_protocol/transform/gemini/websocket/from_http/
request.rs

1use crate::gemini::count_tokens::types::{GeminiContent, GeminiPart};
2use crate::gemini::generate_content::request::GeminiGenerateContentRequest;
3use crate::gemini::live::request::GeminiLiveConnectRequest;
4use crate::gemini::live::types::{
5    GeminiBidiGenerateContentClientContent, GeminiBidiGenerateContentClientMessage,
6    GeminiBidiGenerateContentClientMessageType, GeminiBidiGenerateContentSetup,
7    GeminiBidiGenerateContentToolResponse, GeminiFunctionResponse,
8};
9use crate::gemini::stream_generate_content::request::GeminiStreamGenerateContentRequest;
10use crate::transform::gemini::model_get::utils::ensure_models_prefix;
11use crate::transform::gemini::websocket::context::GeminiWebsocketTransformContext;
12use crate::transform::utils::TransformError;
13
14fn setup_message(
15    request: &GeminiStreamGenerateContentRequest,
16) -> GeminiBidiGenerateContentClientMessage {
17    GeminiBidiGenerateContentClientMessage {
18        message_type: GeminiBidiGenerateContentClientMessageType::Setup {
19            setup: GeminiBidiGenerateContentSetup {
20                model: ensure_models_prefix(&request.path.model),
21                generation_config: request.body.generation_config.clone(),
22                system_instruction: request.body.system_instruction.clone(),
23                tools: request.body.tools.clone(),
24                ..GeminiBidiGenerateContentSetup::default()
25            },
26        },
27    }
28}
29
30fn content_message(turns: Vec<GeminiContent>) -> Option<GeminiBidiGenerateContentClientMessage> {
31    if turns.is_empty() {
32        return None;
33    }
34
35    Some(GeminiBidiGenerateContentClientMessage {
36        message_type: GeminiBidiGenerateContentClientMessageType::ClientContent {
37            client_content: GeminiBidiGenerateContentClientContent {
38                turns: Some(turns),
39                turn_complete: Some(true),
40            },
41        },
42    })
43}
44
45fn part_as_pure_function_response(part: &GeminiPart) -> Option<GeminiFunctionResponse> {
46    let function_response = part.function_response.clone()?;
47    let has_non_response_fields = part.text.is_some()
48        || part.inline_data.is_some()
49        || part.function_call.is_some()
50        || part.file_data.is_some()
51        || part.executable_code.is_some()
52        || part.code_execution_result.is_some();
53    if has_non_response_fields {
54        return None;
55    }
56    Some(function_response)
57}
58
59fn split_turns_and_tool_responses(
60    request: &GeminiStreamGenerateContentRequest,
61    _ctx: &mut GeminiWebsocketTransformContext,
62) -> (Vec<GeminiContent>, Vec<GeminiFunctionResponse>) {
63    let mut turns = Vec::new();
64    let mut function_responses = Vec::new();
65
66    for content in &request.body.contents {
67        let extracted = content
68            .parts
69            .iter()
70            .map(part_as_pure_function_response)
71            .collect::<Option<Vec<_>>>();
72        if let Some(responses) = extracted {
73            if responses.is_empty() {
74                turns.push(content.clone());
75            } else {
76                function_responses.extend(responses);
77            }
78        } else {
79            turns.push(content.clone());
80        }
81    }
82
83    (turns, function_responses)
84}
85
86fn tool_response_message(
87    function_responses: Vec<GeminiFunctionResponse>,
88) -> Option<GeminiBidiGenerateContentClientMessage> {
89    if function_responses.is_empty() {
90        return None;
91    }
92
93    Some(GeminiBidiGenerateContentClientMessage {
94        message_type: GeminiBidiGenerateContentClientMessageType::ToolResponse {
95            tool_response: GeminiBidiGenerateContentToolResponse {
96                function_responses: Some(function_responses),
97            },
98        },
99    })
100}
101
102pub fn gemini_stream_request_to_live_frames_with_context(
103    value: &GeminiStreamGenerateContentRequest,
104) -> Result<
105    (
106        Vec<GeminiBidiGenerateContentClientMessage>,
107        GeminiWebsocketTransformContext,
108    ),
109    TransformError,
110> {
111    let mut ctx = GeminiWebsocketTransformContext::default();
112    let mut frames = vec![setup_message(value)];
113    let (turns, function_responses) = split_turns_and_tool_responses(value, &mut ctx);
114    if let Some(content) = content_message(turns) {
115        frames.push(content);
116    }
117    if let Some(tool_response) = tool_response_message(function_responses) {
118        frames.push(tool_response);
119    }
120    Ok((frames, ctx))
121}
122
123pub fn gemini_stream_request_to_live_connect_with_context(
124    value: &GeminiStreamGenerateContentRequest,
125) -> Result<(GeminiLiveConnectRequest, GeminiWebsocketTransformContext), TransformError> {
126    Ok((
127        GeminiLiveConnectRequest {
128            body: Some(setup_message(value)),
129            ..GeminiLiveConnectRequest::default()
130        },
131        GeminiWebsocketTransformContext::default(),
132    ))
133}
134
135pub fn gemini_nonstream_request_to_live_frames_with_context(
136    value: &GeminiGenerateContentRequest,
137) -> Result<
138    (
139        Vec<GeminiBidiGenerateContentClientMessage>,
140        GeminiWebsocketTransformContext,
141    ),
142    TransformError,
143> {
144    let stream_request = GeminiStreamGenerateContentRequest::try_from(value)?;
145    gemini_stream_request_to_live_frames_with_context(&stream_request)
146}
147
148pub fn gemini_nonstream_request_to_live_connect_with_context(
149    value: &GeminiGenerateContentRequest,
150) -> Result<(GeminiLiveConnectRequest, GeminiWebsocketTransformContext), TransformError> {
151    let stream_request = GeminiStreamGenerateContentRequest::try_from(value)?;
152    gemini_stream_request_to_live_connect_with_context(&stream_request)
153}
154
155impl TryFrom<&GeminiStreamGenerateContentRequest> for Vec<GeminiBidiGenerateContentClientMessage> {
156    type Error = TransformError;
157
158    fn try_from(value: &GeminiStreamGenerateContentRequest) -> Result<Self, TransformError> {
159        Ok(gemini_stream_request_to_live_frames_with_context(value)?.0)
160    }
161}
162
163impl TryFrom<&GeminiStreamGenerateContentRequest> for GeminiLiveConnectRequest {
164    type Error = TransformError;
165
166    fn try_from(value: &GeminiStreamGenerateContentRequest) -> Result<Self, TransformError> {
167        Ok(gemini_stream_request_to_live_connect_with_context(value)?.0)
168    }
169}
170
171impl TryFrom<GeminiStreamGenerateContentRequest> for GeminiLiveConnectRequest {
172    type Error = TransformError;
173
174    fn try_from(value: GeminiStreamGenerateContentRequest) -> Result<Self, TransformError> {
175        GeminiLiveConnectRequest::try_from(&value)
176    }
177}
178
179impl TryFrom<&GeminiGenerateContentRequest> for Vec<GeminiBidiGenerateContentClientMessage> {
180    type Error = TransformError;
181
182    fn try_from(value: &GeminiGenerateContentRequest) -> Result<Self, TransformError> {
183        Ok(gemini_nonstream_request_to_live_frames_with_context(value)?.0)
184    }
185}
186
187impl TryFrom<&GeminiGenerateContentRequest> for GeminiLiveConnectRequest {
188    type Error = TransformError;
189
190    fn try_from(value: &GeminiGenerateContentRequest) -> Result<Self, TransformError> {
191        Ok(gemini_nonstream_request_to_live_connect_with_context(value)?.0)
192    }
193}
194
195impl TryFrom<GeminiGenerateContentRequest> for GeminiLiveConnectRequest {
196    type Error = TransformError;
197
198    fn try_from(value: GeminiGenerateContentRequest) -> Result<Self, TransformError> {
199        GeminiLiveConnectRequest::try_from(&value)
200    }
201}