livespeech_sdk/types/
messages.rs

1//! WebSocket message types for the LiveSpeech SDK
2
3use serde::{Deserialize, Serialize};
4use super::config::Tool;
5
6/// Tool response payload for function calling
7#[derive(Debug, Clone, Serialize)]
8#[serde(rename_all = "camelCase")]
9pub struct ToolResponsePayload {
10    /// Tool call ID from the toolCall event
11    pub id: String,
12    /// Function execution result
13    pub response: serde_json::Value,
14}
15
16/// System message payload
17#[derive(Debug, Clone, Serialize)]
18#[serde(rename_all = "camelCase")]
19pub struct SystemMessagePayload {
20    /// The message text (max 500 characters)
21    pub text: String,
22    /// If true, AI responds immediately (default: true)
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub trigger_response: Option<bool>,
25}
26
27/// Client message types (sent from client to server)
28#[derive(Debug, Clone, Serialize)]
29#[serde(tag = "action", rename_all = "camelCase")]
30pub enum ClientMessage {
31    /// Start a new session
32    #[serde(rename = "startSession")]
33    StartSession {
34        #[serde(rename = "prePrompt", skip_serializing_if = "Option::is_none")]
35        pre_prompt: Option<String>,
36        #[serde(skip_serializing_if = "Option::is_none")]
37        language: Option<String>,
38        #[serde(rename = "pipelineMode", skip_serializing_if = "Option::is_none")]
39        pipeline_mode: Option<String>,
40        #[serde(rename = "aiSpeaksFirst", skip_serializing_if = "Option::is_none")]
41        ai_speaks_first: Option<bool>,
42        #[serde(rename = "allowHarmCategory", skip_serializing_if = "Option::is_none")]
43        allow_harm_category: Option<bool>,
44        #[serde(skip_serializing_if = "Option::is_none")]
45        tools: Option<Vec<Tool>>,
46    },
47    /// End the current session
48    #[serde(rename = "endSession")]
49    EndSession,
50    /// Start audio streaming session
51    #[serde(rename = "audioStart")]
52    AudioStart,
53    /// Send audio chunk
54    #[serde(rename = "audioChunk")]
55    AudioChunk {
56        /// Base64 encoded PCM16 audio data
57        data: String,
58    },
59    /// End audio streaming session
60    #[serde(rename = "audioEnd")]
61    AudioEnd,
62    /// Send a system message to the AI during a live session
63    #[serde(rename = "systemMessage")]
64    SystemMessage {
65        /// System message payload
66        payload: SystemMessagePayload,
67    },
68    /// Send tool response (function execution result) back to AI
69    #[serde(rename = "toolResponse")]
70    ToolResponse {
71        /// Tool response payload
72        payload: ToolResponsePayload,
73    },
74    /// Update user ID (guest-to-user migration)
75    #[serde(rename = "updateUserId")]
76    UpdateUserId {
77        /// The authenticated user's unique identifier
78        #[serde(rename = "userId")]
79        user_id: String,
80    },
81    /// Explicit interrupt (stop AI response)
82    #[serde(rename = "interrupt")]
83    Interrupt,
84    /// Ping for keep-alive
85    #[serde(rename = "ping")]
86    Ping,
87}
88
89/// Server message types (received from server)
90#[derive(Debug, Clone, Deserialize)]
91#[serde(tag = "type", rename_all = "camelCase")]
92pub enum ServerMessage {
93    /// Session started
94    #[serde(rename = "sessionStarted")]
95    SessionStarted {
96        #[serde(rename = "sessionId")]
97        session_id: String,
98        timestamp: String,
99    },
100    /// Session ended
101    #[serde(rename = "sessionEnded")]
102    SessionEnded {
103        #[serde(rename = "sessionId")]
104        session_id: String,
105        timestamp: String,
106    },
107    /// Ready - session is ready for audio input
108    #[serde(rename = "ready")]
109    Ready { timestamp: String },
110    /// User transcript - user's speech transcription
111    #[serde(rename = "userTranscript")]
112    UserTranscript {
113        text: String,
114        timestamp: String,
115    },
116    /// Response - AI's text response
117    #[serde(rename = "response")]
118    Response {
119        text: String,
120        #[serde(rename = "isFinal")]
121        is_final: bool,
122        timestamp: String,
123    },
124    /// Audio - AI's audio response
125    #[serde(rename = "audio")]
126    Audio {
127        /// Base64 encoded PCM16 audio data
128        data: String,
129        format: String,
130        #[serde(rename = "sampleRate")]
131        sample_rate: u32,
132        timestamp: String,
133    },
134    /// Turn complete - AI has finished its response turn
135    #[serde(rename = "turnComplete")]
136    TurnComplete { timestamp: String },
137    /// Error message
138    #[serde(rename = "error")]
139    Error {
140        code: String,
141        message: String,
142        timestamp: String,
143    },
144    /// Tool call from AI (function calling)
145    #[serde(rename = "toolCall")]
146    ToolCall {
147        /// Unique ID for this tool call
148        id: String,
149        /// Function name to execute
150        name: String,
151        /// Function arguments
152        #[serde(default)]
153        args: serde_json::Value,
154        timestamp: String,
155    },
156    /// User ID updated (guest-to-user migration complete)
157    #[serde(rename = "userIdUpdated")]
158    UserIdUpdated {
159        /// The new user ID that was set
160        #[serde(rename = "userId")]
161        user_id: String,
162        /// Number of messages migrated from guest to user partition
163        #[serde(rename = "migratedMessages", default)]
164        migrated_messages: usize,
165        timestamp: String,
166    },
167    /// Interrupted - AI response was interrupted by user speech (barge-in)
168    #[serde(rename = "interrupted")]
169    Interrupted {
170        timestamp: String,
171    },
172    /// Pong response
173    #[serde(rename = "pong")]
174    Pong { timestamp: String },
175}
176
177impl ClientMessage {
178    /// Create a start session message
179    pub fn start_session(
180        pre_prompt: Option<String>,
181        language: Option<String>,
182        pipeline_mode: Option<String>,
183        ai_speaks_first: Option<bool>,
184        allow_harm_category: Option<bool>,
185        tools: Option<Vec<Tool>>,
186    ) -> Self {
187        ClientMessage::StartSession {
188            pre_prompt,
189            language,
190            pipeline_mode,
191            ai_speaks_first,
192            allow_harm_category,
193            tools,
194        }
195    }
196
197    /// Create an end session message
198    pub fn end_session() -> Self {
199        ClientMessage::EndSession
200    }
201
202    /// Create an audio start message
203    pub fn audio_start() -> Self {
204        ClientMessage::AudioStart
205    }
206
207    /// Create an audio chunk message
208    pub fn audio_chunk(data: impl Into<String>) -> Self {
209        ClientMessage::AudioChunk { data: data.into() }
210    }
211
212    /// Create an audio end message
213    pub fn audio_end() -> Self {
214        ClientMessage::AudioEnd
215    }
216
217    /// Create a ping message
218    pub fn ping() -> Self {
219        ClientMessage::Ping
220    }
221
222    /// Create a system message (AI responds immediately)
223    pub fn system_message(text: impl Into<String>) -> Self {
224        ClientMessage::SystemMessage {
225            payload: SystemMessagePayload {
226                text: text.into(),
227                trigger_response: None, // defaults to true on server
228            },
229        }
230    }
231
232    /// Create a system message with explicit trigger_response option
233    pub fn system_message_with_options(text: impl Into<String>, trigger_response: bool) -> Self {
234        ClientMessage::SystemMessage {
235            payload: SystemMessagePayload {
236                text: text.into(),
237                trigger_response: Some(trigger_response),
238            },
239        }
240    }
241
242    /// Create a tool response message (function execution result)
243    pub fn tool_response(id: impl Into<String>, response: serde_json::Value) -> Self {
244        ClientMessage::ToolResponse {
245            payload: ToolResponsePayload {
246                id: id.into(),
247                response,
248            },
249        }
250    }
251
252    /// Create an update user ID message (guest-to-user migration)
253    pub fn update_user_id(user_id: impl Into<String>) -> Self {
254        ClientMessage::UpdateUserId {
255            user_id: user_id.into(),
256        }
257    }
258
259    /// Create an interrupt message (explicit stop)
260    pub fn interrupt() -> Self {
261        ClientMessage::Interrupt
262    }
263
264    /// Serialize to JSON string
265    pub fn to_json(&self) -> Result<String, serde_json::Error> {
266        serde_json::to_string(self)
267    }
268}
269
270impl ServerMessage {
271    /// Parse from JSON string
272    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
273        serde_json::from_str(json)
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_serialize_start_session() {
283        let msg = ClientMessage::start_session(Some("You are helpful".to_string()), None, None, None, None, None);
284        let json = msg.to_json().unwrap();
285        assert!(json.contains("startSession"));
286        assert!(json.contains("prePrompt"));
287    }
288
289    #[test]
290    fn test_serialize_start_session_with_language() {
291        let msg = ClientMessage::start_session(Some("You are helpful".to_string()), Some("ko-KR".to_string()), None, None, None, None);
292        let json = msg.to_json().unwrap();
293        assert!(json.contains("startSession"));
294        assert!(json.contains("prePrompt"));
295        assert!(json.contains("ko-KR"));
296    }
297
298    #[test]
299    fn test_serialize_start_session_with_ai_speaks_first() {
300        let msg = ClientMessage::start_session(
301            Some("You are helpful. Greet the user.".to_string()),
302            None,
303            Some("live".to_string()),
304            Some(true),
305            None,
306            None,
307        );
308        let json = msg.to_json().unwrap();
309        assert!(json.contains("startSession"));
310        assert!(json.contains("aiSpeaksFirst"));
311        assert!(json.contains("true"));
312    }
313
314    #[test]
315    fn test_serialize_start_session_with_allow_harm_category() {
316        let msg = ClientMessage::start_session(
317            Some("You are helpful.".to_string()),
318            None,
319            None,
320            None,
321            Some(false),
322            None,
323        );
324        let json = msg.to_json().unwrap();
325        assert!(json.contains("startSession"));
326        assert!(json.contains("allowHarmCategory"));
327        assert!(json.contains("false"));
328    }
329
330    #[test]
331    fn test_serialize_tool_response() {
332        let msg = ClientMessage::tool_response("call_123", serde_json::json!({"success": true}));
333        let json = msg.to_json().unwrap();
334        assert!(json.contains("toolResponse"));
335        assert!(json.contains("call_123"));
336        assert!(json.contains("success"));
337    }
338
339    #[test]
340    fn test_deserialize_tool_call() {
341        let json = r#"{"type":"toolCall","id":"call_abc","name":"open_login","args":{},"timestamp":"2024-01-01T00:00:00Z"}"#;
342        let msg = ServerMessage::from_json(json).unwrap();
343        match msg {
344            ServerMessage::ToolCall { id, name, .. } => {
345                assert_eq!(id, "call_abc");
346                assert_eq!(name, "open_login");
347            }
348            _ => panic!("Expected ToolCall message"),
349        }
350    }
351
352    #[test]
353    fn test_serialize_audio_chunk() {
354        let msg = ClientMessage::audio_chunk("base64data");
355        let json = msg.to_json().unwrap();
356        assert!(json.contains("audioChunk"));
357        assert!(json.contains("base64data"));
358    }
359
360    #[test]
361    fn test_serialize_audio_start() {
362        let msg = ClientMessage::audio_start();
363        let json = msg.to_json().unwrap();
364        assert!(json.contains("audioStart"));
365    }
366
367    #[test]
368    fn test_deserialize_session_started() {
369        let json = r#"{"type":"sessionStarted","sessionId":"abc123","timestamp":"2024-01-01T00:00:00Z"}"#;
370        let msg = ServerMessage::from_json(json).unwrap();
371        match msg {
372            ServerMessage::SessionStarted { session_id, .. } => {
373                assert_eq!(session_id, "abc123");
374            }
375            _ => panic!("Expected SessionStarted message"),
376        }
377    }
378
379    #[test]
380    fn test_deserialize_ready() {
381        let json = r#"{"type":"ready","timestamp":"2024-01-01T00:00:00Z"}"#;
382        let msg = ServerMessage::from_json(json).unwrap();
383        match msg {
384            ServerMessage::Ready { .. } => {}
385            _ => panic!("Expected Ready message"),
386        }
387    }
388
389    #[test]
390    fn test_deserialize_user_transcript() {
391        let json = r#"{"type":"userTranscript","text":"Hello world","timestamp":"2024-01-01T00:00:00Z"}"#;
392        let msg = ServerMessage::from_json(json).unwrap();
393        match msg {
394            ServerMessage::UserTranscript { text, .. } => {
395                assert_eq!(text, "Hello world");
396            }
397            _ => panic!("Expected UserTranscript message"),
398        }
399    }
400
401    #[test]
402    fn test_deserialize_response() {
403        let json = r#"{"type":"response","text":"Hello!","isFinal":true,"timestamp":"2024-01-01T00:00:00Z"}"#;
404        let msg = ServerMessage::from_json(json).unwrap();
405        match msg {
406            ServerMessage::Response { text, is_final, .. } => {
407                assert_eq!(text, "Hello!");
408                assert!(is_final);
409            }
410            _ => panic!("Expected Response message"),
411        }
412    }
413
414    #[test]
415    fn test_deserialize_turn_complete() {
416        let json = r#"{"type":"turnComplete","timestamp":"2024-01-01T00:00:00Z"}"#;
417        let msg = ServerMessage::from_json(json).unwrap();
418        match msg {
419            ServerMessage::TurnComplete { .. } => {}
420            _ => panic!("Expected TurnComplete message"),
421        }
422    }
423
424    #[test]
425    fn test_serialize_update_user_id() {
426        let msg = ClientMessage::update_user_id("user_abc123");
427        let json = msg.to_json().unwrap();
428        assert!(json.contains("updateUserId"));
429        assert!(json.contains("user_abc123"));
430    }
431
432    #[test]
433    fn test_deserialize_user_id_updated() {
434        let json = r#"{"type":"userIdUpdated","userId":"user_abc123","migratedMessages":12,"timestamp":"2024-01-01T00:00:00Z"}"#;
435        let msg = ServerMessage::from_json(json).unwrap();
436        match msg {
437            ServerMessage::UserIdUpdated { user_id, migrated_messages, .. } => {
438                assert_eq!(user_id, "user_abc123");
439                assert_eq!(migrated_messages, 12);
440            }
441            _ => panic!("Expected UserIdUpdated message"),
442        }
443    }
444
445    #[test]
446    fn test_deserialize_interrupted() {
447        let json = r#"{"type":"interrupted","timestamp":"2024-01-01T00:00:00Z"}"#;
448        let msg = ServerMessage::from_json(json).unwrap();
449        match msg {
450            ServerMessage::Interrupted { timestamp } => {
451                assert_eq!(timestamp, "2024-01-01T00:00:00Z");
452            }
453            _ => panic!("Expected Interrupted message"),
454        }
455    }
456}