Skip to main content

ferrum_models/
audio_processor.rs

1//! Audio preprocessing for Whisper ASR.
2//!
3//! Load audio files → decode → resample to 16kHz mono → f32 PCM samples.
4//! Supports WAV natively; M4A/MP3/FLAC/OGG via ffmpeg auto-conversion.
5
6use ferrum_types::{FerrumError, Result};
7use std::path::Path;
8
9/// Whisper processes 30-second chunks. At 16kHz → 480,000 samples.
10pub const CHUNK_SAMPLES: usize = 16000 * 30;
11
12/// Load audio file and return 16kHz mono f32 PCM samples.
13///
14/// If the file is not WAV, tries ffmpeg conversion automatically.
15pub fn load_audio(path: &str) -> Result<Vec<f32>> {
16    let p = Path::new(path);
17    let ext = p
18        .extension()
19        .and_then(|e| e.to_str())
20        .unwrap_or("")
21        .to_lowercase();
22
23    // WAV: direct load
24    if ext == "wav" {
25        return load_wav_file(path);
26    }
27
28    // Non-WAV: convert via ffmpeg
29    convert_with_ffmpeg(path)
30}
31
32/// Load audio file and return mono f32 PCM samples at a configurable sample rate.
33///
34/// Similar to `load_audio` but resamples to `target_rate` instead of 16kHz.
35/// Useful for TTS speaker encoder which expects 24kHz input.
36pub fn load_audio_at_rate(path: &str, target_rate: u32) -> Result<Vec<f32>> {
37    let p = Path::new(path);
38    let ext = p
39        .extension()
40        .and_then(|e| e.to_str())
41        .unwrap_or("")
42        .to_lowercase();
43
44    // WAV: decode then resample to target_rate
45    if ext == "wav" {
46        return load_wav_file_at_rate(path, target_rate);
47    }
48
49    // Non-WAV: convert via ffmpeg to target_rate
50    convert_with_ffmpeg_at_rate(path, target_rate)
51}
52
53/// Load audio from raw bytes.
54/// Tries WAV first; if that fails and bytes look non-WAV, tries ffmpeg.
55pub fn load_audio_bytes(data: &[u8]) -> Result<Vec<f32>> {
56    // Try WAV first
57    match load_wav_bytes(data) {
58        Ok(pcm) => return Ok(pcm),
59        Err(_) => {}
60    }
61
62    // Fallback: write to temp file and convert via ffmpeg
63    let tmp = std::env::temp_dir().join("ferrum_audio_tmp");
64    std::fs::write(&tmp, data).map_err(|e| FerrumError::model(format!("write temp audio: {e}")))?;
65    let result = convert_with_ffmpeg(tmp.to_str().unwrap_or(""));
66    let _ = std::fs::remove_file(&tmp);
67    result
68}
69
70/// Split PCM samples into 30-second chunks for Whisper processing.
71pub fn chunk_pcm(pcm: &[f32]) -> Vec<&[f32]> {
72    if pcm.len() <= CHUNK_SAMPLES {
73        return vec![pcm];
74    }
75    pcm.chunks(CHUNK_SAMPLES).collect()
76}
77
78// ── WAV loading ─────────────────────────────────────────────────────────
79
80fn load_wav_file(path: &str) -> Result<Vec<f32>> {
81    let reader = hound::WavReader::open(path)
82        .map_err(|e| FerrumError::model(format!("open audio {path}: {e}")))?;
83    decode_wav(reader)
84}
85
86fn load_wav_bytes(data: &[u8]) -> Result<Vec<f32>> {
87    let cursor = std::io::Cursor::new(data);
88    let reader =
89        hound::WavReader::new(cursor).map_err(|e| FerrumError::model(format!("decode: {e}")))?;
90    decode_wav(reader)
91}
92
93fn decode_wav<R: std::io::Read>(reader: hound::WavReader<R>) -> Result<Vec<f32>> {
94    let spec = reader.spec();
95    let sample_rate = spec.sample_rate as f64;
96    let channels = spec.channels as usize;
97
98    let samples: Vec<f32> = match spec.sample_format {
99        hound::SampleFormat::Float => reader
100            .into_samples::<f32>()
101            .filter_map(|s| s.ok())
102            .collect(),
103        hound::SampleFormat::Int => {
104            let bits = spec.bits_per_sample;
105            let max_val = (1u32 << (bits - 1)) as f32;
106            reader
107                .into_samples::<i32>()
108                .filter_map(|s| s.ok())
109                .map(|s| s as f32 / max_val)
110                .collect()
111        }
112    };
113
114    // Mix to mono
115    let mono: Vec<f32> = if channels == 1 {
116        samples
117    } else {
118        samples
119            .chunks(channels)
120            .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
121            .collect()
122    };
123
124    // Resample to 16kHz if needed
125    if (sample_rate - 16000.0).abs() < 1.0 {
126        Ok(mono)
127    } else {
128        Ok(resample(&mono, sample_rate, 16000.0))
129    }
130}
131
132// ── ffmpeg conversion ───────────────────────────────────────────────────
133
134fn convert_with_ffmpeg(input_path: &str) -> Result<Vec<f32>> {
135    let output = std::env::temp_dir().join("ferrum_ffmpeg_out.wav");
136    let output_str = output.to_string_lossy().to_string();
137
138    let status = std::process::Command::new("ffmpeg")
139        .args([
140            "-y",
141            "-i",
142            input_path,
143            "-ar",
144            "16000",
145            "-ac",
146            "1",
147            "-sample_fmt",
148            "s16",
149            "-f",
150            "wav",
151            &output_str,
152        ])
153        .stdout(std::process::Stdio::null())
154        .stderr(std::process::Stdio::null())
155        .status();
156
157    match status {
158        Ok(s) if s.success() => {
159            let result = load_wav_file(&output_str);
160            let _ = std::fs::remove_file(&output);
161            result
162        }
163        Ok(s) => Err(FerrumError::model(format!(
164            "ffmpeg exited with code {}. Is the audio file valid?",
165            s.code().unwrap_or(-1)
166        ))),
167        Err(_) => Err(FerrumError::model(
168            "ffmpeg not found. Install ffmpeg to process non-WAV audio (brew install ffmpeg)",
169        )),
170    }
171}
172
173// ── WAV loading at configurable rate ─────────────────────────────────────
174
175fn load_wav_file_at_rate(path: &str, target_rate: u32) -> Result<Vec<f32>> {
176    let reader = hound::WavReader::open(path)
177        .map_err(|e| FerrumError::model(format!("open audio {path}: {e}")))?;
178    decode_wav_at_rate(reader, target_rate)
179}
180
181fn decode_wav_at_rate<R: std::io::Read>(
182    reader: hound::WavReader<R>,
183    target_rate: u32,
184) -> Result<Vec<f32>> {
185    let spec = reader.spec();
186    let sample_rate = spec.sample_rate as f64;
187    let channels = spec.channels as usize;
188
189    let samples: Vec<f32> = match spec.sample_format {
190        hound::SampleFormat::Float => reader
191            .into_samples::<f32>()
192            .filter_map(|s| s.ok())
193            .collect(),
194        hound::SampleFormat::Int => {
195            let bits = spec.bits_per_sample;
196            let max_val = (1u32 << (bits - 1)) as f32;
197            reader
198                .into_samples::<i32>()
199                .filter_map(|s| s.ok())
200                .map(|s| s as f32 / max_val)
201                .collect()
202        }
203    };
204
205    // Mix to mono
206    let mono: Vec<f32> = if channels == 1 {
207        samples
208    } else {
209        samples
210            .chunks(channels)
211            .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
212            .collect()
213    };
214
215    // Resample to target_rate if needed
216    let target = target_rate as f64;
217    if (sample_rate - target).abs() < 1.0 {
218        Ok(mono)
219    } else {
220        Ok(resample(&mono, sample_rate, target))
221    }
222}
223
224fn convert_with_ffmpeg_at_rate(input_path: &str, target_rate: u32) -> Result<Vec<f32>> {
225    let output = std::env::temp_dir().join("ferrum_ffmpeg_out_rate.wav");
226    let output_str = output.to_string_lossy().to_string();
227    let rate_str = target_rate.to_string();
228
229    let status = std::process::Command::new("ffmpeg")
230        .args([
231            "-y",
232            "-i",
233            input_path,
234            "-ar",
235            &rate_str,
236            "-ac",
237            "1",
238            "-sample_fmt",
239            "s16",
240            "-f",
241            "wav",
242            &output_str,
243        ])
244        .stdout(std::process::Stdio::null())
245        .stderr(std::process::Stdio::null())
246        .status();
247
248    match status {
249        Ok(s) if s.success() => {
250            let result = load_wav_file_at_rate(&output_str, target_rate);
251            let _ = std::fs::remove_file(&output);
252            result
253        }
254        Ok(s) => Err(FerrumError::model(format!(
255            "ffmpeg exited with code {}. Is the audio file valid?",
256            s.code().unwrap_or(-1)
257        ))),
258        Err(_) => Err(FerrumError::model(
259            "ffmpeg not found. Install ffmpeg to process non-WAV audio (brew install ffmpeg)",
260        )),
261    }
262}
263
264// ── Resampler ───────────────────────────────────────────────────────────
265
266pub(crate) fn resample(input: &[f32], from_rate: f64, to_rate: f64) -> Vec<f32> {
267    use rubato::{
268        audioadapter::Adapter, Async, FixedAsync, Resampler as RubatoResampler,
269        SincInterpolationParameters, SincInterpolationType, WindowFunction,
270    };
271
272    let ratio = to_rate / from_rate;
273    let chunk_size = 1024;
274
275    let params = SincInterpolationParameters {
276        sinc_len: 128,
277        f_cutoff: 0.95,
278        interpolation: SincInterpolationType::Linear,
279        oversampling_factor: 128,
280        window: WindowFunction::BlackmanHarris2,
281    };
282
283    let mut resampler =
284        Async::<f32>::new_sinc(ratio, 1.0, &params, chunk_size, 1, FixedAsync::Input)
285            .expect("resample init");
286
287    let mut output = Vec::new();
288    let mut pos = 0;
289    while pos < input.len() {
290        let end = (pos + chunk_size).min(input.len());
291        let chunk = &input[pos..end];
292        let data: Vec<f32> = if chunk.len() < chunk_size {
293            let mut p = chunk.to_vec();
294            p.resize(chunk_size, 0.0);
295            p
296        } else {
297            chunk.to_vec()
298        };
299
300        let input_vecs = vec![data];
301        let input_adapter =
302            audioadapter_buffers::direct::SequentialSliceOfVecs::new(&input_vecs, 1, chunk_size)
303                .expect("input adapter");
304        let result = resampler
305            .process(&input_adapter, 0, None)
306            .expect("resample");
307        let frames = result.frames();
308        for i in 0..frames {
309            output.push(result.read_sample(0, i).unwrap_or(0.0));
310        }
311        pos += chunk_size;
312    }
313    output
314}