use std::sync::Arc;
use crate::error::CodecError;
use crate::traits::CodecBackend;
#[derive(Clone, Debug)]
pub struct CodecBackendHandle<B: CodecBackend> {
backend: Arc<B>,
}
impl<B: CodecBackend> CodecBackendHandle<B> {
pub fn new(backend: B) -> Self {
Self {
backend: Arc::new(backend),
}
}
#[must_use]
pub fn from_arc(backend: Arc<B>) -> Self {
Self { backend }
}
#[must_use]
pub fn backend(&self) -> &Arc<B> {
&self.backend
}
#[must_use]
pub fn into_dyn(self) -> DynCodecProvider
where
B: 'static,
{
DynCodecProvider {
backend: self.backend as Arc<dyn CodecBackend>,
}
}
pub async fn encode_pcm(
&self,
samples: &[f32],
sample_rate: u32,
) -> Result<Vec<u32>, CodecError> {
self.backend.encode_pcm(samples, sample_rate).await
}
pub async fn decode_tokens(
&self,
tokens: &[u32],
num_codebooks: usize,
) -> Result<Vec<f32>, CodecError> {
self.backend.decode_tokens(tokens, num_codebooks).await
}
}
#[derive(Clone)]
pub struct DynCodecProvider {
backend: Arc<dyn CodecBackend>,
}
impl DynCodecProvider {
#[must_use]
pub fn new(backend: Arc<dyn CodecBackend>) -> Self {
Self { backend }
}
#[must_use]
pub fn backend(&self) -> &Arc<dyn CodecBackend> {
&self.backend
}
pub async fn encode_pcm(
&self,
samples: &[f32],
sample_rate: u32,
) -> Result<Vec<u32>, CodecError> {
self.backend.encode_pcm(samples, sample_rate).await
}
pub async fn decode_tokens(
&self,
tokens: &[u32],
num_codebooks: usize,
) -> Result<Vec<f32>, CodecError> {
self.backend.decode_tokens(tokens, num_codebooks).await
}
}
impl std::fmt::Debug for DynCodecProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DynCodecProvider")
.field("backend_id", &self.backend.id())
.field("provider_kind", &self.backend.provider_kind())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use blazen_audio::AudioBackend;
struct FakeCodec;
#[async_trait]
impl AudioBackend for FakeCodec {
fn id(&self) -> &'static str {
"fake-codec"
}
fn provider_kind(&self) -> &'static str {
"codec"
}
}
#[async_trait]
impl CodecBackend for FakeCodec {
async fn encode_pcm(
&self,
samples: &[f32],
_sample_rate: u32,
) -> Result<Vec<u32>, CodecError> {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
Ok(samples.iter().map(|s| (s.abs() * 1000.0) as u32).collect())
}
async fn decode_tokens(
&self,
tokens: &[u32],
num_codebooks: usize,
) -> Result<Vec<f32>, CodecError> {
if !tokens.len().is_multiple_of(num_codebooks) {
return Err(CodecError::invalid_input("misaligned"));
}
#[allow(clippy::cast_precision_loss)]
Ok(tokens.iter().map(|&t| (t as f32) / 1000.0).collect())
}
fn num_codebooks(&self) -> usize {
1
}
}
#[tokio::test]
async fn typed_provider_forwards_to_backend() {
let provider = CodecBackendHandle::new(FakeCodec);
let tokens = provider.encode_pcm(&[0.1, 0.2, 0.3], 24_000).await.unwrap();
assert_eq!(tokens.len(), 3);
let pcm = provider.decode_tokens(&tokens, 1).await.unwrap();
assert_eq!(pcm.len(), 3);
}
#[tokio::test]
async fn dyn_provider_forwards_to_backend() {
let dyn_provider = CodecBackendHandle::new(FakeCodec).into_dyn();
let tokens = dyn_provider
.encode_pcm(&[0.5], 24_000)
.await
.expect("encode succeeds");
assert_eq!(tokens, vec![500]);
}
#[tokio::test]
async fn dyn_provider_debug_includes_id() {
let dyn_provider = CodecBackendHandle::new(FakeCodec).into_dyn();
let dbg = format!("{dyn_provider:?}");
assert!(dbg.contains("fake-codec"));
assert!(dbg.contains("codec"));
}
}