1use serde::{Deserialize, Serialize};
4use super::config::Tool;
5
6#[derive(Debug, Clone, Serialize)]
8#[serde(rename_all = "camelCase")]
9pub struct ToolResponsePayload {
10 pub id: String,
12 pub response: serde_json::Value,
14}
15
16#[derive(Debug, Clone, Serialize)]
18#[serde(rename_all = "camelCase")]
19pub struct SystemMessagePayload {
20 pub text: String,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub trigger_response: Option<bool>,
25}
26
27#[derive(Debug, Clone, Serialize)]
29#[serde(tag = "action", rename_all = "camelCase")]
30pub enum ClientMessage {
31 #[serde(rename = "startSession")]
33 StartSession {
34 #[serde(rename = "prePrompt", skip_serializing_if = "Option::is_none")]
35 pre_prompt: Option<String>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 language: Option<String>,
38 #[serde(rename = "pipelineMode", skip_serializing_if = "Option::is_none")]
39 pipeline_mode: Option<String>,
40 #[serde(rename = "aiSpeaksFirst", skip_serializing_if = "Option::is_none")]
41 ai_speaks_first: Option<bool>,
42 #[serde(rename = "allowHarmCategory", skip_serializing_if = "Option::is_none")]
43 allow_harm_category: Option<bool>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 tools: Option<Vec<Tool>>,
46 },
47 #[serde(rename = "endSession")]
49 EndSession,
50 #[serde(rename = "audioStart")]
52 AudioStart,
53 #[serde(rename = "audioChunk")]
55 AudioChunk {
56 data: String,
58 },
59 #[serde(rename = "audioEnd")]
61 AudioEnd,
62 #[serde(rename = "systemMessage")]
64 SystemMessage {
65 payload: SystemMessagePayload,
67 },
68 #[serde(rename = "toolResponse")]
70 ToolResponse {
71 payload: ToolResponsePayload,
73 },
74 #[serde(rename = "updateUserId")]
76 UpdateUserId {
77 #[serde(rename = "userId")]
79 user_id: String,
80 },
81 #[serde(rename = "interrupt")]
83 Interrupt,
84 #[serde(rename = "ping")]
86 Ping,
87}
88
89#[derive(Debug, Clone, Deserialize)]
91#[serde(tag = "type", rename_all = "camelCase")]
92pub enum ServerMessage {
93 #[serde(rename = "sessionStarted")]
95 SessionStarted {
96 #[serde(rename = "sessionId")]
97 session_id: String,
98 timestamp: String,
99 },
100 #[serde(rename = "sessionEnded")]
102 SessionEnded {
103 #[serde(rename = "sessionId")]
104 session_id: String,
105 timestamp: String,
106 },
107 #[serde(rename = "ready")]
109 Ready { timestamp: String },
110 #[serde(rename = "userTranscript")]
112 UserTranscript {
113 text: String,
114 timestamp: String,
115 },
116 #[serde(rename = "response")]
118 Response {
119 text: String,
120 #[serde(rename = "isFinal")]
121 is_final: bool,
122 timestamp: String,
123 },
124 #[serde(rename = "audio")]
126 Audio {
127 data: String,
129 format: String,
130 #[serde(rename = "sampleRate")]
131 sample_rate: u32,
132 timestamp: String,
133 },
134 #[serde(rename = "turnComplete")]
136 TurnComplete { timestamp: String },
137 #[serde(rename = "error")]
139 Error {
140 code: String,
141 message: String,
142 timestamp: String,
143 },
144 #[serde(rename = "toolCall")]
146 ToolCall {
147 id: String,
149 name: String,
151 #[serde(default)]
153 args: serde_json::Value,
154 timestamp: String,
155 },
156 #[serde(rename = "userIdUpdated")]
158 UserIdUpdated {
159 #[serde(rename = "userId")]
161 user_id: String,
162 #[serde(rename = "migratedMessages", default)]
164 migrated_messages: usize,
165 timestamp: String,
166 },
167 #[serde(rename = "interrupted")]
169 Interrupted {
170 timestamp: String,
171 },
172 #[serde(rename = "pong")]
174 Pong { timestamp: String },
175}
176
177impl ClientMessage {
178 pub fn start_session(
180 pre_prompt: Option<String>,
181 language: Option<String>,
182 pipeline_mode: Option<String>,
183 ai_speaks_first: Option<bool>,
184 allow_harm_category: Option<bool>,
185 tools: Option<Vec<Tool>>,
186 ) -> Self {
187 ClientMessage::StartSession {
188 pre_prompt,
189 language,
190 pipeline_mode,
191 ai_speaks_first,
192 allow_harm_category,
193 tools,
194 }
195 }
196
197 pub fn end_session() -> Self {
199 ClientMessage::EndSession
200 }
201
202 pub fn audio_start() -> Self {
204 ClientMessage::AudioStart
205 }
206
207 pub fn audio_chunk(data: impl Into<String>) -> Self {
209 ClientMessage::AudioChunk { data: data.into() }
210 }
211
212 pub fn audio_end() -> Self {
214 ClientMessage::AudioEnd
215 }
216
217 pub fn ping() -> Self {
219 ClientMessage::Ping
220 }
221
222 pub fn system_message(text: impl Into<String>) -> Self {
224 ClientMessage::SystemMessage {
225 payload: SystemMessagePayload {
226 text: text.into(),
227 trigger_response: None, },
229 }
230 }
231
232 pub fn system_message_with_options(text: impl Into<String>, trigger_response: bool) -> Self {
234 ClientMessage::SystemMessage {
235 payload: SystemMessagePayload {
236 text: text.into(),
237 trigger_response: Some(trigger_response),
238 },
239 }
240 }
241
242 pub fn tool_response(id: impl Into<String>, response: serde_json::Value) -> Self {
244 ClientMessage::ToolResponse {
245 payload: ToolResponsePayload {
246 id: id.into(),
247 response,
248 },
249 }
250 }
251
252 pub fn update_user_id(user_id: impl Into<String>) -> Self {
254 ClientMessage::UpdateUserId {
255 user_id: user_id.into(),
256 }
257 }
258
259 pub fn interrupt() -> Self {
261 ClientMessage::Interrupt
262 }
263
264 pub fn to_json(&self) -> Result<String, serde_json::Error> {
266 serde_json::to_string(self)
267 }
268}
269
270impl ServerMessage {
271 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
273 serde_json::from_str(json)
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_serialize_start_session() {
283 let msg = ClientMessage::start_session(Some("You are helpful".to_string()), None, None, None, None, None);
284 let json = msg.to_json().unwrap();
285 assert!(json.contains("startSession"));
286 assert!(json.contains("prePrompt"));
287 }
288
289 #[test]
290 fn test_serialize_start_session_with_language() {
291 let msg = ClientMessage::start_session(Some("You are helpful".to_string()), Some("ko-KR".to_string()), None, None, None, None);
292 let json = msg.to_json().unwrap();
293 assert!(json.contains("startSession"));
294 assert!(json.contains("prePrompt"));
295 assert!(json.contains("ko-KR"));
296 }
297
298 #[test]
299 fn test_serialize_start_session_with_ai_speaks_first() {
300 let msg = ClientMessage::start_session(
301 Some("You are helpful. Greet the user.".to_string()),
302 None,
303 Some("live".to_string()),
304 Some(true),
305 None,
306 None,
307 );
308 let json = msg.to_json().unwrap();
309 assert!(json.contains("startSession"));
310 assert!(json.contains("aiSpeaksFirst"));
311 assert!(json.contains("true"));
312 }
313
314 #[test]
315 fn test_serialize_start_session_with_allow_harm_category() {
316 let msg = ClientMessage::start_session(
317 Some("You are helpful.".to_string()),
318 None,
319 None,
320 None,
321 Some(false),
322 None,
323 );
324 let json = msg.to_json().unwrap();
325 assert!(json.contains("startSession"));
326 assert!(json.contains("allowHarmCategory"));
327 assert!(json.contains("false"));
328 }
329
330 #[test]
331 fn test_serialize_tool_response() {
332 let msg = ClientMessage::tool_response("call_123", serde_json::json!({"success": true}));
333 let json = msg.to_json().unwrap();
334 assert!(json.contains("toolResponse"));
335 assert!(json.contains("call_123"));
336 assert!(json.contains("success"));
337 }
338
339 #[test]
340 fn test_deserialize_tool_call() {
341 let json = r#"{"type":"toolCall","id":"call_abc","name":"open_login","args":{},"timestamp":"2024-01-01T00:00:00Z"}"#;
342 let msg = ServerMessage::from_json(json).unwrap();
343 match msg {
344 ServerMessage::ToolCall { id, name, .. } => {
345 assert_eq!(id, "call_abc");
346 assert_eq!(name, "open_login");
347 }
348 _ => panic!("Expected ToolCall message"),
349 }
350 }
351
352 #[test]
353 fn test_serialize_audio_chunk() {
354 let msg = ClientMessage::audio_chunk("base64data");
355 let json = msg.to_json().unwrap();
356 assert!(json.contains("audioChunk"));
357 assert!(json.contains("base64data"));
358 }
359
360 #[test]
361 fn test_serialize_audio_start() {
362 let msg = ClientMessage::audio_start();
363 let json = msg.to_json().unwrap();
364 assert!(json.contains("audioStart"));
365 }
366
367 #[test]
368 fn test_deserialize_session_started() {
369 let json = r#"{"type":"sessionStarted","sessionId":"abc123","timestamp":"2024-01-01T00:00:00Z"}"#;
370 let msg = ServerMessage::from_json(json).unwrap();
371 match msg {
372 ServerMessage::SessionStarted { session_id, .. } => {
373 assert_eq!(session_id, "abc123");
374 }
375 _ => panic!("Expected SessionStarted message"),
376 }
377 }
378
379 #[test]
380 fn test_deserialize_ready() {
381 let json = r#"{"type":"ready","timestamp":"2024-01-01T00:00:00Z"}"#;
382 let msg = ServerMessage::from_json(json).unwrap();
383 match msg {
384 ServerMessage::Ready { .. } => {}
385 _ => panic!("Expected Ready message"),
386 }
387 }
388
389 #[test]
390 fn test_deserialize_user_transcript() {
391 let json = r#"{"type":"userTranscript","text":"Hello world","timestamp":"2024-01-01T00:00:00Z"}"#;
392 let msg = ServerMessage::from_json(json).unwrap();
393 match msg {
394 ServerMessage::UserTranscript { text, .. } => {
395 assert_eq!(text, "Hello world");
396 }
397 _ => panic!("Expected UserTranscript message"),
398 }
399 }
400
401 #[test]
402 fn test_deserialize_response() {
403 let json = r#"{"type":"response","text":"Hello!","isFinal":true,"timestamp":"2024-01-01T00:00:00Z"}"#;
404 let msg = ServerMessage::from_json(json).unwrap();
405 match msg {
406 ServerMessage::Response { text, is_final, .. } => {
407 assert_eq!(text, "Hello!");
408 assert!(is_final);
409 }
410 _ => panic!("Expected Response message"),
411 }
412 }
413
414 #[test]
415 fn test_deserialize_turn_complete() {
416 let json = r#"{"type":"turnComplete","timestamp":"2024-01-01T00:00:00Z"}"#;
417 let msg = ServerMessage::from_json(json).unwrap();
418 match msg {
419 ServerMessage::TurnComplete { .. } => {}
420 _ => panic!("Expected TurnComplete message"),
421 }
422 }
423
424 #[test]
425 fn test_serialize_update_user_id() {
426 let msg = ClientMessage::update_user_id("user_abc123");
427 let json = msg.to_json().unwrap();
428 assert!(json.contains("updateUserId"));
429 assert!(json.contains("user_abc123"));
430 }
431
432 #[test]
433 fn test_deserialize_user_id_updated() {
434 let json = r#"{"type":"userIdUpdated","userId":"user_abc123","migratedMessages":12,"timestamp":"2024-01-01T00:00:00Z"}"#;
435 let msg = ServerMessage::from_json(json).unwrap();
436 match msg {
437 ServerMessage::UserIdUpdated { user_id, migrated_messages, .. } => {
438 assert_eq!(user_id, "user_abc123");
439 assert_eq!(migrated_messages, 12);
440 }
441 _ => panic!("Expected UserIdUpdated message"),
442 }
443 }
444
445 #[test]
446 fn test_deserialize_interrupted() {
447 let json = r#"{"type":"interrupted","timestamp":"2024-01-01T00:00:00Z"}"#;
448 let msg = ServerMessage::from_json(json).unwrap();
449 match msg {
450 ServerMessage::Interrupted { timestamp } => {
451 assert_eq!(timestamp, "2024-01-01T00:00:00Z");
452 }
453 _ => panic!("Expected Interrupted message"),
454 }
455 }
456}