blazen_audio_codec/
provider.rs1use std::sync::Arc;
11
12use crate::error::CodecError;
13use crate::traits::CodecBackend;
14
15#[derive(Clone, Debug)]
19pub struct CodecBackendHandle<B: CodecBackend> {
20 backend: Arc<B>,
21}
22
23impl<B: CodecBackend> CodecBackendHandle<B> {
24 pub fn new(backend: B) -> Self {
26 Self {
27 backend: Arc::new(backend),
28 }
29 }
30
31 #[must_use]
34 pub fn from_arc(backend: Arc<B>) -> Self {
35 Self { backend }
36 }
37
38 #[must_use]
40 pub fn backend(&self) -> &Arc<B> {
41 &self.backend
42 }
43
44 #[must_use]
46 pub fn into_dyn(self) -> DynCodecProvider
47 where
48 B: 'static,
49 {
50 DynCodecProvider {
51 backend: self.backend as Arc<dyn CodecBackend>,
52 }
53 }
54
55 pub async fn encode_pcm(
61 &self,
62 samples: &[f32],
63 sample_rate: u32,
64 ) -> Result<Vec<u32>, CodecError> {
65 self.backend.encode_pcm(samples, sample_rate).await
66 }
67
68 pub async fn decode_tokens(
74 &self,
75 tokens: &[u32],
76 num_codebooks: usize,
77 ) -> Result<Vec<f32>, CodecError> {
78 self.backend.decode_tokens(tokens, num_codebooks).await
79 }
80}
81
82#[derive(Clone)]
90pub struct DynCodecProvider {
91 backend: Arc<dyn CodecBackend>,
92}
93
94impl DynCodecProvider {
95 #[must_use]
97 pub fn new(backend: Arc<dyn CodecBackend>) -> Self {
98 Self { backend }
99 }
100
101 #[must_use]
103 pub fn backend(&self) -> &Arc<dyn CodecBackend> {
104 &self.backend
105 }
106
107 pub async fn encode_pcm(
113 &self,
114 samples: &[f32],
115 sample_rate: u32,
116 ) -> Result<Vec<u32>, CodecError> {
117 self.backend.encode_pcm(samples, sample_rate).await
118 }
119
120 pub async fn decode_tokens(
126 &self,
127 tokens: &[u32],
128 num_codebooks: usize,
129 ) -> Result<Vec<f32>, CodecError> {
130 self.backend.decode_tokens(tokens, num_codebooks).await
131 }
132}
133
134impl std::fmt::Debug for DynCodecProvider {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("DynCodecProvider")
137 .field("backend_id", &self.backend.id())
138 .field("provider_kind", &self.backend.provider_kind())
139 .finish()
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use async_trait::async_trait;
147 use blazen_audio::AudioBackend;
148
149 struct FakeCodec;
150
151 #[async_trait]
152 impl AudioBackend for FakeCodec {
153 fn id(&self) -> &'static str {
154 "fake-codec"
155 }
156 fn provider_kind(&self) -> &'static str {
157 "codec"
158 }
159 }
160
161 #[async_trait]
162 impl CodecBackend for FakeCodec {
163 async fn encode_pcm(
164 &self,
165 samples: &[f32],
166 _sample_rate: u32,
167 ) -> Result<Vec<u32>, CodecError> {
168 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
170 Ok(samples.iter().map(|s| (s.abs() * 1000.0) as u32).collect())
171 }
172
173 async fn decode_tokens(
174 &self,
175 tokens: &[u32],
176 num_codebooks: usize,
177 ) -> Result<Vec<f32>, CodecError> {
178 if !tokens.len().is_multiple_of(num_codebooks) {
179 return Err(CodecError::invalid_input("misaligned"));
180 }
181 #[allow(clippy::cast_precision_loss)]
182 Ok(tokens.iter().map(|&t| (t as f32) / 1000.0).collect())
183 }
184
185 fn num_codebooks(&self) -> usize {
186 1
187 }
188 }
189
190 #[tokio::test]
191 async fn typed_provider_forwards_to_backend() {
192 let provider = CodecBackendHandle::new(FakeCodec);
193 let tokens = provider.encode_pcm(&[0.1, 0.2, 0.3], 24_000).await.unwrap();
194 assert_eq!(tokens.len(), 3);
195 let pcm = provider.decode_tokens(&tokens, 1).await.unwrap();
196 assert_eq!(pcm.len(), 3);
197 }
198
199 #[tokio::test]
200 async fn dyn_provider_forwards_to_backend() {
201 let dyn_provider = CodecBackendHandle::new(FakeCodec).into_dyn();
202 let tokens = dyn_provider
203 .encode_pcm(&[0.5], 24_000)
204 .await
205 .expect("encode succeeds");
206 assert_eq!(tokens, vec![500]);
207 }
208
209 #[tokio::test]
210 async fn dyn_provider_debug_includes_id() {
211 let dyn_provider = CodecBackendHandle::new(FakeCodec).into_dyn();
212 let dbg = format!("{dyn_provider:?}");
213 assert!(dbg.contains("fake-codec"));
214 assert!(dbg.contains("codec"));
215 }
216}