#![allow(unsafe_code)]
use core::{ptr::NonNull, str};
use std::sync::Arc;
use crate::{
context::Context,
error::{WhisperError, WhisperResult},
lang::Lang,
params::Params,
sys,
};
pub struct State {
ptr: Option<NonNull<sys::whisper_state>>,
ctx: Arc<Context>,
}
unsafe impl Send for State {}
impl State {
pub(crate) fn from_raw(ptr: NonNull<sys::whisper_state>, ctx: Arc<Context>) -> Self {
Self {
ptr: Some(ptr),
ctx,
}
}
pub fn context(&self) -> &Arc<Context> {
&self.ctx
}
pub fn is_poisoned(&self) -> bool {
self.ptr.is_none()
}
#[inline]
fn raw(&self) -> Option<*mut sys::whisper_state> {
self.ptr.map(NonNull::as_ptr)
}
pub fn full(&mut self, params: &Params, samples: &[f32]) -> WhisperResult<()> {
const MIN_SAMPLES_FOR_REFLECTIVE_PAD: usize = 201;
if samples.len() < MIN_SAMPLES_FOR_REFLECTIVE_PAD {
return Err(WhisperError::SamplesTooShort {
samples: samples.len(),
min_required: MIN_SAMPLES_FOR_REFLECTIVE_PAD,
});
}
let len = i32::try_from(samples.len()).map_err(|_| WhisperError::SamplesOverflow {
samples: samples.len(),
})?;
if let Some(prompt) = params.prompt_tokens() {
let vocab = self.ctx.n_vocab();
for &tok in prompt {
if tok < 0 || tok >= vocab {
return Err(WhisperError::TokenOutOfRange {
token: tok,
vocab_size: vocab,
});
}
}
}
let lang_opt = params.language();
let auto_detect = params.detect_language()
|| match lang_opt {
None => true,
Some(s) => s.is_empty() || s == "auto",
};
if auto_detect {
if !self.ctx.is_multilingual() {
return Err(WhisperError::LanguageNotSupportedByModel(
smol_str::SmolStr::new_static("auto"),
));
}
} else if let Some(lang) = lang_opt {
let model_supports = if let Some(lang_id) = crate::lang::lang_id_for(lang) {
if self.ctx.is_multilingual() {
let token = unsafe { sys::whisper_token_lang(self.ctx.as_raw(), lang_id) };
let sot = self.ctx.token_sot();
let translate = unsafe { sys::whisper_token_translate(self.ctx.as_raw()) };
token > sot && token < translate
} else {
lang_id == 0
}
} else {
false
};
if !model_supports {
return Err(WhisperError::LanguageNotSupportedByModel(
smol_str::SmolStr::new(lang),
));
}
}
let offset_ms = i64::from(params.offset_ms());
let duration_ms = i64::from(params.duration_ms());
let audio_duration_ms = (samples.len() as i64) * 1000 / 16_000;
if offset_ms < 0
|| duration_ms < 0
|| (duration_ms > 0 && offset_ms.saturating_add(duration_ms) > audio_duration_ms)
|| (duration_ms == 0 && offset_ms > audio_duration_ms)
{
return Err(WhisperError::InvalidDuration {
offset_ms: params.offset_ms(),
duration_ms: params.duration_ms(),
audio_duration_ms,
});
}
let state_ptr = match self.ptr {
Some(p) => p.as_ptr(),
None => return Err(WhisperError::StateLost { code: -7 }),
};
let _full_guard = self.ctx.full_lock();
if self.ctx.is_poisoned() {
return Err(WhisperError::ContextPoisoned);
}
let rc = unsafe {
sys::whispercpp_full_with_state(
self.ctx.as_raw(),
state_ptr,
params.as_raw(),
samples.as_ptr(),
len,
)
};
if rc == 0 {
return Ok(());
}
if rc == -7 {
self.ptr = None;
self.ctx.mark_lost();
return Err(WhisperError::StateLost { code: rc });
}
if rc <= sys::WHISPERCPP_ERR_BAD_ALLOC {
self.ctx.mark_lost();
if let Some(p) = self.ptr.take() {
unsafe { sys::whisper_free_state(p.as_ptr()) };
}
return Err(WhisperError::StateLost { code: rc });
}
Err(WhisperError::Full { code: rc })
}
pub fn n_segments(&self) -> i32 {
let Some(state) = self.raw() else { return 0 };
unsafe { sys::whisper_full_n_segments_from_state(state) }
}
pub fn n_mel_frames(&self) -> i32 {
let Some(state) = self.raw() else { return 0 };
unsafe { sys::whisper_n_len_from_state(state) }
}
pub fn print_timings(&self) {
let Some(state) = self.raw() else { return };
let ctx = self.ctx.as_raw();
unsafe { sys::whispercpp_print_timings_with_state(ctx, state) };
}
pub fn reset_timings(&mut self) {
let Some(state) = self.raw() else { return };
unsafe { sys::whispercpp_reset_timings_with_state(state) };
}
pub fn segment(&self, idx: i32) -> Option<Segment<'_>> {
let state_ptr = self.ptr?;
if idx < 0 || idx >= self.n_segments() {
return None;
}
Some(Segment {
state: state_ptr,
idx,
_marker: core::marker::PhantomData,
})
}
pub fn segments_iter(&self) -> Segments<'_> {
Segments {
state: self,
next: 0,
end: self.n_segments().max(0),
}
}
pub fn detected_lang(&self) -> Option<Lang> {
let state = self.raw()?;
let id = unsafe { sys::whisper_full_lang_id_from_state(state) };
if id < 0 {
return None;
}
let raw = unsafe { sys::whisper_lang_str(id) };
if raw.is_null() {
return None;
}
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
let code = str::from_utf8(bytes).ok()?;
Some(Lang::from_iso639_1(code))
}
}
impl Drop for State {
fn drop(&mut self) {
if let Some(p) = self.ptr.take() {
unsafe { sys::whisper_free_state(p.as_ptr()) }
}
}
}
#[derive(Clone, Copy)]
pub struct Segment<'a> {
state: NonNull<sys::whisper_state>,
idx: i32,
_marker: core::marker::PhantomData<&'a ()>,
}
impl<'a> Segment<'a> {
pub fn t0(&self) -> i64 {
unsafe { sys::whisper_full_get_segment_t0_from_state(self.state.as_ptr(), self.idx) }
}
pub fn t1(&self) -> i64 {
unsafe { sys::whisper_full_get_segment_t1_from_state(self.state.as_ptr(), self.idx) }
}
pub fn text(&self) -> WhisperResult<&'a str> {
let raw =
unsafe { sys::whisper_full_get_segment_text_from_state(self.state.as_ptr(), self.idx) };
if raw.is_null() {
return Ok("");
}
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
str::from_utf8(bytes).map_err(WhisperError::from)
}
pub fn no_speech_prob(&self) -> f32 {
unsafe {
sys::whisper_full_get_segment_no_speech_prob_from_state(self.state.as_ptr(), self.idx)
}
}
pub fn n_tokens(&self) -> i32 {
unsafe { sys::whisper_full_n_tokens_from_state(self.state.as_ptr(), self.idx) }
}
pub fn token(&self, tok_idx: i32) -> Option<Token> {
if tok_idx < 0 || tok_idx >= self.n_tokens() {
return None;
}
let raw = unsafe {
sys::whisper_full_get_token_data_from_state(self.state.as_ptr(), self.idx, tok_idx)
};
Some(Token::from_raw(raw))
}
pub fn tokens_iter(&self) -> Tokens<'a> {
Tokens {
segment: *self,
next: 0,
end: self.n_tokens().max(0),
}
}
pub fn speaker_turn_next(&self) -> bool {
unsafe {
sys::whisper_full_get_segment_speaker_turn_next_from_state(self.state.as_ptr(), self.idx)
}
}
}
pub struct Segments<'a> {
state: &'a State,
next: i32,
end: i32,
}
impl<'a> Iterator for Segments<'a> {
type Item = Segment<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.next >= self.end {
return None;
}
let state_ptr = self.state.ptr?;
let idx = self.next;
self.next += 1;
Some(Segment {
state: state_ptr,
idx,
_marker: core::marker::PhantomData,
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.end - self.next).max(0) as usize;
(remaining, Some(remaining))
}
}
impl DoubleEndedIterator for Segments<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.next >= self.end {
return None;
}
let state_ptr = self.state.ptr?;
self.end -= 1;
Some(Segment {
state: state_ptr,
idx: self.end,
_marker: core::marker::PhantomData,
})
}
}
impl ExactSizeIterator for Segments<'_> {
fn len(&self) -> usize {
(self.end - self.next).max(0) as usize
}
}
impl core::iter::FusedIterator for Segments<'_> {}
pub struct Tokens<'state> {
segment: Segment<'state>,
next: i32,
end: i32,
}
impl Iterator for Tokens<'_> {
type Item = Token;
fn next(&mut self) -> Option<Self::Item> {
if self.next >= self.end {
return None;
}
let tok_idx = self.next;
self.next += 1;
let raw = unsafe {
sys::whisper_full_get_token_data_from_state(
self.segment.state.as_ptr(),
self.segment.idx,
tok_idx,
)
};
Some(Token::from_raw(raw))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.end - self.next).max(0) as usize;
(remaining, Some(remaining))
}
}
impl DoubleEndedIterator for Tokens<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.next >= self.end {
return None;
}
self.end -= 1;
let raw = unsafe {
sys::whisper_full_get_token_data_from_state(
self.segment.state.as_ptr(),
self.segment.idx,
self.end,
)
};
Some(Token::from_raw(raw))
}
}
impl ExactSizeIterator for Tokens<'_> {
fn len(&self) -> usize {
(self.end - self.next).max(0) as usize
}
}
impl core::iter::FusedIterator for Tokens<'_> {}
impl<'a> IntoIterator for &'a State {
type Item = Segment<'a>;
type IntoIter = Segments<'a>;
fn into_iter(self) -> Segments<'a> {
self.segments_iter()
}
}
impl<'a> IntoIterator for Segment<'a> {
type Item = Token;
type IntoIter = Tokens<'a>;
fn into_iter(self) -> Tokens<'a> {
self.tokens_iter()
}
}
impl<'a> IntoIterator for &Segment<'a> {
type Item = Token;
type IntoIter = Tokens<'a>;
fn into_iter(self) -> Tokens<'a> {
self.tokens_iter()
}
}
#[derive(Debug, Clone, Copy)]
pub struct Token {
id: i32,
p: f32,
plog: f32,
pt: f32,
ptsum: f32,
t0: i64,
t1: i64,
t_dtw: i64,
vlen: f32,
}
impl Token {
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn id(&self) -> i32 {
self.id
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn p(&self) -> f32 {
self.p
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn plog(&self) -> f32 {
self.plog
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn pt(&self) -> f32 {
self.pt
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn ptsum(&self) -> f32 {
self.ptsum
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn t0(&self) -> i64 {
self.t0
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn t1(&self) -> i64 {
self.t1
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn t_dtw(&self) -> Option<i64> {
if self.t_dtw < 0 {
None
} else {
Some(self.t_dtw)
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn vlen(&self) -> f32 {
self.vlen
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) const fn from_raw(raw: crate::sys::whisper_token_data) -> Self {
Self {
id: raw.id,
p: raw.p,
plog: raw.plog,
pt: raw.pt,
ptsum: raw.ptsum,
t0: raw.t0,
t1: raw.t1,
t_dtw: raw.t_dtw,
vlen: raw.vlen,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_from_raw_projects_every_field_including_t_dtw() {
let raw = sys::whisper_token_data {
id: 1234,
tid: 5678,
p: 0.8,
plog: -0.22,
pt: 0.05,
ptsum: 0.12,
t0: 100,
t1: 250,
t_dtw: 175,
vlen: 0.42,
};
let tok = Token::from_raw(raw);
assert_eq!(tok.id(), 1234);
assert!((tok.p() - 0.8).abs() < 1e-6);
assert!((tok.plog() - -0.22).abs() < 1e-6);
assert!((tok.pt() - 0.05).abs() < 1e-6);
assert!((tok.ptsum() - 0.12).abs() < 1e-6);
assert_eq!(tok.t0(), 100);
assert_eq!(tok.t1(), 250);
assert_eq!(
tok.t_dtw(),
Some(175),
"Token::from_raw must project the DTW timestamp",
);
assert!((tok.vlen() - 0.42).abs() < 1e-6);
}
#[test]
fn t_dtw_is_independent_of_t0_t1() {
let raw = sys::whisper_token_data {
id: 0,
tid: 0,
p: 0.0,
plog: 0.0,
pt: 0.0,
ptsum: 0.0,
t0: 100,
t1: 200,
t_dtw: 150,
vlen: 0.0,
};
let tok = Token::from_raw(raw);
assert_eq!(tok.t0(), 100);
assert_eq!(tok.t1(), 200);
assert_eq!(tok.t_dtw(), Some(150));
assert_ne!(tok.t_dtw(), Some(tok.t0()));
assert_ne!(tok.t_dtw(), Some(tok.t1()));
}
#[test]
fn t_dtw_sentinel_minus_one_maps_to_none() {
let raw = sys::whisper_token_data {
id: 0,
tid: 0,
p: 0.0,
plog: 0.0,
pt: 0.0,
ptsum: 0.0,
t0: 0,
t1: 0,
t_dtw: -1,
vlen: 0.0,
};
let tok = Token::from_raw(raw);
assert_eq!(
tok.t_dtw(),
None,
"t_dtw == -1 must surface as None (DTW unavailable for token)",
);
}
#[test]
fn t_dtw_zero_maps_to_some_zero() {
let raw = sys::whisper_token_data {
id: 0,
tid: 0,
p: 0.0,
plog: 0.0,
pt: 0.0,
ptsum: 0.0,
t0: 0,
t1: 0,
t_dtw: 0,
vlen: 0.0,
};
let tok = Token::from_raw(raw);
assert_eq!(
tok.t_dtw(),
Some(0),
"t_dtw == 0 is a valid timestamp (token at audio start), not the sentinel",
);
}
struct PoisonedStateFixture {
state: State,
_leaked_ctx: core::mem::ManuallyDrop<Arc<Context>>,
}
impl PoisonedStateFixture {
fn new() -> Self {
let ctx = Arc::new(unsafe { Context::dangling_for_test() });
let state = State {
ptr: None,
ctx: Arc::clone(&ctx),
};
Self {
state,
_leaked_ctx: core::mem::ManuallyDrop::new(ctx),
}
}
}
impl core::ops::Deref for PoisonedStateFixture {
type Target = State;
fn deref(&self) -> &State {
&self.state
}
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn segments_iter_empty_state_yields_zero_items() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
let count = state.segments_iter().count();
assert_eq!(count, 0, "poisoned state must yield zero segments");
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn segments_iter_count_matches_n_segments() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
let expected = state.n_segments();
assert_eq!(expected, 0, "test fixture is poisoned");
let actual = state.segments_iter().count() as i32;
assert_eq!(actual, expected);
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn segments_iter_exact_size_len_matches_count() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
let iter = state.segments_iter();
let len_before = iter.len();
let counted = iter.count();
assert_eq!(len_before, counted);
assert_eq!(len_before, 0);
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn segments_iter_size_hint_is_exact() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
let iter = state.segments_iter();
let (lower, upper) = iter.size_hint();
assert_eq!(lower, 0);
assert_eq!(upper, Some(0));
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn segments_iter_fused_after_exhaustion() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
let mut iter = state.segments_iter();
assert!(iter.next().is_none());
assert!(iter.next().is_none());
assert!(iter.next().is_none());
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn multiple_segments_iter_alive_concurrently() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
let it1 = state.segments_iter();
let it2 = state.segments_iter();
assert_eq!(it1.len(), it2.len());
assert_eq!(it1.count(), 0);
assert_eq!(it2.count(), 0);
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn segments_iter_composes_with_adapters() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
let collected: Vec<_> = state.segments_iter().map(|seg| seg.t0()).collect();
assert!(collected.is_empty());
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn nested_segments_and_tokens_iter_compiles() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
let mut total: i32 = 0;
for seg in state.segments_iter() {
for tok in seg.tokens_iter() {
total = total.wrapping_add(tok.id());
}
}
assert_eq!(total, 0);
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn iterator_type_bounds_are_correct() {
fn assert_iter<I: Iterator>(_: I) {}
fn assert_exact_size<I: ExactSizeIterator>(_: I) {}
fn assert_fused<I: core::iter::FusedIterator>(_: I) {}
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
assert_iter(state.segments_iter());
assert_exact_size(state.segments_iter());
assert_fused(state.segments_iter());
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn tokens_iter_composes_with_flat_map() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
let total: usize = state
.segments_iter()
.flat_map(|seg| seg.tokens_iter())
.count();
assert_eq!(total, 0);
fn assert_token_iter<I: Iterator<Item = Token>>(_: I) {}
let fixture2 = PoisonedStateFixture::new();
let state2 = &*fixture2;
assert_token_iter(state2.segments_iter().flat_map(|seg| seg.tokens_iter()));
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn into_iter_for_state_ref_yields_segments() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
fn assert_segment_iter<'a, I: IntoIterator<Item = Segment<'a>>>(_: I) {}
assert_segment_iter(state);
let count = state.into_iter().count();
assert_eq!(count, 0);
}
#[test]
fn into_iter_for_segment_compiles() {
fn assert_token_iter<I: IntoIterator<Item = Token>>(_: PhantomData<I>) {}
use core::marker::PhantomData;
assert_token_iter::<Segment<'_>>(PhantomData);
assert_token_iter::<&Segment<'_>>(PhantomData);
}
#[cfg_attr(miri, ignore = "intentional Arc leak in PoisonedStateFixture")]
#[test]
fn segments_iter_double_ended_compiles_and_empty_yields_none() {
let fixture = PoisonedStateFixture::new();
let state = &*fixture;
fn assert_dei<I: DoubleEndedIterator>(_: I) {}
assert_dei(state.segments_iter());
let mut iter = state.segments_iter();
assert!(iter.next_back().is_none());
let rev_count = state.segments_iter().rev().count();
assert_eq!(rev_count, 0);
}
#[test]
fn tokens_iter_double_ended_compiles() {
fn assert_dei<I: DoubleEndedIterator>(_: PhantomData<I>) {}
use core::marker::PhantomData;
assert_dei::<Tokens<'_>>(PhantomData);
}
}