Skip to main content

ferrum_models/
audio_processor.rs

1//! Audio preprocessing for Whisper ASR.
2//!
3//! Load audio files → decode → resample to 16kHz mono → f32 PCM samples.
4//! Pure-Rust pipeline via `symphonia` — no ffmpeg runtime dependency.
5//! Supports WAV / MP3 / FLAC / M4A (AAC) / OGG (Vorbis).
6
7use ferrum_types::{FerrumError, Result};
8use std::path::Path;
9
10/// Whisper processes 30-second chunks. At 16kHz → 480,000 samples.
11pub const CHUNK_SAMPLES: usize = 16000 * 30;
12
13/// Load audio file and return 16kHz mono f32 PCM samples.
14pub fn load_audio(path: &str) -> Result<Vec<f32>> {
15    load_audio_at_rate(path, 16000)
16}
17
18/// Load audio file and return mono f32 PCM samples at a configurable sample rate.
19///
20/// Accepts WAV / MP3 / FLAC / M4A (AAC) / OGG (Vorbis). Non-mono sources are
21/// downmixed by averaging channels; sample rate is converted via sinc resampling.
22/// Useful for TTS speaker encoder which expects 24kHz input.
23pub fn load_audio_at_rate(path: &str, target_rate: u32) -> Result<Vec<f32>> {
24    let file = std::fs::File::open(path)
25        .map_err(|e| FerrumError::model(format!("open audio {path}: {e}")))?;
26    let mss = symphonia::core::io::MediaSourceStream::new(Box::new(file), Default::default());
27
28    // Hint the decoder with the file extension — purely an optimisation; the
29    // probe still content-sniffs if the hint is absent or wrong.
30    let mut hint = symphonia::core::probe::Hint::new();
31    if let Some(ext) = Path::new(path).extension().and_then(|e| e.to_str()) {
32        hint.with_extension(&ext.to_lowercase());
33    }
34
35    decode_with_symphonia(mss, &hint, target_rate)
36}
37
38/// Load audio from raw bytes (used by the HTTP multipart endpoint).
39///
40/// Content-sniffs the format; supports the same codec set as `load_audio`.
41pub fn load_audio_bytes(data: &[u8]) -> Result<Vec<f32>> {
42    let cursor = std::io::Cursor::new(data.to_vec());
43    let mss = symphonia::core::io::MediaSourceStream::new(Box::new(cursor), Default::default());
44    let hint = symphonia::core::probe::Hint::new();
45    decode_with_symphonia(mss, &hint, 16000)
46}
47
48/// Split PCM samples into 30-second chunks for Whisper processing.
49pub fn chunk_pcm(pcm: &[f32]) -> Vec<&[f32]> {
50    if pcm.len() <= CHUNK_SAMPLES {
51        return vec![pcm];
52    }
53    pcm.chunks(CHUNK_SAMPLES).collect()
54}
55
56// ── symphonia-based decoding ────────────────────────────────────────────
57
58fn decode_with_symphonia(
59    mss: symphonia::core::io::MediaSourceStream,
60    hint: &symphonia::core::probe::Hint,
61    target_rate: u32,
62) -> Result<Vec<f32>> {
63    use symphonia::core::codecs::DecoderOptions;
64    use symphonia::core::errors::Error as SymError;
65    use symphonia::core::formats::FormatOptions;
66    use symphonia::core::meta::MetadataOptions;
67
68    let fmt_opts: FormatOptions = Default::default();
69    let meta_opts: MetadataOptions = Default::default();
70    let probed = symphonia::default::get_probe()
71        .format(hint, mss, &fmt_opts, &meta_opts)
72        .map_err(|e| FerrumError::model(format!("probe audio: {e}")))?;
73
74    let mut format = probed.format;
75
76    // First audio track wins. Whisper-friendly files are mono or stereo; we
77    // downmix to mono below.
78    let track = format
79        .tracks()
80        .iter()
81        .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
82        .ok_or_else(|| FerrumError::model("no audio track in file"))?;
83    let track_id = track.id;
84
85    let source_rate = track
86        .codec_params
87        .sample_rate
88        .ok_or_else(|| FerrumError::model("audio track missing sample rate"))?;
89    let channels = track
90        .codec_params
91        .channels
92        .map(|c| c.count())
93        .unwrap_or(1)
94        .max(1);
95
96    let dec_opts: DecoderOptions = Default::default();
97    let mut decoder = symphonia::default::get_codecs()
98        .make(&track.codec_params, &dec_opts)
99        .map_err(|e| FerrumError::model(format!("decoder init: {e}")))?;
100
101    // Accumulate interleaved f32 samples across all decoded packets, then
102    // downmix + resample once at the end. Cheaper in wall time than per-packet
103    // resampling for short files and simpler to reason about.
104    let mut interleaved: Vec<f32> = Vec::new();
105
106    loop {
107        let packet = match format.next_packet() {
108            Ok(p) => p,
109            Err(SymError::IoError(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
110            Err(SymError::ResetRequired) => break,
111            Err(e) => return Err(FerrumError::model(format!("next_packet: {e}"))),
112        };
113        if packet.track_id() != track_id {
114            continue;
115        }
116
117        match decoder.decode(&packet) {
118            Ok(decoded) => append_interleaved_f32(&decoded, &mut interleaved, channels),
119            Err(SymError::DecodeError(_)) => continue, // skip corrupt packet
120            Err(SymError::IoError(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
121            Err(e) => return Err(FerrumError::model(format!("decode: {e}"))),
122        }
123    }
124
125    // Downmix to mono (average channels)
126    let mono: Vec<f32> = if channels == 1 {
127        interleaved
128    } else {
129        interleaved
130            .chunks(channels)
131            .map(|c| c.iter().sum::<f32>() / channels as f32)
132            .collect()
133    };
134
135    // Resample to target_rate if needed
136    if source_rate == target_rate {
137        Ok(mono)
138    } else {
139        Ok(resample(&mono, source_rate as f64, target_rate as f64))
140    }
141}
142
143/// Convert any symphonia AudioBufferRef variant into interleaved f32 samples
144/// and append to `out`. Handles U8 / S16 / S24 / S32 / F32 / F64 — which
145/// covers every codec we compile in.
146fn append_interleaved_f32(
147    buf: &symphonia::core::audio::AudioBufferRef<'_>,
148    out: &mut Vec<f32>,
149    channels: usize,
150) {
151    use symphonia::core::audio::{AudioBuffer, AudioBufferRef, Signal};
152    use symphonia::core::conv::IntoSample;
153
154    fn push_interleaved<S>(buf: &AudioBuffer<S>, out: &mut Vec<f32>, channels: usize)
155    where
156        S: symphonia::core::sample::Sample + IntoSample<f32> + Copy,
157    {
158        let frames = buf.frames();
159        out.reserve(frames * channels);
160        for frame in 0..frames {
161            for ch in 0..channels {
162                let s: S = buf.chan(ch)[frame];
163                out.push(s.into_sample());
164            }
165        }
166    }
167
168    match buf {
169        AudioBufferRef::U8(b) => push_interleaved(b, out, channels),
170        AudioBufferRef::S16(b) => push_interleaved(b, out, channels),
171        AudioBufferRef::S24(b) => push_interleaved(b, out, channels),
172        AudioBufferRef::S32(b) => push_interleaved(b, out, channels),
173        AudioBufferRef::F32(b) => push_interleaved(b, out, channels),
174        AudioBufferRef::F64(b) => push_interleaved(b, out, channels),
175        _ => {
176            // U16, S8 etc. not exposed by the codecs we compile in; silently
177            // skip rather than add dead conversion code.
178        }
179    }
180}
181
182// ── Resampler ───────────────────────────────────────────────────────────
183
184pub(crate) fn resample(input: &[f32], from_rate: f64, to_rate: f64) -> Vec<f32> {
185    use rubato::{
186        audioadapter::Adapter, Async, FixedAsync, Resampler as RubatoResampler,
187        SincInterpolationParameters, SincInterpolationType, WindowFunction,
188    };
189
190    let ratio = to_rate / from_rate;
191    let chunk_size = 1024;
192
193    let params = SincInterpolationParameters {
194        sinc_len: 128,
195        f_cutoff: 0.95,
196        interpolation: SincInterpolationType::Linear,
197        oversampling_factor: 128,
198        window: WindowFunction::BlackmanHarris2,
199    };
200
201    let mut resampler =
202        Async::<f32>::new_sinc(ratio, 1.0, &params, chunk_size, 1, FixedAsync::Input)
203            .expect("resample init");
204
205    let mut output = Vec::new();
206    let mut pos = 0;
207    while pos < input.len() {
208        let end = (pos + chunk_size).min(input.len());
209        let chunk = &input[pos..end];
210        let data: Vec<f32> = if chunk.len() < chunk_size {
211            let mut p = chunk.to_vec();
212            p.resize(chunk_size, 0.0);
213            p
214        } else {
215            chunk.to_vec()
216        };
217
218        let input_vecs = vec![data];
219        let input_adapter =
220            audioadapter_buffers::direct::SequentialSliceOfVecs::new(&input_vecs, 1, chunk_size)
221                .expect("input adapter");
222        let result = resampler
223            .process(&input_adapter, 0, None)
224            .expect("resample");
225        let frames = result.frames();
226        for i in 0..frames {
227            output.push(result.read_sample(0, i).unwrap_or(0.0));
228        }
229        pos += chunk_size;
230    }
231    output
232}