use std::ffi::{CStr, CString, c_char};
use std::ptr;
use crate::inference::audio;
use crate::inference::{Engine, OwnedReservation, SessionTriplet, StreamingState};
pub struct PhosttEngine {
engine: Engine,
}
impl PhosttEngine {
pub fn new(engine: Engine) -> Self {
Self { engine }
}
pub fn engine(&self) -> &Engine {
&self.engine
}
}
pub struct PhosttStream {
state: StreamingState,
triplet: SessionTriplet,
reservation: OwnedReservation<SessionTriplet>,
pending_byte: Option<u8>,
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn phostt_engine_new(model_dir: *const c_char) -> *mut PhosttEngine {
if model_dir.is_null() {
tracing::error!("phostt_engine_new: model_dir is null");
eprintln!("phostt_engine_new: model_dir is null");
return ptr::null_mut();
}
let dir_str = match unsafe { CStr::from_ptr(model_dir) }.to_str() {
Ok(s) => s,
Err(e) => {
tracing::error!("phostt_engine_new: model_dir is not valid UTF-8: {e}");
eprintln!("phostt_engine_new: model_dir is not valid UTF-8: {e}");
return ptr::null_mut();
}
};
match Engine::load(dir_str) {
Ok(engine) => {
let handle = Box::new(PhosttEngine { engine });
Box::into_raw(handle)
}
Err(e) => {
tracing::error!("phostt_engine_new: failed to load engine: {e}");
eprintln!("phostt_engine_new: failed to load engine: {e}");
ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn phostt_transcribe_file(
engine: *mut PhosttEngine,
wav_path: *const c_char,
) -> *mut c_char {
if engine.is_null() {
tracing::error!("phostt_transcribe_file: engine is null");
eprintln!("phostt_transcribe_file: engine is null");
return ptr::null_mut();
}
if wav_path.is_null() {
tracing::error!("phostt_transcribe_file: wav_path is null");
eprintln!("phostt_transcribe_file: wav_path is null");
return ptr::null_mut();
}
let path_str = match unsafe { CStr::from_ptr(wav_path) }.to_str() {
Ok(s) => s,
Err(e) => {
tracing::error!("phostt_transcribe_file: wav_path is not valid UTF-8: {e}");
eprintln!("phostt_transcribe_file: wav_path is not valid UTF-8: {e}");
return ptr::null_mut();
}
};
let engine_ref = unsafe { &(*engine).engine };
let mut guard = match engine_ref.pool.checkout_blocking() {
Ok(g) => g,
Err(e) => {
tracing::error!("phostt_transcribe_file: failed to checkout session from pool: {e}");
eprintln!("phostt_transcribe_file: failed to checkout session from pool: {e}");
return ptr::null_mut();
}
};
let result = match engine_ref.transcribe_file(path_str, &mut guard) {
Ok(r) => r,
Err(e) => {
tracing::error!("phostt_transcribe_file: transcription failed: {e}");
eprintln!("phostt_transcribe_file: transcription failed: {e}");
return ptr::null_mut();
}
};
match CString::new(result.text) {
Ok(cstr) => cstr.into_raw(),
Err(e) => {
tracing::error!("phostt_transcribe_file: result contains interior NUL: {e}");
eprintln!("phostt_transcribe_file: result contains interior NUL: {e}");
ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn phostt_string_free(s: *mut c_char) {
if !s.is_null() {
let _ = unsafe { CString::from_raw(s) };
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn phostt_engine_free(engine: *mut PhosttEngine) {
if !engine.is_null() {
let _ = unsafe { Box::from_raw(engine) };
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn phostt_stream_new(engine: *mut PhosttEngine) -> *mut PhosttStream {
if engine.is_null() {
tracing::error!("phostt_stream_new: engine is null");
eprintln!("phostt_stream_new: engine is null");
return ptr::null_mut();
}
let engine_ref = unsafe { &(*engine).engine };
let guard = match engine_ref.pool.checkout_blocking() {
Ok(g) => g,
Err(e) => {
tracing::error!("phostt_stream_new: pool checkout failed: {e}");
eprintln!("phostt_stream_new: pool checkout failed: {e}");
return ptr::null_mut();
}
};
let (triplet, reservation) = guard.into_owned();
let state = match engine_ref.create_state(false) {
Ok(s) => s,
Err(e) => {
tracing::error!("phostt_stream_new: failed to create streaming state: {e}");
eprintln!("phostt_stream_new: failed to create streaming state: {e}");
reservation.checkin(triplet);
return ptr::null_mut();
}
};
let stream = PhosttStream {
state,
triplet,
reservation,
pending_byte: None,
};
Box::into_raw(Box::new(stream))
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn phostt_stream_process_chunk(
engine: *mut PhosttEngine,
stream: *mut PhosttStream,
pcm16_bytes: *const u8,
len: usize,
sample_rate: u32,
) -> *mut c_char {
if engine.is_null() {
tracing::error!("phostt_stream_process_chunk: engine is null");
return ptr::null_mut();
}
if stream.is_null() {
tracing::error!("phostt_stream_process_chunk: stream is null");
return ptr::null_mut();
}
if pcm16_bytes.is_null() {
tracing::error!("phostt_stream_process_chunk: pcm16_bytes is null");
return ptr::null_mut();
}
let engine_ref = unsafe { &(*engine).engine };
let stream_ref = unsafe { &mut (*stream) };
if sample_rate == 0 {
tracing::error!("phostt_stream_process_chunk: sample_rate is zero");
return ptr::null_mut();
}
let bytes = unsafe { std::slice::from_raw_parts(pcm16_bytes, len) };
let carry_prev = stream_ref.pending_byte.take();
let mut combined = Vec::with_capacity(bytes.len() + 1);
if let Some(prev) = carry_prev {
combined.push(prev);
}
combined.extend_from_slice(bytes);
if !combined.len().is_multiple_of(2) {
stream_ref.pending_byte = combined.pop();
}
let pcm16: Vec<i16> = combined
.chunks_exact(2)
.map(|c| i16::from_le_bytes([c[0], c[1]]))
.collect();
let mut samples_f32: Vec<f32> = pcm16.iter().map(|&s| s as f32 / 32768.0).collect();
if sample_rate != crate::inference::TARGET_SAMPLE_RATE {
samples_f32 = match audio::resample(
&samples_f32,
sample_rate,
crate::inference::TARGET_SAMPLE_RATE,
) {
Ok(s) => s,
Err(e) => {
tracing::error!("phostt_stream_process_chunk: resample failed: {e}");
return ptr::null_mut();
}
};
}
let segments = match engine_ref.process_chunk(
&samples_f32,
&mut stream_ref.state,
&mut stream_ref.triplet,
) {
Ok(segs) => segs,
Err(e) => {
tracing::error!("phostt_stream_process_chunk: inference failed: {e}");
return ptr::null_mut();
}
};
let json = serde_json::to_string(&segments).unwrap_or_else(|_| "[]".into());
match CString::new(json) {
Ok(cstr) => cstr.into_raw(),
Err(_) => ptr::null_mut(),
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn phostt_stream_flush(
engine: *mut PhosttEngine,
stream: *mut PhosttStream,
) -> *mut c_char {
if engine.is_null() {
tracing::error!("phostt_stream_flush: engine is null");
return ptr::null_mut();
}
if stream.is_null() {
tracing::error!("phostt_stream_flush: stream is null");
return ptr::null_mut();
}
let engine_ref = unsafe { &(*engine).engine };
let stream_ref = unsafe { &mut (*stream) };
let segments: Vec<crate::inference::TranscriptSegment> = engine_ref
.flush_state(&mut stream_ref.state, &mut stream_ref.triplet)
.into_iter()
.collect();
let json = serde_json::to_string(&segments).unwrap_or_else(|_| "[]".into());
match CString::new(json) {
Ok(cstr) => cstr.into_raw(),
Err(_) => ptr::null_mut(),
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn phostt_stream_free(stream: *mut PhosttStream) {
if !stream.is_null() {
let stream = unsafe { Box::from_raw(stream) };
stream.reservation.checkin(stream.triplet);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_new_null_engine() {
let stream = unsafe { phostt_stream_new(ptr::null_mut()) };
assert!(stream.is_null());
}
#[test]
fn test_stream_new_pool_closed() {
let engine = Engine::test_stub();
engine.pool.close();
let engine_ptr = Box::into_raw(Box::new(PhosttEngine { engine }));
let stream = unsafe { phostt_stream_new(engine_ptr) };
assert!(stream.is_null());
unsafe { phostt_engine_free(engine_ptr) };
}
#[test]
fn test_stream_process_chunk_null_args() {
let engine = Engine::test_stub();
let engine_ptr = Box::into_raw(Box::new(PhosttEngine { engine }));
let r = unsafe {
phostt_stream_process_chunk(
ptr::null_mut(),
ptr::null_mut(),
ptr::null(),
0,
crate::inference::TARGET_SAMPLE_RATE,
)
};
assert!(r.is_null());
let r = unsafe {
phostt_stream_process_chunk(
engine_ptr,
ptr::null_mut(),
ptr::null(),
0,
crate::inference::TARGET_SAMPLE_RATE,
)
};
assert!(r.is_null());
let dummy_stream = std::ptr::dangling_mut::<PhosttStream>();
let r = unsafe {
phostt_stream_process_chunk(
engine_ptr,
dummy_stream,
ptr::null(),
0,
crate::inference::TARGET_SAMPLE_RATE,
)
};
assert!(r.is_null());
unsafe { phostt_engine_free(engine_ptr) };
}
#[test]
fn test_stream_flush_null_args() {
let engine = Engine::test_stub();
let engine_ptr = Box::into_raw(Box::new(PhosttEngine { engine }));
let r = unsafe { phostt_stream_flush(ptr::null_mut(), ptr::null_mut()) };
assert!(r.is_null());
let r = unsafe { phostt_stream_flush(engine_ptr, ptr::null_mut()) };
assert!(r.is_null());
unsafe { phostt_engine_free(engine_ptr) };
}
#[test]
fn test_stream_free_null() {
unsafe { phostt_stream_free(ptr::null_mut()) };
}
}