use std::path::PathBuf;
use std::sync::Mutex;
use async_trait::async_trait;
use bytes::Bytes;
use tracing::{debug, info};
use whisper_rs::{
FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, WhisperState,
};
use super::model::{WhisperModel, model_path};
#[cfg(target_os = "macos")]
use super::model::{coreml_encoder_exists, coreml_encoder_path, ensure_coreml_encoder};
use super::{Result, TranscribeError, Transcriber};
#[derive(Debug, Clone)]
pub struct LocalWhisperConfig {
pub model: WhisperModel,
pub model_path: Option<PathBuf>,
pub coreml: bool,
}
impl LocalWhisperConfig {
pub fn new(model: WhisperModel) -> Self {
Self {
model,
model_path: None,
coreml: cfg!(target_os = "macos"), }
}
pub fn with_model_path(mut self, path: PathBuf) -> Self {
self.model_path = Some(path);
self
}
pub fn with_coreml(mut self, enabled: bool) -> Self {
self.coreml = enabled;
self
}
}
struct WhisperInstance {
_context: WhisperContext,
state: WhisperState,
}
impl WhisperInstance {
fn new(context: WhisperContext) -> Result<Self> {
let state = context.create_state().map_err(|e| {
TranscribeError::TranscriptionFailed(format!("Failed to create state: {}", e))
})?;
Ok(Self {
_context: context,
state,
})
}
}
pub struct LocalWhisperClient {
config: LocalWhisperConfig,
instance: Mutex<Option<WhisperInstance>>,
}
impl LocalWhisperClient {
pub fn new(config: LocalWhisperConfig) -> Self {
Self {
config,
instance: Mutex::new(None),
}
}
#[cfg(target_os = "macos")]
async fn ensure_coreml_setup(&self) -> Result<()> {
if !self.config.coreml {
return Ok(());
}
if coreml_encoder_exists(self.config.model)
.map_err(|e| TranscribeError::TranscriptionFailed(e.to_string()))?
{
info!(
model = ?self.config.model,
path = ?coreml_encoder_path(self.config.model).ok(),
"CoreML encoder available"
);
return Ok(());
}
info!(
model = ?self.config.model,
"Downloading CoreML encoder for faster transcription..."
);
ensure_coreml_encoder(self.config.model, |downloaded, total| {
let percent = (downloaded as f64 / total as f64 * 100.0) as u32;
if percent % 10 == 0 {
debug!("CoreML encoder download: {}%", percent);
}
})
.await
.map_err(|e| {
TranscribeError::TranscriptionFailed(format!(
"Failed to download CoreML encoder: {}",
e
))
})?;
Ok(())
}
#[cfg(not(target_os = "macos"))]
async fn ensure_coreml_setup(&self) -> Result<()> {
Ok(())
}
fn ensure_instance(&self) -> Result<std::sync::MutexGuard<'_, Option<WhisperInstance>>> {
let mut guard = self.instance.lock().map_err(|e| {
TranscribeError::TranscriptionFailed(format!("Failed to lock instance: {}", e))
})?;
if guard.is_none() {
let path = match &self.config.model_path {
Some(p) => p.clone(),
None => model_path(self.config.model)
.map_err(|e| TranscribeError::TranscriptionFailed(e.to_string()))?,
};
info!(path = ?path, "Loading Whisper model");
let ctx = WhisperContext::new_with_params(
path.to_str().ok_or_else(|| {
TranscribeError::TranscriptionFailed("Invalid model path".to_string())
})?,
WhisperContextParameters::default(),
)
.map_err(|e| {
TranscribeError::TranscriptionFailed(format!("Failed to load model: {}", e))
})?;
let instance = WhisperInstance::new(ctx)?;
info!("Whisper model loaded successfully");
*guard = Some(instance);
}
Ok(guard)
}
fn convert_audio(&self, audio: &[u8]) -> Result<Vec<f32>> {
use std::io::Cursor;
let cursor = Cursor::new(audio);
let reader = hound::WavReader::new(cursor).map_err(|e| {
TranscribeError::InvalidAudioFormat(format!("Failed to read WAV: {}", e))
})?;
let spec = reader.spec();
let sample_rate = spec.sample_rate;
let channels = spec.channels as usize;
debug!(
sample_rate = sample_rate,
channels = channels,
bits_per_sample = spec.bits_per_sample,
"Converting audio"
);
let samples: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Float => reader
.into_samples::<f32>()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
TranscribeError::InvalidAudioFormat(format!("Failed to read samples: {}", e))
})?,
hound::SampleFormat::Int => {
let bits = spec.bits_per_sample;
let max_val = (1u32 << (bits - 1)) as f32;
reader
.into_samples::<i32>()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
TranscribeError::InvalidAudioFormat(format!(
"Failed to read samples: {}",
e
))
})?
.into_iter()
.map(|s| s as f32 / max_val)
.collect()
}
};
let original_sample_count = samples.len();
let mono_samples: Vec<f32> = if channels > 1 {
samples
.chunks(channels)
.map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
.collect()
} else {
samples
};
let target_rate = 16000;
let resampled = if sample_rate != target_rate {
resample(&mono_samples, sample_rate, target_rate)
} else {
mono_samples
};
debug!(
original_samples = original_sample_count,
resampled_samples = resampled.len(),
"Audio conversion complete"
);
Ok(resampled)
}
}
fn resample(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
if from_rate == to_rate {
return samples.to_vec();
}
let ratio = from_rate as f64 / to_rate as f64;
let new_len = (samples.len() as f64 / ratio) as usize;
let mut result = Vec::with_capacity(new_len);
for i in 0..new_len {
let src_idx = i as f64 * ratio;
let src_idx_floor = src_idx.floor() as usize;
let frac = src_idx - src_idx_floor as f64;
let sample = if src_idx_floor + 1 < samples.len() {
let s0 = samples[src_idx_floor] as f64;
let s1 = samples[src_idx_floor + 1] as f64;
(s0 * (1.0 - frac) + s1 * frac) as f32
} else if src_idx_floor < samples.len() {
samples[src_idx_floor]
} else {
0.0
};
result.push(sample);
}
result
}
#[async_trait]
impl Transcriber for LocalWhisperClient {
async fn transcribe(&self, audio: Bytes, language: Option<&str>) -> Result<String> {
self.ensure_coreml_setup().await?;
let samples = self.convert_audio(&audio)?;
let mut guard = self.ensure_instance()?;
let instance = guard.as_mut().expect("instance should be initialized");
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
if let Some(lang) = language {
params.set_language(Some(lang));
} else {
params.set_language(None);
}
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
instance.state.full(params, &samples).map_err(|e| {
TranscribeError::TranscriptionFailed(format!("Transcription failed: {}", e))
})?;
let num_segments = instance.state.full_n_segments();
let mut result = String::new();
for i in 0..num_segments {
if let Some(segment) = instance.state.get_segment(i)
&& let Ok(text) = segment.to_str_lossy()
{
result.push_str(&text);
}
}
Ok(result.trim().to_string())
}
fn name(&self) -> &str {
"local-whisper"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resample() {
let samples: Vec<f32> = (0..48000).map(|i| (i as f32 / 48000.0).sin()).collect();
let resampled = resample(&samples, 48000, 16000);
assert_eq!(resampled.len(), 16000);
}
#[test]
fn test_config_new() {
let config = LocalWhisperConfig::new(WhisperModel::BaseQ8_0);
assert_eq!(config.model, WhisperModel::BaseQ8_0);
assert!(config.model_path.is_none());
}
#[test]
#[ignore]
fn test_transcribe_wav() {
use std::io::Cursor;
use crate::transcribe::model::ensure_model;
let project_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.parent()
.unwrap();
let test_file = project_root.join("test.wav");
if !test_file.exists() {
eprintln!("Skipping test: {} not found", test_file.display());
eprintln!("Place a test.wav file in the project root to run this test");
return;
}
let audio_data = std::fs::read(&test_file).expect("Failed to read test.wav");
eprintln!(
"Read {} bytes from {}",
audio_data.len(),
test_file.display()
);
let processed_file = project_root.join("test-processed.wav");
{
let cursor = Cursor::new(&audio_data);
let reader = hound::WavReader::new(cursor).expect("Failed to read WAV");
let spec = reader.spec();
let sample_rate = spec.sample_rate;
let channels = spec.channels as usize;
eprintln!(
"Input: {}Hz, {} channels, {} bits",
sample_rate, channels, spec.bits_per_sample
);
let samples: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Float => reader
.into_samples::<f32>()
.collect::<std::result::Result<Vec<_>, _>>()
.expect("Failed to read float samples"),
hound::SampleFormat::Int => {
let bits = spec.bits_per_sample;
let max_val = (1u32 << (bits - 1)) as f32;
reader
.into_samples::<i32>()
.collect::<std::result::Result<Vec<_>, _>>()
.expect("Failed to read int samples")
.into_iter()
.map(|s| s as f32 / max_val)
.collect()
}
};
let mono_samples: Vec<f32> = if channels > 1 {
samples
.chunks(channels)
.map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
.collect()
} else {
samples
};
let target_rate = 16000u32;
let resampled = if sample_rate != target_rate {
resample(&mono_samples, sample_rate, target_rate)
} else {
mono_samples
};
eprintln!(
"Output: {}Hz, 1 channel, {} samples ({:.2}s)",
target_rate,
resampled.len(),
resampled.len() as f64 / target_rate as f64
);
let out_spec = hound::WavSpec {
channels: 1,
sample_rate: target_rate,
bits_per_sample: 32,
sample_format: hound::SampleFormat::Float,
};
let mut writer =
hound::WavWriter::create(&processed_file, out_spec).expect("Failed to create WAV");
for sample in &resampled {
writer
.write_sample(*sample)
.expect("Failed to write sample");
}
writer.finalize().expect("Failed to finalize WAV");
eprintln!("Wrote processed audio to {}", processed_file.display());
}
let model = WhisperModel::TinyQ8_0;
eprintln!("Ensuring model {:?} is available...", model);
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
ensure_model(model, |downloaded, total| {
let percent = (downloaded as f64 / total as f64 * 100.0) as u32;
if percent.is_multiple_of(25) {
eprintln!("Downloading model: {}%", percent);
}
})
.await
.expect("Failed to download model");
});
let config = LocalWhisperConfig::new(model);
let client = LocalWhisperClient::new(config);
let result = rt.block_on(async { client.transcribe(audio_data.into(), None).await });
match result {
Ok(text) => {
eprintln!("Transcription successful!");
eprintln!("---");
eprintln!("{}", text);
eprintln!("---");
assert!(!text.is_empty(), "Transcription should not be empty");
}
Err(e) => {
panic!("Transcription failed: {:?}", e);
}
}
}
}