use anyhow::{Context, Result};
use async_trait::async_trait;
use super::{TranscriptionBackend, TranscriptionRequest, TranscriptionResult};
use crate::model_manager;
#[derive(Debug, Default, Clone)]
pub struct LocalWhisperProvider;
#[async_trait]
impl TranscriptionBackend for LocalWhisperProvider {
fn name(&self) -> &'static str {
"local-whisper"
}
fn display_name(&self) -> &'static str {
"Local Whisper"
}
fn transcribe_sync(
&self,
model_path: &str, request: TranscriptionRequest,
) -> Result<TranscriptionResult> {
transcribe_local(model_path, request)
}
async fn transcribe_async(
&self,
_client: &reqwest::Client, model_path: &str,
request: TranscriptionRequest,
) -> Result<TranscriptionResult> {
let model_path = model_path.to_string();
tokio::task::spawn_blocking(move || transcribe_local(&model_path, request))
.await
.context("Task join failed")?
}
}
fn transcribe_local(
model_path: &str,
request: TranscriptionRequest,
) -> Result<TranscriptionResult> {
use super::TranscriptionStage;
request.report(TranscriptionStage::Transcribing);
let pcm_samples = decode_and_resample(&request.audio_data)?;
transcribe_samples(model_path, &pcm_samples, request.language.as_deref())
}
pub fn transcribe_raw(
model_path: &str,
samples: &[f32],
language: Option<&str>,
) -> Result<TranscriptionResult> {
transcribe_samples(model_path, samples, language)
}
fn transcribe_samples(
model_path: &str,
samples: &[f32],
language: Option<&str>,
) -> Result<TranscriptionResult> {
use whisper_rs::{FullParams, SamplingStrategy};
let mut model_guard = model_manager::get_model(model_path)?;
let state = model_guard.state_mut();
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
if let Some(lang) = language {
params.set_language(Some(lang));
}
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
state
.full(params, samples)
.context("Transcription failed")?;
let num_segments = state.full_n_segments();
let mut text = String::new();
for i in 0..num_segments {
if let Some(segment) = state.get_segment(i)
&& let Ok(segment_text) = segment.to_str()
{
text.push_str(segment_text);
}
}
drop(model_guard);
model_manager::maybe_unload();
Ok(TranscriptionResult {
text: text.trim().to_string(),
})
}
fn decode_and_resample(mp3_data: &[u8]) -> Result<Vec<f32>> {
use minimp3::{Decoder, Frame};
let mut decoder = Decoder::new(mp3_data);
let mut samples = Vec::new();
let mut sample_rate = 0u32;
let mut channels = 0u16;
loop {
match decoder.next_frame() {
Ok(Frame {
data,
sample_rate: sr,
channels: ch,
..
}) => {
sample_rate = sr as u32;
channels = ch as u16;
samples.extend(data.iter().map(|&s| s as f32 / i16::MAX as f32));
}
Err(minimp3::Error::Eof) => break,
Err(e) => anyhow::bail!("MP3 decode error: {:?}", e),
}
}
if samples.is_empty() {
anyhow::bail!("No audio data decoded from MP3");
}
crate::resample::resample_to_16k(&samples, sample_rate, channels)
}