Skip to main content

blazen_audio_codec/
provider.rs

1//! Provider wrappers exposing [`crate::CodecBackend`] to in-process
2//! callers and to the binding / bridge layer.
3//!
4//! See `PR_AUDIO_PLAN.md` Appendix B for the full dual-shape rationale.
5//! Short version: Rust callers benefit from monomorphization through
6//! [`CodecBackendHandle<B>`]; bindings (Python / Node / UniFFI / cabi) can't
7//! cross generics through their C ABI and need an erased
8//! [`DynCodecProvider`] (`Arc<dyn CodecBackend>`) instead.
9
10use std::sync::Arc;
11
12use crate::error::CodecError;
13use crate::traits::CodecBackend;
14
15/// Typed codec provider — wraps a concrete [`CodecBackend`] implementation
16/// and is monomorphized by the compiler. Use this from Rust callers in
17/// hot loops.
18#[derive(Clone, Debug)]
19pub struct CodecBackendHandle<B: CodecBackend> {
20    backend: Arc<B>,
21}
22
23impl<B: CodecBackend> CodecBackendHandle<B> {
24    /// Wrap an existing backend.
25    pub fn new(backend: B) -> Self {
26        Self {
27            backend: Arc::new(backend),
28        }
29    }
30
31    /// Wrap an already-`Arc`'d backend (lets two providers share one
32    /// loaded model instance).
33    #[must_use]
34    pub fn from_arc(backend: Arc<B>) -> Self {
35        Self { backend }
36    }
37
38    /// Borrow the underlying backend.
39    #[must_use]
40    pub fn backend(&self) -> &Arc<B> {
41        &self.backend
42    }
43
44    /// Erase to a [`DynCodecProvider`] for the bindings boundary.
45    #[must_use]
46    pub fn into_dyn(self) -> DynCodecProvider
47    where
48        B: 'static,
49    {
50        DynCodecProvider {
51            backend: self.backend as Arc<dyn CodecBackend>,
52        }
53    }
54
55    /// Forward to [`CodecBackend::encode_pcm`].
56    ///
57    /// # Errors
58    ///
59    /// Propagates [`CodecError`] from the underlying backend.
60    pub async fn encode_pcm(
61        &self,
62        samples: &[f32],
63        sample_rate: u32,
64    ) -> Result<Vec<u32>, CodecError> {
65        self.backend.encode_pcm(samples, sample_rate).await
66    }
67
68    /// Forward to [`CodecBackend::decode_tokens`].
69    ///
70    /// # Errors
71    ///
72    /// Propagates [`CodecError`] from the underlying backend.
73    pub async fn decode_tokens(
74        &self,
75        tokens: &[u32],
76        num_codebooks: usize,
77    ) -> Result<Vec<f32>, CodecError> {
78        self.backend.decode_tokens(tokens, num_codebooks).await
79    }
80}
81
82// ---------------------------------------------------------------------------
83// Erased provider for the bindings boundary
84// ---------------------------------------------------------------------------
85
86/// Type-erased codec provider — wraps `Arc<dyn CodecBackend>`. Use this
87/// from the Python / Node / UniFFI / cabi bridge layers where generic
88/// providers can't cross the FFI boundary.
89#[derive(Clone)]
90pub struct DynCodecProvider {
91    backend: Arc<dyn CodecBackend>,
92}
93
94impl DynCodecProvider {
95    /// Wrap a pre-erased backend.
96    #[must_use]
97    pub fn new(backend: Arc<dyn CodecBackend>) -> Self {
98        Self { backend }
99    }
100
101    /// Borrow the underlying backend.
102    #[must_use]
103    pub fn backend(&self) -> &Arc<dyn CodecBackend> {
104        &self.backend
105    }
106
107    /// Forward to [`CodecBackend::encode_pcm`].
108    ///
109    /// # Errors
110    ///
111    /// Propagates [`CodecError`] from the underlying backend.
112    pub async fn encode_pcm(
113        &self,
114        samples: &[f32],
115        sample_rate: u32,
116    ) -> Result<Vec<u32>, CodecError> {
117        self.backend.encode_pcm(samples, sample_rate).await
118    }
119
120    /// Forward to [`CodecBackend::decode_tokens`].
121    ///
122    /// # Errors
123    ///
124    /// Propagates [`CodecError`] from the underlying backend.
125    pub async fn decode_tokens(
126        &self,
127        tokens: &[u32],
128        num_codebooks: usize,
129    ) -> Result<Vec<f32>, CodecError> {
130        self.backend.decode_tokens(tokens, num_codebooks).await
131    }
132}
133
134impl std::fmt::Debug for DynCodecProvider {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_struct("DynCodecProvider")
137            .field("backend_id", &self.backend.id())
138            .field("provider_kind", &self.backend.provider_kind())
139            .finish()
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use async_trait::async_trait;
147    use blazen_audio::AudioBackend;
148
149    struct FakeCodec;
150
151    #[async_trait]
152    impl AudioBackend for FakeCodec {
153        fn id(&self) -> &'static str {
154            "fake-codec"
155        }
156        fn provider_kind(&self) -> &'static str {
157            "codec"
158        }
159    }
160
161    #[async_trait]
162    impl CodecBackend for FakeCodec {
163        async fn encode_pcm(
164            &self,
165            samples: &[f32],
166            _sample_rate: u32,
167        ) -> Result<Vec<u32>, CodecError> {
168            // Identity-ish encode: one token per sample.
169            #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
170            Ok(samples.iter().map(|s| (s.abs() * 1000.0) as u32).collect())
171        }
172
173        async fn decode_tokens(
174            &self,
175            tokens: &[u32],
176            num_codebooks: usize,
177        ) -> Result<Vec<f32>, CodecError> {
178            if !tokens.len().is_multiple_of(num_codebooks) {
179                return Err(CodecError::invalid_input("misaligned"));
180            }
181            #[allow(clippy::cast_precision_loss)]
182            Ok(tokens.iter().map(|&t| (t as f32) / 1000.0).collect())
183        }
184
185        fn num_codebooks(&self) -> usize {
186            1
187        }
188    }
189
190    #[tokio::test]
191    async fn typed_provider_forwards_to_backend() {
192        let provider = CodecBackendHandle::new(FakeCodec);
193        let tokens = provider.encode_pcm(&[0.1, 0.2, 0.3], 24_000).await.unwrap();
194        assert_eq!(tokens.len(), 3);
195        let pcm = provider.decode_tokens(&tokens, 1).await.unwrap();
196        assert_eq!(pcm.len(), 3);
197    }
198
199    #[tokio::test]
200    async fn dyn_provider_forwards_to_backend() {
201        let dyn_provider = CodecBackendHandle::new(FakeCodec).into_dyn();
202        let tokens = dyn_provider
203            .encode_pcm(&[0.5], 24_000)
204            .await
205            .expect("encode succeeds");
206        assert_eq!(tokens, vec![500]);
207    }
208
209    #[tokio::test]
210    async fn dyn_provider_debug_includes_id() {
211        let dyn_provider = CodecBackendHandle::new(FakeCodec).into_dyn();
212        let dbg = format!("{dyn_provider:?}");
213        assert!(dbg.contains("fake-codec"));
214        assert!(dbg.contains("codec"));
215    }
216}