use crate::audio_extraction::AudioFile;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use std::path::{Path, PathBuf};
use thiserror::Error;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
#[derive(Debug, Error)]
pub enum SpeechToTextError {
#[error("Failed to load Whisper model from {path}: {message}")]
ModelLoadFailed { path: PathBuf, message: String },
#[error("Failed to read audio file {path}: {message}")]
AudioReadFailed { path: PathBuf, message: String },
#[error("Invalid audio format: {0}")]
InvalidAudioFormat(String),
#[error("Transcription failed: {0}")]
TranscriptionFailed(String),
#[error("Failed to detect language: invalid language ID {0}")]
LanguageDetectionFailed(i32),
#[error("Whisper model not initialized")]
ModelNotInitialized,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Transcript {
pub text: String,
pub language: String,
}
pub(crate) fn audio_to_text(
audio: &AudioFile,
model_path: &Path,
) -> Result<Transcript, SpeechToTextError> {
whisper_rs::install_logging_hooks();
let mut params = WhisperContextParameters::default();
params.use_gpu(true);
let ctx = WhisperContext::new_with_params(
model_path
.to_str()
.ok_or_else(|| SpeechToTextError::ModelLoadFailed {
path: model_path.to_path_buf(),
message: "Invalid UTF-8 in model path".to_string(),
})?,
params,
)
.map_err(|e| SpeechToTextError::ModelLoadFailed {
path: model_path.to_path_buf(),
message: e.to_string(),
})?;
let reader =
hound::WavReader::open(audio.deref()).map_err(|e| SpeechToTextError::AudioReadFailed {
path: audio.deref().to_path_buf(),
message: e.to_string(),
})?;
let spec = reader.spec();
if spec.sample_rate != 16000 {
return Err(SpeechToTextError::InvalidAudioFormat(format!(
"Expected 16kHz sample rate, got {} Hz",
spec.sample_rate
)));
}
if spec.channels != 1 {
return Err(SpeechToTextError::InvalidAudioFormat(format!(
"Expected mono audio (1 channel), got {} channels",
spec.channels
)));
}
let samples: Vec<i16> = reader
.into_samples::<i16>()
.collect::<Result<Vec<i16>, _>>()
.map_err(|e| SpeechToTextError::AudioReadFailed {
path: audio.deref().to_path_buf(),
message: e.to_string(),
})?;
let mut audio_data = vec![0.0f32; samples.len()];
whisper_rs::convert_integer_to_float_audio(&samples, &mut audio_data)
.map_err(|e| SpeechToTextError::InvalidAudioFormat(e.to_string()))?;
drop(samples);
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
let mut state = ctx.create_state().map_err(|e| {
SpeechToTextError::TranscriptionFailed(format!("Failed to create state: {}", e))
})?;
state
.full(params, &audio_data[..])
.map_err(|e| SpeechToTextError::TranscriptionFailed(e.to_string()))?;
drop(audio_data);
let lang_id = state.full_lang_id_from_state();
let language = whisper_rs::get_lang_str(lang_id)
.ok_or(SpeechToTextError::LanguageDetectionFailed(lang_id))?
.to_string();
let mut text = String::new();
for segment in state.as_iter() {
text.push_str(&format!("{}", segment));
}
Ok(Transcript {
text: text.trim().to_string(),
language,
})
}