#![allow(unsafe_code)]
use core::{
ptr::NonNull,
sync::atomic::{AtomicBool, Ordering},
};
use std::{
ffi::CString,
path::Path,
sync::{Arc, Mutex, MutexGuard},
};
use smol_str::SmolStr;
use crate::{
error::{WhisperError, WhisperResult},
state::State,
sys,
};
pub(crate) fn init_lock() -> MutexGuard<'static, ()> {
static LOCK: Mutex<()> = Mutex::new(());
LOCK.lock().unwrap_or_else(|e| e.into_inner())
}
pub const DEFAULT_DTW_MEM_SIZE: usize = 128 * 1024 * 1024;
pub const MIN_DTW_MEM_SIZE: usize = DEFAULT_DTW_MEM_SIZE;
#[cfg(target_pointer_width = "64")]
pub const MAX_DTW_MEM_SIZE: usize = 4 * 1024 * 1024 * 1024;
#[cfg(not(target_pointer_width = "64"))]
pub const MAX_DTW_MEM_SIZE: usize = 1024 * 1024 * 1024;
#[cfg_attr(not(tarpaulin), inline(always))]
const fn clamp_dtw_mem_size(n: usize) -> usize {
if n < MIN_DTW_MEM_SIZE {
MIN_DTW_MEM_SIZE
} else if n > MAX_DTW_MEM_SIZE {
MAX_DTW_MEM_SIZE
} else {
n
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
const fn alignment_head_count(preset: AlignmentHeadsPreset) -> usize {
match preset {
AlignmentHeadsPreset::None => 0,
AlignmentHeadsPreset::TinyEn => 8,
AlignmentHeadsPreset::Tiny => 6,
AlignmentHeadsPreset::BaseEn => 5,
AlignmentHeadsPreset::Base => 8,
AlignmentHeadsPreset::SmallEn => 19,
AlignmentHeadsPreset::Small => 10,
AlignmentHeadsPreset::MediumEn => 18,
AlignmentHeadsPreset::Medium => 6,
AlignmentHeadsPreset::LargeV1 => 9,
AlignmentHeadsPreset::LargeV2 => 23,
AlignmentHeadsPreset::LargeV3 => 10,
AlignmentHeadsPreset::LargeV3Turbo => 6,
}
}
pub const SUPPORTED_DTW_N_TEXT_CTX: i32 = 448;
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn required_dtw_mem_size_for(preset: AlignmentHeadsPreset) -> usize {
let n_heads = alignment_head_count(preset);
if n_heads == 0 {
return 0;
}
let per_tensor = (SUPPORTED_DTW_N_TEXT_CTX as usize) * 1500 * n_heads * 4;
let with_safety = (per_tensor * 3) * 3 / 2;
if with_safety < MIN_DTW_MEM_SIZE {
MIN_DTW_MEM_SIZE
} else if with_safety > MAX_DTW_MEM_SIZE {
MAX_DTW_MEM_SIZE
} else {
with_safety
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlignmentHeadsPreset {
None,
TinyEn,
Tiny,
BaseEn,
Base,
SmallEn,
Small,
MediumEn,
Medium,
LargeV1,
LargeV2,
LargeV3,
LargeV3Turbo,
}
impl AlignmentHeadsPreset {
#[cfg_attr(not(tarpaulin), inline(always))]
const fn to_raw(self) -> sys::whisper_alignment_heads_preset {
match self {
Self::None => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE,
Self::TinyEn => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN,
Self::Tiny => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY,
Self::BaseEn => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN,
Self::Base => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE,
Self::SmallEn => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN,
Self::Small => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL,
Self::MediumEn => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN,
Self::Medium => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM,
Self::LargeV1 => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1,
Self::LargeV2 => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2,
Self::LargeV3 => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3,
Self::LargeV3Turbo => sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3_TURBO,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ContextParams {
use_gpu: bool,
gpu_device: i32,
flash_attn: bool,
dtw_token_timestamps: bool,
dtw_aheads_preset: AlignmentHeadsPreset,
dtw_mem_size: usize,
}
impl ContextParams {
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn new() -> Self {
Self {
use_gpu: true,
gpu_device: 0,
flash_attn: false,
dtw_token_timestamps: false,
dtw_aheads_preset: AlignmentHeadsPreset::None,
dtw_mem_size: DEFAULT_DTW_MEM_SIZE,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn use_gpu(&self) -> bool {
self.use_gpu
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_use_gpu(mut self, on: bool) -> Self {
self.use_gpu = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn gpu_device(&self) -> i32 {
self.gpu_device
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_gpu_device(mut self, idx: i32) -> Self {
self.gpu_device = idx;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn flash_attn(&self) -> bool {
self.flash_attn
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_flash_attn(mut self, on: bool) -> Self {
self.flash_attn = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn dtw_token_timestamps(&self) -> bool {
self.dtw_token_timestamps
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_dtw_token_timestamps(mut self, on: bool) -> Self {
self.dtw_token_timestamps = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn dtw_aheads_preset(&self) -> AlignmentHeadsPreset {
self.dtw_aheads_preset
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_dtw_aheads_preset(mut self, preset: AlignmentHeadsPreset) -> Self {
self.dtw_aheads_preset = preset;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn dtw_mem_size(&self) -> usize {
self.dtw_mem_size
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_dtw_mem_size(mut self, n: usize) -> Self {
self.dtw_mem_size = clamp_dtw_mem_size(n);
self
}
}
impl Default for ContextParams {
#[cfg_attr(not(tarpaulin), inline(always))]
fn default() -> Self {
Self::new()
}
}
pub struct Context {
ptr: NonNull<sys::whisper_context>,
lost: AtomicBool,
full_lock: Mutex<()>,
}
unsafe impl Send for Context {}
unsafe impl Sync for Context {}
impl Context {
pub fn new(path: impl AsRef<Path>, params: ContextParams) -> WhisperResult<Self> {
let path_ref = path.as_ref();
let path_str = path_ref.to_string_lossy();
let cpath = CString::new(path_str.as_ref())
.map_err(|_| WhisperError::InvalidCString(smol_str::SmolStr::new(path_str.as_ref())))?;
let mut cparams = unsafe { sys::whisper_context_default_params() };
cparams.use_gpu = params.use_gpu();
cparams.gpu_device = params.gpu_device();
cparams.flash_attn = params.flash_attn();
let dtw_on =
params.dtw_token_timestamps() && params.dtw_aheads_preset() != AlignmentHeadsPreset::None;
if dtw_on && params.flash_attn() {
return Err(WhisperError::ContextLoad {
path: smol_str::SmolStr::new(path_str.as_ref()),
reason: SmolStr::new_static(
"DTW token timestamps cannot be combined with flash_attn — \
whisper.cpp silently disables DTW under flash_attn. \
Set with_flash_attn(false) or with_dtw_token_timestamps(false).",
),
code: None,
});
}
cparams.dtw_token_timestamps = dtw_on;
cparams.dtw_aheads_preset = if dtw_on {
params.dtw_aheads_preset().to_raw()
} else {
AlignmentHeadsPreset::None.to_raw()
};
let clamped_user = clamp_dtw_mem_size(params.dtw_mem_size());
cparams.dtw_mem_size = if dtw_on {
let required = required_dtw_mem_size_for(params.dtw_aheads_preset());
if clamped_user >= required {
clamped_user
} else {
required
}
} else {
clamped_user
};
let _lock = init_lock();
let raw = unsafe { sys::whispercpp_init_from_file_no_state(cpath.as_ptr(), cparams) };
if let Some(ptr) = NonNull::new(raw) {
if dtw_on {
let n_text_ctx = unsafe { sys::whisper_n_text_ctx(ptr.as_ptr()) };
if n_text_ctx > SUPPORTED_DTW_N_TEXT_CTX {
unsafe { sys::whisper_free(ptr.as_ptr()) };
return Err(WhisperError::ContextLoad {
path: smol_str::SmolStr::new(path_str.as_ref()),
reason: SmolStr::new_static(
"DTW enabled with a model whose n_text_ctx exceeds SUPPORTED_DTW_N_TEXT_CTX (448) — \
disable DTW (with_dtw_token_timestamps(false)) or use a standard checkpoint",
),
code: None,
});
}
}
return Ok(Self {
ptr,
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
});
}
let exc = unsafe { sys::whispercpp_take_last_constructor_exception() };
let (reason, code) = if exc != 0 {
(
SmolStr::new_static(
"whispercpp_init_from_file_no_state caught C++ exception; \
native cleanup completed via init_context RAII exit + model_load RAII guards",
),
Some(exc),
)
} else {
(
SmolStr::new_static(
"whispercpp_init_from_file_no_state returned NULL (upstream load failure, no native exception caught)",
),
None,
)
};
Err(WhisperError::ContextLoad {
path: smol_str::SmolStr::new(path_str.as_ref()),
reason,
code,
})
}
pub fn create_state(self: &Arc<Self>) -> WhisperResult<State> {
if self.lost.load(Ordering::Acquire) {
return Err(WhisperError::ContextPoisoned);
}
let _lock = init_lock();
let raw = unsafe { sys::whispercpp_init_state(self.ptr.as_ptr()) };
if self.lost.load(Ordering::Acquire) {
if let Some(state_ptr) = NonNull::new(raw) {
unsafe { sys::whisper_free_state(state_ptr.as_ptr()) };
let _ = unsafe { sys::whispercpp_take_last_constructor_exception() };
return Err(WhisperError::ContextPoisoned);
}
let _ = unsafe { sys::whispercpp_take_last_constructor_exception() };
return Err(WhisperError::ContextPoisoned);
}
if let Some(state_ptr) = NonNull::new(raw) {
let _ = unsafe { sys::whispercpp_take_last_constructor_exception() };
return Ok(State::from_raw(state_ptr, Arc::clone(self)));
}
let exc = unsafe { sys::whispercpp_take_last_constructor_exception() };
Err(WhisperError::StateInit {
code: if exc == 0 { None } else { Some(exc) },
})
}
pub(crate) fn as_raw(&self) -> *mut sys::whisper_context {
self.ptr.as_ptr()
}
#[cfg(test)]
pub(crate) unsafe fn dangling_for_test() -> Self {
Self {
ptr: NonNull::<sys::whisper_context>::dangling(),
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn mark_lost(&self) {
self.lost.store(true, Ordering::Release);
}
pub fn is_poisoned(&self) -> bool {
self.lost.load(Ordering::Acquire)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn full_lock(&self) -> MutexGuard<'_, ()> {
self
.full_lock
.lock()
.unwrap_or_else(|poison| poison.into_inner())
}
pub fn is_multilingual(&self) -> bool {
unsafe { sys::whisper_is_multilingual(self.ptr.as_ptr()) != 0 }
}
pub fn n_vocab(&self) -> i32 {
unsafe { sys::whisper_n_vocab(self.ptr.as_ptr()) }
}
pub fn n_audio_ctx(&self) -> i32 {
unsafe { sys::whisper_n_audio_ctx(self.ptr.as_ptr()) }
}
pub fn n_text_ctx(&self) -> i32 {
unsafe { sys::whisper_n_text_ctx(self.ptr.as_ptr()) }
}
pub fn model_type(&self) -> Option<&'static str> {
let raw = unsafe { sys::whisper_model_type_readable(self.ptr.as_ptr()) };
if raw.is_null() {
return None;
}
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
core::str::from_utf8(bytes).ok()
}
pub fn token_eot(&self) -> i32 {
unsafe { sys::whisper_token_eot(self.ptr.as_ptr()) }
}
pub fn token_sot(&self) -> i32 {
unsafe { sys::whisper_token_sot(self.ptr.as_ptr()) }
}
pub fn token_beg(&self) -> i32 {
unsafe { sys::whisper_token_beg(self.ptr.as_ptr()) }
}
pub fn token_to_str(&self, token: i32) -> Option<&str> {
let n = self.n_vocab();
if token < 0 || token >= n {
return None;
}
let raw = unsafe { sys::whisper_token_to_str(self.ptr.as_ptr(), token) };
if raw.is_null() {
return None;
}
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
core::str::from_utf8(bytes).ok()
}
pub fn token_to_bytes(&self, token: i32) -> Option<&[u8]> {
let n = self.n_vocab();
if token < 0 || token >= n {
return None;
}
let raw = unsafe { sys::whisper_token_to_str(self.ptr.as_ptr(), token) };
if raw.is_null() {
return None;
}
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
Some(bytes)
}
pub fn token_translate(&self) -> i32 {
unsafe { sys::whisper_token_translate(self.ptr.as_ptr()) }
}
pub fn token_transcribe(&self) -> i32 {
unsafe { sys::whisper_token_transcribe(self.ptr.as_ptr()) }
}
pub fn token_prev(&self) -> i32 {
unsafe { sys::whisper_token_prev(self.ptr.as_ptr()) }
}
pub fn token_nosp(&self) -> i32 {
unsafe { sys::whisper_token_nosp(self.ptr.as_ptr()) }
}
pub fn token_not(&self) -> i32 {
unsafe { sys::whisper_token_not(self.ptr.as_ptr()) }
}
pub fn token_solm(&self) -> i32 {
unsafe { sys::whisper_token_solm(self.ptr.as_ptr()) }
}
pub fn token_for_lang(&self, lang: &crate::Lang) -> Option<i32> {
if !self.is_multilingual() {
return None;
}
let lang_id = crate::lang_id_for(lang.as_str())?;
let token = unsafe { sys::whisper_token_lang(self.ptr.as_ptr(), lang_id) };
let sot = self.token_sot();
let translate = unsafe { sys::whisper_token_translate(self.ptr.as_ptr()) };
if token > sot && token < translate {
Some(token)
} else {
None
}
}
pub fn tokenize(&self, text: &str) -> Option<Vec<i32>> {
let bytes = text.as_bytes();
if bytes.contains(&0) {
return None;
}
let mut nul_terminated: Vec<u8> = Vec::new();
nul_terminated.try_reserve_exact(bytes.len() + 1).ok()?;
nul_terminated.extend_from_slice(bytes);
nul_terminated.push(0);
let cstr_ptr: *const core::ffi::c_char = nul_terminated.as_ptr().cast();
let ctx_ptr = self.ptr.as_ptr();
const INITIAL_CAPACITY: usize = 256;
let mut buf: Vec<i32> = Vec::new();
buf.try_reserve_exact(INITIAL_CAPACITY).ok()?;
let written = unsafe {
sys::whispercpp_tokenize(ctx_ptr, cstr_ptr, buf.as_mut_ptr(), INITIAL_CAPACITY as i32)
};
if written == i32::MIN {
return None;
}
if written >= 0 {
let written_usize = written as usize;
if written_usize > buf.capacity() {
return None;
}
unsafe { buf.set_len(written_usize) };
return Some(buf);
}
let needed = (-written) as usize;
let mut buf: Vec<i32> = Vec::new();
buf.try_reserve_exact(needed).ok()?;
let written =
unsafe { sys::whispercpp_tokenize(ctx_ptr, cstr_ptr, buf.as_mut_ptr(), needed as i32) };
if written == i32::MIN {
return None;
}
if written < 0 {
return None;
}
let written_usize = written as usize;
if written_usize > buf.capacity() {
return None;
}
unsafe { buf.set_len(written_usize) };
Some(buf)
}
pub fn tokenize_one(&self, text: &str) -> Option<i32> {
let bytes = text.as_bytes();
if bytes.contains(&0) {
return None;
}
let mut nul_terminated: Vec<u8> = Vec::new();
nul_terminated.try_reserve_exact(bytes.len() + 1).ok()?;
nul_terminated.extend_from_slice(bytes);
nul_terminated.push(0);
let cstr_ptr: *const core::ffi::c_char = nul_terminated.as_ptr().cast();
let ctx_ptr = self.ptr.as_ptr();
let mut out = [0i32; 1];
let written = unsafe { sys::whispercpp_tokenize(ctx_ptr, cstr_ptr, out.as_mut_ptr(), 1) };
if written == 1 { Some(out[0]) } else { None }
}
pub fn model_dims(&self) -> ModelDims {
let p = self.ptr.as_ptr();
unsafe {
ModelDims {
n_audio_state: sys::whisper_model_n_audio_state(p),
n_audio_head: sys::whisper_model_n_audio_head(p),
n_audio_layer: sys::whisper_model_n_audio_layer(p),
n_text_state: sys::whisper_model_n_text_state(p),
n_text_head: sys::whisper_model_n_text_head(p),
n_text_layer: sys::whisper_model_n_text_layer(p),
n_mels: sys::whisper_model_n_mels(p),
model_ftype: sys::whisper_model_ftype(p),
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ModelDims {
pub n_audio_state: i32,
pub n_audio_head: i32,
pub n_audio_layer: i32,
pub n_text_state: i32,
pub n_text_head: i32,
pub n_text_layer: i32,
pub n_mels: i32,
pub model_ftype: i32,
}
pub fn system_info() -> Option<smol_str::SmolStr> {
use std::sync::{Mutex, OnceLock};
static CACHE: OnceLock<Option<smol_str::SmolStr>> = OnceLock::new();
static INIT_LOCK: Mutex<()> = Mutex::new(());
if let Some(v) = CACHE.get() {
return v.clone();
}
let _guard = INIT_LOCK.lock().unwrap_or_else(|e| e.into_inner());
if let Some(v) = CACHE.get() {
return v.clone();
}
let raw = unsafe { sys::whispercpp_print_system_info() };
let result = if raw.is_null() {
None
} else {
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
core::str::from_utf8(bytes).ok().map(smol_str::SmolStr::new)
};
let _ = CACHE.set(result.clone());
result
}
impl Drop for Context {
fn drop(&mut self) {
unsafe {
sys::whisper_free(self.ptr.as_ptr());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fresh_context_marker_starts_unpoisoned() {
let dangling = NonNull::<sys::whisper_context>::dangling();
let ctx = Context {
ptr: dangling,
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
};
assert!(!ctx.is_poisoned());
ctx.mark_lost();
assert!(ctx.is_poisoned());
core::mem::forget(ctx);
}
#[test]
fn mark_lost_is_idempotent_and_monotonic() {
let dangling = NonNull::<sys::whisper_context>::dangling();
let ctx = Context {
ptr: dangling,
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
};
ctx.mark_lost();
ctx.mark_lost();
ctx.mark_lost();
assert!(ctx.is_poisoned(), "stays true across repeated marks");
core::mem::forget(ctx);
}
#[test]
fn mark_lost_visible_across_threads() {
let dangling = NonNull::<sys::whisper_context>::dangling();
let ctx = Arc::new(Context {
ptr: dangling,
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
});
let ctx_b = Arc::clone(&ctx);
let handle = std::thread::spawn(move || {
ctx_b.mark_lost();
});
handle.join().unwrap();
assert!(
ctx.is_poisoned(),
"the post-join Acquire load must see the spawn-side Release store"
);
core::mem::forget(Arc::try_unwrap(ctx).ok().unwrap());
}
#[test]
fn default_context_params_have_dtw_off_and_default_mem_budget() {
let p = ContextParams::new();
assert!(!p.dtw_token_timestamps());
assert_eq!(p.dtw_aheads_preset(), AlignmentHeadsPreset::None);
assert_eq!(p.dtw_mem_size(), DEFAULT_DTW_MEM_SIZE);
assert_eq!(DEFAULT_DTW_MEM_SIZE, 128 * 1024 * 1024);
}
#[test]
fn context_params_chained_dtw_setters_compose() {
let custom_mem = MIN_DTW_MEM_SIZE * 2;
let p = ContextParams::new()
.with_use_gpu(false)
.with_dtw_token_timestamps(true)
.with_dtw_aheads_preset(AlignmentHeadsPreset::LargeV3Turbo)
.with_dtw_mem_size(custom_mem);
assert!(!p.use_gpu());
assert!(p.dtw_token_timestamps());
assert_eq!(p.dtw_aheads_preset(), AlignmentHeadsPreset::LargeV3Turbo);
assert_eq!(p.dtw_mem_size(), custom_mem);
}
#[test]
fn clamp_dtw_mem_size_pins_invariants() {
assert_eq!(clamp_dtw_mem_size(0), MIN_DTW_MEM_SIZE);
assert_eq!(clamp_dtw_mem_size(1), MIN_DTW_MEM_SIZE);
assert_eq!(clamp_dtw_mem_size(1024), MIN_DTW_MEM_SIZE);
assert_eq!(clamp_dtw_mem_size(MIN_DTW_MEM_SIZE - 1), MIN_DTW_MEM_SIZE);
assert_eq!(clamp_dtw_mem_size(MIN_DTW_MEM_SIZE), MIN_DTW_MEM_SIZE);
assert_eq!(
clamp_dtw_mem_size(MIN_DTW_MEM_SIZE + 1),
MIN_DTW_MEM_SIZE + 1
);
assert_eq!(
clamp_dtw_mem_size(256 * 1024 * 1024),
256 * 1024 * 1024,
"256 MiB sits between MIN ({MIN_DTW_MEM_SIZE}) and MAX ({MAX_DTW_MEM_SIZE})",
);
assert_eq!(clamp_dtw_mem_size(MAX_DTW_MEM_SIZE), MAX_DTW_MEM_SIZE);
assert_eq!(
clamp_dtw_mem_size(MAX_DTW_MEM_SIZE - 1),
MAX_DTW_MEM_SIZE - 1
);
assert_eq!(clamp_dtw_mem_size(MAX_DTW_MEM_SIZE + 1), MAX_DTW_MEM_SIZE);
assert_eq!(clamp_dtw_mem_size(usize::MAX), MAX_DTW_MEM_SIZE);
const { assert!(MIN_DTW_MEM_SIZE <= MAX_DTW_MEM_SIZE) };
assert_eq!(MIN_DTW_MEM_SIZE, DEFAULT_DTW_MEM_SIZE);
}
#[test]
fn with_dtw_mem_size_clamps_zero_and_usize_max() {
let p = ContextParams::new().with_dtw_mem_size(0);
assert_eq!(
p.dtw_mem_size(),
MIN_DTW_MEM_SIZE,
"0 → MIN (defends against ggml_init NULL on zero arena)",
);
let p = ContextParams::new().with_dtw_mem_size(usize::MAX);
assert_eq!(
p.dtw_mem_size(),
MAX_DTW_MEM_SIZE,
"usize::MAX → MAX (defends against ggml_init internal arena math overflow)",
);
let p = ContextParams::new().with_dtw_mem_size(MIN_DTW_MEM_SIZE * 2);
assert_eq!(p.dtw_mem_size(), MIN_DTW_MEM_SIZE * 2);
}
#[test]
#[cfg_attr(miri, ignore = "FFI: whisper_context_default_params")]
fn context_new_rejects_dtw_plus_flash_attn() {
let params = ContextParams::new()
.with_flash_attn(true)
.with_dtw_token_timestamps(true)
.with_dtw_aheads_preset(AlignmentHeadsPreset::LargeV3Turbo);
let result = Context::new("/nonexistent/dtw+flash-attn-test.bin", params);
match result {
Err(WhisperError::ContextLoad { reason, .. }) => {
assert!(
reason.contains("DTW") && reason.contains("flash_attn"),
"ContextLoad reason must explain DTW + flash_attn incompatibility — got: {}",
reason,
);
}
Err(e) => panic!(
"expected ContextLoad with DTW + flash_attn rejection, got: {:?}",
e,
),
Ok(_) => panic!("expected error for DTW + flash_attn config, got Ok(Context)"),
}
}
#[test]
#[cfg_attr(miri, ignore = "FFI: whisper_context_default_params")]
fn context_new_rejects_dtw_plus_flash_attn_setter_order_invariant() {
let params = ContextParams::new()
.with_dtw_token_timestamps(true)
.with_dtw_aheads_preset(AlignmentHeadsPreset::LargeV3Turbo)
.with_flash_attn(true);
let result = Context::new("/nonexistent/dtw+flash-attn-test2.bin", params);
let err = result
.err()
.expect("expected ContextLoad error, got Ok(Context)");
assert!(
matches!(err, WhisperError::ContextLoad { .. }),
"expected ContextLoad regardless of setter order, got: {:?}",
err,
);
}
#[test]
#[cfg_attr(miri, ignore = "FFI: whisper_context_default_params")]
fn context_new_accepts_flash_attn_with_dtw_timestamps_but_no_preset() {
let params = ContextParams::new()
.with_flash_attn(true)
.with_dtw_token_timestamps(true);
let result = Context::new("/nonexistent/no-dtw-fine.bin", params);
if let Err(WhisperError::ContextLoad { reason, .. }) = &result {
assert!(
!(reason.contains("DTW") && reason.contains("flash_attn")),
"DTW + flash_attn rejection fired for a config where DTW is off: {}",
reason,
);
}
}
#[test]
fn supported_dtw_n_text_ctx_pins_to_standard_whisper_value() {
assert_eq!(
SUPPORTED_DTW_N_TEXT_CTX, 448,
"Standard whisper checkpoints all use n_text_ctx = 448. \
If you changed this, also re-derive the DTW scratch budget \
and update the byte-count pins in \
`required_dtw_mem_size_pins_per_preset_minimums`.",
);
}
#[test]
fn alignment_head_count_matches_whisper_cpp_tables() {
use AlignmentHeadsPreset::*;
assert_eq!(alignment_head_count(None), 0);
assert_eq!(alignment_head_count(TinyEn), 8);
assert_eq!(alignment_head_count(Tiny), 6);
assert_eq!(alignment_head_count(BaseEn), 5);
assert_eq!(alignment_head_count(Base), 8);
assert_eq!(alignment_head_count(SmallEn), 19);
assert_eq!(alignment_head_count(Small), 10);
assert_eq!(alignment_head_count(MediumEn), 18);
assert_eq!(alignment_head_count(Medium), 6);
assert_eq!(alignment_head_count(LargeV1), 9);
assert_eq!(alignment_head_count(LargeV2), 23);
assert_eq!(alignment_head_count(LargeV3), 10);
assert_eq!(alignment_head_count(LargeV3Turbo), 6);
}
#[test]
fn required_dtw_mem_size_pins_per_preset_minimums() {
use AlignmentHeadsPreset::*;
assert_eq!(required_dtw_mem_size_for(None), 0);
for preset in [
TinyEn,
Tiny,
BaseEn,
Base,
Small,
Medium,
LargeV1,
LargeV3,
LargeV3Turbo,
] {
let req = required_dtw_mem_size_for(preset);
assert!(
req >= MIN_DTW_MEM_SIZE,
"{:?} requires {} bytes; must be ≥ MIN_DTW_MEM_SIZE ({})",
preset,
req,
MIN_DTW_MEM_SIZE,
);
assert!(
req <= MAX_DTW_MEM_SIZE,
"{:?} requires {} bytes; must be ≤ MAX_DTW_MEM_SIZE ({})",
preset,
req,
MAX_DTW_MEM_SIZE,
);
}
for preset in [SmallEn, MediumEn, LargeV2] {
let req = required_dtw_mem_size_for(preset);
assert!(
req > MIN_DTW_MEM_SIZE,
"{:?} requires only {} bytes — must exceed MIN_DTW_MEM_SIZE ({}) \
to fit its high-head DTW pipeline; without this the wrapper's \
floor would let whisper.cpp abort during decode",
preset,
req,
MIN_DTW_MEM_SIZE,
);
}
assert_eq!(required_dtw_mem_size_for(LargeV2), 278_208_000);
assert_eq!(required_dtw_mem_size_for(SmallEn), 229_824_000);
assert_eq!(required_dtw_mem_size_for(MediumEn), 217_728_000);
}
#[test]
fn alignment_heads_preset_maps_to_distinct_raw_values() {
use AlignmentHeadsPreset::*;
let presets = [
None,
TinyEn,
Tiny,
BaseEn,
Base,
SmallEn,
Small,
MediumEn,
Medium,
LargeV1,
LargeV2,
LargeV3,
LargeV3Turbo,
];
let raws: Vec<sys::whisper_alignment_heads_preset> =
presets.iter().map(|p| p.to_raw()).collect();
let mut sorted = raws.clone();
sorted.sort();
sorted.dedup();
assert_eq!(
sorted.len(),
presets.len(),
"AlignmentHeadsPreset → raw mapping must be injective: got {:?}",
raws,
);
}
#[test]
fn full_lock_serialises_concurrent_holders() {
let dangling = NonNull::<sys::whisper_context>::dangling();
let ctx = Arc::new(Context {
ptr: dangling,
lost: AtomicBool::new(false),
full_lock: Mutex::new(()),
});
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let mut handles = Vec::new();
for _ in 0..4 {
let ctx_t = Arc::clone(&ctx);
let counter_t = Arc::clone(&counter);
handles.push(std::thread::spawn(move || {
let _g = ctx_t.full_lock();
let pre = counter_t.fetch_add(1, Ordering::SeqCst);
assert_eq!(pre, 0, "another holder slipped past the mutex");
std::thread::sleep(std::time::Duration::from_millis(2));
let post = counter_t.fetch_sub(1, Ordering::SeqCst);
assert_eq!(post, 1, "another holder is concurrent with us");
}));
}
for h in handles {
h.join().unwrap();
}
core::mem::forget(Arc::try_unwrap(ctx).ok().unwrap());
}
}