1use std::time::Duration;
15
16use base64::Engine;
17
18use crate::error::CodecError;
19use crate::types::{ClientMessage, ServerEvent, ServerMessage};
20
21pub fn encode(msg: &ClientMessage) -> Result<String, CodecError> {
27 serde_json::to_string(msg).map_err(CodecError::Serialize)
28}
29
30pub fn decode(json: &str) -> Result<ServerMessage, CodecError> {
32 serde_json::from_str(json).map_err(CodecError::Deserialize)
33}
34
35pub fn into_events(msg: ServerMessage) -> Vec<ServerEvent> {
54 let mut events = Vec::new();
55
56 if msg.setup_complete.is_some() {
58 events.push(ServerEvent::SetupComplete);
59 }
60
61 if let Some(sc) = msg.server_content {
63 if let Some(t) = sc.input_transcription
64 && let Some(text) = t.text
65 {
66 events.push(ServerEvent::InputTranscription(text));
67 }
68 if let Some(t) = sc.output_transcription
69 && let Some(text) = t.text
70 {
71 events.push(ServerEvent::OutputTranscription(text));
72 }
73
74 if let Some(turn) = sc.model_turn {
75 for part in turn.parts {
76 if let Some(text) = part.text {
77 events.push(ServerEvent::ModelText(text));
78 }
79 if let Some(blob) = part.inline_data {
80 match base64::engine::general_purpose::STANDARD.decode(&blob.data) {
81 Ok(bytes) => events.push(ServerEvent::ModelAudio(bytes)),
82 Err(e) => {
83 tracing::warn!(error = %e, "failed to base64-decode model audio");
84 }
85 }
86 }
87 }
88 }
89
90 if sc.interrupted == Some(true) {
91 events.push(ServerEvent::Interrupted);
92 }
93 if sc.generation_complete == Some(true) {
94 events.push(ServerEvent::GenerationComplete);
95 }
96 if sc.turn_complete == Some(true) {
97 events.push(ServerEvent::TurnComplete);
98 }
99 }
100
101 if let Some(tc) = msg.tool_call {
103 events.push(ServerEvent::ToolCall(tc.function_calls));
104 }
105 if let Some(tcc) = msg.tool_call_cancellation {
106 events.push(ServerEvent::ToolCallCancellation(tcc.ids));
107 }
108
109 if let Some(sr) = msg.session_resumption_update {
111 events.push(ServerEvent::SessionResumption {
112 new_handle: sr.new_handle,
113 resumable: sr.resumable.unwrap_or(false),
114 });
115 }
116 if let Some(ga) = msg.go_away {
117 events.push(ServerEvent::GoAway {
118 time_left: ga.time_left.as_deref().and_then(parse_protobuf_duration),
119 });
120 }
121
122 if let Some(usage) = msg.usage_metadata {
124 events.push(ServerEvent::Usage(usage));
125 }
126
127 if let Some(err) = msg.error {
129 events.push(ServerEvent::Error(err));
130 }
131
132 events
133}
134
135fn parse_protobuf_duration(s: &str) -> Option<Duration> {
138 let s = s.trim();
139 let secs_str = s.strip_suffix('s')?;
140 let secs: f64 = secs_str.parse().ok()?;
141 Some(Duration::from_secs_f64(secs))
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::types::*;
148
149 #[test]
152 fn encode_setup_minimal() {
153 let msg = ClientMessage::Setup(SetupConfig {
154 model: "models/gemini-3.1-flash-live-preview".into(),
155 ..Default::default()
156 });
157 let json = encode(&msg).unwrap();
158 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
159 assert_eq!(v["setup"]["model"], "models/gemini-3.1-flash-live-preview");
160 assert!(v["setup"].get("generationConfig").is_none());
162 }
163
164 #[test]
165 fn encode_setup_full() {
166 let msg = ClientMessage::Setup(SetupConfig {
167 model: "models/gemini-3.1-flash-live-preview".into(),
168 generation_config: Some(GenerationConfig {
169 response_modalities: Some(vec![Modality::Audio, Modality::Text]),
170 speech_config: Some(SpeechConfig {
171 voice_config: VoiceConfig {
172 prebuilt_voice_config: PrebuiltVoiceConfig {
173 voice_name: "Kore".into(),
174 },
175 },
176 }),
177 thinking_config: Some(ThinkingConfig {
178 thinking_level: Some(ThinkingLevel::Medium),
179 ..Default::default()
180 }),
181 ..Default::default()
182 }),
183 system_instruction: Some(Content {
184 role: None,
185 parts: vec![Part {
186 text: Some("You are a helpful assistant.".into()),
187 inline_data: None,
188 }],
189 }),
190 input_audio_transcription: Some(AudioTranscriptionConfig {}),
191 output_audio_transcription: Some(AudioTranscriptionConfig {}),
192 session_resumption: Some(SessionResumptionConfig { handle: None }),
193 context_window_compression: Some(ContextWindowCompressionConfig {
194 sliding_window: Some(SlidingWindow::default()),
195 trigger_tokens: None,
196 }),
197 ..Default::default()
198 });
199 let json = encode(&msg).unwrap();
200 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
201 let setup = &v["setup"];
202 assert_eq!(setup["generationConfig"]["responseModalities"][0], "AUDIO");
203 assert_eq!(setup["generationConfig"]["responseModalities"][1], "TEXT");
204 assert_eq!(
205 setup["generationConfig"]["speechConfig"]["voiceConfig"]["prebuiltVoiceConfig"]["voiceName"],
206 "Kore"
207 );
208 assert_eq!(
209 setup["generationConfig"]["thinkingConfig"]["thinkingLevel"],
210 "medium"
211 );
212 assert_eq!(
213 setup["systemInstruction"]["parts"][0]["text"],
214 "You are a helpful assistant."
215 );
216 assert_eq!(setup["inputAudioTranscription"], serde_json::json!({}));
218 assert_eq!(setup["outputAudioTranscription"], serde_json::json!({}));
219 assert_eq!(
220 setup["contextWindowCompression"],
221 serde_json::json!({ "slidingWindow": {} })
222 );
223 }
224
225 #[test]
226 fn encode_setup_with_builtin_and_function_tools() {
227 let msg = ClientMessage::Setup(SetupConfig {
228 model: "models/gemini-3.1-flash-live-preview".into(),
229 tools: Some(vec![
230 Tool::GoogleSearch(GoogleSearchTool {}),
231 Tool::FunctionDeclarations(vec![FunctionDeclaration {
232 name: "read_file".into(),
233 description: "Read a file from the workspace.".into(),
234 parameters: serde_json::json!({
235 "type": "object",
236 "required": ["path"],
237 "properties": {
238 "path": { "type": "string" }
239 }
240 }),
241 scheduling: None,
242 behavior: None,
243 }]),
244 ]),
245 ..Default::default()
246 });
247 let json = encode(&msg).unwrap();
248 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
249 let tools = v["setup"]["tools"].as_array().expect("tools array");
250 assert_eq!(tools[0]["googleSearch"], serde_json::json!({}));
251 assert_eq!(tools[1]["functionDeclarations"][0]["name"], "read_file");
252 }
253
254 #[test]
255 fn encode_client_content() {
256 let msg = ClientMessage::ClientContent(ClientContent {
257 turns: Some(vec![
258 Content {
259 role: Some("user".into()),
260 parts: vec![Part {
261 text: Some("Hello".into()),
262 inline_data: None,
263 }],
264 },
265 Content {
266 role: Some("model".into()),
267 parts: vec![Part {
268 text: Some("Hi!".into()),
269 inline_data: None,
270 }],
271 },
272 ]),
273 turn_complete: Some(true),
274 });
275 let json = encode(&msg).unwrap();
276 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
277 assert_eq!(v["clientContent"]["turns"][0]["role"], "user");
278 assert_eq!(v["clientContent"]["turnComplete"], true);
279 }
280
281 #[test]
282 fn encode_realtime_input_audio() {
283 let msg = ClientMessage::RealtimeInput(RealtimeInput {
284 audio: Some(Blob {
285 data: "AAAA".into(),
286 mime_type: "audio/pcm;rate=16000".into(),
287 }),
288 video: None,
289 text: None,
290 activity_start: None,
291 activity_end: None,
292 audio_stream_end: None,
293 });
294 let json = encode(&msg).unwrap();
295 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
296 assert_eq!(
297 v["realtimeInput"]["audio"]["mimeType"],
298 "audio/pcm;rate=16000"
299 );
300 assert!(v["realtimeInput"].get("video").is_none());
302 }
303
304 #[test]
305 fn encode_tool_response() {
306 let msg = ClientMessage::ToolResponse(ToolResponseMessage {
307 function_responses: vec![FunctionResponse {
308 id: "call_123".into(),
309 name: "get_weather".into(),
310 response: serde_json::json!({"temperature": 72}),
311 }],
312 });
313 let json = encode(&msg).unwrap();
314 let v: serde_json::Value = serde_json::from_str(&json).unwrap();
315 assert_eq!(v["toolResponse"]["functionResponses"][0]["id"], "call_123");
316 assert_eq!(
317 v["toolResponse"]["functionResponses"][0]["response"]["temperature"],
318 72
319 );
320 }
321
322 #[test]
325 fn decode_setup_complete() {
326 let json = r#"{"setupComplete":{}}"#;
327 let msg = decode(json).unwrap();
328 assert!(msg.setup_complete.is_some());
329 assert!(msg.server_content.is_none());
330 }
331
332 #[test]
333 fn decode_server_content_text() {
334 let json = r#"{
335 "serverContent": {
336 "modelTurn": {
337 "parts": [{"text": "Hello there!"}]
338 },
339 "turnComplete": true
340 }
341 }"#;
342 let msg = decode(json).unwrap();
343 let sc = msg.server_content.unwrap();
344 let turn = sc.model_turn.unwrap();
345 assert_eq!(turn.parts[0].text.as_deref(), Some("Hello there!"));
346 assert_eq!(sc.turn_complete, Some(true));
347 }
348
349 #[test]
350 fn decode_server_content_with_transcription() {
351 let json = r#"{
352 "serverContent": {
353 "inputTranscription": {"text": "What's the weather?"},
354 "outputTranscription": {"text": "It's sunny today."}
355 }
356 }"#;
357 let msg = decode(json).unwrap();
358 let sc = msg.server_content.unwrap();
359 assert_eq!(
360 sc.input_transcription.unwrap().text.as_deref(),
361 Some("What's the weather?")
362 );
363 assert_eq!(
364 sc.output_transcription.unwrap().text.as_deref(),
365 Some("It's sunny today.")
366 );
367 }
368
369 #[test]
370 fn decode_transcription_finished_without_text() {
371 let json = r#"{
372 "serverContent": {
373 "outputTranscription": {"finished": true}
374 }
375 }"#;
376 let msg = decode(json).unwrap();
377 let sc = msg.server_content.unwrap();
378 let transcription = sc.output_transcription.unwrap();
379 assert_eq!(transcription.text, None);
380 assert_eq!(transcription.finished, Some(true));
381 }
382
383 #[test]
384 fn decode_tool_call() {
385 let json = r#"{
386 "toolCall": {
387 "functionCalls": [{
388 "id": "call_abc",
389 "name": "get_weather",
390 "args": {"city": "Tokyo"}
391 }]
392 }
393 }"#;
394 let msg = decode(json).unwrap();
395 let tc = msg.tool_call.unwrap();
396 assert_eq!(tc.function_calls[0].id, "call_abc");
397 assert_eq!(tc.function_calls[0].name, "get_weather");
398 assert_eq!(tc.function_calls[0].args["city"], "Tokyo");
399 }
400
401 #[test]
402 fn decode_tool_call_cancellation() {
403 let json = r#"{"toolCallCancellation":{"ids":["call_1","call_2"]}}"#;
404 let msg = decode(json).unwrap();
405 let tcc = msg.tool_call_cancellation.unwrap();
406 assert_eq!(tcc.ids, vec!["call_1", "call_2"]);
407 }
408
409 #[test]
410 fn decode_go_away() {
411 let json = r#"{"goAway":{"timeLeft":"30s"}}"#;
412 let msg = decode(json).unwrap();
413 assert_eq!(msg.go_away.unwrap().time_left.as_deref(), Some("30s"));
414 }
415
416 #[test]
417 fn decode_session_resumption() {
418 let json = r#"{"sessionResumptionUpdate":{"newHandle":"handle-xyz","resumable":true}}"#;
419 let msg = decode(json).unwrap();
420 let sr = msg.session_resumption_update.unwrap();
421 assert_eq!(sr.new_handle.as_deref(), Some("handle-xyz"));
422 assert_eq!(sr.resumable, Some(true));
423 }
424
425 #[test]
426 fn decode_usage_metadata() {
427 let json = r#"{
428 "usageMetadata": {
429 "promptTokenCount": 100,
430 "responseTokenCount": 50,
431 "totalTokenCount": 150
432 }
433 }"#;
434 let msg = decode(json).unwrap();
435 let u = msg.usage_metadata.unwrap();
436 assert_eq!(u.prompt_token_count, 100);
437 assert_eq!(u.response_token_count, 50);
438 assert_eq!(u.total_token_count, 150);
439 assert_eq!(u.cached_content_token_count, 0);
441 }
442
443 #[test]
444 fn decode_error() {
445 let json = r#"{"error":{"message":"rate limit exceeded"}}"#;
446 let msg = decode(json).unwrap();
447 assert_eq!(msg.error.unwrap().message, "rate limit exceeded");
448 }
449
450 #[test]
451 fn decode_combined_content_and_usage() {
452 let json = r#"{
453 "serverContent": {
454 "modelTurn": {"parts": [{"text": "hi"}]},
455 "turnComplete": true
456 },
457 "usageMetadata": {"totalTokenCount": 42}
458 }"#;
459 let msg = decode(json).unwrap();
460 assert!(msg.server_content.is_some());
461 assert_eq!(msg.usage_metadata.unwrap().total_token_count, 42);
462 }
463
464 #[test]
467 fn into_events_setup_complete() {
468 let msg = decode(r#"{"setupComplete":{}}"#).unwrap();
469 let events = into_events(msg);
470 assert_eq!(events.len(), 1);
471 assert!(matches!(events[0], ServerEvent::SetupComplete));
472 }
473
474 #[test]
475 fn into_events_model_text_and_turn_complete() {
476 let msg = decode(
477 r#"{"serverContent":{"modelTurn":{"parts":[{"text":"hello"}]},"turnComplete":true}}"#,
478 )
479 .unwrap();
480 let events = into_events(msg);
481 assert!(
482 events
483 .iter()
484 .any(|e| matches!(e, ServerEvent::ModelText(t) if t == "hello"))
485 );
486 assert!(
487 events
488 .iter()
489 .any(|e| matches!(e, ServerEvent::TurnComplete))
490 );
491 }
492
493 #[test]
494 fn into_events_model_audio_base64_decoded() {
495 let msg = decode(
497 r#"{"serverContent":{"modelTurn":{"parts":[{"inlineData":{"data":"AQID","mimeType":"audio/pcm;rate=24000"}}]}}}"#,
498 )
499 .unwrap();
500 let events = into_events(msg);
501 assert!(
502 events
503 .iter()
504 .any(|e| matches!(e, ServerEvent::ModelAudio(b) if b == &[1, 2, 3]))
505 );
506 }
507
508 #[test]
509 fn into_events_go_away_parses_duration() {
510 let msg = decode(r#"{"goAway":{"timeLeft":"30s"}}"#).unwrap();
511 let events = into_events(msg);
512 assert!(
513 events.iter().any(
514 |e| matches!(e, ServerEvent::GoAway { time_left: Some(d) } if *d == std::time::Duration::from_secs(30))
515 )
516 );
517 }
518
519 #[test]
520 fn into_events_combined_message() {
521 let json = r#"{
522 "serverContent": {
523 "inputTranscription": {"text": "hey"},
524 "modelTurn": {"parts": [{"text": "hi"}]},
525 "turnComplete": true
526 },
527 "usageMetadata": {"totalTokenCount": 10}
528 }"#;
529 let msg = decode(json).unwrap();
530 let events = into_events(msg);
531 assert_eq!(events.len(), 4);
533 assert!(matches!(&events[0], ServerEvent::InputTranscription(t) if t == "hey"));
534 assert!(matches!(&events[1], ServerEvent::ModelText(t) if t == "hi"));
535 assert!(matches!(&events[2], ServerEvent::TurnComplete));
536 assert!(matches!(&events[3], ServerEvent::Usage(_)));
537 }
538
539 #[test]
540 fn into_events_ignores_transcription_markers_without_text() {
541 let json = r#"{
542 "serverContent": {
543 "outputTranscription": {"finished": true},
544 "turnComplete": true
545 }
546 }"#;
547 let msg = decode(json).unwrap();
548 let events = into_events(msg);
549
550 assert_eq!(events.len(), 1);
551 assert!(matches!(&events[0], ServerEvent::TurnComplete));
552 }
553
554 #[test]
557 fn parse_duration_integer_seconds() {
558 assert_eq!(
559 parse_protobuf_duration("30s"),
560 Some(Duration::from_secs(30))
561 );
562 }
563
564 #[test]
565 fn parse_duration_fractional_seconds() {
566 assert_eq!(
567 parse_protobuf_duration("1.5s"),
568 Some(Duration::from_secs_f64(1.5))
569 );
570 }
571
572 #[test]
573 fn parse_duration_invalid() {
574 assert_eq!(parse_protobuf_duration("30m"), None);
575 assert_eq!(parse_protobuf_duration("abc"), None);
576 }
577}