use std::borrow::Cow;
use async_trait::async_trait;
use livekit::webrtc::audio_frame::AudioFrame;
use livekit::webrtc::audio_source::native::NativeAudioSource;
use crate::error::{RealtimeError, Result};
use crate::runner::EventHandler;
pub struct LiveKitEventHandler<H: EventHandler> {
inner: H,
audio_source: NativeAudioSource,
sample_rate: u32,
num_channels: u32,
}
impl<H: EventHandler> LiveKitEventHandler<H> {
pub fn new(
inner: H,
audio_source: NativeAudioSource,
sample_rate: u32,
num_channels: u32,
) -> Self {
Self { inner, audio_source, sample_rate, num_channels }
}
}
#[async_trait]
impl<H: EventHandler> EventHandler for LiveKitEventHandler<H> {
async fn on_audio(&self, audio: &[u8], item_id: &str) -> Result<()> {
self.inner.on_audio(audio, item_id).await?;
let samples_cow = 'cow: {
#[cfg(target_endian = "little")]
if let Ok(aligned_slice) = bytemuck::try_cast_slice::<u8, i16>(audio) {
break 'cow Cow::Borrowed(aligned_slice);
}
let fallback: Vec<i16> = audio
.chunks_exact(2)
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
Cow::Owned(fallback)
};
if self.num_channels == 0 {
return Err(RealtimeError::provider(
"Cannot push audio to LiveKit NativeAudioSource: num_channels is 0",
));
}
if samples_cow.len() % (self.num_channels as usize) != 0 {
tracing::warn!(
samples_len = samples_cow.len(),
num_channels = self.num_channels,
"Skipping invalid audio frame: sample count is not an exact multiple of channels"
);
return Ok(());
}
let samples_per_channel = samples_cow.len() as u32 / self.num_channels;
let frame = AudioFrame {
data: samples_cow,
sample_rate: self.sample_rate,
num_channels: self.num_channels,
samples_per_channel,
};
if let Err(e) = self.audio_source.capture_frame(&frame).await {
tracing::warn!(error = %e, "Failed to push audio to LiveKit NativeAudioSource");
}
Ok(())
}
async fn on_text(&self, text: &str, item_id: &str) -> Result<()> {
self.inner.on_text(text, item_id).await
}
async fn on_transcript(&self, transcript: &str, item_id: &str) -> Result<()> {
self.inner.on_transcript(transcript, item_id).await
}
async fn on_speech_started(&self, audio_start_ms: u64) -> Result<()> {
self.inner.on_speech_started(audio_start_ms).await
}
async fn on_speech_stopped(&self, audio_end_ms: u64) -> Result<()> {
self.inner.on_speech_stopped(audio_end_ms).await
}
async fn on_response_done(&self) -> Result<()> {
self.inner.on_response_done().await
}
async fn on_error(&self, error: &RealtimeError) -> Result<()> {
self.inner.on_error(error).await
}
}