Skip to main content

gemini_live/
codec.rs

1//! JSON ↔ Rust codec and semantic event decomposition.
2//!
3//! This module sits between the transport layer (raw WebSocket frames) and the
4//! session layer (typed events).  It provides three operations:
5//!
6//! - [`encode`] — serialise a [`ClientMessage`] to a JSON string for sending.
7//! - [`decode`] — parse a JSON string from the wire into a [`ServerMessage`].
8//! - [`into_events`] — decompose one [`ServerMessage`] into a `Vec<ServerEvent>`.
9//!
10//! The split between `decode` and `into_events` allows the session layer to
11//! inspect raw fields (e.g. `session_resumption_update`) before broadcasting
12//! the higher-level events to the application.
13
14use std::time::Duration;
15
16use base64::Engine;
17
18use crate::error::CodecError;
19use crate::types::{ClientMessage, ServerEvent, ServerMessage};
20
21/// Serialise a [`ClientMessage`] to its JSON wire representation.
22///
23/// **Performance note:** allocates a new `String` per call.  A future
24/// `encode_into` variant accepting a reusable buffer is planned
25/// (see `docs/roadmap.md` P-2).
26pub fn encode(msg: &ClientMessage) -> Result<String, CodecError> {
27    serde_json::to_string(msg).map_err(CodecError::Serialize)
28}
29
30/// Parse a JSON string from the wire into a [`ServerMessage`].
31pub fn decode(json: &str) -> Result<ServerMessage, CodecError> {
32    serde_json::from_str(json).map_err(CodecError::Deserialize)
33}
34
35/// Decompose a single [`ServerMessage`] into a sequence of semantic
36/// [`ServerEvent`]s.
37///
38/// One wire message often carries several pieces of information (e.g.
39/// `serverContent` with both a `modelTurn` and `inputTranscription`, plus
40/// `usageMetadata`).  This function teases them apart into discrete,
41/// easy-to-match events.
42///
43/// The event ordering follows a natural "content first, metadata last"
44/// convention:
45/// 1. `SetupComplete`
46/// 2. Transcriptions (input, then output)
47/// 3. Model content (text / audio parts in wire order)
48/// 4. Flags (`Interrupted`, `GenerationComplete`, `TurnComplete`)
49/// 5. Tool calls / cancellations
50/// 6. Session lifecycle (`SessionResumption`, `GoAway`)
51/// 7. `Usage`
52/// 8. `Error`
53pub fn into_events(msg: ServerMessage) -> Vec<ServerEvent> {
54    let mut events = Vec::new();
55
56    // 1. Setup handshake
57    if msg.setup_complete.is_some() {
58        events.push(ServerEvent::SetupComplete);
59    }
60
61    // 2–4. Server content
62    if let Some(sc) = msg.server_content {
63        if let Some(t) = sc.input_transcription
64            && let Some(text) = t.text
65        {
66            events.push(ServerEvent::InputTranscription(text));
67        }
68        if let Some(t) = sc.output_transcription
69            && let Some(text) = t.text
70        {
71            events.push(ServerEvent::OutputTranscription(text));
72        }
73
74        if let Some(turn) = sc.model_turn {
75            for part in turn.parts {
76                if let Some(text) = part.text {
77                    events.push(ServerEvent::ModelText(text));
78                }
79                if let Some(blob) = part.inline_data {
80                    match base64::engine::general_purpose::STANDARD.decode(&blob.data) {
81                        Ok(bytes) => events.push(ServerEvent::ModelAudio(bytes)),
82                        Err(e) => {
83                            tracing::warn!(error = %e, "failed to base64-decode model audio");
84                        }
85                    }
86                }
87            }
88        }
89
90        if sc.interrupted == Some(true) {
91            events.push(ServerEvent::Interrupted);
92        }
93        if sc.generation_complete == Some(true) {
94            events.push(ServerEvent::GenerationComplete);
95        }
96        if sc.turn_complete == Some(true) {
97            events.push(ServerEvent::TurnComplete);
98        }
99    }
100
101    // 5. Tool calls
102    if let Some(tc) = msg.tool_call {
103        events.push(ServerEvent::ToolCall(tc.function_calls));
104    }
105    if let Some(tcc) = msg.tool_call_cancellation {
106        events.push(ServerEvent::ToolCallCancellation(tcc.ids));
107    }
108
109    // 6. Session lifecycle
110    if let Some(sr) = msg.session_resumption_update {
111        events.push(ServerEvent::SessionResumption {
112            new_handle: sr.new_handle,
113            resumable: sr.resumable.unwrap_or(false),
114        });
115    }
116    if let Some(ga) = msg.go_away {
117        events.push(ServerEvent::GoAway {
118            time_left: ga.time_left.as_deref().and_then(parse_protobuf_duration),
119        });
120    }
121
122    // 7. Usage
123    if let Some(usage) = msg.usage_metadata {
124        events.push(ServerEvent::Usage(usage));
125    }
126
127    // 8. Error
128    if let Some(err) = msg.error {
129        events.push(ServerEvent::Error(err));
130    }
131
132    events
133}
134
135/// Parse a protobuf Duration string (e.g. `"30s"`, `"1.5s"`) into a
136/// [`std::time::Duration`].
137fn parse_protobuf_duration(s: &str) -> Option<Duration> {
138    let s = s.trim();
139    let secs_str = s.strip_suffix('s')?;
140    let secs: f64 = secs_str.parse().ok()?;
141    Some(Duration::from_secs_f64(secs))
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::types::*;
148
149    // ── encode ───────────────────────────────────────────────────────────
150
151    #[test]
152    fn encode_setup_minimal() {
153        let msg = ClientMessage::Setup(SetupConfig {
154            model: "models/gemini-3.1-flash-live-preview".into(),
155            ..Default::default()
156        });
157        let json = encode(&msg).unwrap();
158        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
159        assert_eq!(v["setup"]["model"], "models/gemini-3.1-flash-live-preview");
160        // Optional fields should be absent, not null.
161        assert!(v["setup"].get("generationConfig").is_none());
162    }
163
164    #[test]
165    fn encode_setup_full() {
166        let msg = ClientMessage::Setup(SetupConfig {
167            model: "models/gemini-3.1-flash-live-preview".into(),
168            generation_config: Some(GenerationConfig {
169                response_modalities: Some(vec![Modality::Audio, Modality::Text]),
170                speech_config: Some(SpeechConfig {
171                    voice_config: VoiceConfig {
172                        prebuilt_voice_config: PrebuiltVoiceConfig {
173                            voice_name: "Kore".into(),
174                        },
175                    },
176                }),
177                thinking_config: Some(ThinkingConfig {
178                    thinking_level: Some(ThinkingLevel::Medium),
179                    ..Default::default()
180                }),
181                ..Default::default()
182            }),
183            system_instruction: Some(Content {
184                role: None,
185                parts: vec![Part {
186                    text: Some("You are a helpful assistant.".into()),
187                    inline_data: None,
188                }],
189            }),
190            input_audio_transcription: Some(AudioTranscriptionConfig {}),
191            output_audio_transcription: Some(AudioTranscriptionConfig {}),
192            session_resumption: Some(SessionResumptionConfig { handle: None }),
193            context_window_compression: Some(ContextWindowCompressionConfig {
194                sliding_window: Some(SlidingWindow::default()),
195                trigger_tokens: None,
196            }),
197            ..Default::default()
198        });
199        let json = encode(&msg).unwrap();
200        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
201        let setup = &v["setup"];
202        assert_eq!(setup["generationConfig"]["responseModalities"][0], "AUDIO");
203        assert_eq!(setup["generationConfig"]["responseModalities"][1], "TEXT");
204        assert_eq!(
205            setup["generationConfig"]["speechConfig"]["voiceConfig"]["prebuiltVoiceConfig"]["voiceName"],
206            "Kore"
207        );
208        assert_eq!(
209            setup["generationConfig"]["thinkingConfig"]["thinkingLevel"],
210            "medium"
211        );
212        assert_eq!(
213            setup["systemInstruction"]["parts"][0]["text"],
214            "You are a helpful assistant."
215        );
216        // Presence-activated configs should appear as `{}`
217        assert_eq!(setup["inputAudioTranscription"], serde_json::json!({}));
218        assert_eq!(setup["outputAudioTranscription"], serde_json::json!({}));
219        assert_eq!(
220            setup["contextWindowCompression"],
221            serde_json::json!({ "slidingWindow": {} })
222        );
223    }
224
225    #[test]
226    fn encode_setup_with_builtin_and_function_tools() {
227        let msg = ClientMessage::Setup(SetupConfig {
228            model: "models/gemini-3.1-flash-live-preview".into(),
229            tools: Some(vec![
230                Tool::GoogleSearch(GoogleSearchTool {}),
231                Tool::FunctionDeclarations(vec![FunctionDeclaration {
232                    name: "read_file".into(),
233                    description: "Read a file from the workspace.".into(),
234                    parameters: serde_json::json!({
235                        "type": "object",
236                        "required": ["path"],
237                        "properties": {
238                            "path": { "type": "string" }
239                        }
240                    }),
241                    scheduling: None,
242                    behavior: None,
243                }]),
244            ]),
245            ..Default::default()
246        });
247        let json = encode(&msg).unwrap();
248        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
249        let tools = v["setup"]["tools"].as_array().expect("tools array");
250        assert_eq!(tools[0]["googleSearch"], serde_json::json!({}));
251        assert_eq!(tools[1]["functionDeclarations"][0]["name"], "read_file");
252    }
253
254    #[test]
255    fn encode_client_content() {
256        let msg = ClientMessage::ClientContent(ClientContent {
257            turns: Some(vec![
258                Content {
259                    role: Some("user".into()),
260                    parts: vec![Part {
261                        text: Some("Hello".into()),
262                        inline_data: None,
263                    }],
264                },
265                Content {
266                    role: Some("model".into()),
267                    parts: vec![Part {
268                        text: Some("Hi!".into()),
269                        inline_data: None,
270                    }],
271                },
272            ]),
273            turn_complete: Some(true),
274        });
275        let json = encode(&msg).unwrap();
276        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
277        assert_eq!(v["clientContent"]["turns"][0]["role"], "user");
278        assert_eq!(v["clientContent"]["turnComplete"], true);
279    }
280
281    #[test]
282    fn encode_realtime_input_audio() {
283        let msg = ClientMessage::RealtimeInput(RealtimeInput {
284            audio: Some(Blob {
285                data: "AAAA".into(),
286                mime_type: "audio/pcm;rate=16000".into(),
287            }),
288            video: None,
289            text: None,
290            activity_start: None,
291            activity_end: None,
292            audio_stream_end: None,
293        });
294        let json = encode(&msg).unwrap();
295        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
296        assert_eq!(
297            v["realtimeInput"]["audio"]["mimeType"],
298            "audio/pcm;rate=16000"
299        );
300        // Other fields should be absent
301        assert!(v["realtimeInput"].get("video").is_none());
302    }
303
304    #[test]
305    fn encode_tool_response() {
306        let msg = ClientMessage::ToolResponse(ToolResponseMessage {
307            function_responses: vec![FunctionResponse {
308                id: "call_123".into(),
309                name: "get_weather".into(),
310                response: serde_json::json!({"temperature": 72}),
311            }],
312        });
313        let json = encode(&msg).unwrap();
314        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
315        assert_eq!(v["toolResponse"]["functionResponses"][0]["id"], "call_123");
316        assert_eq!(
317            v["toolResponse"]["functionResponses"][0]["response"]["temperature"],
318            72
319        );
320    }
321
322    // ── decode ───────────────────────────────────────────────────────────
323
324    #[test]
325    fn decode_setup_complete() {
326        let json = r#"{"setupComplete":{}}"#;
327        let msg = decode(json).unwrap();
328        assert!(msg.setup_complete.is_some());
329        assert!(msg.server_content.is_none());
330    }
331
332    #[test]
333    fn decode_server_content_text() {
334        let json = r#"{
335            "serverContent": {
336                "modelTurn": {
337                    "parts": [{"text": "Hello there!"}]
338                },
339                "turnComplete": true
340            }
341        }"#;
342        let msg = decode(json).unwrap();
343        let sc = msg.server_content.unwrap();
344        let turn = sc.model_turn.unwrap();
345        assert_eq!(turn.parts[0].text.as_deref(), Some("Hello there!"));
346        assert_eq!(sc.turn_complete, Some(true));
347    }
348
349    #[test]
350    fn decode_server_content_with_transcription() {
351        let json = r#"{
352            "serverContent": {
353                "inputTranscription": {"text": "What's the weather?"},
354                "outputTranscription": {"text": "It's sunny today."}
355            }
356        }"#;
357        let msg = decode(json).unwrap();
358        let sc = msg.server_content.unwrap();
359        assert_eq!(
360            sc.input_transcription.unwrap().text.as_deref(),
361            Some("What's the weather?")
362        );
363        assert_eq!(
364            sc.output_transcription.unwrap().text.as_deref(),
365            Some("It's sunny today.")
366        );
367    }
368
369    #[test]
370    fn decode_transcription_finished_without_text() {
371        let json = r#"{
372            "serverContent": {
373                "outputTranscription": {"finished": true}
374            }
375        }"#;
376        let msg = decode(json).unwrap();
377        let sc = msg.server_content.unwrap();
378        let transcription = sc.output_transcription.unwrap();
379        assert_eq!(transcription.text, None);
380        assert_eq!(transcription.finished, Some(true));
381    }
382
383    #[test]
384    fn decode_tool_call() {
385        let json = r#"{
386            "toolCall": {
387                "functionCalls": [{
388                    "id": "call_abc",
389                    "name": "get_weather",
390                    "args": {"city": "Tokyo"}
391                }]
392            }
393        }"#;
394        let msg = decode(json).unwrap();
395        let tc = msg.tool_call.unwrap();
396        assert_eq!(tc.function_calls[0].id, "call_abc");
397        assert_eq!(tc.function_calls[0].name, "get_weather");
398        assert_eq!(tc.function_calls[0].args["city"], "Tokyo");
399    }
400
401    #[test]
402    fn decode_tool_call_cancellation() {
403        let json = r#"{"toolCallCancellation":{"ids":["call_1","call_2"]}}"#;
404        let msg = decode(json).unwrap();
405        let tcc = msg.tool_call_cancellation.unwrap();
406        assert_eq!(tcc.ids, vec!["call_1", "call_2"]);
407    }
408
409    #[test]
410    fn decode_go_away() {
411        let json = r#"{"goAway":{"timeLeft":"30s"}}"#;
412        let msg = decode(json).unwrap();
413        assert_eq!(msg.go_away.unwrap().time_left.as_deref(), Some("30s"));
414    }
415
416    #[test]
417    fn decode_session_resumption() {
418        let json = r#"{"sessionResumptionUpdate":{"newHandle":"handle-xyz","resumable":true}}"#;
419        let msg = decode(json).unwrap();
420        let sr = msg.session_resumption_update.unwrap();
421        assert_eq!(sr.new_handle.as_deref(), Some("handle-xyz"));
422        assert_eq!(sr.resumable, Some(true));
423    }
424
425    #[test]
426    fn decode_usage_metadata() {
427        let json = r#"{
428            "usageMetadata": {
429                "promptTokenCount": 100,
430                "responseTokenCount": 50,
431                "totalTokenCount": 150
432            }
433        }"#;
434        let msg = decode(json).unwrap();
435        let u = msg.usage_metadata.unwrap();
436        assert_eq!(u.prompt_token_count, 100);
437        assert_eq!(u.response_token_count, 50);
438        assert_eq!(u.total_token_count, 150);
439        // Missing fields default to 0
440        assert_eq!(u.cached_content_token_count, 0);
441    }
442
443    #[test]
444    fn decode_error() {
445        let json = r#"{"error":{"message":"rate limit exceeded"}}"#;
446        let msg = decode(json).unwrap();
447        assert_eq!(msg.error.unwrap().message, "rate limit exceeded");
448    }
449
450    #[test]
451    fn decode_combined_content_and_usage() {
452        let json = r#"{
453            "serverContent": {
454                "modelTurn": {"parts": [{"text": "hi"}]},
455                "turnComplete": true
456            },
457            "usageMetadata": {"totalTokenCount": 42}
458        }"#;
459        let msg = decode(json).unwrap();
460        assert!(msg.server_content.is_some());
461        assert_eq!(msg.usage_metadata.unwrap().total_token_count, 42);
462    }
463
464    // ── into_events ──────────────────────────────────────────────────────
465
466    #[test]
467    fn into_events_setup_complete() {
468        let msg = decode(r#"{"setupComplete":{}}"#).unwrap();
469        let events = into_events(msg);
470        assert_eq!(events.len(), 1);
471        assert!(matches!(events[0], ServerEvent::SetupComplete));
472    }
473
474    #[test]
475    fn into_events_model_text_and_turn_complete() {
476        let msg = decode(
477            r#"{"serverContent":{"modelTurn":{"parts":[{"text":"hello"}]},"turnComplete":true}}"#,
478        )
479        .unwrap();
480        let events = into_events(msg);
481        assert!(
482            events
483                .iter()
484                .any(|e| matches!(e, ServerEvent::ModelText(t) if t == "hello"))
485        );
486        assert!(
487            events
488                .iter()
489                .any(|e| matches!(e, ServerEvent::TurnComplete))
490        );
491    }
492
493    #[test]
494    fn into_events_model_audio_base64_decoded() {
495        // "AQID" is base64 for bytes [1, 2, 3]
496        let msg = decode(
497            r#"{"serverContent":{"modelTurn":{"parts":[{"inlineData":{"data":"AQID","mimeType":"audio/pcm;rate=24000"}}]}}}"#,
498        )
499        .unwrap();
500        let events = into_events(msg);
501        assert!(
502            events
503                .iter()
504                .any(|e| matches!(e, ServerEvent::ModelAudio(b) if b == &[1, 2, 3]))
505        );
506    }
507
508    #[test]
509    fn into_events_go_away_parses_duration() {
510        let msg = decode(r#"{"goAway":{"timeLeft":"30s"}}"#).unwrap();
511        let events = into_events(msg);
512        assert!(
513            events.iter().any(
514                |e| matches!(e, ServerEvent::GoAway { time_left: Some(d) } if *d == std::time::Duration::from_secs(30))
515            )
516        );
517    }
518
519    #[test]
520    fn into_events_combined_message() {
521        let json = r#"{
522            "serverContent": {
523                "inputTranscription": {"text": "hey"},
524                "modelTurn": {"parts": [{"text": "hi"}]},
525                "turnComplete": true
526            },
527            "usageMetadata": {"totalTokenCount": 10}
528        }"#;
529        let msg = decode(json).unwrap();
530        let events = into_events(msg);
531        // Should have: InputTranscription, ModelText, TurnComplete, Usage
532        assert_eq!(events.len(), 4);
533        assert!(matches!(&events[0], ServerEvent::InputTranscription(t) if t == "hey"));
534        assert!(matches!(&events[1], ServerEvent::ModelText(t) if t == "hi"));
535        assert!(matches!(&events[2], ServerEvent::TurnComplete));
536        assert!(matches!(&events[3], ServerEvent::Usage(_)));
537    }
538
539    #[test]
540    fn into_events_ignores_transcription_markers_without_text() {
541        let json = r#"{
542            "serverContent": {
543                "outputTranscription": {"finished": true},
544                "turnComplete": true
545            }
546        }"#;
547        let msg = decode(json).unwrap();
548        let events = into_events(msg);
549
550        assert_eq!(events.len(), 1);
551        assert!(matches!(&events[0], ServerEvent::TurnComplete));
552    }
553
554    // ── parse_protobuf_duration ──────────────────────────────────────────
555
556    #[test]
557    fn parse_duration_integer_seconds() {
558        assert_eq!(
559            parse_protobuf_duration("30s"),
560            Some(Duration::from_secs(30))
561        );
562    }
563
564    #[test]
565    fn parse_duration_fractional_seconds() {
566        assert_eq!(
567            parse_protobuf_duration("1.5s"),
568            Some(Duration::from_secs_f64(1.5))
569        );
570    }
571
572    #[test]
573    fn parse_duration_invalid() {
574        assert_eq!(parse_protobuf_duration("30m"), None);
575        assert_eq!(parse_protobuf_duration("abc"), None);
576    }
577}