Skip to main content

nemotron_asr/
lib.rs

1use nemotron_asr_sys as ffi;
2use std::ffi::{CStr, CString};
3use std::path::Path;
4use std::ptr;
5use thiserror::Error;
6
7#[derive(Error, Debug)]
8pub enum InitializationError {
9    #[error("Failed to initialize model")]
10    InitializationFailed,
11}
12
13#[derive(Error, Debug)]
14pub enum StreamInitializationError {
15    #[error("Failed to create streaming context")]
16    StreamInitializationFailed,
17}
18
19/// Latency mode for streaming ASR
20#[derive(Debug, Copy, Clone, PartialEq, Eq)]
21pub enum LatencyMode {
22    /// Pure causal, 80ms latency
23    PureCausal,
24    /// 160ms latency
25    UltraLow,
26    /// 560ms latency
27    Low,
28    /// 1.12s latency (best quality)
29    Default,
30}
31
32impl From<LatencyMode> for ffi::nemo_latency_mode {
33    fn from(mode: LatencyMode) -> Self {
34        match mode {
35            LatencyMode::PureCausal => ffi::nemo_latency_mode::NEMO_LATENCY_PURE_CAUSAL,
36            LatencyMode::UltraLow => ffi::nemo_latency_mode::NEMO_LATENCY_ULTRA_LOW,
37            LatencyMode::Low => ffi::nemo_latency_mode::NEMO_LATENCY_LOW,
38            LatencyMode::Default => ffi::nemo_latency_mode::NEMO_LATENCY_DEFAULT,
39        }
40    }
41}
42
43/// Cache configuration for streaming
44#[derive(Debug, Clone)]
45pub struct CacheConfig {
46    inner: ffi::nemo_cache_config,
47}
48
49impl CacheConfig {
50    /// Create default cache configuration
51    pub fn default() -> Self {
52        Self {
53            inner: unsafe { ffi::nemo_cache_config_default() },
54        }
55    }
56
57    /// Create cache configuration with specific latency mode
58    pub fn with_latency(mode: LatencyMode) -> Self {
59        Self {
60            inner: unsafe { ffi::nemo_cache_config_with_latency(mode.into()) },
61        }
62    }
63
64    /// Set right context (lookahead frames)
65    pub fn set_right_context(&mut self, context: i32) -> &mut Self {
66        self.inner.att_right_context = context;
67        self
68    }
69
70    /// Get chunk size in mel frames
71    pub fn chunk_mel_frames(&self) -> usize {
72        unsafe { ffi::nemo_cache_config_get_chunk_mel_frames(&self.inner) }
73    }
74
75    /// Get chunk size in audio samples
76    pub fn chunk_samples(&self) -> i32 {
77        unsafe { ffi::nemo_cache_config_get_chunk_samples(&self.inner) }
78    }
79
80    /// Get latency in milliseconds
81    pub fn latency_ms(&self) -> i32 {
82        unsafe { ffi::nemo_cache_config_get_latency_ms(&self.inner) }
83    }
84}
85
86impl Default for CacheConfig {
87    fn default() -> Self {
88        Self::default()
89    }
90}
91
92/// Backend device information
93#[derive(Debug)]
94pub struct BackendDevice {
95    ptr: ffi::ggml_backend_dev_t,
96}
97
98impl BackendDevice {
99    /// Get device name
100    pub fn name(&self) -> &str {
101        unsafe {
102            let name_ptr = ffi::ggml_backend_dev_name(self.ptr);
103            if name_ptr.is_null() {
104                panic!("ggml_backend_dev_name returned NULL");
105            }
106            CStr::from_ptr(name_ptr)
107                .to_str()
108                .expect("ggml_backend_dev_name returned invalid UTF-8")
109        }
110    }
111}
112
113/// Load all available GGML backends from the specified path
114#[cfg(feature = "ggml_backend_dl")]
115pub fn load_backends_from_path(path: impl AsRef<Path>) {
116    let path_str = path.as_ref().to_str().expect("path must be valid UTF-8");
117    let path_c = CString::new(path_str).expect("path must not contain null bytes");
118    unsafe {
119        ffi::ggml_backend_load_all_from_path(path_c.as_ptr());
120    }
121}
122
123/// Get number of available backend devices
124pub fn backend_count() -> usize {
125    unsafe { ffi::ggml_backend_dev_count() }
126}
127
128/// Get backend device by index
129pub fn get_backend(index: usize) -> Option<BackendDevice> {
130    unsafe {
131        let ptr = ffi::ggml_backend_dev_get(index);
132        if ptr.is_null() {
133            None
134        } else {
135            Some(BackendDevice { ptr })
136        }
137    }
138}
139
140/// List all available backends
141pub fn list_backends() -> Vec<BackendDevice> {
142    let count = backend_count();
143    (0..count).filter_map(get_backend).collect()
144}
145
146/// Main model context
147pub struct Context {
148    ptr: *mut ffi::nemo_context_ffi,
149}
150
151impl Context {
152    /// Initialize model from GGUF file with optional backend selection
153    ///
154    /// # Arguments
155    /// * `model_path` - Path to the GGUF model file
156    /// * `backend` - Optional backend name (e.g., "CPU", "Vulkan"). None for auto-select.
157    pub fn new(
158        model_path: impl AsRef<Path>,
159        backend: Option<&str>,
160    ) -> Result<Self, InitializationError> {
161        let path_str = model_path
162            .as_ref()
163            .to_str()
164            .expect("model path must be valid UTF-8");
165        let model_path_c = CString::new(path_str).expect("model path must not contain null bytes");
166        let backend_c = backend.map(|s| CString::new(s).unwrap());
167
168        let ptr = unsafe {
169            ffi::c_nemo_init_with_backend(
170                model_path_c.as_ptr(),
171                backend_c.as_ref().map_or(ptr::null(), |s| s.as_ptr()),
172            )
173        };
174
175        if ptr.is_null() {
176            Err(InitializationError::InitializationFailed)
177        } else {
178            Ok(Self { ptr })
179        }
180    }
181
182    /// Get the name of the backend being used
183    pub fn backend_name(&self) -> &str {
184        unsafe {
185            let name_ptr = ffi::c_nemo_get_backend_name(self.ptr);
186            if name_ptr.is_null() {
187                panic!("c_nemo_get_backend_name returned NULL");
188            }
189            CStr::from_ptr(name_ptr)
190                .to_str()
191                .expect("c_nemo_get_backend_name returned invalid UTF-8")
192        }
193    }
194
195    /// Create a streaming context
196    pub fn create_stream(
197        &mut self,
198        config: Option<&CacheConfig>,
199    ) -> Result<Stream, StreamInitializationError> {
200        let config_ptr = config.map_or(ptr::null(), |c| &c.inner);
201
202        let ptr = unsafe { ffi::c_nemo_stream_init(self.ptr, config_ptr) };
203
204        if ptr.is_null() {
205            Err(StreamInitializationError::StreamInitializationFailed)
206        } else {
207            Ok(Stream { ptr })
208        }
209    }
210}
211
212impl Drop for Context {
213    fn drop(&mut self) {
214        unsafe {
215            ffi::c_nemo_free(self.ptr);
216        }
217    }
218}
219
220unsafe impl Send for Context {}
221
222/// Streaming transcription context
223pub struct Stream {
224    ptr: *mut ffi::nemo_stream_context_ffi,
225}
226
227impl Stream {
228    /// Process audio chunk incrementally
229    ///
230    /// # Arguments
231    /// * `audio` - PCM audio samples (16-bit signed, 16kHz, mono)
232    ///
233    /// Returns new transcription text (may be empty if no new tokens)
234    pub fn process(&mut self, audio: &[i16]) -> String {
235        let text_ptr = unsafe {
236            ffi::c_nemo_stream_process_incremental(self.ptr, audio.as_ptr(), audio.len() as i32)
237        };
238
239        if text_ptr.is_null() {
240            return String::new();
241        }
242
243        unsafe {
244            let text = CStr::from_ptr(text_ptr).to_string_lossy().to_string();
245            ffi::c_nemo_free_string(text_ptr);
246            text
247        }
248    }
249
250    /// Finalize streaming and flush remaining audio
251    ///
252    /// Returns final transcription text
253    pub fn finalize(&mut self) -> String {
254        let text_ptr = unsafe { ffi::c_nemo_stream_finalize(self.ptr) };
255
256        if text_ptr.is_null() {
257            return String::new();
258        }
259
260        unsafe {
261            let text = CStr::from_ptr(text_ptr).to_string_lossy().to_string();
262            ffi::c_nemo_free_string(text_ptr);
263            text
264        }
265    }
266
267    /// Get full accumulated transcript
268    pub fn get_transcript(&self) -> String {
269        let text_ptr = unsafe { ffi::c_nemo_stream_get_transcript(self.ptr) };
270
271        if text_ptr.is_null() {
272            return String::new();
273        }
274
275        unsafe {
276            let text = CStr::from_ptr(text_ptr).to_string_lossy().to_string();
277            ffi::c_nemo_free_string(text_ptr);
278            text
279        }
280    }
281
282    /// Reset streaming state (clear caches and transcript)
283    pub fn reset(&mut self) {
284        unsafe {
285            ffi::c_nemo_stream_reset(self.ptr);
286        }
287    }
288}
289
290impl Drop for Stream {
291    fn drop(&mut self) {
292        unsafe {
293            ffi::c_nemo_stream_free(self.ptr);
294        }
295    }
296}
297
298unsafe impl Send for Stream {}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_cache_config() {
306        let config = CacheConfig::default();
307        assert!(config.chunk_samples() > 0);
308        assert!(config.latency_ms() > 0);
309    }
310
311    #[test]
312    fn test_backend_list() {
313        let count = backend_count();
314        println!("Available backends: {}", count);
315
316        for i in 0..count {
317            if let Some(backend) = get_backend(i) {
318                println!("  Backend {}: {}", i, backend.name());
319            }
320        }
321    }
322}