livespeech_sdk/
client.rs

1//! LiveSpeech client implementation
2
3use crate::audio::AudioEncoder;
4use crate::error::{LiveSpeechError, Result};
5use crate::types::*;
6
7use futures_util::{SinkExt, StreamExt};
8use std::sync::Arc;
9use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11use tracing::{debug, error, info, warn};
12use url::Url;
13
14/// Connection state
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ConnectionState {
17    Disconnected,
18    Connecting,
19    Connected,
20    Reconnecting,
21}
22
23/// LiveSpeech client for real-time speech-to-speech AI conversations
24pub struct LiveSpeechClient {
25    config: Config,
26    state: Arc<RwLock<ConnectionState>>,
27    connection_id: Arc<RwLock<Option<String>>>,
28    session_id: Arc<RwLock<Option<String>>>,
29    is_streaming: Arc<RwLock<bool>>,
30    audio_encoder: AudioEncoder,
31    
32    // Message sender to the WebSocket task
33    ws_sender: Arc<Mutex<Option<mpsc::Sender<ClientMessage>>>>,
34    
35    // Event handlers
36    response_handler: Arc<RwLock<Option<ResponseHandler>>>,
37    audio_handler: Arc<RwLock<Option<AudioHandler>>>,
38    error_handler: Arc<RwLock<Option<ErrorHandler>>>,
39    
40    // Broadcast channel for external subscribers (allows multiple receivers)
41    event_sender: broadcast::Sender<LiveSpeechEvent>,
42}
43
44impl LiveSpeechClient {
45    /// Create a new LiveSpeech client
46    pub fn new(config: Config) -> Self {
47        let (event_sender, _) = broadcast::channel(100);
48        
49        Self {
50            config,
51            state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
52            connection_id: Arc::new(RwLock::new(None)),
53            session_id: Arc::new(RwLock::new(None)),
54            is_streaming: Arc::new(RwLock::new(false)),
55            audio_encoder: AudioEncoder::new(),
56            ws_sender: Arc::new(Mutex::new(None)),
57            response_handler: Arc::new(RwLock::new(None)),
58            audio_handler: Arc::new(RwLock::new(None)),
59            error_handler: Arc::new(RwLock::new(None)),
60            event_sender,
61        }
62    }
63
64    /// Get current connection state
65    pub async fn connection_state(&self) -> ConnectionState {
66        *self.state.read().await
67    }
68
69    /// Get connection ID
70    pub async fn connection_id(&self) -> Option<String> {
71        self.connection_id.read().await.clone()
72    }
73
74    /// Get current session ID
75    pub async fn session_id(&self) -> Option<String> {
76        self.session_id.read().await.clone()
77    }
78
79    /// Check if connected
80    pub async fn is_connected(&self) -> bool {
81        *self.state.read().await == ConnectionState::Connected
82    }
83
84    /// Check if session is active
85    pub async fn has_active_session(&self) -> bool {
86        self.session_id.read().await.is_some()
87    }
88
89    /// Check if audio streaming is active
90    pub async fn is_audio_streaming(&self) -> bool {
91        *self.is_streaming.read().await
92    }
93
94    /// Set response handler (AI's text response)
95    pub async fn on_response<F>(&self, handler: F)
96    where
97        F: Fn(&str, bool) + Send + Sync + 'static,
98    {
99        *self.response_handler.write().await = Some(Box::new(handler));
100    }
101
102    /// Set audio handler (AI's audio response)
103    pub async fn on_audio<F>(&self, handler: F)
104    where
105        F: Fn(&[u8]) + Send + Sync + 'static,
106    {
107        *self.audio_handler.write().await = Some(Box::new(handler));
108    }
109
110    /// Set error handler
111    pub async fn on_error<F>(&self, handler: F)
112    where
113        F: Fn(&ErrorEvent) + Send + Sync + 'static,
114    {
115        *self.error_handler.write().await = Some(Box::new(handler));
116    }
117
118    /// Subscribe to all events (recommended for full control)
119    pub fn subscribe(&self) -> broadcast::Receiver<LiveSpeechEvent> {
120        self.event_sender.subscribe()
121    }
122
123    /// Connect to the server
124    pub async fn connect(&self) -> Result<()> {
125        let current_state = *self.state.read().await;
126        if current_state == ConnectionState::Connected || current_state == ConnectionState::Connecting {
127            warn!("Already connected or connecting");
128            return Ok(());
129        }
130
131        *self.state.write().await = ConnectionState::Connecting;
132
133        // Parse and modify URL to include API key and optional user ID
134        let mut url = Url::parse(&self.config.endpoint)?;
135        url.query_pairs_mut()
136            .append_pair("apiKey", &self.config.api_key);
137        
138        // Add userId if provided (for conversation memory persistence)
139        if let Some(ref user_id) = self.config.user_id {
140            url.query_pairs_mut().append_pair("userId", user_id);
141        }
142
143        info!("Connecting to {}", url.host_str().unwrap_or("unknown"));
144
145        // Connect with timeout
146        let connect_future = connect_async(url.as_str());
147        let (ws_stream, _response) = tokio::time::timeout(
148            self.config.connection_timeout,
149            connect_future,
150        )
151        .await
152        .map_err(|_| LiveSpeechError::ConnectionTimeout)?
153        .map_err(LiveSpeechError::WebSocket)?;
154
155        info!("WebSocket connected");
156
157        *self.state.write().await = ConnectionState::Connected;
158        let conn_id = generate_connection_id();
159        *self.connection_id.write().await = Some(conn_id.clone());
160
161        // Emit connected event
162        let timestamp = chrono::Utc::now().to_rfc3339();
163        let _ = self.event_sender.send(LiveSpeechEvent::Connected(ConnectedEvent {
164            connection_id: conn_id,
165            timestamp,
166        }));
167
168        let (write, read) = ws_stream.split();
169        let write = Arc::new(Mutex::new(write));
170
171        // Create channel for sending messages
172        let (msg_sender, mut msg_receiver) = mpsc::channel::<ClientMessage>(100);
173        *self.ws_sender.lock().await = Some(msg_sender);
174
175        // Clone references for the tasks
176        let state = self.state.clone();
177        let session_id = self.session_id.clone();
178        let is_streaming = self.is_streaming.clone();
179        let response_handler = self.response_handler.clone();
180        let audio_handler = self.audio_handler.clone();
181        let error_handler = self.error_handler.clone();
182        let event_sender = self.event_sender.clone();
183        let audio_encoder = self.audio_encoder.clone();
184
185        // Spawn write task
186        let write_clone = write.clone();
187        tokio::spawn(async move {
188            while let Some(msg) = msg_receiver.recv().await {
189                if let Ok(json) = msg.to_json() {
190                    debug!("Sending message: {:?}", msg);
191                    let mut writer = write_clone.lock().await;
192                    if let Err(e) = writer.send(Message::Text(json)).await {
193                        error!("Failed to send message: {}", e);
194                        break;
195                    }
196                }
197            }
198        });
199
200        // Spawn read task
201        tokio::spawn(async move {
202            let mut read = read;
203            while let Some(result) = read.next().await {
204                match result {
205                    Ok(Message::Text(text)) => {
206                        debug!("Received message: {}", text);
207                        match ServerMessage::from_json(&text) {
208                            Ok(msg) => {
209                                Self::handle_message(
210                                    msg,
211                                    &state,
212                                    &session_id,
213                                    &is_streaming,
214                                    &response_handler,
215                                    &audio_handler,
216                                    &error_handler,
217                                    &event_sender,
218                                    &audio_encoder,
219                                )
220                                .await;
221                            }
222                            Err(e) => {
223                                warn!("Failed to parse message: {}", e);
224                            }
225                        }
226                    }
227                    Ok(Message::Close(_)) => {
228                        info!("WebSocket closed by server");
229                        *state.write().await = ConnectionState::Disconnected;
230                        break;
231                    }
232                    Err(e) => {
233                        error!("WebSocket error: {}", e);
234                        *state.write().await = ConnectionState::Disconnected;
235                        break;
236                    }
237                    _ => {}
238                }
239            }
240        });
241
242        // Start ping task
243        let ws_sender = self.ws_sender.clone();
244        tokio::spawn(async move {
245            let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
246            loop {
247                interval.tick().await;
248                if let Some(sender) = ws_sender.lock().await.as_ref() {
249                    if sender.send(ClientMessage::Ping).await.is_err() {
250                        break;
251                    }
252                } else {
253                    break;
254                }
255            }
256        });
257
258        Ok(())
259    }
260
261    /// Disconnect from the server
262    pub async fn disconnect(&self) {
263        info!("Disconnecting");
264        *self.ws_sender.lock().await = None;
265        *self.state.write().await = ConnectionState::Disconnected;
266        *self.connection_id.write().await = None;
267        *self.session_id.write().await = None;
268        *self.is_streaming.write().await = false;
269    }
270
271    /// Start a new session
272    pub async fn start_session(&self, config: Option<SessionConfig>) -> Result<String> {
273        if !self.is_connected().await {
274            return Err(LiveSpeechError::NotConnected);
275        }
276
277        if self.session_id.read().await.is_some() {
278            return Err(LiveSpeechError::SessionAlreadyActive);
279        }
280
281        let (pre_prompt, language, pipeline_mode, ai_speaks_first, allow_harm_category, tools) = config
282            .map(|c| (
283                c.pre_prompt,
284                c.language,
285                Some(c.pipeline_mode.as_str().to_string()),
286                if c.ai_speaks_first { Some(true) } else { None },
287                // allow_harm_category defaults to false; only send Some(true) if explicitly enabled
288                if c.allow_harm_category { Some(true) } else { None },
289                c.tools,
290            ))
291            .unwrap_or((None, None, None, None, None, None));
292        let msg = ClientMessage::start_session(pre_prompt, language, pipeline_mode, ai_speaks_first, allow_harm_category, tools);
293        
294        // Subscribe to events before sending message
295        let mut events = self.event_sender.subscribe();
296        
297        self.send_message(msg).await?;
298
299        // Wait for SessionStarted or Error event with timeout
300        let timeout_duration = self.config.connection_timeout;
301        let result = tokio::time::timeout(timeout_duration, async {
302            while let Ok(event) = events.recv().await {
303                match event {
304                    LiveSpeechEvent::SessionStarted(e) => {
305                        return Ok(e.session_id);
306                    }
307                    LiveSpeechEvent::Error(e) if matches!(e.code, ErrorCode::SessionError) => {
308                        return Err(LiveSpeechError::SessionError(e.message));
309                    }
310                    _ => continue,
311                }
312            }
313            Err(LiveSpeechError::ChannelReceive)
314        })
315        .await
316        .map_err(|_| LiveSpeechError::ConnectionTimeout)?;
317        
318        result
319    }
320
321    /// End the current session
322    pub async fn end_session(&self) -> Result<()> {
323        if self.session_id.read().await.is_none() {
324            warn!("No active session to end");
325            return Ok(());
326        }
327
328        // Stop streaming if active
329        if *self.is_streaming.read().await {
330            self.audio_end().await?;
331        }
332
333        // Subscribe to events before sending message
334        let mut events = self.event_sender.subscribe();
335
336        self.send_message(ClientMessage::end_session()).await?;
337        
338        // Wait for SessionEnded event with timeout
339        let timeout_duration = self.config.connection_timeout;
340        tokio::time::timeout(timeout_duration, async {
341            while let Ok(event) = events.recv().await {
342                if matches!(event, LiveSpeechEvent::SessionEnded(_)) {
343                    return;
344                }
345            }
346        })
347        .await
348        .map_err(|_| LiveSpeechError::ConnectionTimeout)?;
349        
350        Ok(())
351    }
352
353    /// Start audio streaming session
354    pub async fn audio_start(&self) -> Result<()> {
355        if !self.is_connected().await {
356            return Err(LiveSpeechError::NotConnected);
357        }
358
359        if self.session_id.read().await.is_none() {
360            return Err(LiveSpeechError::NoActiveSession);
361        }
362
363        if *self.is_streaming.read().await {
364            return Err(LiveSpeechError::AlreadyStreaming);
365        }
366
367        // Subscribe to events before sending message to ensure we don't miss Ready
368        let mut events = self.event_sender.subscribe();
369        
370        self.send_message(ClientMessage::audio_start()).await?;
371
372        // Wait for Ready or Error event with timeout
373        // This ensures Gemini Live session is fully established before returning
374        let timeout_duration = self.config.connection_timeout;
375        let result = tokio::time::timeout(timeout_duration, async {
376            while let Ok(event) = events.recv().await {
377                match event {
378                    LiveSpeechEvent::Ready(_) => {
379                        return Ok(());
380                    }
381                    LiveSpeechEvent::Error(e) if matches!(e.code, ErrorCode::StreamingError) => {
382                        return Err(LiveSpeechError::SessionError(e.message));
383                    }
384                    _ => continue,
385                }
386            }
387            Err(LiveSpeechError::ChannelReceive)
388        })
389        .await
390        .map_err(|_| LiveSpeechError::ConnectionTimeout)?;
391        
392        result?;
393        
394        *self.is_streaming.write().await = true;
395        info!("LiveSpeech audio stream started");
396        Ok(())
397    }
398
399    /// Send audio chunk (PCM16 bytes)
400    pub async fn send_audio_chunk(&self, data: &[u8]) -> Result<()> {
401        if !self.is_connected().await {
402            return Err(LiveSpeechError::NotConnected);
403        }
404
405        if !*self.is_streaming.read().await {
406            return Err(LiveSpeechError::NotStreaming);
407        }
408
409        let base64_data = self.audio_encoder.encode(data);
410        self.send_message(ClientMessage::audio_chunk(base64_data)).await
411    }
412
413    /// End audio streaming session
414    pub async fn audio_end(&self) -> Result<()> {
415        if !*self.is_streaming.read().await {
416            warn!("Not streaming");
417            return Ok(());
418        }
419
420        self.send_message(ClientMessage::audio_end()).await?;
421        *self.is_streaming.write().await = false;
422        Ok(())
423    }
424
425    /// Send a system message to the AI (AI responds immediately)
426    ///
427    /// This allows your application to inject context or trigger AI responses
428    /// based on external events (e.g., game events, app state changes, timers).
429    ///
430    /// # Arguments
431    /// * `text` - The system message text (max 500 characters)
432    ///
433    /// # Errors
434    /// Returns an error if not connected or not streaming.
435    pub async fn send_system_message(&self, text: &str) -> Result<()> {
436        self.send_system_message_with_options(text, true).await
437    }
438
439    /// Send a system message to the AI with explicit trigger_response option
440    ///
441    /// # Arguments
442    /// * `text` - The system message text (max 500 characters)
443    /// * `trigger_response` - If true, AI responds immediately; if false, context only
444    ///
445    /// # Errors
446    /// Returns an error if not connected, not streaming, or text exceeds 500 chars.
447    pub async fn send_system_message_with_options(&self, text: &str, trigger_response: bool) -> Result<()> {
448        if !self.is_connected().await {
449            return Err(LiveSpeechError::NotConnected);
450        }
451
452        if !*self.is_streaming.read().await {
453            return Err(LiveSpeechError::NotStreaming);
454        }
455
456        if text.len() > 500 {
457            return Err(LiveSpeechError::InvalidParameter("System message too long (max 500 characters)".to_string()));
458        }
459
460        info!("Sending system message: {} (trigger_response: {})", text, trigger_response);
461        self.send_message(ClientMessage::system_message_with_options(text, trigger_response)).await
462    }
463
464    /// Send a tool response (function execution result) back to AI
465    ///
466    /// After receiving a `ToolCall` event, execute the function and send
467    /// the result back using this method. The AI will then continue the
468    /// conversation based on the result.
469    ///
470    /// # Arguments
471    /// * `id` - The tool call ID from the `ToolCallEvent`
472    /// * `response` - The function execution result (will be JSON serialized)
473    ///
474    /// # Example
475    /// ```ignore
476    /// client.on_tool_call(|event| {
477    ///     let result = execute_function(&event.name, &event.args);
478    ///     client.send_tool_response(&event.id, result).await?;
479    /// });
480    /// ```
481    ///
482    /// # Errors
483    /// Returns an error if not connected or not streaming.
484    pub async fn send_tool_response(&self, id: &str, response: serde_json::Value) -> Result<()> {
485        if !self.is_connected().await {
486            return Err(LiveSpeechError::NotConnected);
487        }
488
489        if !*self.is_streaming.read().await {
490            return Err(LiveSpeechError::NotStreaming);
491        }
492
493        info!("Sending tool response for id: {}", id);
494        self.send_message(ClientMessage::tool_response(id, response)).await
495    }
496
497    /// Explicitly interrupt the current AI response
498    ///
499    /// Use this method for:
500    /// - UI "Stop" button functionality
501    /// - Programmatic control to stop AI mid-response
502    ///
503    /// Note: In most cases, simply speaking will trigger automatic
504    /// interruption via Gemini's voice activity detection (VAD).
505    /// This method is for explicit programmatic control.
506    ///
507    /// # Example
508    /// ```ignore
509    /// // User clicks "Stop" button
510    /// client.interrupt().await?;
511    /// ```
512    ///
513    /// # Errors
514    /// Returns an error if not connected or not streaming.
515    pub async fn interrupt(&self) -> Result<()> {
516        if !self.is_connected().await {
517            return Err(LiveSpeechError::NotConnected);
518        }
519
520        if !*self.is_streaming.read().await {
521            return Err(LiveSpeechError::NotStreaming);
522        }
523
524        info!("Sending explicit interrupt");
525        self.send_message(ClientMessage::interrupt()).await
526    }
527
528    /// Update user ID for the current connection (guest-to-user migration)
529    ///
530    /// When a guest user logs in during a session, call this method to migrate
531    /// their conversation history to a persistent user partition. This enables:
532    /// - Entity extraction and long-term memory for the user
533    /// - Conversation continuity across sessions
534    /// - Personalization based on past interactions
535    ///
536    /// # Arguments
537    /// * `user_id` - The authenticated user's unique identifier
538    ///
539    /// # Example
540    /// ```ignore
541    /// // After successful login
542    /// let user = authenticate(credentials).await?;
543    /// client.update_user_id(&user.id).await?;
544    /// ```
545    ///
546    /// # Errors
547    /// Returns an error if not connected or user_id is empty.
548    pub async fn update_user_id(&self, user_id: &str) -> Result<()> {
549        if !self.is_connected().await {
550            return Err(LiveSpeechError::NotConnected);
551        }
552
553        if user_id.trim().is_empty() {
554            return Err(LiveSpeechError::InvalidParameter("userId cannot be empty".to_string()));
555        }
556
557        info!("Updating user ID: {}", user_id);
558        self.send_message(ClientMessage::update_user_id(user_id)).await
559    }
560
561    /// Send a client message
562    async fn send_message(&self, msg: ClientMessage) -> Result<()> {
563        let sender = self.ws_sender.lock().await;
564        if let Some(sender) = sender.as_ref() {
565            sender
566                .send(msg)
567                .await
568                .map_err(|_| LiveSpeechError::ChannelSend)?;
569            Ok(())
570        } else {
571            Err(LiveSpeechError::NotConnected)
572        }
573    }
574
575    /// Handle incoming server message
576    async fn handle_message(
577        msg: ServerMessage,
578        state: &Arc<RwLock<ConnectionState>>,
579        session_id: &Arc<RwLock<Option<String>>>,
580        is_streaming: &Arc<RwLock<bool>>,
581        response_handler: &Arc<RwLock<Option<ResponseHandler>>>,
582        audio_handler: &Arc<RwLock<Option<AudioHandler>>>,
583        error_handler: &Arc<RwLock<Option<ErrorHandler>>>,
584        event_sender: &broadcast::Sender<LiveSpeechEvent>,
585        audio_encoder: &AudioEncoder,
586    ) {
587        let timestamp = chrono::Utc::now().to_rfc3339();
588        
589        match msg {
590            ServerMessage::SessionStarted { session_id: sess_id, .. } => {
591                *session_id.write().await = Some(sess_id.clone());
592                let _ = event_sender.send(LiveSpeechEvent::SessionStarted(SessionStartedEvent {
593                    session_id: sess_id,
594                    timestamp,
595                }));
596            }
597            
598            ServerMessage::SessionEnded { session_id: sess_id, .. } => {
599                *session_id.write().await = None;
600                *is_streaming.write().await = false;
601                let _ = event_sender.send(LiveSpeechEvent::SessionEnded(SessionEndedEvent {
602                    session_id: sess_id,
603                    timestamp,
604                }));
605            }
606            
607            ServerMessage::Ready { .. } => {
608                info!("Session ready for audio input");
609                let _ = event_sender.send(LiveSpeechEvent::Ready(ReadyEvent { timestamp }));
610            }
611            
612            ServerMessage::UserTranscript { text, .. } => {
613                info!("User transcript: {}", text);
614                let _ = event_sender.send(LiveSpeechEvent::UserTranscript(UserTranscriptEvent {
615                    text,
616                    timestamp,
617                }));
618            }
619            
620            ServerMessage::Response { text, is_final, .. } => {
621                if let Some(handler) = response_handler.read().await.as_ref() {
622                    handler(&text, is_final);
623                }
624                let _ = event_sender.send(LiveSpeechEvent::Response(ResponseEvent {
625                    text,
626                    is_final,
627                    timestamp,
628                }));
629            }
630            
631            ServerMessage::Audio { data, format, sample_rate, .. } => {
632                if let Ok(audio_data) = audio_encoder.decode(&data) {
633                    if let Some(handler) = audio_handler.read().await.as_ref() {
634                        handler(&audio_data);
635                    }
636                    let _ = event_sender.send(LiveSpeechEvent::Audio(AudioEvent {
637                        data: audio_data,
638                        format,
639                        sample_rate,
640                        timestamp,
641                    }));
642                }
643            }
644            
645            ServerMessage::TurnComplete { .. } => {
646                info!("Turn complete");
647                let _ = event_sender.send(LiveSpeechEvent::TurnComplete(TurnCompleteEvent { timestamp }));
648            }
649            
650            ServerMessage::ToolCall { id, name, args, .. } => {
651                info!("Tool call received: {} (id: {})", name, id);
652                let _ = event_sender.send(LiveSpeechEvent::ToolCall(ToolCallEvent {
653                    id,
654                    name,
655                    args,
656                    timestamp,
657                }));
658            }
659            
660            ServerMessage::UserIdUpdated { user_id, migrated_messages, .. } => {
661                info!("User ID updated: {}, migrated {} messages", user_id, migrated_messages);
662                let _ = event_sender.send(LiveSpeechEvent::UserIdUpdated(UserIdUpdatedEvent {
663                    user_id,
664                    migrated_messages,
665                    timestamp,
666                }));
667            }
668            
669            ServerMessage::Interrupted { .. } => {
670                info!("AI response interrupted (barge-in)");
671                let _ = event_sender.send(LiveSpeechEvent::Interrupted(InterruptedEvent {
672                    timestamp,
673                }));
674            }
675            
676            ServerMessage::Error { code, message, .. } => {
677                let error_code = match code.as_str() {
678                    "connection_failed" => ErrorCode::ConnectionFailed,
679                    "authentication_failed" => ErrorCode::AuthenticationFailed,
680                    "session_error" => ErrorCode::SessionError,
681                    "audio_error" => ErrorCode::AudioError,
682                    "streaming_error" => ErrorCode::StreamingError,
683                    "stt_error" => ErrorCode::SttError,
684                    "llm_error" => ErrorCode::LlmError,
685                    "tts_error" => ErrorCode::TtsError,
686                    "rate_limit" => ErrorCode::RateLimit,
687                    "user_id_update_error" => ErrorCode::UserIdUpdateError,
688                    _ => ErrorCode::InternalError,
689                };
690                
691                let error_event = ErrorEvent {
692                    code: error_code,
693                    message: message.clone(),
694                    details: None,
695                    timestamp: timestamp.clone(),
696                };
697                
698                if let Some(handler) = error_handler.read().await.as_ref() {
699                    handler(&error_event);
700                }
701                let _ = event_sender.send(LiveSpeechEvent::Error(error_event));
702            }
703            
704            ServerMessage::Pong { .. } => {
705                debug!("Pong received");
706            }
707        }
708        
709        // Suppress unused warning
710        let _ = state;
711    }
712}
713
714/// Generate a client-side connection ID
715fn generate_connection_id() -> String {
716    use std::time::{SystemTime, UNIX_EPOCH};
717    let timestamp = SystemTime::now()
718        .duration_since(UNIX_EPOCH)
719        .unwrap_or_default()
720        .as_millis();
721    format!("client_{}_{:x}", timestamp, rand::random::<u32>())
722}