Skip to main content

gproxy_protocol/transform/gemini/websocket/to_http/
request.rs

1use crate::gemini::count_tokens::types::{GeminiContent, GeminiContentRole, GeminiPart};
2use crate::gemini::generate_content::request::GeminiGenerateContentRequest;
3use crate::gemini::live::request::GeminiLiveConnectRequest;
4use crate::gemini::live::types::{
5    GeminiBidiGenerateContentClientMessage, GeminiBidiGenerateContentClientMessageType,
6};
7use crate::gemini::stream_generate_content::request::{
8    AltQueryParameter, GeminiStreamGenerateContentRequest, PathParameters, QueryParameters,
9    RequestBody,
10};
11use crate::transform::gemini::model_get::utils::ensure_models_prefix;
12use crate::transform::gemini::websocket::context::GeminiWebsocketTransformContext;
13use crate::transform::utils::TransformError;
14
15const UNSUPPORTED_REALTIME_INPUT: &str =
16    "cannot convert Gemini realtimeInput websocket frame to streamGenerateContent request";
17const MISSING_SETUP_MODEL: &str =
18    "cannot convert Gemini websocket frames to streamGenerateContent request without setup model";
19const FALLBACK_MODEL: &str = "models/unknown";
20
21pub fn gemini_live_client_messages_to_stream_request_with_context(
22    value: &[GeminiBidiGenerateContentClientMessage],
23) -> Result<
24    (
25        GeminiStreamGenerateContentRequest,
26        GeminiWebsocketTransformContext,
27    ),
28    TransformError,
29> {
30    let mut ctx = GeminiWebsocketTransformContext::default();
31    let mut model = None::<String>;
32    let mut generation_config = None;
33    let mut system_instruction = None;
34    let mut tools = None;
35    let mut contents = Vec::<GeminiContent>::new();
36
37    for message in value {
38        match &message.message_type {
39            GeminiBidiGenerateContentClientMessageType::Setup { setup } => {
40                model = Some(ensure_models_prefix(&setup.model));
41                generation_config = setup.generation_config.clone();
42                system_instruction = setup.system_instruction.clone();
43                tools = setup.tools.clone();
44                if let Some(prefix_turns) = &setup.prefix_turns {
45                    contents.extend(prefix_turns.clone());
46                }
47            }
48            GeminiBidiGenerateContentClientMessageType::ClientContent { client_content } => {
49                if let Some(turns) = &client_content.turns {
50                    contents.extend(turns.clone());
51                }
52            }
53            GeminiBidiGenerateContentClientMessageType::ToolResponse { tool_response } => {
54                if let Some(function_responses) = &tool_response.function_responses {
55                    let parts = function_responses
56                        .iter()
57                        .cloned()
58                        .map(|response| GeminiPart {
59                            function_response: Some(response),
60                            ..GeminiPart::default()
61                        })
62                        .collect::<Vec<_>>();
63                    if !parts.is_empty() {
64                        contents.push(GeminiContent {
65                            parts,
66                            role: Some(GeminiContentRole::User),
67                        });
68                    }
69                }
70            }
71            GeminiBidiGenerateContentClientMessageType::RealtimeInput { .. } => {
72                ctx.push_warning(UNSUPPORTED_REALTIME_INPUT.to_string());
73            }
74        }
75    }
76
77    let model = model.unwrap_or_else(|| {
78        ctx.push_warning(MISSING_SETUP_MODEL.to_string());
79        FALLBACK_MODEL.to_string()
80    });
81
82    Ok((
83        GeminiStreamGenerateContentRequest {
84            path: PathParameters { model },
85            query: QueryParameters {
86                alt: Some(AltQueryParameter::Sse),
87            },
88            body: RequestBody {
89                contents,
90                tools,
91                tool_config: None,
92                safety_settings: None,
93                system_instruction,
94                generation_config,
95                cached_content: None,
96                store: None,
97            },
98            ..GeminiStreamGenerateContentRequest::default()
99        },
100        ctx,
101    ))
102}
103
104pub fn gemini_live_connect_to_stream_request_with_context(
105    value: &GeminiLiveConnectRequest,
106) -> Result<
107    (
108        GeminiStreamGenerateContentRequest,
109        GeminiWebsocketTransformContext,
110    ),
111    TransformError,
112> {
113    let Some(frame) = value.body.as_ref() else {
114        let mut ctx = GeminiWebsocketTransformContext::default();
115        ctx.push_warning(
116            "cannot convert Gemini live connect request without initial body; downgraded to empty streamGenerateContent request"
117                .to_string(),
118        );
119        return Ok((
120            GeminiStreamGenerateContentRequest {
121                path: PathParameters {
122                    model: FALLBACK_MODEL.to_string(),
123                },
124                query: QueryParameters {
125                    alt: Some(AltQueryParameter::Sse),
126                },
127                ..GeminiStreamGenerateContentRequest::default()
128            },
129            ctx,
130        ));
131    };
132    gemini_live_client_messages_to_stream_request_with_context(std::slice::from_ref(frame))
133}
134
135pub fn gemini_live_client_messages_to_nonstream_request_with_context(
136    value: &[GeminiBidiGenerateContentClientMessage],
137) -> Result<
138    (
139        GeminiGenerateContentRequest,
140        GeminiWebsocketTransformContext,
141    ),
142    TransformError,
143> {
144    let (stream_request, ctx) = gemini_live_client_messages_to_stream_request_with_context(value)?;
145    let request = GeminiGenerateContentRequest::try_from(stream_request)?;
146    Ok((request, ctx))
147}
148
149pub fn gemini_live_connect_to_nonstream_request_with_context(
150    value: &GeminiLiveConnectRequest,
151) -> Result<
152    (
153        GeminiGenerateContentRequest,
154        GeminiWebsocketTransformContext,
155    ),
156    TransformError,
157> {
158    let (stream_request, ctx) = gemini_live_connect_to_stream_request_with_context(value)?;
159    let request = GeminiGenerateContentRequest::try_from(stream_request)?;
160    Ok((request, ctx))
161}
162
163impl TryFrom<&GeminiBidiGenerateContentClientMessage> for GeminiStreamGenerateContentRequest {
164    type Error = TransformError;
165
166    fn try_from(value: &GeminiBidiGenerateContentClientMessage) -> Result<Self, TransformError> {
167        Ok(
168            gemini_live_client_messages_to_stream_request_with_context(std::slice::from_ref(
169                value,
170            ))?
171            .0,
172        )
173    }
174}
175
176impl TryFrom<&[GeminiBidiGenerateContentClientMessage]> for GeminiStreamGenerateContentRequest {
177    type Error = TransformError;
178
179    fn try_from(value: &[GeminiBidiGenerateContentClientMessage]) -> Result<Self, TransformError> {
180        Ok(gemini_live_client_messages_to_stream_request_with_context(value)?.0)
181    }
182}
183
184impl TryFrom<&GeminiLiveConnectRequest> for GeminiStreamGenerateContentRequest {
185    type Error = TransformError;
186
187    fn try_from(value: &GeminiLiveConnectRequest) -> Result<Self, TransformError> {
188        Ok(gemini_live_connect_to_stream_request_with_context(value)?.0)
189    }
190}
191
192impl TryFrom<GeminiLiveConnectRequest> for GeminiStreamGenerateContentRequest {
193    type Error = TransformError;
194
195    fn try_from(value: GeminiLiveConnectRequest) -> Result<Self, TransformError> {
196        GeminiStreamGenerateContentRequest::try_from(&value)
197    }
198}
199
200impl TryFrom<&GeminiBidiGenerateContentClientMessage> for GeminiGenerateContentRequest {
201    type Error = TransformError;
202
203    fn try_from(value: &GeminiBidiGenerateContentClientMessage) -> Result<Self, TransformError> {
204        Ok(
205            gemini_live_client_messages_to_nonstream_request_with_context(std::slice::from_ref(
206                value,
207            ))?
208            .0,
209        )
210    }
211}
212
213impl TryFrom<&[GeminiBidiGenerateContentClientMessage]> for GeminiGenerateContentRequest {
214    type Error = TransformError;
215
216    fn try_from(value: &[GeminiBidiGenerateContentClientMessage]) -> Result<Self, TransformError> {
217        Ok(gemini_live_client_messages_to_nonstream_request_with_context(value)?.0)
218    }
219}
220
221impl TryFrom<&GeminiLiveConnectRequest> for GeminiGenerateContentRequest {
222    type Error = TransformError;
223
224    fn try_from(value: &GeminiLiveConnectRequest) -> Result<Self, TransformError> {
225        Ok(gemini_live_connect_to_nonstream_request_with_context(value)?.0)
226    }
227}
228
229impl TryFrom<GeminiLiveConnectRequest> for GeminiGenerateContentRequest {
230    type Error = TransformError;
231
232    fn try_from(value: GeminiLiveConnectRequest) -> Result<Self, TransformError> {
233        GeminiGenerateContentRequest::try_from(&value)
234    }
235}