#![allow(unsafe_code)]
use core::{cell::UnsafeCell, ffi::c_void};
use std::{
ffi::CString,
panic::{AssertUnwindSafe, catch_unwind},
};
type AbortCallback = Box<UnsafeCell<Box<dyn FnMut() -> bool>>>;
pub const MAX_N_THREADS: i32 = 1;
pub const MAX_BEAM_SIZE: i32 = 64;
#[cfg_attr(not(tarpaulin), inline(always))]
const fn clamp_topk(k: i32) -> i32 {
if k < 1 {
1
} else if k > MAX_BEAM_SIZE {
MAX_BEAM_SIZE
} else {
k
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
const fn clamp_n_threads(n: i32) -> i32 {
if n < 1 {
1
} else if n > MAX_N_THREADS {
MAX_N_THREADS
} else {
n
}
}
pub const MAX_INITIAL_TS_S: f32 = 30.0;
pub const MAX_TEMPERATURE: f32 = 1.0;
pub const MIN_TEMPERATURE_INC: f32 = 1e-3;
#[cfg_attr(not(tarpaulin), inline(always))]
const fn clamp_max_initial_ts(t: f32) -> f32 {
if !t.is_finite() || t < 0.0 {
0.0
} else if t > MAX_INITIAL_TS_S {
MAX_INITIAL_TS_S
} else {
t
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
const fn clamp_temperature(t: f32) -> f32 {
if !t.is_finite() || t < 0.0 {
0.0
} else if t > MAX_TEMPERATURE {
MAX_TEMPERATURE
} else {
t
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
const fn clamp_temperature_inc(inc: f32) -> f32 {
if !inc.is_finite() || inc < MIN_TEMPERATURE_INC {
0.0
} else if inc > 1.0 {
1.0
} else {
inc
}
}
use crate::{
error::{WhisperError, WhisperResult},
sys,
};
#[derive(Debug, Clone, Copy)]
pub enum SamplingStrategy {
Greedy {
best_of: i32,
},
BeamSearch {
beam_size: i32,
patience: f32,
},
}
pub struct Params {
raw: sys::whisper_full_params,
_initial_prompt: Option<CString>,
_language: Option<CString>,
_prompt_tokens: Option<Vec<sys::whisper_token>>,
_abort_callback: Option<AbortCallback>,
}
impl Params {
pub fn new(strategy: SamplingStrategy) -> Self {
let cstrategy = match strategy {
SamplingStrategy::Greedy { .. } => sys::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY,
SamplingStrategy::BeamSearch { .. } => {
sys::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH
}
};
let mut raw = unsafe { sys::whisper_full_default_params(cstrategy as _) };
match strategy {
SamplingStrategy::Greedy { best_of } => {
raw.greedy.best_of = clamp_topk(best_of);
}
SamplingStrategy::BeamSearch {
beam_size,
patience,
} => {
raw.beam_search.beam_size = clamp_topk(beam_size);
raw.beam_search.patience = patience;
}
}
raw.n_threads = clamp_n_threads(raw.n_threads);
Self {
raw,
_initial_prompt: None,
_language: None,
_prompt_tokens: None,
_abort_callback: None,
}
}
pub fn set_language(&mut self, lang: &str) -> WhisperResult<&mut Self> {
let cstr =
CString::new(lang).map_err(|_| WhisperError::InvalidCString(smol_str::SmolStr::new(lang)))?;
self.raw.language = cstr.as_ptr();
self._language = Some(cstr);
Ok(self)
}
pub fn set_initial_prompt(&mut self, prompt: &str) -> WhisperResult<&mut Self> {
let cstr = CString::new(prompt).map_err(|_| {
let head: String = prompt.chars().take(64).collect();
WhisperError::InvalidCString(smol_str::SmolStr::new(head))
})?;
self.raw.initial_prompt = cstr.as_ptr();
self._initial_prompt = Some(cstr);
Ok(self)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_detect_language(&mut self, on: bool) -> &mut Self {
self.raw.detect_language = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_n_threads(&mut self, n: i32) -> &mut Self {
self.raw.n_threads = clamp_n_threads(n);
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const unsafe fn set_n_threads_unchecked(&mut self, n: i32) -> &mut Self {
self.raw.n_threads = if n < 1 { 1 } else { n };
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const unsafe fn set_beam_size_unchecked(&mut self, n: i32) -> &mut Self {
self.raw.beam_search.beam_size = if n < 1 { 1 } else { n };
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const unsafe fn set_best_of_unchecked(&mut self, n: i32) -> &mut Self {
self.raw.greedy.best_of = if n < 1 { 1 } else { n };
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_no_context(&mut self, on: bool) -> &mut Self {
self.raw.no_context = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_no_speech_thold(&mut self, t: f32) -> &mut Self {
self.raw.no_speech_thold = t;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_temperature(&mut self, t: f32) -> &mut Self {
self.raw.temperature = clamp_temperature(t);
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_temperature_inc(&mut self, inc: f32) -> &mut Self {
self.raw.temperature_inc = clamp_temperature_inc(inc);
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_suppress_blank(&mut self, on: bool) -> &mut Self {
self.raw.suppress_blank = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_suppress_nst(&mut self, on: bool) -> &mut Self {
self.raw.suppress_nst = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn silence_print_toggles(&mut self) -> &mut Self {
self.raw.print_special = false;
self.raw.print_progress = false;
self.raw.print_realtime = false;
self.raw.print_timestamps = false;
self
}
pub fn set_abort_callback<F>(&mut self, f: F) -> &mut Self
where
F: FnMut() -> bool + 'static,
{
self.raw.abort_callback = None;
self.raw.abort_callback_user_data = core::ptr::null_mut();
let _old = self._abort_callback.take();
drop(_old);
let outer: AbortCallback = Box::new(UnsafeCell::new(Box::new(f)));
let user_data = (&*outer) as *const UnsafeCell<Box<dyn FnMut() -> bool>> as *mut c_void;
self._abort_callback = Some(outer);
self.raw.abort_callback_user_data = user_data;
self.raw.abort_callback = Some(abort_trampoline);
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_offset_ms(&mut self, ms: i32) -> &mut Self {
self.raw.offset_ms = if ms < 0 { 0 } else { ms };
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_duration_ms(&mut self, ms: i32) -> &mut Self {
self.raw.duration_ms = ms;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_audio_ctx(&mut self, n: i32) -> &mut Self {
self.raw.audio_ctx = n;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_max_tokens(&mut self, n: i32) -> &mut Self {
self.raw.max_tokens = n;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_max_len(&mut self, n: i32) -> &mut Self {
self.raw.max_len = n;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_max_initial_ts(&mut self, t: f32) -> &mut Self {
self.raw.max_initial_ts = clamp_max_initial_ts(t);
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_n_max_text_ctx(&mut self, n: i32) -> &mut Self {
self.raw.n_max_text_ctx = n;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_single_segment(&mut self, on: bool) -> &mut Self {
self.raw.single_segment = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_logprob_thold(&mut self, t: f32) -> &mut Self {
self.raw.logprob_thold = t;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_entropy_thold(&mut self, t: f32) -> &mut Self {
self.raw.entropy_thold = t;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_thold_pt(&mut self, t: f32) -> &mut Self {
self.raw.thold_pt = t;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_thold_ptsum(&mut self, t: f32) -> &mut Self {
self.raw.thold_ptsum = t;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_length_penalty(&mut self, p: f32) -> &mut Self {
self.raw.length_penalty = p;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_no_timestamps(&mut self, on: bool) -> &mut Self {
self.raw.no_timestamps = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_split_on_word(&mut self, on: bool) -> &mut Self {
self.raw.split_on_word = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_token_timestamps(&mut self, on: bool) -> &mut Self {
self.raw.token_timestamps = on;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_translate(&mut self, on: bool) -> &mut Self {
self.raw.translate = on;
self
}
pub fn set_tokens(&mut self, tokens: &[i32]) -> &mut Self {
if tokens.is_empty() {
self.raw.prompt_tokens = core::ptr::null();
self.raw.prompt_n_tokens = 0;
self._prompt_tokens = None;
} else {
let max_len = i32::MAX as usize;
let take = tokens.len().min(max_len);
let owned: Vec<sys::whisper_token> = tokens[..take].to_vec();
self.raw.prompt_tokens = owned.as_ptr();
self.raw.prompt_n_tokens = owned.len() as i32;
self._prompt_tokens = Some(owned);
}
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) const fn as_raw(&self) -> sys::whisper_full_params {
self.raw
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn prompt_tokens(&self) -> Option<&[i32]> {
self._prompt_tokens.as_deref()
}
}
impl core::fmt::Debug for Params {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Params")
.field("raw", &self.raw)
.field("language", &self._language)
.field("initial_prompt", &self._initial_prompt)
.field(
"abort_callback",
&self
._abort_callback
.as_ref()
.map(|_| "<installed>")
.unwrap_or("<none>"),
)
.field(
"prompt_tokens",
&self._prompt_tokens.as_ref().map(|v| v.len()).unwrap_or(0),
)
.finish()
}
}
unsafe extern "C" fn abort_trampoline(user_data: *mut c_void) -> bool {
let cell: &UnsafeCell<Box<dyn FnMut() -> bool>> =
unsafe { &*(user_data as *const UnsafeCell<Box<dyn FnMut() -> bool>>) };
let boxed: &mut Box<dyn FnMut() -> bool> = unsafe { &mut *cell.get() };
catch_unwind(AssertUnwindSafe(boxed)).unwrap_or(true)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg_attr(miri, ignore = "FFI: whisper_full_default_params")]
fn default_params_n_threads_normalises_to_at_least_one() {
let p = Params::new(SamplingStrategy::Greedy { best_of: 1 });
assert!(
p.raw.n_threads >= 1,
"default n_threads = {}; must be ≥ 1 to dodge the upstream vector<thread>(n - 1) underflow",
p.raw.n_threads,
);
assert!(
p.raw.n_threads <= MAX_N_THREADS,
"default n_threads = {} above MAX_N_THREADS = {}",
p.raw.n_threads,
MAX_N_THREADS,
);
}
#[test]
#[cfg_attr(miri, ignore = "FFI: whisper_full_default_params")]
fn set_n_threads_clamps_zero_negative_and_oversized() {
let mut p = Params::new(SamplingStrategy::Greedy { best_of: 1 });
p.set_n_threads(0);
assert_eq!(p.raw.n_threads, 1, "0 → 1");
p.set_n_threads(-42);
assert_eq!(p.raw.n_threads, 1, "negative → 1");
p.set_n_threads(i32::MIN);
assert_eq!(p.raw.n_threads, 1, "i32::MIN → 1");
p.set_n_threads(MAX_N_THREADS + 1);
assert_eq!(p.raw.n_threads, MAX_N_THREADS, "above-cap → MAX_N_THREADS");
p.set_n_threads(i32::MAX);
assert_eq!(p.raw.n_threads, MAX_N_THREADS, "i32::MAX → MAX_N_THREADS");
}
#[test]
fn clamp_n_threads_pins_invariants() {
assert_eq!(clamp_n_threads(0), 1);
assert_eq!(clamp_n_threads(-1), 1);
assert_eq!(clamp_n_threads(1), 1);
assert_eq!(clamp_n_threads(MAX_N_THREADS), MAX_N_THREADS);
assert_eq!(clamp_n_threads(MAX_N_THREADS + 1), MAX_N_THREADS);
assert_eq!(clamp_n_threads(2), MAX_N_THREADS);
assert_eq!(clamp_n_threads(8), MAX_N_THREADS);
}
#[test]
#[cfg_attr(miri, ignore = "FFI: whisper_full_default_params")]
fn set_n_threads_unchecked_bypasses_upper_cap_only() {
let mut p = Params::new(SamplingStrategy::Greedy { best_of: 1 });
unsafe { p.set_n_threads_unchecked(8) };
assert_eq!(p.raw.n_threads, 8);
unsafe { p.set_n_threads_unchecked(MAX_N_THREADS + 1) };
assert_eq!(p.raw.n_threads, MAX_N_THREADS + 1);
unsafe { p.set_n_threads_unchecked(64) };
assert_eq!(p.raw.n_threads, 64);
unsafe { p.set_n_threads_unchecked(0) };
assert_eq!(p.raw.n_threads, 1, "0 → 1 even on unchecked path");
unsafe { p.set_n_threads_unchecked(-7) };
assert_eq!(p.raw.n_threads, 1, "negative → 1 even on unchecked path");
}
#[test]
fn max_beam_size_pins_to_64() {
assert_eq!(MAX_BEAM_SIZE, 64);
}
#[test]
#[cfg_attr(miri, ignore = "FFI: whisper_full_default_params")]
fn unchecked_topk_setters_bypass_upper_cap_only() {
let mut p = Params::new(SamplingStrategy::BeamSearch {
beam_size: 1,
patience: -1.0,
});
unsafe { p.set_beam_size_unchecked(MAX_BEAM_SIZE + 1) };
assert_eq!(p.raw.beam_search.beam_size, MAX_BEAM_SIZE + 1);
unsafe { p.set_beam_size_unchecked(128) };
assert_eq!(p.raw.beam_search.beam_size, 128);
unsafe { p.set_beam_size_unchecked(0) };
assert_eq!(p.raw.beam_search.beam_size, 1);
unsafe { p.set_beam_size_unchecked(-9) };
assert_eq!(p.raw.beam_search.beam_size, 1);
let mut g = Params::new(SamplingStrategy::Greedy { best_of: 1 });
unsafe { g.set_best_of_unchecked(MAX_BEAM_SIZE + 1) };
assert_eq!(g.raw.greedy.best_of, MAX_BEAM_SIZE + 1);
unsafe { g.set_best_of_unchecked(0) };
assert_eq!(g.raw.greedy.best_of, 1);
}
}