use std::sync::Arc;
use axum::extract::ws::{Message, WebSocket};
use tokio::sync::mpsc;
use crate::frames::{AudioRawData, Frame, FrameDirection, FrameProcessor};
use crate::transport::{BaseTransport, TransportParams};
use crate::transport::output::OutputMessage;
#[derive(Debug, Clone)]
pub struct WebSocketParams {
pub transport: TransportParams,
}
impl Default for WebSocketParams {
fn default() -> Self {
Self {
transport: TransportParams {
audio_in_enabled: true,
audio_in_sample_rate: Some(16_000),
audio_in_channels: 1,
audio_in_passthrough: true,
audio_in_stream_on_start: true,
..TransportParams::default()
},
}
}
}
pub struct WebSocketTransport {
base: Arc<BaseTransport>,
audio_out_rx: std::sync::Mutex<Option<mpsc::Receiver<OutputMessage>>>,
}
const AUDIO_OUT_CHANNEL_CAP: usize = 150;
impl WebSocketTransport {
pub fn new(name: &str, params: WebSocketParams) -> Self {
let base = Arc::new(BaseTransport::new(name, params.transport));
let (audio_out_tx, audio_out_rx) = mpsc::channel::<OutputMessage>(AUDIO_OUT_CHANNEL_CAP);
base.set_audio_out_tx(audio_out_tx);
Self {
base,
audio_out_rx: std::sync::Mutex::new(Some(audio_out_rx)),
}
}
pub fn input(&self) -> FrameProcessor {
self.base.input()
}
pub fn output(&self) -> FrameProcessor {
self.base.output()
}
pub async fn run_socket(
&self,
mut socket: WebSocket,
push_tx: mpsc::Sender<(Frame, FrameDirection)>,
) {
let mut audio_out_rx = self
.audio_out_rx
.lock()
.unwrap()
.take()
.expect("run_socket called more than once on the same WebSocketTransport");
let base = self.base.clone();
loop {
tokio::select! {
msg = socket.recv() => {
match msg {
Some(Ok(Message::Binary(bytes))) => {
let data = AudioRawData::new(bytes.to_vec(), 16_000, 1);
base.push_audio_frame(data).await;
}
Some(Ok(Message::Text(text))) => {
handle_incoming_text(&text, &push_tx).await;
}
Some(Ok(Message::Close(_))) | None => {
log::debug!("WebSocketTransport: client closed connection");
break;
}
Some(Ok(_)) => {}
Some(Err(e)) => {
log::warn!("WebSocketTransport: socket error: {}", e);
break;
}
}
}
output_msg = audio_out_rx.recv() => {
match output_msg {
Some(OutputMessage::Audio(bytes)) => {
if socket.send(Message::Binary(bytes.into())).await.is_err() {
log::warn!("WebSocketTransport: failed to send audio");
break;
}
}
Some(OutputMessage::Text(json)) => {
if socket.send(Message::Text(json.into())).await.is_err() {
log::warn!("WebSocketTransport: failed to send RAVI text message");
break;
}
}
Some(OutputMessage::Interruption) => {
while let Ok(queued) = audio_out_rx.try_recv() {
match queued {
OutputMessage::Audio(_) => {} OutputMessage::Interruption => break,
OutputMessage::Text(_) => {} }
}
let json = r#"{"type":"interruption"}"#;
if socket.send(Message::Text(json.into())).await.is_err() {
log::warn!("WebSocketTransport: failed to send interruption");
break;
}
log::debug!("WebSocketTransport: sent interruption to client");
}
None => break, }
}
}
}
let _ = push_tx
.send((Frame::end(), FrameDirection::Downstream))
.await;
}
}
async fn handle_incoming_text(
text: &str,
push_tx: &mpsc::Sender<(Frame, FrameDirection)>,
) {
let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) else {
log::warn!("WebSocketTransport: ignoring non-JSON text message");
return;
};
let msg_type = msg.get("type").and_then(|v| v.as_str()).unwrap_or("");
let label = msg.get("label").and_then(|v| v.as_str()).unwrap_or("");
if label == "ravi" {
let Some(msg_id) = msg.get("id").and_then(|v| v.as_str()) else {
log::warn!("WebSocketTransport: RAVI message missing 'id' field — dropping");
return;
};
let data_str = msg.get("data").map(|d| d.to_string());
let frame = Frame::ravi_client_message(msg_id, msg_type, data_str);
let _ = push_tx.send((frame, FrameDirection::Downstream)).await;
log::trace!("WebSocketTransport: RAVI '{}' (id={})", msg_type, msg_id);
return;
}
if msg_type == "client_interruption" {
log::info!("WebSocketTransport: legacy client-initiated interruption");
let _ = push_tx
.send((Frame::interruption(), FrameDirection::Downstream))
.await;
}
}