use crate::config::Config;
use crate::error::TranscribeError;
use std::path::Path;
#[cfg(feature = "whisper")]
use std::path::PathBuf;
pub fn transcribe(audio_path: &Path, config: &Config) -> Result<String, TranscribeError> {
let samples = load_audio_samples(audio_path)?;
if samples.is_empty() {
return Err(TranscribeError::EmptyAudio);
}
#[cfg(feature = "whisper")]
{
transcribe_with_whisper(&samples, audio_path, config)
}
#[cfg(not(feature = "whisper"))]
{
let _ = config; let duration_secs = samples.len() as f64 / 16000.0;
Ok(format!(
"[Transcription placeholder — whisper feature not enabled]\n\
Audio file: {}\n\
Duration: {:.1}s ({} samples at 16kHz)\n\
\n\
Build with `cargo build --features whisper` and download a model\n\
via `minutes setup` to enable real transcription.",
audio_path.display(),
duration_secs,
samples.len(),
))
}
}
#[cfg(feature = "whisper")]
fn transcribe_with_whisper(
samples: &[f32],
_audio_path: &Path,
config: &Config,
) -> Result<String, TranscribeError> {
#[cfg(not(feature = "whisper"))]
let _ = config;
let model_path = resolve_model_path(config)?;
tracing::info!(model = %model_path.display(), "loading whisper model");
let ctx = whisper_rs::WhisperContext::new_with_params(
model_path
.to_str()
.ok_or_else(|| TranscribeError::ModelLoadError("invalid model path encoding".into()))?,
whisper_rs::WhisperContextParameters::default(),
)
.map_err(|e| TranscribeError::ModelLoadError(format!("{}", e)))?;
tracing::info!(
samples = samples.len(),
duration_secs = samples.len() as f64 / 16000.0,
"starting whisper transcription"
);
let mut state = ctx
.create_state()
.map_err(|e| TranscribeError::TranscriptionFailed(format!("create state: {}", e)))?;
let mut params =
whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
params.set_n_threads(num_cpus());
params.set_language(config.transcription.language.as_deref());
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
params.set_token_timestamps(true);
state
.full(params, samples)
.map_err(|e| TranscribeError::TranscriptionFailed(format!("{}", e)))?;
let num_segments = state.full_n_segments();
let mut transcript = String::new();
for i in 0..num_segments {
let segment = match state.get_segment(i) {
Some(seg) => seg,
None => continue,
};
let start_ts = segment.start_timestamp();
let text = segment
.to_str_lossy()
.map_err(|e| TranscribeError::TranscriptionFailed(format!("get text: {}", e)))?;
let text = text.trim();
if text.is_empty() {
continue;
}
let mins = start_ts / 6000;
let secs = (start_ts % 6000) / 100;
transcript.push_str(&format!("[{}:{:02}] {}\n", mins, secs, text));
}
let word_count = transcript.split_whitespace().count();
tracing::info!(
segments = num_segments,
words = word_count,
"transcription complete"
);
Ok(transcript)
}
fn load_audio_samples(path: &Path) -> Result<Vec<f32>, TranscribeError> {
let ext = path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_lowercase();
match ext.as_str() {
"wav" => load_wav(path),
"m4a" | "mp3" | "ogg" | "webm" | "mp4" | "aac" => decode_with_symphonia(path),
other => Err(TranscribeError::UnsupportedFormat(other.to_string())),
}
}
fn load_wav(path: &Path) -> Result<Vec<f32>, TranscribeError> {
let reader = hound::WavReader::open(path).map_err(|e| {
if e.to_string().contains("Not a WAVE file") || e.to_string().contains("unexpected EOF") {
TranscribeError::UnsupportedFormat("corrupt or invalid WAV file".into())
} else {
TranscribeError::Io(std::io::Error::other(e.to_string()))
}
})?;
let spec = reader.spec();
let sample_rate = spec.sample_rate;
let channels = spec.channels as usize;
let bits = spec.bits_per_sample;
let max_val = (1_i64 << (bits - 1)) as f32; let raw_samples: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Int => reader
.into_samples::<i32>()
.filter_map(|s| s.ok())
.map(|s| s as f32 / max_val)
.collect(),
hound::SampleFormat::Float => reader
.into_samples::<f32>()
.filter_map(|s| s.ok())
.collect(),
};
if raw_samples.is_empty() {
return Err(TranscribeError::EmptyAudio);
}
let mono = if channels > 1 {
raw_samples
.chunks(channels)
.map(|frame| frame.iter().sum::<f32>() / channels as f32)
.collect()
} else {
raw_samples
};
let resampled = if sample_rate != 16000 {
resample(&mono, sample_rate, 16000)
} else {
mono
};
Ok(normalize_audio(resampled))
}
fn decode_with_symphonia(path: &Path) -> Result<Vec<f32>, TranscribeError> {
use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::DecoderOptions;
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
let file = std::fs::File::open(path)?;
let mss = MediaSourceStream::new(Box::new(file), Default::default());
let mut hint = Hint::new();
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
hint.with_extension(ext);
}
let format_opts = FormatOptions::default();
let metadata_opts = MetadataOptions::default();
let probed = symphonia::default::get_probe()
.format(&hint, mss, &format_opts, &metadata_opts)
.map_err(|e| TranscribeError::UnsupportedFormat(format!("probe failed: {}", e)))?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.ok_or_else(|| TranscribeError::UnsupportedFormat("no audio track found".into()))?;
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(1);
let decoder_opts = DecoderOptions::default();
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &decoder_opts)
.map_err(|e| TranscribeError::UnsupportedFormat(format!("decoder: {}", e)))?;
let mut all_samples: Vec<f32> = Vec::new();
loop {
let packet = match format.next_packet() {
Ok(packet) => packet,
Err(symphonia::core::errors::Error::IoError(ref e))
if e.kind() == std::io::ErrorKind::UnexpectedEof =>
{
break; }
Err(_) => break,
};
if packet.track_id() != track_id {
continue;
}
let decoded = match decoder.decode(&packet) {
Ok(decoded) => decoded,
Err(_) => continue, };
let spec = *decoded.spec();
let duration = decoded.capacity();
let mut sample_buf = SampleBuffer::<f32>::new(duration as u64, spec);
sample_buf.copy_interleaved_ref(decoded);
let samples = sample_buf.samples();
if channels > 1 {
for chunk in samples.chunks(channels) {
let mono_sample = chunk.iter().sum::<f32>() / channels as f32;
all_samples.push(mono_sample);
}
} else {
all_samples.extend_from_slice(samples);
}
}
if all_samples.is_empty() {
return Err(TranscribeError::EmptyAudio);
}
let resampled = if sample_rate != 16000 {
resample(&all_samples, sample_rate, 16000)
} else {
all_samples
};
Ok(normalize_audio(resampled))
}
fn resample(samples: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
if from_rate == to_rate {
return samples.to_vec();
}
let ratio = from_rate as f64 / to_rate as f64;
let output_len = (samples.len() as f64 / ratio) as usize;
let mut output = Vec::with_capacity(output_len);
for i in 0..output_len {
let src_pos = i as f64 * ratio;
let src_idx = src_pos as usize;
let frac = src_pos - src_idx as f64;
let sample = if src_idx + 1 < samples.len() {
samples[src_idx] as f64 * (1.0 - frac) + samples[src_idx + 1] as f64 * frac
} else {
samples[src_idx] as f64
};
output.push(sample as f32);
}
output
}
fn normalize_audio(mut samples: Vec<f32>) -> Vec<f32> {
if samples.is_empty() {
return samples;
}
let peak = samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
const TARGET_PEAK: f32 = 0.5;
const QUIET_THRESHOLD: f32 = 0.1;
const NOISE_FLOOR: f32 = 0.0001;
if peak < QUIET_THRESHOLD && peak > NOISE_FLOOR {
let gain = TARGET_PEAK / peak;
let gain = gain.min(100.0);
tracing::info!(
peak = format!("{:.4}", peak),
gain = format!("{:.1}x", gain),
"auto-normalizing quiet audio"
);
for s in &mut samples {
*s = (*s * gain).clamp(-1.0, 1.0);
}
}
samples
}
#[cfg(feature = "whisper")]
pub fn resolve_model_path_for_dictation(config: &Config) -> Result<PathBuf, TranscribeError> {
let model_name = &config.dictation.model;
let model_dir = &config.transcription.model_path;
let candidates = [
model_dir.join(format!("ggml-{}.bin", model_name)),
model_dir.join(format!("whisper-{}.bin", model_name)),
model_dir.join(format!("{}.bin", model_name)),
];
for candidate in &candidates {
if candidate.exists() {
return Ok(candidate.clone());
}
}
let direct = PathBuf::from(model_name);
if direct.exists() {
return Ok(direct);
}
Err(TranscribeError::ModelNotFound(format!(
"Expected model file \"ggml-{}.bin\" in {}",
model_name,
model_dir.display(),
)))
}
#[cfg(feature = "whisper")]
fn resolve_model_path(config: &Config) -> Result<PathBuf, TranscribeError> {
let model_name = &config.transcription.model;
let model_dir = &config.transcription.model_path;
let candidates = [
model_dir.join(format!("ggml-{}.bin", model_name)),
model_dir.join(format!("whisper-{}.bin", model_name)),
model_dir.join(format!("{}.bin", model_name)),
];
for candidate in &candidates {
if candidate.exists() {
return Ok(candidate.clone());
}
}
let direct = PathBuf::from(model_name);
if direct.exists() {
return Ok(direct);
}
Err(TranscribeError::ModelNotFound(format!(
"Expected model file \"ggml-{}.bin\" in {}",
model_name,
model_dir.display(),
)))
}
#[cfg(feature = "whisper")]
fn num_cpus() -> i32 {
std::thread::available_parallelism()
.map(|p| p.get() as i32)
.unwrap_or(4)
.min(8) }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resample_preserves_length_proportionally() {
let samples: Vec<f32> = (0..44100).map(|i| (i as f32 / 44100.0).sin()).collect();
let resampled = resample(&samples, 44100, 16000);
let expected = 16000;
assert!(
(resampled.len() as i64 - expected as i64).unsigned_abs() < 10,
"expected ~{} samples, got {}",
expected,
resampled.len()
);
}
#[test]
fn resample_noop_at_same_rate() {
let samples = vec![1.0f32, 2.0, 3.0, 4.0];
let resampled = resample(&samples, 16000, 16000);
assert_eq!(samples, resampled);
}
#[test]
fn normalize_boosts_quiet_audio() {
let samples = vec![0.005f32, -0.008, 0.01, -0.003, 0.007];
let normalized = normalize_audio(samples);
let peak = normalized.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
assert!(peak > 0.4, "expected peak > 0.4, got {}", peak);
assert!(peak <= 0.5, "expected peak <= 0.5, got {}", peak);
}
#[test]
fn normalize_leaves_loud_audio_untouched() {
let samples = vec![0.3f32, -0.5, 0.2, -0.1];
let normalized = normalize_audio(samples.clone());
assert_eq!(samples, normalized);
}
#[test]
fn normalize_ignores_noise_floor() {
let samples = vec![0.00001f32, -0.00002, 0.00001];
let normalized = normalize_audio(samples.clone());
assert_eq!(samples, normalized);
}
#[test]
#[cfg(feature = "whisper")]
fn resolve_model_path_returns_error_for_missing() {
let config = Config {
transcription: crate::config::TranscriptionConfig {
model: "nonexistent".into(),
model_path: PathBuf::from("/tmp/no-such-dir"),
min_words: 10,
language: Some("en".into()),
},
..Config::default()
};
let result = resolve_model_path(&config);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("minutes setup --model tiny"),
"error should tell user how to fix it: {}",
err
);
assert!(
err.contains("ggml-nonexistent.bin"),
"error should include expected model filename: {}",
err
);
assert!(
err.contains("/tmp/no-such-dir"),
"error should include the model directory: {}",
err
);
}
#[test]
fn load_wav_rejects_empty_file() {
let dir = tempfile::TempDir::new().unwrap();
let path = dir.path().join("empty.wav");
std::fs::write(&path, "").unwrap();
let result = load_wav(&path);
assert!(result.is_err());
}
#[test]
fn load_wav_reads_valid_wav() {
let dir = tempfile::TempDir::new().unwrap();
let path = dir.path().join("test.wav");
let spec = hound::WavSpec {
channels: 1,
sample_rate: 16000,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut writer = hound::WavWriter::create(&path, spec).unwrap();
for i in 0..16000 {
let sample =
(10000.0 * (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 16000.0).sin()) as i16;
writer.write_sample(sample).unwrap();
}
writer.finalize().unwrap();
let samples = load_wav(&path).unwrap();
assert!(!samples.is_empty());
assert_eq!(samples.len(), 16000);
}
#[test]
fn load_audio_rejects_unknown_extension() {
let dir = tempfile::TempDir::new().unwrap();
let path = dir.path().join("test.xyz");
std::fs::write(&path, "not audio").unwrap();
let result = load_audio_samples(&path);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("xyz"));
}
}