1use nemotron_asr_sys as ffi;
2use std::ffi::{CStr, CString};
3use std::path::Path;
4use std::ptr;
5use thiserror::Error;
6
7#[derive(Error, Debug)]
8pub enum InitializationError {
9 #[error("Failed to initialize model")]
10 InitializationFailed,
11}
12
13#[derive(Error, Debug)]
14pub enum StreamInitializationError {
15 #[error("Failed to create streaming context")]
16 StreamInitializationFailed,
17}
18
19#[derive(Debug, Copy, Clone, PartialEq, Eq)]
21pub enum LatencyMode {
22 PureCausal,
24 UltraLow,
26 Low,
28 Default,
30}
31
32impl From<LatencyMode> for ffi::nemo_latency_mode {
33 fn from(mode: LatencyMode) -> Self {
34 match mode {
35 LatencyMode::PureCausal => ffi::nemo_latency_mode::NEMO_LATENCY_PURE_CAUSAL,
36 LatencyMode::UltraLow => ffi::nemo_latency_mode::NEMO_LATENCY_ULTRA_LOW,
37 LatencyMode::Low => ffi::nemo_latency_mode::NEMO_LATENCY_LOW,
38 LatencyMode::Default => ffi::nemo_latency_mode::NEMO_LATENCY_DEFAULT,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct CacheConfig {
46 inner: ffi::nemo_cache_config,
47}
48
49impl CacheConfig {
50 pub fn default() -> Self {
52 Self {
53 inner: unsafe { ffi::nemo_cache_config_default() },
54 }
55 }
56
57 pub fn with_latency(mode: LatencyMode) -> Self {
59 Self {
60 inner: unsafe { ffi::nemo_cache_config_with_latency(mode.into()) },
61 }
62 }
63
64 pub fn set_right_context(&mut self, context: i32) -> &mut Self {
66 self.inner.att_right_context = context;
67 self
68 }
69
70 pub fn chunk_mel_frames(&self) -> usize {
72 unsafe { ffi::nemo_cache_config_get_chunk_mel_frames(&self.inner) }
73 }
74
75 pub fn chunk_samples(&self) -> i32 {
77 unsafe { ffi::nemo_cache_config_get_chunk_samples(&self.inner) }
78 }
79
80 pub fn latency_ms(&self) -> i32 {
82 unsafe { ffi::nemo_cache_config_get_latency_ms(&self.inner) }
83 }
84}
85
86impl Default for CacheConfig {
87 fn default() -> Self {
88 Self::default()
89 }
90}
91
92#[derive(Debug)]
94pub struct BackendDevice {
95 ptr: ffi::ggml_backend_dev_t,
96}
97
98impl BackendDevice {
99 pub fn name(&self) -> &str {
101 unsafe {
102 let name_ptr = ffi::ggml_backend_dev_name(self.ptr);
103 if name_ptr.is_null() {
104 panic!("ggml_backend_dev_name returned NULL");
105 }
106 CStr::from_ptr(name_ptr)
107 .to_str()
108 .expect("ggml_backend_dev_name returned invalid UTF-8")
109 }
110 }
111}
112
113#[cfg(feature = "ggml_backend_dl")]
115pub fn load_backends_from_path(path: impl AsRef<Path>) {
116 let path_str = path.as_ref().to_str().expect("path must be valid UTF-8");
117 let path_c = CString::new(path_str).expect("path must not contain null bytes");
118 unsafe {
119 ffi::ggml_backend_load_all_from_path(path_c.as_ptr());
120 }
121}
122
123pub fn backend_count() -> usize {
125 unsafe { ffi::ggml_backend_dev_count() }
126}
127
128pub fn get_backend(index: usize) -> Option<BackendDevice> {
130 unsafe {
131 let ptr = ffi::ggml_backend_dev_get(index);
132 if ptr.is_null() {
133 None
134 } else {
135 Some(BackendDevice { ptr })
136 }
137 }
138}
139
140pub fn list_backends() -> Vec<BackendDevice> {
142 let count = backend_count();
143 (0..count).filter_map(get_backend).collect()
144}
145
146pub struct Context {
148 ptr: *mut ffi::nemo_context_ffi,
149}
150
151impl Context {
152 pub fn new(
158 model_path: impl AsRef<Path>,
159 backend: Option<&str>,
160 ) -> Result<Self, InitializationError> {
161 let path_str = model_path
162 .as_ref()
163 .to_str()
164 .expect("model path must be valid UTF-8");
165 let model_path_c = CString::new(path_str).expect("model path must not contain null bytes");
166 let backend_c = backend.map(|s| CString::new(s).unwrap());
167
168 let ptr = unsafe {
169 ffi::c_nemo_init_with_backend(
170 model_path_c.as_ptr(),
171 backend_c.as_ref().map_or(ptr::null(), |s| s.as_ptr()),
172 )
173 };
174
175 if ptr.is_null() {
176 Err(InitializationError::InitializationFailed)
177 } else {
178 Ok(Self { ptr })
179 }
180 }
181
182 pub fn backend_name(&self) -> &str {
184 unsafe {
185 let name_ptr = ffi::c_nemo_get_backend_name(self.ptr);
186 if name_ptr.is_null() {
187 panic!("c_nemo_get_backend_name returned NULL");
188 }
189 CStr::from_ptr(name_ptr)
190 .to_str()
191 .expect("c_nemo_get_backend_name returned invalid UTF-8")
192 }
193 }
194
195 pub fn create_stream(
197 &mut self,
198 config: Option<&CacheConfig>,
199 ) -> Result<Stream, StreamInitializationError> {
200 let config_ptr = config.map_or(ptr::null(), |c| &c.inner);
201
202 let ptr = unsafe { ffi::c_nemo_stream_init(self.ptr, config_ptr) };
203
204 if ptr.is_null() {
205 Err(StreamInitializationError::StreamInitializationFailed)
206 } else {
207 Ok(Stream { ptr })
208 }
209 }
210}
211
212impl Drop for Context {
213 fn drop(&mut self) {
214 unsafe {
215 ffi::c_nemo_free(self.ptr);
216 }
217 }
218}
219
220unsafe impl Send for Context {}
221
222pub struct Stream {
224 ptr: *mut ffi::nemo_stream_context_ffi,
225}
226
227impl Stream {
228 pub fn process(&mut self, audio: &[i16]) -> String {
235 let text_ptr = unsafe {
236 ffi::c_nemo_stream_process_incremental(self.ptr, audio.as_ptr(), audio.len() as i32)
237 };
238
239 if text_ptr.is_null() {
240 return String::new();
241 }
242
243 unsafe {
244 let text = CStr::from_ptr(text_ptr).to_string_lossy().to_string();
245 ffi::c_nemo_free_string(text_ptr);
246 text
247 }
248 }
249
250 pub fn finalize(&mut self) -> String {
254 let text_ptr = unsafe { ffi::c_nemo_stream_finalize(self.ptr) };
255
256 if text_ptr.is_null() {
257 return String::new();
258 }
259
260 unsafe {
261 let text = CStr::from_ptr(text_ptr).to_string_lossy().to_string();
262 ffi::c_nemo_free_string(text_ptr);
263 text
264 }
265 }
266
267 pub fn get_transcript(&self) -> String {
269 let text_ptr = unsafe { ffi::c_nemo_stream_get_transcript(self.ptr) };
270
271 if text_ptr.is_null() {
272 return String::new();
273 }
274
275 unsafe {
276 let text = CStr::from_ptr(text_ptr).to_string_lossy().to_string();
277 ffi::c_nemo_free_string(text_ptr);
278 text
279 }
280 }
281
282 pub fn reset(&mut self) {
284 unsafe {
285 ffi::c_nemo_stream_reset(self.ptr);
286 }
287 }
288}
289
290impl Drop for Stream {
291 fn drop(&mut self) {
292 unsafe {
293 ffi::c_nemo_stream_free(self.ptr);
294 }
295 }
296}
297
298unsafe impl Send for Stream {}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_cache_config() {
306 let config = CacheConfig::default();
307 assert!(config.chunk_samples() > 0);
308 assert!(config.latency_ms() > 0);
309 }
310
311 #[test]
312 fn test_backend_list() {
313 let count = backend_count();
314 println!("Available backends: {}", count);
315
316 for i in 0..count {
317 if let Some(backend) = get_backend(i) {
318 println!(" Backend {}: {}", i, backend.name());
319 }
320 }
321 }
322}