#![allow(unsafe_code)]
use core::{ptr::NonNull, str};
use std::sync::Arc;
use crate::{
context::Context,
error::{WhisperError, WhisperResult},
lang::Lang,
params::Params,
sys,
};
pub struct State {
ptr: Option<NonNull<sys::whisper_state>>,
ctx: Arc<Context>,
}
unsafe impl Send for State {}
impl State {
pub(crate) fn from_raw(ptr: NonNull<sys::whisper_state>, ctx: Arc<Context>) -> Self {
Self {
ptr: Some(ptr),
ctx,
}
}
pub fn context(&self) -> &Arc<Context> {
&self.ctx
}
pub fn is_poisoned(&self) -> bool {
self.ptr.is_none()
}
#[inline]
fn raw(&self) -> Option<*mut sys::whisper_state> {
self.ptr.map(NonNull::as_ptr)
}
pub fn full(&mut self, params: &Params, samples: &[f32]) -> WhisperResult<()> {
const MIN_SAMPLES_FOR_REFLECTIVE_PAD: usize = 201;
if samples.len() < MIN_SAMPLES_FOR_REFLECTIVE_PAD {
return Err(WhisperError::SamplesTooShort {
samples: samples.len(),
min_required: MIN_SAMPLES_FOR_REFLECTIVE_PAD,
});
}
let len = i32::try_from(samples.len()).map_err(|_| WhisperError::SamplesOverflow {
samples: samples.len(),
})?;
if let Some(prompt) = params.prompt_tokens() {
let vocab = self.ctx.n_vocab();
for &tok in prompt {
if tok < 0 || tok >= vocab {
return Err(WhisperError::TokenOutOfRange {
token: tok,
vocab_size: vocab,
});
}
}
}
let state_ptr = match self.ptr {
Some(p) => p.as_ptr(),
None => return Err(WhisperError::StateLost { code: -7 }),
};
let _full_guard = self.ctx.full_lock();
if self.ctx.is_poisoned() {
return Err(WhisperError::ContextPoisoned);
}
let rc = unsafe {
sys::whispercpp_full_with_state(
self.ctx.as_raw(),
state_ptr,
params.as_raw(),
samples.as_ptr(),
len,
)
};
if rc == 0 {
return Ok(());
}
if rc == -7 {
self.ptr = None;
self.ctx.mark_lost();
return Err(WhisperError::StateLost { code: rc });
}
if rc <= sys::WHISPERCPP_ERR_BAD_ALLOC {
self.ctx.mark_lost();
if let Some(p) = self.ptr.take() {
unsafe { sys::whisper_free_state(p.as_ptr()) };
}
return Err(WhisperError::StateLost { code: rc });
}
Err(WhisperError::Full { code: rc })
}
pub fn n_segments(&self) -> i32 {
let Some(state) = self.raw() else { return 0 };
unsafe { sys::whisper_full_n_segments_from_state(state) }
}
pub fn segment(&self, idx: i32) -> Option<Segment<'_>> {
let state_ptr = self.ptr?;
if idx < 0 || idx >= self.n_segments() {
return None;
}
Some(Segment {
state: state_ptr,
idx,
_marker: core::marker::PhantomData,
})
}
pub fn detected_lang(&self) -> Option<Lang> {
let state = self.raw()?;
let id = unsafe { sys::whisper_full_lang_id_from_state(state) };
if id < 0 {
return None;
}
let raw = unsafe { sys::whisper_lang_str(id) };
if raw.is_null() {
return None;
}
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
let code = str::from_utf8(bytes).ok()?;
Some(Lang::from_iso639_1(code))
}
}
impl Drop for State {
fn drop(&mut self) {
if let Some(p) = self.ptr.take() {
unsafe { sys::whisper_free_state(p.as_ptr()) }
}
}
}
#[derive(Clone, Copy)]
pub struct Segment<'a> {
state: NonNull<sys::whisper_state>,
idx: i32,
_marker: core::marker::PhantomData<&'a ()>,
}
impl<'a> Segment<'a> {
pub fn t0(&self) -> i64 {
unsafe { sys::whisper_full_get_segment_t0_from_state(self.state.as_ptr(), self.idx) }
}
pub fn t1(&self) -> i64 {
unsafe { sys::whisper_full_get_segment_t1_from_state(self.state.as_ptr(), self.idx) }
}
pub fn text(&self) -> WhisperResult<&'a str> {
let raw =
unsafe { sys::whisper_full_get_segment_text_from_state(self.state.as_ptr(), self.idx) };
if raw.is_null() {
return Ok("");
}
let bytes = unsafe { core::ffi::CStr::from_ptr(raw).to_bytes() };
str::from_utf8(bytes).map_err(WhisperError::from)
}
pub fn no_speech_prob(&self) -> f32 {
unsafe {
sys::whisper_full_get_segment_no_speech_prob_from_state(self.state.as_ptr(), self.idx)
}
}
pub fn n_tokens(&self) -> i32 {
unsafe { sys::whisper_full_n_tokens_from_state(self.state.as_ptr(), self.idx) }
}
pub fn token(&self, tok_idx: i32) -> Option<Token> {
if tok_idx < 0 || tok_idx >= self.n_tokens() {
return None;
}
let raw = unsafe {
sys::whisper_full_get_token_data_from_state(self.state.as_ptr(), self.idx, tok_idx)
};
Some(Token::from_raw(raw))
}
pub fn speaker_turn_next(&self) -> bool {
unsafe {
sys::whisper_full_get_segment_speaker_turn_next_from_state(self.state.as_ptr(), self.idx)
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Token {
id: i32,
p: f32,
plog: f32,
pt: f32,
ptsum: f32,
t0: i64,
t1: i64,
vlen: f32,
}
impl Token {
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn id(&self) -> i32 {
self.id
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn p(&self) -> f32 {
self.p
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn plog(&self) -> f32 {
self.plog
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn pt(&self) -> f32 {
self.pt
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn ptsum(&self) -> f32 {
self.ptsum
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn t0(&self) -> i64 {
self.t0
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn t1(&self) -> i64 {
self.t1
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn vlen(&self) -> f32 {
self.vlen
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) const fn from_raw(raw: crate::sys::whisper_token_data) -> Self {
Self {
id: raw.id,
p: raw.p,
plog: raw.plog,
pt: raw.pt,
ptsum: raw.ptsum,
t0: raw.t0,
t1: raw.t1,
vlen: raw.vlen,
}
}
}