aic_sdk/
model.rs

1use crate::error::*;
2
3use aic_sdk_sys::*;
4
5use std::{
6    ffi::{CStr, CString},
7    marker::PhantomData,
8    path::Path,
9    ptr,
10};
11
12/// High-level wrapper for the ai-coustics audio enhancement model.
13///
14/// This struct provides a safe, Rust-friendly interface to the underlying C library.
15/// It handles memory management automatically and converts C-style error codes
16/// to Rust `Result` types.
17///
18/// # Sharing and Multi-threading
19///
20/// `Model` is `Send` and `Sync`, so you can share it across threads. It does not implement
21/// `Clone`, so wrap it in an `Arc` if you need shared ownership.
22///
23/// # Example
24///
25/// ```rust,no_run
26/// # use aic_sdk::{Model, ProcessorConfig, Processor};
27/// # let license_key = std::env::var("AIC_SDK_LICENSE").unwrap();
28/// let model = Model::from_file("/path/to/model.aicmodel").unwrap();
29/// let config = ProcessorConfig::optimal(&model).with_num_channels(2);
30/// let mut processor = Processor::new(&model, &license_key).unwrap();
31/// processor.initialize(&config).unwrap();
32/// let mut audio_buffer = vec![0.0f32; config.num_channels as usize * config.num_frames];
33/// processor.process_interleaved(&mut audio_buffer).unwrap();
34/// ```
35///
36/// # Multi-threaded Example
37///
38/// ```rust,no_run
39/// # use aic_sdk::{Model, ProcessorConfig, Processor};
40/// # use std::{thread, sync::Arc};
41/// let model = Arc::new(Model::from_file("/path/to/model.aicmodel").unwrap());
42///
43/// // Spawn multiple threads, each with its own processor but sharing the same model
44/// let handles: Vec<_> = (0..4)
45///     .map(|i| {
46///         let model_clone = Arc::clone(&model);
47///         thread::spawn(move || {
48///             let license_key = std::env::var("AIC_SDK_LICENSE").unwrap();
49///             let mut processor = Processor::new(&model_clone, &license_key).unwrap();
50///             // Process audio in this thread...
51///         })
52///     })
53///     .collect();
54///
55/// for handle in handles {
56///     handle.join().unwrap();
57/// }
58/// ```
59pub struct Model<'a> {
60    /// Raw pointer to the C model structure
61    ptr: *mut AicModel,
62    /// Marker to tie the lifetime of the model to the lifetime of its weights
63    marker: PhantomData<&'a [u8]>,
64}
65
66impl<'a> Model<'a> {
67    /// Creates a new audio enhancement model instance.
68    ///
69    /// Multiple models can be created to process different audio streams simultaneously
70    /// or to switch between different enhancement algorithms during runtime.
71    ///
72    /// # Arguments
73    ///
74    /// * `path` - Filesystem path to a model file.
75    ///
76    /// # Returns
77    ///
78    /// Returns a `Result` containing the new `Model` instance or an `AicError` if creation fails.
79    ///
80    /// # Example
81    ///
82    /// ```rust,no_run
83    /// # use aic_sdk::Model;
84    /// let model = Model::from_file("/path/to/model.aicmodel").unwrap();
85    /// ```
86    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Model<'static>, AicError> {
87        let mut model_ptr: *mut AicModel = ptr::null_mut();
88        let c_path = CString::new(path.as_ref().to_string_lossy().as_bytes()).unwrap();
89
90        // SAFETY:
91        // - `model_ptr` points to stack memory we own.
92        // - `c_path` is a valid, null-terminated string.
93        let error_code = unsafe { aic_model_create_from_file(&mut model_ptr, c_path.as_ptr()) };
94
95        handle_error(error_code)?;
96
97        // This should never happen if the C library is well-behaved, but let's be defensive
98        assert!(
99            !model_ptr.is_null(),
100            "C library returned success but null pointer"
101        );
102
103        Ok(Model {
104            ptr: model_ptr,
105            marker: PhantomData,
106        })
107    }
108
109    /// Creates a new model instance from an in-memory buffer.
110    ///
111    /// The buffer must be 64-byte aligned.
112    ///
113    /// Consider using [`include_model!`](macro@crate::include_model) to embed a model file at compile time with
114    /// the correct alignment.
115    ///
116    /// # Arguments
117    ///
118    /// * `buffer` - Raw bytes of the model file.
119    ///
120    /// # Returns
121    ///
122    /// Returns a `Result` containing the new `Model` instance or an `AicError` if creation fails.
123    ///
124    /// # Example
125    ///
126    /// ```rust,ignore
127    /// # use aic_sdk::{include_model, Model};
128    /// static MODEL: &'static [u8] = include_model!("/path/to/model.aicmodel");
129    /// let model = Model::from_buffer(MODEL).unwrap();
130    /// ```
131    pub fn from_buffer(buffer: &'a [u8]) -> Result<Self, AicError> {
132        let mut model_ptr: *mut AicModel = ptr::null_mut();
133
134        // SAFETY:
135        // - `buffer` is a valid slice and immutable for `'a`.
136        // - The SDK only reads from `buffer` for the lifetime of the model.
137        let error_code =
138            unsafe { aic_model_create_from_buffer(&mut model_ptr, buffer.as_ptr(), buffer.len()) };
139
140        handle_error(error_code)?;
141
142        // This should never happen if the C library is well-behaved, but let's be defensive
143        assert!(
144            !model_ptr.is_null(),
145            "C library returned success but null pointer"
146        );
147
148        Ok(Model {
149            ptr: model_ptr,
150            marker: PhantomData,
151        })
152    }
153
154    /// Returns the model identifier string.
155    pub fn id(&self) -> &str {
156        // SAFETY: `self` owns a valid model pointer created by the SDK.
157        let id_ptr = unsafe { aic_model_get_id(self.as_const_ptr()) };
158        if id_ptr.is_null() {
159            return "unknown";
160        }
161
162        // SAFETY: Pointer is valid for the lifetime of `self` and is null-terminated.
163        unsafe { CStr::from_ptr(id_ptr).to_str().unwrap_or("unknown") }
164    }
165
166    /// Retrieves the native sample rate of the processor's model.
167    ///
168    /// Each model is optimized for a specific sample rate, which determines the frequency
169    /// range of the enhanced audio output. While you can process audio at any sample rate,
170    /// understanding the model's native rate helps predict the enhancement quality.
171    ///
172    /// **How sample rate affects enhancement:**
173    /// - Models trained at lower sample rates (e.g., 8 kHz) can only enhance frequencies
174    ///   up to their Nyquist limit (4 kHz for 8 kHz models)
175    /// - When processing higher sample rate input (e.g., 48 kHz) with a lower-rate model,
176    ///   only the lower frequency components will be enhanced
177    ///
178    /// **Enhancement blending:**
179    /// When enhancement strength is set below 1.0, the enhanced signal is blended with
180    /// the original, maintaining the full frequency spectrum of your input while adding
181    /// the model's noise reduction capabilities to the lower frequencies.
182    ///
183    /// **Sample rate and optimal frames relationship:**
184    /// When using different sample rates than the model's native rate, the optimal number
185    /// of frames (returned by `optimal_num_frames`) will change. The model's output
186    /// delay remains constant regardless of sample rate as long as you use the optimal frame
187    /// count for that rate.
188    ///
189    /// **Recommendation:**
190    /// For maximum enhancement quality across the full frequency spectrum, match your
191    /// input sample rate to the model's native rate when possible.
192    ///
193    /// # Returns
194    ///
195    /// Returns the model's native sample rate.
196    ///
197    /// # Example
198    ///
199    /// ```rust,no_run
200    /// # use aic_sdk::{Model, Processor};
201    /// # let license_key = std::env::var("AIC_SDK_LICENSE").unwrap();
202    /// # let model = Model::from_file("/path/to/model.aicmodel").unwrap();
203    /// let optimal_sample_rate = model.optimal_sample_rate();
204    /// println!("Optimal sample rate: {optimal_sample_rate} Hz");
205    /// ```
206    pub fn optimal_sample_rate(&self) -> u32 {
207        let mut sample_rate: u32 = 0;
208        // SAFETY:
209        // - `self.as_const_ptr()` is a valid pointer to a live model.
210        // - `sample_rate` points to stack storage for output.
211        let error_code =
212            unsafe { aic_model_get_optimal_sample_rate(self.as_const_ptr(), &mut sample_rate) };
213
214        // This should never fail. If it does, it's a bug in the SDK.
215        // `aic_get_optimal_sample_rate` is documented to always succeed if given a valid processor pointer.
216        assert_success(
217            error_code,
218            "`aic_model_get_optimal_sample_rate` failed. This is a bug, please open an issue on GitHub for further investigation.",
219        );
220
221        // This should never fail
222        sample_rate
223    }
224
225    /// Retrieves the optimal number of frames for the selected model at a given sample rate.
226    ///
227    ///
228    /// Using the optimal number of frames minimizes latency by avoiding internal buffering.
229    ///
230    /// **When you use a different frame count than the optimal value, the model will
231    /// introduce additional buffering latency on top of its base processing delay.**
232    ///
233    /// The optimal frame count varies based on the sample rate. Each model operates on a
234    /// fixed time window duration, so the required number of frames changes with sample rate.
235    /// For example, a model designed for 10 ms processing windows requires 480 frames at
236    /// 48 kHz, but only 160 frames at 16 kHz to capture the same duration of audio.
237    ///
238    /// Call this function with your intended sample rate before calling
239    /// [`Processor::initialize`](crate::Processor::initialize) to determine the best frame count for minimal latency.
240    ///
241    /// # Arguments
242    ///
243    /// * `sample_rate` - The sample rate in Hz for which to calculate the optimal frame count.
244    ///
245    /// # Returns
246    ///
247    /// Returns the optimal frame count.
248    ///
249    /// # Example
250    ///
251    /// ```rust,no_run
252    /// # use aic_sdk::{Model, Processor};
253    /// # let license_key = std::env::var("AIC_SDK_LICENSE").unwrap();
254    /// # let model = Model::from_file("/path/to/model.aicmodel").unwrap();
255    /// # let sample_rate = model.optimal_sample_rate();
256    /// let optimal_frames = model.optimal_num_frames(sample_rate);
257    /// println!("Optimal frame count: {optimal_frames}");
258    /// ```
259    pub fn optimal_num_frames(&self, sample_rate: u32) -> usize {
260        let mut num_frames: usize = 0;
261        // SAFETY:
262        // - `self.as_const_ptr()` is a valid pointer to a live model.
263        // - `num_frames` points to stack storage for output.
264        let error_code = unsafe {
265            aic_model_get_optimal_num_frames(self.as_const_ptr(), sample_rate, &mut num_frames)
266        };
267
268        // This should never fail. If it does, it's a bug in the SDK.
269        // `aic_get_optimal_num_frames` is documented to always succeed if given valid pointers.
270        assert_success(
271            error_code,
272            "`aic_model_get_optimal_num_frames` failed. This is a bug, please open an issue on GitHub for further investigation.",
273        );
274
275        num_frames
276    }
277
278    /// Downloads a model file from the ai-coustics artifact CDN.
279    ///
280    /// This method fetches the model manifest, verifies that the requested model
281    /// exists in a version compatible with this library, and downloads the model
282    /// file to the specified directory. If the model file already exists, it will not
283    /// be re-downloaded. If the existing file's checksum does not match, the model will
284    /// be downloaded and the existing file will be replaced.
285    ///
286    /// The manifest file is not cached and will always be downloaded on every call
287    /// to ensure the latest model versions are always used.
288    ///
289    /// Available models can be browsed at [artifacts.ai-coustics.io](https://artifacts.ai-coustics.io/).
290    ///
291    /// # Arguments
292    ///
293    /// * `model_id` - The model identifier (e.g., `"quail-l-16khz"`).
294    /// * `download_dir` - Directory where the model file will be stored.
295    ///
296    /// # Returns
297    ///
298    /// Returns the full path to the model file on success, or an [`AicError`] if the
299    /// operation fails.
300    ///
301    /// # Note
302    ///
303    /// This is a blocking operation that performs network I/O.
304    #[cfg(feature = "download-model")]
305    pub fn download<P: AsRef<Path>>(
306        model_id: &str,
307        download_dir: P,
308    ) -> Result<std::path::PathBuf, AicError> {
309        let compatible_version = crate::get_compatible_model_version();
310        crate::download::download(model_id, compatible_version, download_dir)
311            .map_err(|err| AicError::ModelDownload(err.to_string()))
312    }
313
314    pub(crate) fn as_const_ptr(&self) -> *const AicModel {
315        self.ptr as *const AicModel
316    }
317}
318
319impl<'a> Drop for Model<'a> {
320    fn drop(&mut self) {
321        if !self.ptr.is_null() {
322            // SAFETY:
323            // - `self.ptr` was allocated by the SDK and is still owned by this wrapper.
324            unsafe { aic_model_destroy(self.ptr) };
325        }
326    }
327}
328
329// SAFETY:
330// - Model wraps a raw pointer to an AicModel which is immutable after creation and it
331//   does not provide access to it through its public API.
332// - Methods only pass the pointer to SDK calls documented as thread-safe for const access.
333unsafe impl<'a> Send for Model<'a> {}
334// SAFETY:
335// - Model wraps a raw pointer to an AicModel which is immutable after creation and it
336//   does not provide access to it through its public API.
337// - Methods only pass the pointer to SDK calls documented as thread-safe for const access.
338unsafe impl<'a> Sync for Model<'a> {}
339
340/// Embeds the bytes of model file, ensuring proper alignment.
341///
342/// This macro uses Rust's standard library's [`include_bytes!`](std::include_bytes) macro
343/// to include the model file at compile time.
344///
345/// # Example
346///
347/// ```rust,ignore
348/// # use aic_sdk::{include_model, Model};
349///
350/// static MODEL: &'static [u8] = include_model!("/path/to/model.aicmodel");
351/// let model = Model::from_buffer(MODEL).unwrap();
352/// ```
353#[macro_export]
354macro_rules! include_model {
355    ($path:expr) => {{
356        #[repr(C, align(64))]
357        struct __Aligned<T: ?Sized>(T);
358
359        const __DATA: &'static __Aligned<[u8; include_bytes!($path).len()]> =
360            &__Aligned(*include_bytes!($path));
361
362        &__DATA.0
363    }};
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn include_model_aligns_to_64_bytes() {
372        // Use the README.md as a dummy file for testing
373        let data = include_model!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"));
374
375        let ptr = data.as_ptr() as usize;
376        assert!(
377            ptr.is_multiple_of(64),
378            "include_model should align data to 64 bytes"
379        );
380    }
381
382    #[test]
383    fn model_is_send_and_sync() {
384        // Compile-time check that Model implements Send and Sync.
385        // This ensures the model can be safely shared across threads.
386        fn assert_send<T: Send>() {}
387        fn assert_sync<T: Sync>() {}
388
389        assert_send::<Model>();
390        assert_sync::<Model>();
391    }
392}
393
394#[doc(hidden)]
395mod _compile_fail_tests {
396    //! Compile-fail regression: a `Model` created from a buffer must not outlive the buffer.
397    //!
398    //! ```rust,compile_fail
399    //! use aic_sdk::Model;
400    //!
401    //! fn leak_model_from_buffer() -> Model<'static> {
402    //!     let bytes = vec![0u8; 64];
403    //!     let model = Model::from_buffer(&bytes).unwrap();
404    //!     model
405    //! }
406    //!
407    //! fn main() {
408    //!     let _ = leak_model_from_buffer();
409    //! }
410    //! ```
411}