aic_sdk/model.rs
1use crate::error::*;
2
3use aic_sdk_sys::*;
4
5use std::{
6 ffi::{CStr, CString},
7 marker::PhantomData,
8 path::Path,
9 ptr,
10};
11
12/// High-level wrapper for the ai-coustics audio enhancement model.
13///
14/// This struct provides a safe, Rust-friendly interface to the underlying C library.
15/// It handles memory management automatically and converts C-style error codes
16/// to Rust `Result` types.
17///
18/// # Sharing and Multi-threading
19///
20/// `Model` is `Send` and `Sync`, so you can share it across threads. It does not implement
21/// `Clone`, so wrap it in an `Arc` if you need shared ownership.
22///
23/// # Example
24///
25/// ```rust,no_run
26/// # use aic_sdk::{Model, ProcessorConfig, Processor};
27/// # let license_key = std::env::var("AIC_SDK_LICENSE").unwrap();
28/// let model = Model::from_file("/path/to/model.aicmodel").unwrap();
29/// let config = ProcessorConfig::optimal(&model).with_num_channels(2);
30/// let mut processor = Processor::new(&model, &license_key).unwrap();
31/// processor.initialize(&config).unwrap();
32/// let mut audio_buffer = vec![0.0f32; config.num_channels as usize * config.num_frames];
33/// processor.process_interleaved(&mut audio_buffer).unwrap();
34/// ```
35///
36/// # Multi-threaded Example
37///
38/// ```rust,no_run
39/// # use aic_sdk::{Model, ProcessorConfig, Processor};
40/// # use std::{thread, sync::Arc};
41/// let model = Arc::new(Model::from_file("/path/to/model.aicmodel").unwrap());
42///
43/// // Spawn multiple threads, each with its own processor but sharing the same model
44/// let handles: Vec<_> = (0..4)
45/// .map(|i| {
46/// let model_clone = Arc::clone(&model);
47/// thread::spawn(move || {
48/// let license_key = std::env::var("AIC_SDK_LICENSE").unwrap();
49/// let mut processor = Processor::new(&model_clone, &license_key).unwrap();
50/// // Process audio in this thread...
51/// })
52/// })
53/// .collect();
54///
55/// for handle in handles {
56/// handle.join().unwrap();
57/// }
58/// ```
59pub struct Model<'a> {
60 /// Raw pointer to the C model structure
61 ptr: *mut AicModel,
62 /// Marker to tie the lifetime of the model to the lifetime of its weights
63 marker: PhantomData<&'a [u8]>,
64}
65
66impl<'a> Model<'a> {
67 /// Creates a new audio enhancement model instance.
68 ///
69 /// Multiple models can be created to process different audio streams simultaneously
70 /// or to switch between different enhancement algorithms during runtime.
71 ///
72 /// # Arguments
73 ///
74 /// * `path` - Filesystem path to a model file.
75 ///
76 /// # Returns
77 ///
78 /// Returns a `Result` containing the new `Model` instance or an `AicError` if creation fails.
79 ///
80 /// # Example
81 ///
82 /// ```rust,no_run
83 /// # use aic_sdk::Model;
84 /// let model = Model::from_file("/path/to/model.aicmodel").unwrap();
85 /// ```
86 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Model<'static>, AicError> {
87 let mut model_ptr: *mut AicModel = ptr::null_mut();
88 let c_path = CString::new(path.as_ref().to_string_lossy().as_bytes()).unwrap();
89
90 // SAFETY:
91 // - `model_ptr` points to stack memory we own.
92 // - `c_path` is a valid, null-terminated string.
93 let error_code = unsafe { aic_model_create_from_file(&mut model_ptr, c_path.as_ptr()) };
94
95 handle_error(error_code)?;
96
97 // This should never happen if the C library is well-behaved, but let's be defensive
98 assert!(
99 !model_ptr.is_null(),
100 "C library returned success but null pointer"
101 );
102
103 Ok(Model {
104 ptr: model_ptr,
105 marker: PhantomData,
106 })
107 }
108
109 /// Creates a new model instance from an in-memory buffer.
110 ///
111 /// The buffer must be 64-byte aligned.
112 ///
113 /// Consider using [`include_model!`](macro@crate::include_model) to embed a model file at compile time with
114 /// the correct alignment.
115 ///
116 /// # Arguments
117 ///
118 /// * `buffer` - Raw bytes of the model file.
119 ///
120 /// # Returns
121 ///
122 /// Returns a `Result` containing the new `Model` instance or an `AicError` if creation fails.
123 ///
124 /// # Example
125 ///
126 /// ```rust,ignore
127 /// # use aic_sdk::{include_model, Model};
128 /// static MODEL: &'static [u8] = include_model!("/path/to/model.aicmodel");
129 /// let model = Model::from_buffer(MODEL).unwrap();
130 /// ```
131 pub fn from_buffer(buffer: &'a [u8]) -> Result<Self, AicError> {
132 let mut model_ptr: *mut AicModel = ptr::null_mut();
133
134 // SAFETY:
135 // - `buffer` is a valid slice and immutable for `'a`.
136 // - The SDK only reads from `buffer` for the lifetime of the model.
137 let error_code =
138 unsafe { aic_model_create_from_buffer(&mut model_ptr, buffer.as_ptr(), buffer.len()) };
139
140 handle_error(error_code)?;
141
142 // This should never happen if the C library is well-behaved, but let's be defensive
143 assert!(
144 !model_ptr.is_null(),
145 "C library returned success but null pointer"
146 );
147
148 Ok(Model {
149 ptr: model_ptr,
150 marker: PhantomData,
151 })
152 }
153
154 /// Returns the model identifier string.
155 pub fn id(&self) -> &str {
156 // SAFETY: `self` owns a valid model pointer created by the SDK.
157 let id_ptr = unsafe { aic_model_get_id(self.as_const_ptr()) };
158 if id_ptr.is_null() {
159 return "unknown";
160 }
161
162 // SAFETY: Pointer is valid for the lifetime of `self` and is null-terminated.
163 unsafe { CStr::from_ptr(id_ptr).to_str().unwrap_or("unknown") }
164 }
165
166 /// Retrieves the native sample rate of the processor's model.
167 ///
168 /// Each model is optimized for a specific sample rate, which determines the frequency
169 /// range of the enhanced audio output. While you can process audio at any sample rate,
170 /// understanding the model's native rate helps predict the enhancement quality.
171 ///
172 /// **How sample rate affects enhancement:**
173 /// - Models trained at lower sample rates (e.g., 8 kHz) can only enhance frequencies
174 /// up to their Nyquist limit (4 kHz for 8 kHz models)
175 /// - When processing higher sample rate input (e.g., 48 kHz) with a lower-rate model,
176 /// only the lower frequency components will be enhanced
177 ///
178 /// **Enhancement blending:**
179 /// When enhancement strength is set below 1.0, the enhanced signal is blended with
180 /// the original, maintaining the full frequency spectrum of your input while adding
181 /// the model's noise reduction capabilities to the lower frequencies.
182 ///
183 /// **Sample rate and optimal frames relationship:**
184 /// When using different sample rates than the model's native rate, the optimal number
185 /// of frames (returned by `optimal_num_frames`) will change. The model's output
186 /// delay remains constant regardless of sample rate as long as you use the optimal frame
187 /// count for that rate.
188 ///
189 /// **Recommendation:**
190 /// For maximum enhancement quality across the full frequency spectrum, match your
191 /// input sample rate to the model's native rate when possible.
192 ///
193 /// # Returns
194 ///
195 /// Returns the model's native sample rate.
196 ///
197 /// # Example
198 ///
199 /// ```rust,no_run
200 /// # use aic_sdk::{Model, Processor};
201 /// # let license_key = std::env::var("AIC_SDK_LICENSE").unwrap();
202 /// # let model = Model::from_file("/path/to/model.aicmodel").unwrap();
203 /// let optimal_sample_rate = model.optimal_sample_rate();
204 /// println!("Optimal sample rate: {optimal_sample_rate} Hz");
205 /// ```
206 pub fn optimal_sample_rate(&self) -> u32 {
207 let mut sample_rate: u32 = 0;
208 // SAFETY:
209 // - `self.as_const_ptr()` is a valid pointer to a live model.
210 // - `sample_rate` points to stack storage for output.
211 let error_code =
212 unsafe { aic_model_get_optimal_sample_rate(self.as_const_ptr(), &mut sample_rate) };
213
214 // This should never fail. If it does, it's a bug in the SDK.
215 // `aic_get_optimal_sample_rate` is documented to always succeed if given a valid processor pointer.
216 assert_success(
217 error_code,
218 "`aic_model_get_optimal_sample_rate` failed. This is a bug, please open an issue on GitHub for further investigation.",
219 );
220
221 // This should never fail
222 sample_rate
223 }
224
225 /// Retrieves the optimal number of frames for the selected model at a given sample rate.
226 ///
227 ///
228 /// Using the optimal number of frames minimizes latency by avoiding internal buffering.
229 ///
230 /// **When you use a different frame count than the optimal value, the model will
231 /// introduce additional buffering latency on top of its base processing delay.**
232 ///
233 /// The optimal frame count varies based on the sample rate. Each model operates on a
234 /// fixed time window duration, so the required number of frames changes with sample rate.
235 /// For example, a model designed for 10 ms processing windows requires 480 frames at
236 /// 48 kHz, but only 160 frames at 16 kHz to capture the same duration of audio.
237 ///
238 /// Call this function with your intended sample rate before calling
239 /// [`Processor::initialize`](crate::Processor::initialize) to determine the best frame count for minimal latency.
240 ///
241 /// # Arguments
242 ///
243 /// * `sample_rate` - The sample rate in Hz for which to calculate the optimal frame count.
244 ///
245 /// # Returns
246 ///
247 /// Returns the optimal frame count.
248 ///
249 /// # Example
250 ///
251 /// ```rust,no_run
252 /// # use aic_sdk::{Model, Processor};
253 /// # let license_key = std::env::var("AIC_SDK_LICENSE").unwrap();
254 /// # let model = Model::from_file("/path/to/model.aicmodel").unwrap();
255 /// # let sample_rate = model.optimal_sample_rate();
256 /// let optimal_frames = model.optimal_num_frames(sample_rate);
257 /// println!("Optimal frame count: {optimal_frames}");
258 /// ```
259 pub fn optimal_num_frames(&self, sample_rate: u32) -> usize {
260 let mut num_frames: usize = 0;
261 // SAFETY:
262 // - `self.as_const_ptr()` is a valid pointer to a live model.
263 // - `num_frames` points to stack storage for output.
264 let error_code = unsafe {
265 aic_model_get_optimal_num_frames(self.as_const_ptr(), sample_rate, &mut num_frames)
266 };
267
268 // This should never fail. If it does, it's a bug in the SDK.
269 // `aic_get_optimal_num_frames` is documented to always succeed if given valid pointers.
270 assert_success(
271 error_code,
272 "`aic_model_get_optimal_num_frames` failed. This is a bug, please open an issue on GitHub for further investigation.",
273 );
274
275 num_frames
276 }
277
278 /// Downloads a model file from the ai-coustics artifact CDN.
279 ///
280 /// This method fetches the model manifest, verifies that the requested model
281 /// exists in a version compatible with this library, and downloads the model
282 /// file to the specified directory. If the model file already exists, it will not
283 /// be re-downloaded. If the existing file's checksum does not match, the model will
284 /// be downloaded and the existing file will be replaced.
285 ///
286 /// The manifest file is not cached and will always be downloaded on every call
287 /// to ensure the latest model versions are always used.
288 ///
289 /// Available models can be browsed at [artifacts.ai-coustics.io](https://artifacts.ai-coustics.io/).
290 ///
291 /// # Arguments
292 ///
293 /// * `model_id` - The model identifier (e.g., `"quail-l-16khz"`).
294 /// * `download_dir` - Directory where the model file will be stored.
295 ///
296 /// # Returns
297 ///
298 /// Returns the full path to the model file on success, or an [`AicError`] if the
299 /// operation fails.
300 ///
301 /// # Note
302 ///
303 /// This is a blocking operation that performs network I/O.
304 #[cfg(feature = "download-model")]
305 pub fn download<P: AsRef<Path>>(
306 model_id: &str,
307 download_dir: P,
308 ) -> Result<std::path::PathBuf, AicError> {
309 let compatible_version = crate::get_compatible_model_version();
310 crate::download::download(model_id, compatible_version, download_dir)
311 .map_err(|err| AicError::ModelDownload(err.to_string()))
312 }
313
314 pub(crate) fn as_const_ptr(&self) -> *const AicModel {
315 self.ptr as *const AicModel
316 }
317}
318
319impl<'a> Drop for Model<'a> {
320 fn drop(&mut self) {
321 if !self.ptr.is_null() {
322 // SAFETY:
323 // - `self.ptr` was allocated by the SDK and is still owned by this wrapper.
324 unsafe { aic_model_destroy(self.ptr) };
325 }
326 }
327}
328
329// SAFETY:
330// - Model wraps a raw pointer to an AicModel which is immutable after creation and it
331// does not provide access to it through its public API.
332// - Methods only pass the pointer to SDK calls documented as thread-safe for const access.
333unsafe impl<'a> Send for Model<'a> {}
334// SAFETY:
335// - Model wraps a raw pointer to an AicModel which is immutable after creation and it
336// does not provide access to it through its public API.
337// - Methods only pass the pointer to SDK calls documented as thread-safe for const access.
338unsafe impl<'a> Sync for Model<'a> {}
339
340/// Embeds the bytes of model file, ensuring proper alignment.
341///
342/// This macro uses Rust's standard library's [`include_bytes!`](std::include_bytes) macro
343/// to include the model file at compile time.
344///
345/// # Example
346///
347/// ```rust,ignore
348/// # use aic_sdk::{include_model, Model};
349///
350/// static MODEL: &'static [u8] = include_model!("/path/to/model.aicmodel");
351/// let model = Model::from_buffer(MODEL).unwrap();
352/// ```
353#[macro_export]
354macro_rules! include_model {
355 ($path:expr) => {{
356 #[repr(C, align(64))]
357 struct __Aligned<T: ?Sized>(T);
358
359 const __DATA: &'static __Aligned<[u8; include_bytes!($path).len()]> =
360 &__Aligned(*include_bytes!($path));
361
362 &__DATA.0
363 }};
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn include_model_aligns_to_64_bytes() {
372 // Use the README.md as a dummy file for testing
373 let data = include_model!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"));
374
375 let ptr = data.as_ptr() as usize;
376 assert!(
377 ptr.is_multiple_of(64),
378 "include_model should align data to 64 bytes"
379 );
380 }
381
382 #[test]
383 fn model_is_send_and_sync() {
384 // Compile-time check that Model implements Send and Sync.
385 // This ensures the model can be safely shared across threads.
386 fn assert_send<T: Send>() {}
387 fn assert_sync<T: Sync>() {}
388
389 assert_send::<Model>();
390 assert_sync::<Model>();
391 }
392}
393
394#[doc(hidden)]
395mod _compile_fail_tests {
396 //! Compile-fail regression: a `Model` created from a buffer must not outlive the buffer.
397 //!
398 //! ```rust,compile_fail
399 //! use aic_sdk::Model;
400 //!
401 //! fn leak_model_from_buffer() -> Model<'static> {
402 //! let bytes = vec![0u8; 64];
403 //! let model = Model::from_buffer(&bytes).unwrap();
404 //! model
405 //! }
406 //!
407 //! fn main() {
408 //! let _ = leak_model_from_buffer();
409 //! }
410 //! ```
411}