use std::ffi::{CStr, CString, c_char};
use std::ptr;
use crate::inference::{Engine, OwnedReservation, SessionTriplet, StreamingState, audio};
pub struct GigasttEngine {
engine: Engine,
}
pub struct GigasttStream {
state: StreamingState,
triplet: SessionTriplet,
reservation: OwnedReservation<SessionTriplet>,
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn gigastt_engine_new(model_dir: *const c_char) -> *mut GigasttEngine {
unsafe { gigastt_engine_new_with_pool_size(model_dir, 4) }
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn gigastt_engine_new_with_pool_size(
model_dir: *const c_char,
pool_size: usize,
) -> *mut GigasttEngine {
if model_dir.is_null() {
tracing::error!("gigastt_engine_new_with_pool_size: model_dir is null");
eprintln!("gigastt_engine_new_with_pool_size: model_dir is null");
return ptr::null_mut();
}
let dir_str = match unsafe { CStr::from_ptr(model_dir) }.to_str() {
Ok(s) => s,
Err(e) => {
tracing::error!("gigastt_engine_new_with_pool_size: model_dir is not valid UTF-8: {e}");
eprintln!("gigastt_engine_new_with_pool_size: model_dir is not valid UTF-8: {e}");
return ptr::null_mut();
}
};
match Engine::load_with_pool_size(dir_str, pool_size) {
Ok(engine) => {
let handle = Box::new(GigasttEngine { engine });
Box::into_raw(handle)
}
Err(e) => {
tracing::error!("gigastt_engine_new_with_pool_size: failed to load engine: {e}");
eprintln!("gigastt_engine_new_with_pool_size: failed to load engine: {e}");
ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn gigastt_transcribe_file(
engine: *mut GigasttEngine,
wav_path: *const c_char,
) -> *mut c_char {
if engine.is_null() {
tracing::error!("gigastt_transcribe_file: engine is null");
eprintln!("gigastt_transcribe_file: engine is null");
return ptr::null_mut();
}
if wav_path.is_null() {
tracing::error!("gigastt_transcribe_file: wav_path is null");
eprintln!("gigastt_transcribe_file: wav_path is null");
return ptr::null_mut();
}
let path_str = match unsafe { CStr::from_ptr(wav_path) }.to_str() {
Ok(s) => s,
Err(e) => {
tracing::error!("gigastt_transcribe_file: wav_path is not valid UTF-8: {e}");
eprintln!("gigastt_transcribe_file: wav_path is not valid UTF-8: {e}");
return ptr::null_mut();
}
};
let engine_ref = unsafe { &(*engine).engine };
let mut guard = match engine_ref.pool.checkout_blocking() {
Ok(g) => g,
Err(e) => {
tracing::error!("gigastt_transcribe_file: failed to checkout session from pool: {e}");
eprintln!("gigastt_transcribe_file: failed to checkout session from pool: {e}");
return ptr::null_mut();
}
};
let result = match engine_ref.transcribe_file(path_str, &mut guard) {
Ok(r) => r,
Err(e) => {
tracing::error!("gigastt_transcribe_file: transcription failed: {e}");
eprintln!("gigastt_transcribe_file: transcription failed: {e}");
return ptr::null_mut();
}
};
match CString::new(result.text) {
Ok(cstr) => cstr.into_raw(),
Err(e) => {
tracing::error!("gigastt_transcribe_file: result contains interior NUL: {e}");
eprintln!("gigastt_transcribe_file: result contains interior NUL: {e}");
ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn gigastt_string_free(s: *mut c_char) {
if !s.is_null() {
let _ = unsafe { CString::from_raw(s) };
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn gigastt_engine_free(engine: *mut GigasttEngine) {
if !engine.is_null() {
let _ = unsafe { Box::from_raw(engine) };
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn gigastt_quantize_model(
model_dir: *const c_char,
force: bool,
) -> *mut c_char {
if model_dir.is_null() {
tracing::error!("gigastt_quantize_model: model_dir is null");
eprintln!("gigastt_quantize_model: model_dir is null");
return match CString::new("model_dir is null") {
Ok(cstr) => cstr.into_raw(),
Err(_) => CString::new("quantization error").unwrap().into_raw(),
};
}
let dir_str = match unsafe { CStr::from_ptr(model_dir) }.to_str() {
Ok(s) => s,
Err(e) => {
tracing::error!("gigastt_quantize_model: model_dir is not valid UTF-8: {e}");
eprintln!("gigastt_quantize_model: model_dir is not valid UTF-8: {e}");
let msg = format!("model_dir is not valid UTF-8: {e}");
return match CString::new(msg) {
Ok(cstr) => cstr.into_raw(),
Err(_) => CString::new("model_dir is not valid UTF-8")
.unwrap()
.into_raw(),
};
}
};
let model_dir = std::path::Path::new(dir_str);
let input = model_dir.join("v3_e2e_rnnt_encoder.onnx");
let output = model_dir.join("v3_e2e_rnnt_encoder_int8.onnx");
if !force && output.exists() {
return match CString::new("ok") {
Ok(cstr) => cstr.into_raw(),
Err(_) => CString::new("ok").unwrap().into_raw(),
};
}
if let Err(e) = crate::quantize::quantize_model(&input, &output) {
tracing::error!("gigastt_quantize_model: quantization failed: {e}");
eprintln!("gigastt_quantize_model: quantization failed: {e}");
let msg = format!("quantization failed: {e}");
return match CString::new(msg) {
Ok(cstr) => cstr.into_raw(),
Err(_) => CString::new("quantization failed").unwrap().into_raw(),
};
}
match CString::new("ok") {
Ok(cstr) => cstr.into_raw(),
Err(_) => CString::new("ok").unwrap().into_raw(),
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn gigastt_stream_new(engine: *mut GigasttEngine) -> *mut GigasttStream {
if engine.is_null() {
tracing::error!("gigastt_stream_new: engine is null");
eprintln!("gigastt_stream_new: engine is null");
return ptr::null_mut();
}
let engine_ref = unsafe { &(*engine).engine };
let guard = match engine_ref.pool.checkout_blocking() {
Ok(g) => g,
Err(e) => {
tracing::error!("gigastt_stream_new: pool checkout failed: {e}");
eprintln!("gigastt_stream_new: pool checkout failed: {e}");
return ptr::null_mut();
}
};
let (triplet, reservation) = guard.into_owned();
let state = engine_ref.create_state(false);
let stream = GigasttStream {
state,
triplet,
reservation,
};
Box::into_raw(Box::new(stream))
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn gigastt_stream_process_chunk(
engine: *mut GigasttEngine,
stream: *mut GigasttStream,
pcm16_bytes: *const u8,
len: usize,
sample_rate: u32,
) -> *mut c_char {
if engine.is_null() {
tracing::error!("gigastt_stream_process_chunk: engine is null");
return ptr::null_mut();
}
if stream.is_null() {
tracing::error!("gigastt_stream_process_chunk: stream is null");
return ptr::null_mut();
}
if pcm16_bytes.is_null() {
tracing::error!("gigastt_stream_process_chunk: pcm16_bytes is null");
return ptr::null_mut();
}
let engine_ref = unsafe { &(*engine).engine };
let stream_ref = unsafe { &mut (*stream) };
let bytes = unsafe { std::slice::from_raw_parts(pcm16_bytes, len) };
let pcm16: Vec<i16> = bytes
.chunks_exact(2)
.map(|c| i16::from_le_bytes([c[0], c[1]]))
.collect();
let mut samples_f32: Vec<f32> = pcm16.iter().map(|&s| s as f32 / 32768.0).collect();
if sample_rate != 16000 {
samples_f32 = match audio::resample(&samples_f32, sample_rate, 16000) {
Ok(s) => s,
Err(e) => {
tracing::error!("gigastt_stream_process_chunk: resample failed: {e}");
return ptr::null_mut();
}
};
}
let segments = match engine_ref.process_chunk(
&samples_f32,
&mut stream_ref.state,
&mut stream_ref.triplet,
) {
Ok(segs) => segs,
Err(e) => {
tracing::error!("gigastt_stream_process_chunk: inference failed: {e}");
return ptr::null_mut();
}
};
let json = serde_json::to_string(&segments).unwrap_or_else(|_| "[]".into());
match CString::new(json) {
Ok(cstr) => cstr.into_raw(),
Err(_) => ptr::null_mut(),
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn gigastt_stream_flush(
engine: *mut GigasttEngine,
stream: *mut GigasttStream,
) -> *mut c_char {
if engine.is_null() {
tracing::error!("gigastt_stream_flush: engine is null");
return ptr::null_mut();
}
if stream.is_null() {
tracing::error!("gigastt_stream_flush: stream is null");
return ptr::null_mut();
}
let engine_ref = unsafe { &(*engine).engine };
let stream_ref = unsafe { &mut (*stream) };
let segments: Vec<crate::inference::TranscriptSegment> = engine_ref
.flush_state(&mut stream_ref.state)
.into_iter()
.collect();
let json = serde_json::to_string(&segments).unwrap_or_else(|_| "[]".into());
match CString::new(json) {
Ok(cstr) => cstr.into_raw(),
Err(_) => ptr::null_mut(),
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn gigastt_stream_free(stream: *mut GigasttStream) {
if !stream.is_null() {
let stream = unsafe { Box::from_raw(stream) };
stream.reservation.checkin(stream.triplet);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_new_null_engine() {
let stream = unsafe { gigastt_stream_new(ptr::null_mut()) };
assert!(stream.is_null());
}
#[test]
fn test_stream_process_chunk_null_args() {
let r = unsafe {
gigastt_stream_process_chunk(ptr::null_mut(), ptr::null_mut(), ptr::null(), 0, 16000)
};
assert!(r.is_null());
}
#[test]
fn test_stream_flush_null_args() {
let r = unsafe { gigastt_stream_flush(ptr::null_mut(), ptr::null_mut()) };
assert!(r.is_null());
}
#[test]
fn test_stream_free_null() {
unsafe { gigastt_stream_free(ptr::null_mut()) };
}
#[test]
fn test_quantize_model_null_dir() {
let r = unsafe { gigastt_quantize_model(ptr::null(), false) };
assert!(!r.is_null());
let s = unsafe { CStr::from_ptr(r) }.to_str().unwrap();
assert!(s.contains("null"));
unsafe { gigastt_string_free(r) };
}
#[test]
fn test_quantize_model_invalid_utf8() {
let bad = [0x80u8, 0x81, 0x82, 0];
let r = unsafe { gigastt_quantize_model(bad.as_ptr() as *const c_char, false) };
assert!(!r.is_null());
let s = unsafe { CStr::from_ptr(r) }.to_str().unwrap();
assert!(s.contains("UTF-8"));
unsafe { gigastt_string_free(r) };
}
}