#![allow(unsafe_code)]
use core::{
ptr::NonNull,
sync::atomic::{AtomicBool, Ordering},
};
use std::{
ffi::CString,
path::Path,
sync::{Arc, Mutex, MutexGuard},
};
use crate::{
error::{WhisperError, WhisperResult},
state::State,
sys,
};
pub(crate) fn init_lock() -> MutexGuard<'static, ()> {
static LOCK: Mutex<()> = Mutex::new(());
LOCK.lock().unwrap_or_else(|e| e.into_inner())
}
#[derive(Debug, Clone, Copy)]
pub struct ContextParams {
use_gpu: bool,
gpu_device: i32,
flash_attn: bool,
}
impl ContextParams {
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn new() -> Self {
Self {
use_gpu: true,
gpu_device: 0,
flash_attn: false,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn use_gpu(&self) -> bool {
self.use_gpu
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_use_gpu(mut self, on: bool) -> Self {
self.use_gpu = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn gpu_device(&self) -> i32 {
self.gpu_device
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_gpu_device(mut self, idx: i32) -> Self {
self.gpu_device = idx;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn flash_attn(&self) -> bool {
self.flash_attn
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_flash_attn(mut self, on: bool) -> Self {
self.flash_attn = on;
self
}
}
impl Default for ContextParams {
#[cfg_attr(not(tarpaulin), inline(always))]
fn default() -> Self {
Self::new()
}
}
pub struct Context {
ptr: NonNull<sys::whisper_context>,
lost: AtomicBool,
full_lock: Mutex<()>,
}
unsafe impl Send for Context {}
unsafe impl Sync for Context {}
impl Context {
pub fn new(path: impl AsRef<Path>, params: ContextParams) -> WhisperResult<Self> {
let path_ref = path.as_ref();
let path_str = path_ref.to_string_lossy();
let cpath = CString::new(path_str.as_ref())
.map_err(|_| WhisperError::InvalidCString(smol_str::SmolStr::new(path_str.as_ref())))?;
let mut cparams = unsafe { sys::whisper_context_default_params() };
cparams.use_gpu = params.use_gpu();
cparams.gpu_device = params.gpu_device();
cparams.flash_attn = params.flash_attn();
let _lock = init_lock();
let raw = unsafe { sys::whispercpp_init_from_file_no_state(cpath.as_ptr(), cparams) };
if let Some(ptr) = NonNull::new(raw) {
return Ok(Self {
ptr,
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
});
}
let exc = unsafe { sys::whispercpp_take_last_constructor_exception() };
if exc != 0 {
return Err(WhisperError::ConstructorLost {
origin: "context",
code: exc,
});
}
Err(WhisperError::ContextLoad {
path: smol_str::SmolStr::new(path_str.as_ref()),
reason: smol_str::SmolStr::new(
"whispercpp_init_from_file_no_state returned NULL (upstream load failure, no native exception caught)",
),
})
}
pub fn create_state(self: &Arc<Self>) -> WhisperResult<State> {
if self.lost.load(Ordering::Acquire) {
return Err(WhisperError::ContextPoisoned);
}
let _lock = init_lock();
let raw = unsafe { sys::whispercpp_init_state(self.ptr.as_ptr()) };
if self.lost.load(Ordering::Acquire) {
if let Some(state_ptr) = NonNull::new(raw) {
unsafe { sys::whisper_free_state(state_ptr.as_ptr()) };
}
let _ = unsafe { sys::whispercpp_take_last_constructor_exception() };
return Err(WhisperError::ContextPoisoned);
}
if let Some(state_ptr) = NonNull::new(raw) {
return Ok(State::from_raw(state_ptr, Arc::clone(self)));
}
let exc = unsafe { sys::whispercpp_take_last_constructor_exception() };
if exc != 0 {
self.lost.store(true, Ordering::Release);
return Err(WhisperError::ConstructorLost {
origin: "state",
code: exc,
});
}
Err(WhisperError::StateInit)
}
pub(crate) fn as_raw(&self) -> *mut sys::whisper_context {
self.ptr.as_ptr()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn mark_lost(&self) {
self.lost.store(true, Ordering::Release);
}
pub fn is_poisoned(&self) -> bool {
self.lost.load(Ordering::Acquire)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn full_lock(&self) -> MutexGuard<'_, ()> {
self
.full_lock
.lock()
.unwrap_or_else(|poison| poison.into_inner())
}
pub fn is_multilingual(&self) -> bool {
unsafe { sys::whisper_is_multilingual(self.ptr.as_ptr()) != 0 }
}
pub fn n_vocab(&self) -> i32 {
unsafe { sys::whisper_n_vocab(self.ptr.as_ptr()) }
}
pub fn n_audio_ctx(&self) -> i32 {
unsafe { sys::whisper_n_audio_ctx(self.ptr.as_ptr()) }
}
pub fn n_text_ctx(&self) -> i32 {
unsafe { sys::whisper_n_text_ctx(self.ptr.as_ptr()) }
}
pub fn model_type(&self) -> Option<&'static str> {
let raw = unsafe { sys::whisper_model_type_readable(self.ptr.as_ptr()) };
if raw.is_null() {
return None;
}
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
core::str::from_utf8(bytes).ok()
}
pub fn token_eot(&self) -> i32 {
unsafe { sys::whisper_token_eot(self.ptr.as_ptr()) }
}
pub fn token_sot(&self) -> i32 {
unsafe { sys::whisper_token_sot(self.ptr.as_ptr()) }
}
pub fn token_beg(&self) -> i32 {
unsafe { sys::whisper_token_beg(self.ptr.as_ptr()) }
}
pub fn token_to_str(&self, token: i32) -> Option<&str> {
let n = self.n_vocab();
if token < 0 || token >= n {
return None;
}
let raw = unsafe { sys::whisper_token_to_str(self.ptr.as_ptr(), token) };
if raw.is_null() {
return None;
}
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
core::str::from_utf8(bytes).ok()
}
}
pub fn system_info() -> Option<smol_str::SmolStr> {
use std::sync::{Mutex, OnceLock};
static CACHE: OnceLock<Option<smol_str::SmolStr>> = OnceLock::new();
static INIT_LOCK: Mutex<()> = Mutex::new(());
if let Some(v) = CACHE.get() {
return v.clone();
}
let _guard = INIT_LOCK.lock().unwrap_or_else(|e| e.into_inner());
if let Some(v) = CACHE.get() {
return v.clone();
}
let raw = unsafe { sys::whispercpp_print_system_info() };
let result = if raw.is_null() {
None
} else {
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
core::str::from_utf8(bytes).ok().map(smol_str::SmolStr::new)
};
let _ = CACHE.set(result.clone());
result
}
impl Drop for Context {
fn drop(&mut self) {
unsafe {
sys::whisper_free(self.ptr.as_ptr());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fresh_context_marker_starts_unpoisoned() {
let dangling = NonNull::<sys::whisper_context>::dangling();
let ctx = Context {
ptr: dangling,
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
};
assert!(!ctx.is_poisoned());
ctx.mark_lost();
assert!(ctx.is_poisoned());
core::mem::forget(ctx);
}
#[test]
fn mark_lost_is_idempotent_and_monotonic() {
let dangling = NonNull::<sys::whisper_context>::dangling();
let ctx = Context {
ptr: dangling,
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
};
ctx.mark_lost();
ctx.mark_lost();
ctx.mark_lost();
assert!(ctx.is_poisoned(), "stays true across repeated marks");
core::mem::forget(ctx);
}
#[test]
fn mark_lost_visible_across_threads() {
let dangling = NonNull::<sys::whisper_context>::dangling();
let ctx = Arc::new(Context {
ptr: dangling,
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
});
let ctx_b = Arc::clone(&ctx);
let handle = std::thread::spawn(move || {
ctx_b.mark_lost();
});
handle.join().unwrap();
assert!(
ctx.is_poisoned(),
"the post-join Acquire load must see the spawn-side Release store"
);
core::mem::forget(Arc::try_unwrap(ctx).ok().unwrap());
}
#[test]
fn full_lock_serialises_concurrent_holders() {
let dangling = NonNull::<sys::whisper_context>::dangling();
let ctx = Arc::new(Context {
ptr: dangling,
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
});
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let mut handles = Vec::new();
for _ in 0..4 {
let ctx_t = Arc::clone(&ctx);
let counter_t = Arc::clone(&counter);
handles.push(std::thread::spawn(move || {
let _g = ctx_t.full_lock();
let pre = counter_t.fetch_add(1, Ordering::SeqCst);
assert_eq!(pre, 0, "another holder slipped past the mutex");
std::thread::sleep(std::time::Duration::from_millis(2));
let post = counter_t.fetch_sub(1, Ordering::SeqCst);
assert_eq!(post, 1, "another holder is concurrent with us");
}));
}
for h in handles {
h.join().unwrap();
}
core::mem::forget(Arc::try_unwrap(ctx).ok().unwrap());
}
}