1use crate::audio::AudioEncoder;
4use crate::error::{LiveSpeechError, Result};
5use crate::types::*;
6
7use futures_util::{SinkExt, StreamExt};
8use std::sync::Arc;
9use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11use tracing::{debug, error, info, warn};
12use url::Url;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ConnectionState {
17 Disconnected,
18 Connecting,
19 Connected,
20 Reconnecting,
21}
22
23pub struct LiveSpeechClient {
25 config: Config,
26 state: Arc<RwLock<ConnectionState>>,
27 connection_id: Arc<RwLock<Option<String>>>,
28 session_id: Arc<RwLock<Option<String>>>,
29 is_streaming: Arc<RwLock<bool>>,
30 audio_encoder: AudioEncoder,
31
32 ws_sender: Arc<Mutex<Option<mpsc::Sender<ClientMessage>>>>,
34
35 response_handler: Arc<RwLock<Option<ResponseHandler>>>,
37 audio_handler: Arc<RwLock<Option<AudioHandler>>>,
38 error_handler: Arc<RwLock<Option<ErrorHandler>>>,
39
40 event_sender: broadcast::Sender<LiveSpeechEvent>,
42}
43
44impl LiveSpeechClient {
45 pub fn new(config: Config) -> Self {
47 let (event_sender, _) = broadcast::channel(100);
48
49 Self {
50 config,
51 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
52 connection_id: Arc::new(RwLock::new(None)),
53 session_id: Arc::new(RwLock::new(None)),
54 is_streaming: Arc::new(RwLock::new(false)),
55 audio_encoder: AudioEncoder::new(),
56 ws_sender: Arc::new(Mutex::new(None)),
57 response_handler: Arc::new(RwLock::new(None)),
58 audio_handler: Arc::new(RwLock::new(None)),
59 error_handler: Arc::new(RwLock::new(None)),
60 event_sender,
61 }
62 }
63
64 pub async fn connection_state(&self) -> ConnectionState {
66 *self.state.read().await
67 }
68
69 pub async fn connection_id(&self) -> Option<String> {
71 self.connection_id.read().await.clone()
72 }
73
74 pub async fn session_id(&self) -> Option<String> {
76 self.session_id.read().await.clone()
77 }
78
79 pub async fn is_connected(&self) -> bool {
81 *self.state.read().await == ConnectionState::Connected
82 }
83
84 pub async fn has_active_session(&self) -> bool {
86 self.session_id.read().await.is_some()
87 }
88
89 pub async fn is_audio_streaming(&self) -> bool {
91 *self.is_streaming.read().await
92 }
93
94 pub async fn on_response<F>(&self, handler: F)
96 where
97 F: Fn(&str, bool) + Send + Sync + 'static,
98 {
99 *self.response_handler.write().await = Some(Box::new(handler));
100 }
101
102 pub async fn on_audio<F>(&self, handler: F)
104 where
105 F: Fn(&[u8]) + Send + Sync + 'static,
106 {
107 *self.audio_handler.write().await = Some(Box::new(handler));
108 }
109
110 pub async fn on_error<F>(&self, handler: F)
112 where
113 F: Fn(&ErrorEvent) + Send + Sync + 'static,
114 {
115 *self.error_handler.write().await = Some(Box::new(handler));
116 }
117
118 pub fn subscribe(&self) -> broadcast::Receiver<LiveSpeechEvent> {
120 self.event_sender.subscribe()
121 }
122
123 pub async fn connect(&self) -> Result<()> {
125 let current_state = *self.state.read().await;
126 if current_state == ConnectionState::Connected || current_state == ConnectionState::Connecting {
127 warn!("Already connected or connecting");
128 return Ok(());
129 }
130
131 *self.state.write().await = ConnectionState::Connecting;
132
133 let mut url = Url::parse(&self.config.endpoint)?;
135 url.query_pairs_mut()
136 .append_pair("apiKey", &self.config.api_key);
137
138 if let Some(ref user_id) = self.config.user_id {
140 url.query_pairs_mut().append_pair("userId", user_id);
141 }
142
143 info!("Connecting to {}", url.host_str().unwrap_or("unknown"));
144
145 let connect_future = connect_async(url.as_str());
147 let (ws_stream, _response) = tokio::time::timeout(
148 self.config.connection_timeout,
149 connect_future,
150 )
151 .await
152 .map_err(|_| LiveSpeechError::ConnectionTimeout)?
153 .map_err(LiveSpeechError::WebSocket)?;
154
155 info!("WebSocket connected");
156
157 *self.state.write().await = ConnectionState::Connected;
158 let conn_id = generate_connection_id();
159 *self.connection_id.write().await = Some(conn_id.clone());
160
161 let timestamp = chrono::Utc::now().to_rfc3339();
163 let _ = self.event_sender.send(LiveSpeechEvent::Connected(ConnectedEvent {
164 connection_id: conn_id,
165 timestamp,
166 }));
167
168 let (write, read) = ws_stream.split();
169 let write = Arc::new(Mutex::new(write));
170
171 let (msg_sender, mut msg_receiver) = mpsc::channel::<ClientMessage>(100);
173 *self.ws_sender.lock().await = Some(msg_sender);
174
175 let state = self.state.clone();
177 let session_id = self.session_id.clone();
178 let is_streaming = self.is_streaming.clone();
179 let response_handler = self.response_handler.clone();
180 let audio_handler = self.audio_handler.clone();
181 let error_handler = self.error_handler.clone();
182 let event_sender = self.event_sender.clone();
183 let audio_encoder = self.audio_encoder.clone();
184
185 let write_clone = write.clone();
187 tokio::spawn(async move {
188 while let Some(msg) = msg_receiver.recv().await {
189 if let Ok(json) = msg.to_json() {
190 debug!("Sending message: {:?}", msg);
191 let mut writer = write_clone.lock().await;
192 if let Err(e) = writer.send(Message::Text(json)).await {
193 error!("Failed to send message: {}", e);
194 break;
195 }
196 }
197 }
198 });
199
200 tokio::spawn(async move {
202 let mut read = read;
203 while let Some(result) = read.next().await {
204 match result {
205 Ok(Message::Text(text)) => {
206 debug!("Received message: {}", text);
207 match ServerMessage::from_json(&text) {
208 Ok(msg) => {
209 Self::handle_message(
210 msg,
211 &state,
212 &session_id,
213 &is_streaming,
214 &response_handler,
215 &audio_handler,
216 &error_handler,
217 &event_sender,
218 &audio_encoder,
219 )
220 .await;
221 }
222 Err(e) => {
223 warn!("Failed to parse message: {}", e);
224 }
225 }
226 }
227 Ok(Message::Close(_)) => {
228 info!("WebSocket closed by server");
229 *state.write().await = ConnectionState::Disconnected;
230 break;
231 }
232 Err(e) => {
233 error!("WebSocket error: {}", e);
234 *state.write().await = ConnectionState::Disconnected;
235 break;
236 }
237 _ => {}
238 }
239 }
240 });
241
242 let ws_sender = self.ws_sender.clone();
244 tokio::spawn(async move {
245 let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
246 loop {
247 interval.tick().await;
248 if let Some(sender) = ws_sender.lock().await.as_ref() {
249 if sender.send(ClientMessage::Ping).await.is_err() {
250 break;
251 }
252 } else {
253 break;
254 }
255 }
256 });
257
258 Ok(())
259 }
260
261 pub async fn disconnect(&self) {
263 info!("Disconnecting");
264 *self.ws_sender.lock().await = None;
265 *self.state.write().await = ConnectionState::Disconnected;
266 *self.connection_id.write().await = None;
267 *self.session_id.write().await = None;
268 *self.is_streaming.write().await = false;
269 }
270
271 pub async fn start_session(&self, config: Option<SessionConfig>) -> Result<String> {
273 if !self.is_connected().await {
274 return Err(LiveSpeechError::NotConnected);
275 }
276
277 if self.session_id.read().await.is_some() {
278 return Err(LiveSpeechError::SessionAlreadyActive);
279 }
280
281 let (pre_prompt, language, pipeline_mode, ai_speaks_first, allow_harm_category, tools) = config
282 .map(|c| (
283 c.pre_prompt,
284 c.language,
285 Some(c.pipeline_mode.as_str().to_string()),
286 if c.ai_speaks_first { Some(true) } else { None },
287 if c.allow_harm_category { Some(true) } else { None },
289 c.tools,
290 ))
291 .unwrap_or((None, None, None, None, None, None));
292 let msg = ClientMessage::start_session(pre_prompt, language, pipeline_mode, ai_speaks_first, allow_harm_category, tools);
293
294 let mut events = self.event_sender.subscribe();
296
297 self.send_message(msg).await?;
298
299 let timeout_duration = self.config.connection_timeout;
301 let result = tokio::time::timeout(timeout_duration, async {
302 while let Ok(event) = events.recv().await {
303 match event {
304 LiveSpeechEvent::SessionStarted(e) => {
305 return Ok(e.session_id);
306 }
307 LiveSpeechEvent::Error(e) if matches!(e.code, ErrorCode::SessionError) => {
308 return Err(LiveSpeechError::SessionError(e.message));
309 }
310 _ => continue,
311 }
312 }
313 Err(LiveSpeechError::ChannelReceive)
314 })
315 .await
316 .map_err(|_| LiveSpeechError::ConnectionTimeout)?;
317
318 result
319 }
320
321 pub async fn end_session(&self) -> Result<()> {
323 if self.session_id.read().await.is_none() {
324 warn!("No active session to end");
325 return Ok(());
326 }
327
328 if *self.is_streaming.read().await {
330 self.audio_end().await?;
331 }
332
333 let mut events = self.event_sender.subscribe();
335
336 self.send_message(ClientMessage::end_session()).await?;
337
338 let timeout_duration = self.config.connection_timeout;
340 tokio::time::timeout(timeout_duration, async {
341 while let Ok(event) = events.recv().await {
342 if matches!(event, LiveSpeechEvent::SessionEnded(_)) {
343 return;
344 }
345 }
346 })
347 .await
348 .map_err(|_| LiveSpeechError::ConnectionTimeout)?;
349
350 Ok(())
351 }
352
353 pub async fn audio_start(&self) -> Result<()> {
355 if !self.is_connected().await {
356 return Err(LiveSpeechError::NotConnected);
357 }
358
359 if self.session_id.read().await.is_none() {
360 return Err(LiveSpeechError::NoActiveSession);
361 }
362
363 if *self.is_streaming.read().await {
364 return Err(LiveSpeechError::AlreadyStreaming);
365 }
366
367 let mut events = self.event_sender.subscribe();
369
370 self.send_message(ClientMessage::audio_start()).await?;
371
372 let timeout_duration = self.config.connection_timeout;
375 let result = tokio::time::timeout(timeout_duration, async {
376 while let Ok(event) = events.recv().await {
377 match event {
378 LiveSpeechEvent::Ready(_) => {
379 return Ok(());
380 }
381 LiveSpeechEvent::Error(e) if matches!(e.code, ErrorCode::StreamingError) => {
382 return Err(LiveSpeechError::SessionError(e.message));
383 }
384 _ => continue,
385 }
386 }
387 Err(LiveSpeechError::ChannelReceive)
388 })
389 .await
390 .map_err(|_| LiveSpeechError::ConnectionTimeout)?;
391
392 result?;
393
394 *self.is_streaming.write().await = true;
395 info!("LiveSpeech audio stream started");
396 Ok(())
397 }
398
399 pub async fn send_audio_chunk(&self, data: &[u8]) -> Result<()> {
401 if !self.is_connected().await {
402 return Err(LiveSpeechError::NotConnected);
403 }
404
405 if !*self.is_streaming.read().await {
406 return Err(LiveSpeechError::NotStreaming);
407 }
408
409 let base64_data = self.audio_encoder.encode(data);
410 self.send_message(ClientMessage::audio_chunk(base64_data)).await
411 }
412
413 pub async fn audio_end(&self) -> Result<()> {
415 if !*self.is_streaming.read().await {
416 warn!("Not streaming");
417 return Ok(());
418 }
419
420 self.send_message(ClientMessage::audio_end()).await?;
421 *self.is_streaming.write().await = false;
422 Ok(())
423 }
424
425 pub async fn send_system_message(&self, text: &str) -> Result<()> {
436 self.send_system_message_with_options(text, true).await
437 }
438
439 pub async fn send_system_message_with_options(&self, text: &str, trigger_response: bool) -> Result<()> {
448 if !self.is_connected().await {
449 return Err(LiveSpeechError::NotConnected);
450 }
451
452 if !*self.is_streaming.read().await {
453 return Err(LiveSpeechError::NotStreaming);
454 }
455
456 if text.len() > 500 {
457 return Err(LiveSpeechError::InvalidParameter("System message too long (max 500 characters)".to_string()));
458 }
459
460 info!("Sending system message: {} (trigger_response: {})", text, trigger_response);
461 self.send_message(ClientMessage::system_message_with_options(text, trigger_response)).await
462 }
463
464 pub async fn send_tool_response(&self, id: &str, response: serde_json::Value) -> Result<()> {
485 if !self.is_connected().await {
486 return Err(LiveSpeechError::NotConnected);
487 }
488
489 if !*self.is_streaming.read().await {
490 return Err(LiveSpeechError::NotStreaming);
491 }
492
493 info!("Sending tool response for id: {}", id);
494 self.send_message(ClientMessage::tool_response(id, response)).await
495 }
496
497 pub async fn interrupt(&self) -> Result<()> {
516 if !self.is_connected().await {
517 return Err(LiveSpeechError::NotConnected);
518 }
519
520 if !*self.is_streaming.read().await {
521 return Err(LiveSpeechError::NotStreaming);
522 }
523
524 info!("Sending explicit interrupt");
525 self.send_message(ClientMessage::interrupt()).await
526 }
527
528 pub async fn update_user_id(&self, user_id: &str) -> Result<()> {
549 if !self.is_connected().await {
550 return Err(LiveSpeechError::NotConnected);
551 }
552
553 if user_id.trim().is_empty() {
554 return Err(LiveSpeechError::InvalidParameter("userId cannot be empty".to_string()));
555 }
556
557 info!("Updating user ID: {}", user_id);
558 self.send_message(ClientMessage::update_user_id(user_id)).await
559 }
560
561 async fn send_message(&self, msg: ClientMessage) -> Result<()> {
563 let sender = self.ws_sender.lock().await;
564 if let Some(sender) = sender.as_ref() {
565 sender
566 .send(msg)
567 .await
568 .map_err(|_| LiveSpeechError::ChannelSend)?;
569 Ok(())
570 } else {
571 Err(LiveSpeechError::NotConnected)
572 }
573 }
574
575 async fn handle_message(
577 msg: ServerMessage,
578 state: &Arc<RwLock<ConnectionState>>,
579 session_id: &Arc<RwLock<Option<String>>>,
580 is_streaming: &Arc<RwLock<bool>>,
581 response_handler: &Arc<RwLock<Option<ResponseHandler>>>,
582 audio_handler: &Arc<RwLock<Option<AudioHandler>>>,
583 error_handler: &Arc<RwLock<Option<ErrorHandler>>>,
584 event_sender: &broadcast::Sender<LiveSpeechEvent>,
585 audio_encoder: &AudioEncoder,
586 ) {
587 let timestamp = chrono::Utc::now().to_rfc3339();
588
589 match msg {
590 ServerMessage::SessionStarted { session_id: sess_id, .. } => {
591 *session_id.write().await = Some(sess_id.clone());
592 let _ = event_sender.send(LiveSpeechEvent::SessionStarted(SessionStartedEvent {
593 session_id: sess_id,
594 timestamp,
595 }));
596 }
597
598 ServerMessage::SessionEnded { session_id: sess_id, .. } => {
599 *session_id.write().await = None;
600 *is_streaming.write().await = false;
601 let _ = event_sender.send(LiveSpeechEvent::SessionEnded(SessionEndedEvent {
602 session_id: sess_id,
603 timestamp,
604 }));
605 }
606
607 ServerMessage::Ready { .. } => {
608 info!("Session ready for audio input");
609 let _ = event_sender.send(LiveSpeechEvent::Ready(ReadyEvent { timestamp }));
610 }
611
612 ServerMessage::UserTranscript { text, .. } => {
613 info!("User transcript: {}", text);
614 let _ = event_sender.send(LiveSpeechEvent::UserTranscript(UserTranscriptEvent {
615 text,
616 timestamp,
617 }));
618 }
619
620 ServerMessage::Response { text, is_final, .. } => {
621 if let Some(handler) = response_handler.read().await.as_ref() {
622 handler(&text, is_final);
623 }
624 let _ = event_sender.send(LiveSpeechEvent::Response(ResponseEvent {
625 text,
626 is_final,
627 timestamp,
628 }));
629 }
630
631 ServerMessage::Audio { data, format, sample_rate, .. } => {
632 if let Ok(audio_data) = audio_encoder.decode(&data) {
633 if let Some(handler) = audio_handler.read().await.as_ref() {
634 handler(&audio_data);
635 }
636 let _ = event_sender.send(LiveSpeechEvent::Audio(AudioEvent {
637 data: audio_data,
638 format,
639 sample_rate,
640 timestamp,
641 }));
642 }
643 }
644
645 ServerMessage::TurnComplete { .. } => {
646 info!("Turn complete");
647 let _ = event_sender.send(LiveSpeechEvent::TurnComplete(TurnCompleteEvent { timestamp }));
648 }
649
650 ServerMessage::ToolCall { id, name, args, .. } => {
651 info!("Tool call received: {} (id: {})", name, id);
652 let _ = event_sender.send(LiveSpeechEvent::ToolCall(ToolCallEvent {
653 id,
654 name,
655 args,
656 timestamp,
657 }));
658 }
659
660 ServerMessage::UserIdUpdated { user_id, migrated_messages, .. } => {
661 info!("User ID updated: {}, migrated {} messages", user_id, migrated_messages);
662 let _ = event_sender.send(LiveSpeechEvent::UserIdUpdated(UserIdUpdatedEvent {
663 user_id,
664 migrated_messages,
665 timestamp,
666 }));
667 }
668
669 ServerMessage::Interrupted { .. } => {
670 info!("AI response interrupted (barge-in)");
671 let _ = event_sender.send(LiveSpeechEvent::Interrupted(InterruptedEvent {
672 timestamp,
673 }));
674 }
675
676 ServerMessage::Error { code, message, .. } => {
677 let error_code = match code.as_str() {
678 "connection_failed" => ErrorCode::ConnectionFailed,
679 "authentication_failed" => ErrorCode::AuthenticationFailed,
680 "session_error" => ErrorCode::SessionError,
681 "audio_error" => ErrorCode::AudioError,
682 "streaming_error" => ErrorCode::StreamingError,
683 "stt_error" => ErrorCode::SttError,
684 "llm_error" => ErrorCode::LlmError,
685 "tts_error" => ErrorCode::TtsError,
686 "rate_limit" => ErrorCode::RateLimit,
687 "user_id_update_error" => ErrorCode::UserIdUpdateError,
688 _ => ErrorCode::InternalError,
689 };
690
691 let error_event = ErrorEvent {
692 code: error_code,
693 message: message.clone(),
694 details: None,
695 timestamp: timestamp.clone(),
696 };
697
698 if let Some(handler) = error_handler.read().await.as_ref() {
699 handler(&error_event);
700 }
701 let _ = event_sender.send(LiveSpeechEvent::Error(error_event));
702 }
703
704 ServerMessage::Pong { .. } => {
705 debug!("Pong received");
706 }
707 }
708
709 let _ = state;
711 }
712}
713
714fn generate_connection_id() -> String {
716 use std::time::{SystemTime, UNIX_EPOCH};
717 let timestamp = SystemTime::now()
718 .duration_since(UNIX_EPOCH)
719 .unwrap_or_default()
720 .as_millis();
721 format!("client_{}_{:x}", timestamp, rand::random::<u32>())
722}