use std::path::{Path, PathBuf};
use std::sync::Arc;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
use crate::inference::{TranscriptionRequest, TranscriptionResponse, TranscriptionSegment};
pub struct WhisperProvider {
ctx: Arc<WhisperContext>,
model_path: PathBuf,
}
impl WhisperProvider {
pub fn new(model_path: impl AsRef<Path>) -> anyhow::Result<Self> {
let path = model_path.as_ref().to_path_buf();
let params = WhisperContextParameters::new();
let ctx = WhisperContext::new_with_params(
path.to_str()
.ok_or_else(|| anyhow::anyhow!("invalid model path"))?,
params,
)
.map_err(|e| anyhow::anyhow!("failed to load whisper model: {e:?}"))?;
Ok(Self {
ctx: Arc::new(ctx),
model_path: path,
})
}
pub fn model_path(&self) -> &Path {
&self.model_path
}
pub fn transcribe(
&self,
request: &TranscriptionRequest,
) -> anyhow::Result<TranscriptionResponse> {
let decoded = decode_audio(&request.audio)?;
let samples = &decoded.samples;
let mut state = self
.ctx
.create_state()
.map_err(|e| anyhow::anyhow!("failed to create whisper state: {e:?}"))?;
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
params.set_print_progress(false);
params.set_print_special(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
if let Some(lang) = &request.language {
params.set_language(Some(lang));
} else {
params.set_detect_language(true);
}
if request.word_timestamps {
params.set_token_timestamps(true);
}
params.set_n_threads(
std::thread::available_parallelism()
.map(|n| n.get() as i32)
.unwrap_or(4),
);
state
.full(params, samples)
.map_err(|e| anyhow::anyhow!("whisper transcription failed: {e:?}"))?;
let mut text_parts = Vec::new();
let mut segments = Vec::new();
for segment in state.as_iter() {
let seg_text = segment.to_string();
text_parts.push(seg_text.clone());
segments.push(TranscriptionSegment {
text: seg_text,
start_secs: segment.start_timestamp() as f64 / 100.0,
end_secs: segment.end_timestamp() as f64 / 100.0,
});
}
let lang_id = state.full_lang_id_from_state();
let language = whisper_rs::get_lang_str(lang_id)
.unwrap_or("unknown")
.to_string();
let duration_secs = samples.len() as f64 / decoded.sample_rate.max(1) as f64;
Ok(TranscriptionResponse {
text: text_parts.join("").trim().to_string(),
language,
duration_secs,
segments,
})
}
pub async fn transcribe_async(
&self,
request: TranscriptionRequest,
) -> anyhow::Result<TranscriptionResponse> {
let ctx = self.ctx.clone();
let provider = WhisperProvider {
ctx,
model_path: self.model_path.clone(),
};
tokio::task::spawn_blocking(move || provider.transcribe(&request)).await?
}
}
#[derive(Debug)]
struct DecodedAudio {
samples: Vec<f32>,
sample_rate: u32,
}
fn decode_audio(data: &[u8]) -> anyhow::Result<DecodedAudio> {
if data.len() < 44 {
return Err(anyhow::anyhow!("audio data too short for WAV header"));
}
if &data[0..4] != b"RIFF" || &data[8..12] != b"WAVE" {
return Err(anyhow::anyhow!("not a WAV file"));
}
let mut sample_rate: u32 = 16000;
let mut pos = 12;
while pos + 8 < data.len() {
let chunk_id = &data[pos..pos + 4];
let chunk_size =
u32::from_le_bytes([data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]])
as usize;
if chunk_id == b"fmt " && chunk_size >= 16 && pos + 8 + 16 <= data.len() {
let fmt_data = &data[pos + 8..pos + 8 + chunk_size.min(data.len() - pos - 8)];
sample_rate = u32::from_le_bytes([fmt_data[4], fmt_data[5], fmt_data[6], fmt_data[7]]);
}
if chunk_id == b"data" {
let audio_data = &data[pos + 8..pos + 8 + chunk_size.min(data.len() - pos - 8)];
if !audio_data.len().is_multiple_of(2) {
return Err(anyhow::anyhow!(
"audio data has odd byte count ({}), expected 16-bit PCM samples",
audio_data.len()
));
}
let mut samples = Vec::with_capacity(audio_data.len() / 2);
for chunk in audio_data.chunks_exact(2) {
let sample = i16::from_le_bytes([chunk[0], chunk[1]]);
samples.push(sample as f32 / 32768.0);
}
return Ok(DecodedAudio {
samples,
sample_rate,
});
}
let Some(next_pos) = pos.checked_add(8).and_then(|p| p.checked_add(chunk_size)) else {
return Err(anyhow::anyhow!("malformed WAV: chunk size overflow"));
};
pos = next_pos;
if !chunk_size.is_multiple_of(2) {
pos += 1;
}
}
Err(anyhow::anyhow!("no data chunk found in WAV file"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode_audio_rejects_short_data() {
let result = decode_audio(&[0; 10]);
assert!(result.is_err());
}
#[test]
fn decode_audio_rejects_non_wav() {
let result = decode_audio(&[0; 100]);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not a WAV"));
}
#[test]
fn decode_audio_valid_wav() {
let mut wav = Vec::new();
wav.extend_from_slice(b"RIFF");
wav.extend_from_slice(&40u32.to_le_bytes()); wav.extend_from_slice(b"WAVE");
wav.extend_from_slice(b"fmt ");
wav.extend_from_slice(&16u32.to_le_bytes()); wav.extend_from_slice(&1u16.to_le_bytes()); wav.extend_from_slice(&1u16.to_le_bytes()); wav.extend_from_slice(&16000u32.to_le_bytes()); wav.extend_from_slice(&32000u32.to_le_bytes()); wav.extend_from_slice(&2u16.to_le_bytes()); wav.extend_from_slice(&16u16.to_le_bytes()); wav.extend_from_slice(b"data");
wav.extend_from_slice(&4u32.to_le_bytes()); wav.extend_from_slice(&0i16.to_le_bytes()); wav.extend_from_slice(&16384i16.to_le_bytes());
let decoded = decode_audio(&wav).unwrap();
assert_eq!(decoded.samples.len(), 2);
assert!((decoded.samples[0] - 0.0).abs() < f32::EPSILON);
assert!((decoded.samples[1] - 0.5).abs() < 0.001);
assert_eq!(decoded.sample_rate, 16000);
}
}