Skip to main content

gemini_live/
codec.rs

1//! JSON ↔ Rust codec and semantic event decomposition.
2//!
3//! This module sits between the transport layer (raw WebSocket frames) and the
4//! session layer (typed events).  It provides three operations:
5//!
6//! - [`encode`] — serialise a [`ClientMessage`] to a JSON string for sending.
7//! - [`decode`] — parse a JSON string from the wire into a [`ServerMessage`].
8//! - [`into_events`] — decompose one [`ServerMessage`] into a `Vec<ServerEvent>`.
9//!
10//! The split between `decode` and `into_events` allows the session layer to
11//! inspect raw fields (e.g. `session_resumption_update`) before broadcasting
12//! the higher-level events to the application.
13
14use std::time::Duration;
15
16use base64::Engine;
17
18use crate::error::CodecError;
19use crate::types::{ClientMessage, ServerEvent, ServerMessage};
20
21/// Serialise a [`ClientMessage`] to its JSON wire representation.
22///
23/// **Performance note:** allocates a new `String` per call.  A future
24/// `encode_into` variant accepting a reusable buffer is planned
25/// (see `docs/roadmap.md` P-2).
26pub fn encode(msg: &ClientMessage) -> Result<String, CodecError> {
27    serde_json::to_string(msg).map_err(CodecError::Serialize)
28}
29
30/// Parse a JSON string from the wire into a [`ServerMessage`].
31pub fn decode(json: &str) -> Result<ServerMessage, CodecError> {
32    serde_json::from_str(json).map_err(CodecError::Deserialize)
33}
34
35/// Decompose a single [`ServerMessage`] into a sequence of semantic
36/// [`ServerEvent`]s.
37///
38/// One wire message often carries several pieces of information (e.g.
39/// `serverContent` with both a `modelTurn` and `inputTranscription`, plus
40/// `usageMetadata`).  This function teases them apart into discrete,
41/// easy-to-match events.
42///
43/// The event ordering follows a natural "content first, metadata last"
44/// convention:
45/// 1. `SetupComplete`
46/// 2. Transcriptions (input, then output)
47/// 3. Model content (text / audio parts in wire order)
48/// 4. Flags (`Interrupted`, `GenerationComplete`, `TurnComplete`)
49/// 5. Tool calls / cancellations
50/// 6. Session lifecycle (`SessionResumption`, `GoAway`)
51/// 7. `Usage`
52/// 8. `Error`
53pub fn into_events(msg: ServerMessage) -> Vec<ServerEvent> {
54    let mut events = Vec::new();
55
56    // 1. Setup handshake
57    if msg.setup_complete.is_some() {
58        events.push(ServerEvent::SetupComplete);
59    }
60
61    // 2–4. Server content
62    if let Some(sc) = msg.server_content {
63        if let Some(t) = sc.input_transcription {
64            events.push(ServerEvent::InputTranscription(t.text));
65        }
66        if let Some(t) = sc.output_transcription {
67            events.push(ServerEvent::OutputTranscription(t.text));
68        }
69
70        if let Some(turn) = sc.model_turn {
71            for part in turn.parts {
72                if let Some(text) = part.text {
73                    events.push(ServerEvent::ModelText(text));
74                }
75                if let Some(blob) = part.inline_data {
76                    match base64::engine::general_purpose::STANDARD.decode(&blob.data) {
77                        Ok(bytes) => events.push(ServerEvent::ModelAudio(bytes)),
78                        Err(e) => {
79                            tracing::warn!(error = %e, "failed to base64-decode model audio");
80                        }
81                    }
82                }
83            }
84        }
85
86        if sc.interrupted == Some(true) {
87            events.push(ServerEvent::Interrupted);
88        }
89        if sc.generation_complete == Some(true) {
90            events.push(ServerEvent::GenerationComplete);
91        }
92        if sc.turn_complete == Some(true) {
93            events.push(ServerEvent::TurnComplete);
94        }
95    }
96
97    // 5. Tool calls
98    if let Some(tc) = msg.tool_call {
99        events.push(ServerEvent::ToolCall(tc.function_calls));
100    }
101    if let Some(tcc) = msg.tool_call_cancellation {
102        events.push(ServerEvent::ToolCallCancellation(tcc.ids));
103    }
104
105    // 6. Session lifecycle
106    if let Some(sr) = msg.session_resumption_update {
107        events.push(ServerEvent::SessionResumption {
108            new_handle: sr.new_handle,
109            resumable: sr.resumable.unwrap_or(false),
110        });
111    }
112    if let Some(ga) = msg.go_away {
113        events.push(ServerEvent::GoAway {
114            time_left: ga.time_left.as_deref().and_then(parse_protobuf_duration),
115        });
116    }
117
118    // 7. Usage
119    if let Some(usage) = msg.usage_metadata {
120        events.push(ServerEvent::Usage(usage));
121    }
122
123    // 8. Error
124    if let Some(err) = msg.error {
125        events.push(ServerEvent::Error(err));
126    }
127
128    events
129}
130
131/// Parse a protobuf Duration string (e.g. `"30s"`, `"1.5s"`) into a
132/// [`std::time::Duration`].
133fn parse_protobuf_duration(s: &str) -> Option<Duration> {
134    let s = s.trim();
135    let secs_str = s.strip_suffix('s')?;
136    let secs: f64 = secs_str.parse().ok()?;
137    Some(Duration::from_secs_f64(secs))
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::types::*;
144
145    // ── encode ───────────────────────────────────────────────────────────
146
147    #[test]
148    fn encode_setup_minimal() {
149        let msg = ClientMessage::Setup(SetupConfig {
150            model: "models/gemini-3.1-flash-live-preview".into(),
151            ..Default::default()
152        });
153        let json = encode(&msg).unwrap();
154        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
155        assert_eq!(v["setup"]["model"], "models/gemini-3.1-flash-live-preview");
156        // Optional fields should be absent, not null.
157        assert!(v["setup"].get("generationConfig").is_none());
158    }
159
160    #[test]
161    fn encode_setup_full() {
162        let msg = ClientMessage::Setup(SetupConfig {
163            model: "models/gemini-3.1-flash-live-preview".into(),
164            generation_config: Some(GenerationConfig {
165                response_modalities: Some(vec![Modality::Audio, Modality::Text]),
166                speech_config: Some(SpeechConfig {
167                    voice_config: VoiceConfig {
168                        prebuilt_voice_config: PrebuiltVoiceConfig {
169                            voice_name: "Kore".into(),
170                        },
171                    },
172                }),
173                thinking_config: Some(ThinkingConfig {
174                    thinking_level: Some(ThinkingLevel::Medium),
175                    ..Default::default()
176                }),
177                ..Default::default()
178            }),
179            system_instruction: Some(Content {
180                role: None,
181                parts: vec![Part {
182                    text: Some("You are a helpful assistant.".into()),
183                    inline_data: None,
184                }],
185            }),
186            input_audio_transcription: Some(AudioTranscriptionConfig {}),
187            output_audio_transcription: Some(AudioTranscriptionConfig {}),
188            session_resumption: Some(SessionResumptionConfig { handle: None }),
189            ..Default::default()
190        });
191        let json = encode(&msg).unwrap();
192        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
193        let setup = &v["setup"];
194        assert_eq!(setup["generationConfig"]["responseModalities"][0], "AUDIO");
195        assert_eq!(setup["generationConfig"]["responseModalities"][1], "TEXT");
196        assert_eq!(
197            setup["generationConfig"]["speechConfig"]["voiceConfig"]["prebuiltVoiceConfig"]["voiceName"],
198            "Kore"
199        );
200        assert_eq!(
201            setup["generationConfig"]["thinkingConfig"]["thinkingLevel"],
202            "medium"
203        );
204        assert_eq!(
205            setup["systemInstruction"]["parts"][0]["text"],
206            "You are a helpful assistant."
207        );
208        // Presence-activated configs should appear as `{}`
209        assert_eq!(setup["inputAudioTranscription"], serde_json::json!({}));
210        assert_eq!(setup["outputAudioTranscription"], serde_json::json!({}));
211    }
212
213    #[test]
214    fn encode_client_content() {
215        let msg = ClientMessage::ClientContent(ClientContent {
216            turns: Some(vec![
217                Content {
218                    role: Some("user".into()),
219                    parts: vec![Part {
220                        text: Some("Hello".into()),
221                        inline_data: None,
222                    }],
223                },
224                Content {
225                    role: Some("model".into()),
226                    parts: vec![Part {
227                        text: Some("Hi!".into()),
228                        inline_data: None,
229                    }],
230                },
231            ]),
232            turn_complete: Some(true),
233        });
234        let json = encode(&msg).unwrap();
235        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
236        assert_eq!(v["clientContent"]["turns"][0]["role"], "user");
237        assert_eq!(v["clientContent"]["turnComplete"], true);
238    }
239
240    #[test]
241    fn encode_realtime_input_audio() {
242        let msg = ClientMessage::RealtimeInput(RealtimeInput {
243            audio: Some(Blob {
244                data: "AAAA".into(),
245                mime_type: "audio/pcm;rate=16000".into(),
246            }),
247            video: None,
248            text: None,
249            activity_start: None,
250            activity_end: None,
251            audio_stream_end: None,
252        });
253        let json = encode(&msg).unwrap();
254        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
255        assert_eq!(
256            v["realtimeInput"]["audio"]["mimeType"],
257            "audio/pcm;rate=16000"
258        );
259        // Other fields should be absent
260        assert!(v["realtimeInput"].get("video").is_none());
261    }
262
263    #[test]
264    fn encode_tool_response() {
265        let msg = ClientMessage::ToolResponse(ToolResponseMessage {
266            function_responses: vec![FunctionResponse {
267                id: "call_123".into(),
268                name: "get_weather".into(),
269                response: serde_json::json!({"temperature": 72}),
270            }],
271        });
272        let json = encode(&msg).unwrap();
273        let v: serde_json::Value = serde_json::from_str(&json).unwrap();
274        assert_eq!(v["toolResponse"]["functionResponses"][0]["id"], "call_123");
275        assert_eq!(
276            v["toolResponse"]["functionResponses"][0]["response"]["temperature"],
277            72
278        );
279    }
280
281    // ── decode ───────────────────────────────────────────────────────────
282
283    #[test]
284    fn decode_setup_complete() {
285        let json = r#"{"setupComplete":{}}"#;
286        let msg = decode(json).unwrap();
287        assert!(msg.setup_complete.is_some());
288        assert!(msg.server_content.is_none());
289    }
290
291    #[test]
292    fn decode_server_content_text() {
293        let json = r#"{
294            "serverContent": {
295                "modelTurn": {
296                    "parts": [{"text": "Hello there!"}]
297                },
298                "turnComplete": true
299            }
300        }"#;
301        let msg = decode(json).unwrap();
302        let sc = msg.server_content.unwrap();
303        let turn = sc.model_turn.unwrap();
304        assert_eq!(turn.parts[0].text.as_deref(), Some("Hello there!"));
305        assert_eq!(sc.turn_complete, Some(true));
306    }
307
308    #[test]
309    fn decode_server_content_with_transcription() {
310        let json = r#"{
311            "serverContent": {
312                "inputTranscription": {"text": "What's the weather?"},
313                "outputTranscription": {"text": "It's sunny today."}
314            }
315        }"#;
316        let msg = decode(json).unwrap();
317        let sc = msg.server_content.unwrap();
318        assert_eq!(sc.input_transcription.unwrap().text, "What's the weather?");
319        assert_eq!(sc.output_transcription.unwrap().text, "It's sunny today.");
320    }
321
322    #[test]
323    fn decode_tool_call() {
324        let json = r#"{
325            "toolCall": {
326                "functionCalls": [{
327                    "id": "call_abc",
328                    "name": "get_weather",
329                    "args": {"city": "Tokyo"}
330                }]
331            }
332        }"#;
333        let msg = decode(json).unwrap();
334        let tc = msg.tool_call.unwrap();
335        assert_eq!(tc.function_calls[0].id, "call_abc");
336        assert_eq!(tc.function_calls[0].name, "get_weather");
337        assert_eq!(tc.function_calls[0].args["city"], "Tokyo");
338    }
339
340    #[test]
341    fn decode_tool_call_cancellation() {
342        let json = r#"{"toolCallCancellation":{"ids":["call_1","call_2"]}}"#;
343        let msg = decode(json).unwrap();
344        let tcc = msg.tool_call_cancellation.unwrap();
345        assert_eq!(tcc.ids, vec!["call_1", "call_2"]);
346    }
347
348    #[test]
349    fn decode_go_away() {
350        let json = r#"{"goAway":{"timeLeft":"30s"}}"#;
351        let msg = decode(json).unwrap();
352        assert_eq!(msg.go_away.unwrap().time_left.as_deref(), Some("30s"));
353    }
354
355    #[test]
356    fn decode_session_resumption() {
357        let json = r#"{"sessionResumptionUpdate":{"newHandle":"handle-xyz","resumable":true}}"#;
358        let msg = decode(json).unwrap();
359        let sr = msg.session_resumption_update.unwrap();
360        assert_eq!(sr.new_handle.as_deref(), Some("handle-xyz"));
361        assert_eq!(sr.resumable, Some(true));
362    }
363
364    #[test]
365    fn decode_usage_metadata() {
366        let json = r#"{
367            "usageMetadata": {
368                "promptTokenCount": 100,
369                "responseTokenCount": 50,
370                "totalTokenCount": 150
371            }
372        }"#;
373        let msg = decode(json).unwrap();
374        let u = msg.usage_metadata.unwrap();
375        assert_eq!(u.prompt_token_count, 100);
376        assert_eq!(u.response_token_count, 50);
377        assert_eq!(u.total_token_count, 150);
378        // Missing fields default to 0
379        assert_eq!(u.cached_content_token_count, 0);
380    }
381
382    #[test]
383    fn decode_error() {
384        let json = r#"{"error":{"message":"rate limit exceeded"}}"#;
385        let msg = decode(json).unwrap();
386        assert_eq!(msg.error.unwrap().message, "rate limit exceeded");
387    }
388
389    #[test]
390    fn decode_combined_content_and_usage() {
391        let json = r#"{
392            "serverContent": {
393                "modelTurn": {"parts": [{"text": "hi"}]},
394                "turnComplete": true
395            },
396            "usageMetadata": {"totalTokenCount": 42}
397        }"#;
398        let msg = decode(json).unwrap();
399        assert!(msg.server_content.is_some());
400        assert_eq!(msg.usage_metadata.unwrap().total_token_count, 42);
401    }
402
403    // ── into_events ──────────────────────────────────────────────────────
404
405    #[test]
406    fn into_events_setup_complete() {
407        let msg = decode(r#"{"setupComplete":{}}"#).unwrap();
408        let events = into_events(msg);
409        assert_eq!(events.len(), 1);
410        assert!(matches!(events[0], ServerEvent::SetupComplete));
411    }
412
413    #[test]
414    fn into_events_model_text_and_turn_complete() {
415        let msg = decode(
416            r#"{"serverContent":{"modelTurn":{"parts":[{"text":"hello"}]},"turnComplete":true}}"#,
417        )
418        .unwrap();
419        let events = into_events(msg);
420        assert!(
421            events
422                .iter()
423                .any(|e| matches!(e, ServerEvent::ModelText(t) if t == "hello"))
424        );
425        assert!(
426            events
427                .iter()
428                .any(|e| matches!(e, ServerEvent::TurnComplete))
429        );
430    }
431
432    #[test]
433    fn into_events_model_audio_base64_decoded() {
434        // "AQID" is base64 for bytes [1, 2, 3]
435        let msg = decode(
436            r#"{"serverContent":{"modelTurn":{"parts":[{"inlineData":{"data":"AQID","mimeType":"audio/pcm;rate=24000"}}]}}}"#,
437        )
438        .unwrap();
439        let events = into_events(msg);
440        assert!(
441            events
442                .iter()
443                .any(|e| matches!(e, ServerEvent::ModelAudio(b) if b == &[1, 2, 3]))
444        );
445    }
446
447    #[test]
448    fn into_events_go_away_parses_duration() {
449        let msg = decode(r#"{"goAway":{"timeLeft":"30s"}}"#).unwrap();
450        let events = into_events(msg);
451        assert!(
452            events.iter().any(
453                |e| matches!(e, ServerEvent::GoAway { time_left: Some(d) } if *d == std::time::Duration::from_secs(30))
454            )
455        );
456    }
457
458    #[test]
459    fn into_events_combined_message() {
460        let json = r#"{
461            "serverContent": {
462                "inputTranscription": {"text": "hey"},
463                "modelTurn": {"parts": [{"text": "hi"}]},
464                "turnComplete": true
465            },
466            "usageMetadata": {"totalTokenCount": 10}
467        }"#;
468        let msg = decode(json).unwrap();
469        let events = into_events(msg);
470        // Should have: InputTranscription, ModelText, TurnComplete, Usage
471        assert_eq!(events.len(), 4);
472        assert!(matches!(&events[0], ServerEvent::InputTranscription(t) if t == "hey"));
473        assert!(matches!(&events[1], ServerEvent::ModelText(t) if t == "hi"));
474        assert!(matches!(&events[2], ServerEvent::TurnComplete));
475        assert!(matches!(&events[3], ServerEvent::Usage(_)));
476    }
477
478    // ── parse_protobuf_duration ──────────────────────────────────────────
479
480    #[test]
481    fn parse_duration_integer_seconds() {
482        assert_eq!(
483            parse_protobuf_duration("30s"),
484            Some(Duration::from_secs(30))
485        );
486    }
487
488    #[test]
489    fn parse_duration_fractional_seconds() {
490        assert_eq!(
491            parse_protobuf_duration("1.5s"),
492            Some(Duration::from_secs_f64(1.5))
493        );
494    }
495
496    #[test]
497    fn parse_duration_invalid() {
498        assert_eq!(parse_protobuf_duration("30m"), None);
499        assert_eq!(parse_protobuf_duration("abc"), None);
500    }
501}