1use std::time::Duration;
15
16use base64::Engine;
17
18use crate::error::CodecError;
19use crate::types::{ClientMessage, ServerEvent, ServerMessage};
20
21pub fn encode(msg: &ClientMessage) -> Result<String, CodecError> {
27 serde_json::to_string(msg).map_err(CodecError::Serialize)
28}
29
30pub fn decode(json: &str) -> Result<ServerMessage, CodecError> {
32 serde_json::from_str(json).map_err(CodecError::Deserialize)
33}
34
35pub fn into_events(msg: ServerMessage) -> Vec<ServerEvent> {
54 let mut events = Vec::new();
55
56 if msg.setup_complete.is_some() {
58 events.push(ServerEvent::SetupComplete);
59 }
60
61 if let Some(sc) = msg.server_content {
63 if let Some(t) = sc.input_transcription {
64 events.push(ServerEvent::InputTranscription(t.text));
65 }
66 if let Some(t) = sc.output_transcription {
67 events.push(ServerEvent::OutputTranscription(t.text));
68 }
69
70 if let Some(turn) = sc.model_turn {
71 for part in turn.parts {
72 if let Some(text) = part.text {
73 events.push(ServerEvent::ModelText(text));
74 }
75 if let Some(blob) = part.inline_data {
76 match base64::engine::general_purpose::STANDARD.decode(&blob.data) {
77 Ok(bytes) => events.push(ServerEvent::ModelAudio(bytes)),
78 Err(e) => {
79 tracing::warn!(error = %e, "failed to base64-decode model audio");
80 }
81 }
82 }
83 }
84 }
85
86 if sc.interrupted == Some(true) {
87 events.push(ServerEvent::Interrupted);
88 }
89 if sc.generation_complete == Some(true) {
90 events.push(ServerEvent::GenerationComplete);
91 }
92 if sc.turn_complete == Some(true) {
93 events.push(ServerEvent::TurnComplete);
94 }
95 }
96
97 if let Some(tc) = msg.tool_call {
99 events.push(ServerEvent::ToolCall(tc.function_calls));
100 }
101 if let Some(tcc) = msg.tool_call_cancellation {
102 events.push(ServerEvent::ToolCallCancellation(tcc.ids));
103 }
104
105 if let Some(sr) = msg.session_resumption_update {
107 events.push(ServerEvent::SessionResumption {
108 new_handle: sr.new_handle,
109 resumable: sr.resumable.unwrap_or(false),
110 });
111 }
112 if let Some(ga) = msg.go_away {
113 events.push(ServerEvent::GoAway {
114 time_left: ga.time_left.as_deref().and_then(parse_protobuf_duration),
115 });
116 }
117
118 if let Some(usage) = msg.usage_metadata {
120 events.push(ServerEvent::Usage(usage));
121 }
122
123 if let Some(err) = msg.error {
125 events.push(ServerEvent::Error(err));
126 }
127
128 events
129}
130
131fn parse_protobuf_duration(s: &str) -> Option<Duration> {
134 let s = s.trim();
135 let secs_str = s.strip_suffix('s')?;
136 let secs: f64 = secs_str.parse().ok()?;
137 Some(Duration::from_secs_f64(secs))
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::types::*;
144
145 #[test]
148 fn encode_setup_minimal() {
149 let msg = ClientMessage::Setup(SetupConfig {
150 model: "models/gemini-3.1-flash-live-preview".into(),
151 ..Default::default()
152 });
153 let json = encode(&msg).unwrap();
154 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
155 assert_eq!(v["setup"]["model"], "models/gemini-3.1-flash-live-preview");
156 assert!(v["setup"].get("generationConfig").is_none());
158 }
159
160 #[test]
161 fn encode_setup_full() {
162 let msg = ClientMessage::Setup(SetupConfig {
163 model: "models/gemini-3.1-flash-live-preview".into(),
164 generation_config: Some(GenerationConfig {
165 response_modalities: Some(vec![Modality::Audio, Modality::Text]),
166 speech_config: Some(SpeechConfig {
167 voice_config: VoiceConfig {
168 prebuilt_voice_config: PrebuiltVoiceConfig {
169 voice_name: "Kore".into(),
170 },
171 },
172 }),
173 thinking_config: Some(ThinkingConfig {
174 thinking_level: Some(ThinkingLevel::Medium),
175 ..Default::default()
176 }),
177 ..Default::default()
178 }),
179 system_instruction: Some(Content {
180 role: None,
181 parts: vec![Part {
182 text: Some("You are a helpful assistant.".into()),
183 inline_data: None,
184 }],
185 }),
186 input_audio_transcription: Some(AudioTranscriptionConfig {}),
187 output_audio_transcription: Some(AudioTranscriptionConfig {}),
188 session_resumption: Some(SessionResumptionConfig { handle: None }),
189 ..Default::default()
190 });
191 let json = encode(&msg).unwrap();
192 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
193 let setup = &v["setup"];
194 assert_eq!(setup["generationConfig"]["responseModalities"][0], "AUDIO");
195 assert_eq!(setup["generationConfig"]["responseModalities"][1], "TEXT");
196 assert_eq!(
197 setup["generationConfig"]["speechConfig"]["voiceConfig"]["prebuiltVoiceConfig"]["voiceName"],
198 "Kore"
199 );
200 assert_eq!(
201 setup["generationConfig"]["thinkingConfig"]["thinkingLevel"],
202 "medium"
203 );
204 assert_eq!(
205 setup["systemInstruction"]["parts"][0]["text"],
206 "You are a helpful assistant."
207 );
208 assert_eq!(setup["inputAudioTranscription"], serde_json::json!({}));
210 assert_eq!(setup["outputAudioTranscription"], serde_json::json!({}));
211 }
212
213 #[test]
214 fn encode_client_content() {
215 let msg = ClientMessage::ClientContent(ClientContent {
216 turns: Some(vec![
217 Content {
218 role: Some("user".into()),
219 parts: vec![Part {
220 text: Some("Hello".into()),
221 inline_data: None,
222 }],
223 },
224 Content {
225 role: Some("model".into()),
226 parts: vec![Part {
227 text: Some("Hi!".into()),
228 inline_data: None,
229 }],
230 },
231 ]),
232 turn_complete: Some(true),
233 });
234 let json = encode(&msg).unwrap();
235 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
236 assert_eq!(v["clientContent"]["turns"][0]["role"], "user");
237 assert_eq!(v["clientContent"]["turnComplete"], true);
238 }
239
240 #[test]
241 fn encode_realtime_input_audio() {
242 let msg = ClientMessage::RealtimeInput(RealtimeInput {
243 audio: Some(Blob {
244 data: "AAAA".into(),
245 mime_type: "audio/pcm;rate=16000".into(),
246 }),
247 video: None,
248 text: None,
249 activity_start: None,
250 activity_end: None,
251 audio_stream_end: None,
252 });
253 let json = encode(&msg).unwrap();
254 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
255 assert_eq!(
256 v["realtimeInput"]["audio"]["mimeType"],
257 "audio/pcm;rate=16000"
258 );
259 assert!(v["realtimeInput"].get("video").is_none());
261 }
262
263 #[test]
264 fn encode_tool_response() {
265 let msg = ClientMessage::ToolResponse(ToolResponseMessage {
266 function_responses: vec![FunctionResponse {
267 id: "call_123".into(),
268 name: "get_weather".into(),
269 response: serde_json::json!({"temperature": 72}),
270 }],
271 });
272 let json = encode(&msg).unwrap();
273 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
274 assert_eq!(v["toolResponse"]["functionResponses"][0]["id"], "call_123");
275 assert_eq!(
276 v["toolResponse"]["functionResponses"][0]["response"]["temperature"],
277 72
278 );
279 }
280
281 #[test]
284 fn decode_setup_complete() {
285 let json = r#"{"setupComplete":{}}"#;
286 let msg = decode(json).unwrap();
287 assert!(msg.setup_complete.is_some());
288 assert!(msg.server_content.is_none());
289 }
290
291 #[test]
292 fn decode_server_content_text() {
293 let json = r#"{
294 "serverContent": {
295 "modelTurn": {
296 "parts": [{"text": "Hello there!"}]
297 },
298 "turnComplete": true
299 }
300 }"#;
301 let msg = decode(json).unwrap();
302 let sc = msg.server_content.unwrap();
303 let turn = sc.model_turn.unwrap();
304 assert_eq!(turn.parts[0].text.as_deref(), Some("Hello there!"));
305 assert_eq!(sc.turn_complete, Some(true));
306 }
307
308 #[test]
309 fn decode_server_content_with_transcription() {
310 let json = r#"{
311 "serverContent": {
312 "inputTranscription": {"text": "What's the weather?"},
313 "outputTranscription": {"text": "It's sunny today."}
314 }
315 }"#;
316 let msg = decode(json).unwrap();
317 let sc = msg.server_content.unwrap();
318 assert_eq!(sc.input_transcription.unwrap().text, "What's the weather?");
319 assert_eq!(sc.output_transcription.unwrap().text, "It's sunny today.");
320 }
321
322 #[test]
323 fn decode_tool_call() {
324 let json = r#"{
325 "toolCall": {
326 "functionCalls": [{
327 "id": "call_abc",
328 "name": "get_weather",
329 "args": {"city": "Tokyo"}
330 }]
331 }
332 }"#;
333 let msg = decode(json).unwrap();
334 let tc = msg.tool_call.unwrap();
335 assert_eq!(tc.function_calls[0].id, "call_abc");
336 assert_eq!(tc.function_calls[0].name, "get_weather");
337 assert_eq!(tc.function_calls[0].args["city"], "Tokyo");
338 }
339
340 #[test]
341 fn decode_tool_call_cancellation() {
342 let json = r#"{"toolCallCancellation":{"ids":["call_1","call_2"]}}"#;
343 let msg = decode(json).unwrap();
344 let tcc = msg.tool_call_cancellation.unwrap();
345 assert_eq!(tcc.ids, vec!["call_1", "call_2"]);
346 }
347
348 #[test]
349 fn decode_go_away() {
350 let json = r#"{"goAway":{"timeLeft":"30s"}}"#;
351 let msg = decode(json).unwrap();
352 assert_eq!(msg.go_away.unwrap().time_left.as_deref(), Some("30s"));
353 }
354
355 #[test]
356 fn decode_session_resumption() {
357 let json = r#"{"sessionResumptionUpdate":{"newHandle":"handle-xyz","resumable":true}}"#;
358 let msg = decode(json).unwrap();
359 let sr = msg.session_resumption_update.unwrap();
360 assert_eq!(sr.new_handle.as_deref(), Some("handle-xyz"));
361 assert_eq!(sr.resumable, Some(true));
362 }
363
364 #[test]
365 fn decode_usage_metadata() {
366 let json = r#"{
367 "usageMetadata": {
368 "promptTokenCount": 100,
369 "responseTokenCount": 50,
370 "totalTokenCount": 150
371 }
372 }"#;
373 let msg = decode(json).unwrap();
374 let u = msg.usage_metadata.unwrap();
375 assert_eq!(u.prompt_token_count, 100);
376 assert_eq!(u.response_token_count, 50);
377 assert_eq!(u.total_token_count, 150);
378 assert_eq!(u.cached_content_token_count, 0);
380 }
381
382 #[test]
383 fn decode_error() {
384 let json = r#"{"error":{"message":"rate limit exceeded"}}"#;
385 let msg = decode(json).unwrap();
386 assert_eq!(msg.error.unwrap().message, "rate limit exceeded");
387 }
388
389 #[test]
390 fn decode_combined_content_and_usage() {
391 let json = r#"{
392 "serverContent": {
393 "modelTurn": {"parts": [{"text": "hi"}]},
394 "turnComplete": true
395 },
396 "usageMetadata": {"totalTokenCount": 42}
397 }"#;
398 let msg = decode(json).unwrap();
399 assert!(msg.server_content.is_some());
400 assert_eq!(msg.usage_metadata.unwrap().total_token_count, 42);
401 }
402
403 #[test]
406 fn into_events_setup_complete() {
407 let msg = decode(r#"{"setupComplete":{}}"#).unwrap();
408 let events = into_events(msg);
409 assert_eq!(events.len(), 1);
410 assert!(matches!(events[0], ServerEvent::SetupComplete));
411 }
412
413 #[test]
414 fn into_events_model_text_and_turn_complete() {
415 let msg = decode(
416 r#"{"serverContent":{"modelTurn":{"parts":[{"text":"hello"}]},"turnComplete":true}}"#,
417 )
418 .unwrap();
419 let events = into_events(msg);
420 assert!(
421 events
422 .iter()
423 .any(|e| matches!(e, ServerEvent::ModelText(t) if t == "hello"))
424 );
425 assert!(
426 events
427 .iter()
428 .any(|e| matches!(e, ServerEvent::TurnComplete))
429 );
430 }
431
432 #[test]
433 fn into_events_model_audio_base64_decoded() {
434 let msg = decode(
436 r#"{"serverContent":{"modelTurn":{"parts":[{"inlineData":{"data":"AQID","mimeType":"audio/pcm;rate=24000"}}]}}}"#,
437 )
438 .unwrap();
439 let events = into_events(msg);
440 assert!(
441 events
442 .iter()
443 .any(|e| matches!(e, ServerEvent::ModelAudio(b) if b == &[1, 2, 3]))
444 );
445 }
446
447 #[test]
448 fn into_events_go_away_parses_duration() {
449 let msg = decode(r#"{"goAway":{"timeLeft":"30s"}}"#).unwrap();
450 let events = into_events(msg);
451 assert!(
452 events.iter().any(
453 |e| matches!(e, ServerEvent::GoAway { time_left: Some(d) } if *d == std::time::Duration::from_secs(30))
454 )
455 );
456 }
457
458 #[test]
459 fn into_events_combined_message() {
460 let json = r#"{
461 "serverContent": {
462 "inputTranscription": {"text": "hey"},
463 "modelTurn": {"parts": [{"text": "hi"}]},
464 "turnComplete": true
465 },
466 "usageMetadata": {"totalTokenCount": 10}
467 }"#;
468 let msg = decode(json).unwrap();
469 let events = into_events(msg);
470 assert_eq!(events.len(), 4);
472 assert!(matches!(&events[0], ServerEvent::InputTranscription(t) if t == "hey"));
473 assert!(matches!(&events[1], ServerEvent::ModelText(t) if t == "hi"));
474 assert!(matches!(&events[2], ServerEvent::TurnComplete));
475 assert!(matches!(&events[3], ServerEvent::Usage(_)));
476 }
477
478 #[test]
481 fn parse_duration_integer_seconds() {
482 assert_eq!(
483 parse_protobuf_duration("30s"),
484 Some(Duration::from_secs(30))
485 );
486 }
487
488 #[test]
489 fn parse_duration_fractional_seconds() {
490 assert_eq!(
491 parse_protobuf_duration("1.5s"),
492 Some(Duration::from_secs_f64(1.5))
493 );
494 }
495
496 #[test]
497 fn parse_duration_invalid() {
498 assert_eq!(parse_protobuf_duration("30m"), None);
499 assert_eq!(parse_protobuf_duration("abc"), None);
500 }
501}