use std::sync::Arc;
use tokio::sync::mpsc;
use crate::frames::{AudioRawData, Frame, FrameDirection, FrameProcessor};
use crate::transport::{BaseTransport, OutputMessage, TransportParams};
use crate::transport::incoming::dispatch_text_message;
#[derive(Debug, Clone)]
pub enum ChannelMessage {
Audio(Vec<u8>),
Text(String),
Interruption,
ClientVadStart(f64),
ClientVadStop(f64),
}
pub struct ChannelTransport {
base: Arc<BaseTransport>,
incoming_rx: std::sync::Mutex<Option<mpsc::Receiver<ChannelMessage>>>,
audio_out_rx: std::sync::Mutex<Option<mpsc::Receiver<OutputMessage>>>,
sample_rate: u32,
channels: u16,
}
const AUDIO_OUT_CHANNEL_CAP: usize = 150;
impl ChannelTransport {
pub fn new(
name: &str,
params: TransportParams,
incoming_rx: mpsc::Receiver<ChannelMessage>,
) -> Self {
let base = Arc::new(BaseTransport::new(name, params.clone()));
let (audio_out_tx, audio_out_rx) = mpsc::channel::<OutputMessage>(AUDIO_OUT_CHANNEL_CAP);
base.set_audio_out_tx(audio_out_tx);
let sample_rate = params.audio_in_sample_rate.unwrap_or(16_000);
let channels = params.audio_in_channels.max(1);
Self {
base,
incoming_rx: std::sync::Mutex::new(Some(incoming_rx)),
audio_out_rx: std::sync::Mutex::new(Some(audio_out_rx)),
sample_rate,
channels,
}
}
pub fn input(&self) -> FrameProcessor {
self.base.input()
}
pub fn output(&self) -> FrameProcessor {
self.base.output()
}
pub async fn run(
&self,
push_tx: mpsc::Sender<(Frame, FrameDirection)>,
outgoing_tx: mpsc::Sender<ChannelMessage>,
) {
let mut incoming_rx = self
.incoming_rx
.lock()
.unwrap()
.take()
.expect("run called more than once on the same ChannelTransport");
let mut audio_out_rx = self
.audio_out_rx
.lock()
.unwrap()
.take()
.expect("run called more than once on the same ChannelTransport");
let base = self.base.clone();
let sample_rate = self.sample_rate;
let channels = self.channels;
loop {
tokio::select! {
msg = incoming_rx.recv() => {
match msg {
Some(ChannelMessage::Audio(bytes)) => {
let data = AudioRawData::new(bytes, sample_rate, channels);
base.push_audio_frame(data).await;
}
Some(ChannelMessage::Text(text)) => {
dispatch_text_message(&text, &push_tx).await;
}
Some(ChannelMessage::Interruption) => {
let _ = push_tx
.send((Frame::interruption(), FrameDirection::Downstream))
.await;
}
Some(ChannelMessage::ClientVadStart(ts)) => {
let _ = push_tx
.send((Frame::client_vad_user_started_speaking(ts), FrameDirection::Downstream))
.await;
}
Some(ChannelMessage::ClientVadStop(ts)) => {
let _ = push_tx
.send((Frame::client_vad_user_stopped_speaking(ts), FrameDirection::Downstream))
.await;
}
None => {
log::debug!("ChannelTransport: incoming channel closed");
break;
}
}
}
output_msg = audio_out_rx.recv() => {
match output_msg {
Some(OutputMessage::Audio(bytes)) => {
if outgoing_tx.send(ChannelMessage::Audio(bytes)).await.is_err() {
log::warn!("ChannelTransport: outgoing channel closed");
break;
}
}
Some(OutputMessage::Text(json)) => {
if outgoing_tx.send(ChannelMessage::Text(json)).await.is_err() {
log::warn!("ChannelTransport: outgoing channel closed");
break;
}
}
Some(OutputMessage::Interruption) => {
while let Ok(queued) = audio_out_rx.try_recv() {
match queued {
OutputMessage::Audio(_) => {} OutputMessage::Interruption => break,
OutputMessage::Text(_) => {} }
}
if outgoing_tx.send(ChannelMessage::Interruption).await.is_err() {
log::warn!("ChannelTransport: outgoing channel closed");
break;
}
log::debug!("ChannelTransport: sent interruption");
}
None => {
log::debug!("ChannelTransport: audio out channel closed");
break;
}
}
}
}
}
let _ = push_tx
.send((Frame::end(), FrameDirection::Downstream))
.await;
}
}