Skip to main content

proteus_lib/dsp/effects/convolution_reverb/
impulse_response.rs

1//! Load and normalize impulse responses for convolution reverb.
2
3use std::fmt;
4use std::fs::File;
5use std::io::{BufReader, Cursor, Read, Seek};
6use std::path::Path;
7
8use log::{info, warn};
9use matroska::Matroska;
10use rodio::{Decoder, Source};
11
12/// Decoded impulse response audio data.
13///
14/// Samples are stored per-channel, interleaving is handled by consumers.
15#[derive(Debug, Clone)]
16pub struct ImpulseResponse {
17    pub sample_rate: u32,
18    pub channels: Vec<Vec<f32>>,
19}
20
21impl ImpulseResponse {
22    /// Return the number of channels contained in the impulse response.
23    pub fn channel_count(&self) -> usize {
24        self.channels.len()
25    }
26
27    /// Select a channel to use for the requested output index.
28    ///
29    /// Multi-channel IRs are wrapped (round-robin). Mono IRs are reused for
30    /// all outputs.
31    pub fn channel_for_output(&self, index: usize) -> &[f32] {
32        if self.channels.is_empty() {
33            return &[];
34        }
35
36        if self.channels.len() == 1 {
37            return &self.channels[0];
38        }
39
40        let channel_index = index % self.channels.len();
41        &self.channels[channel_index]
42    }
43}
44
45/// Errors that can occur while loading or decoding impulse responses.
46#[derive(Debug)]
47pub enum ImpulseResponseError {
48    Io(std::io::Error),
49    Matroska(matroska::Error),
50    Decode(rodio::decoder::DecoderError),
51    AttachmentNotFound(String),
52    InvalidChannels,
53}
54
55impl fmt::Display for ImpulseResponseError {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        match self {
58            Self::Io(err) => write!(f, "failed to read impulse response: {}", err),
59            Self::Matroska(err) => write!(f, "failed to read prot container: {}", err),
60            Self::Decode(err) => write!(f, "failed to decode impulse response: {}", err),
61            Self::AttachmentNotFound(name) => {
62                write!(f, "impulse response attachment not found: {}", name)
63            }
64            Self::InvalidChannels => write!(f, "impulse response has invalid channel count"),
65        }
66    }
67}
68
69impl std::error::Error for ImpulseResponseError {}
70
71impl From<std::io::Error> for ImpulseResponseError {
72    fn from(err: std::io::Error) -> Self {
73        Self::Io(err)
74    }
75}
76
77impl From<rodio::decoder::DecoderError> for ImpulseResponseError {
78    fn from(err: rodio::decoder::DecoderError) -> Self {
79        Self::Decode(err)
80    }
81}
82
83impl From<matroska::Error> for ImpulseResponseError {
84    fn from(err: matroska::Error) -> Self {
85        Self::Matroska(err)
86    }
87}
88
89/// Load an impulse response from a file path.
90///
91/// # Example
92/// ```no_run
93/// use proteus_lib::dsp::effects::convolution_reverb::impulse_response::load_impulse_response_from_file;
94///
95/// let ir = load_impulse_response_from_file("ir.wav").unwrap();
96/// assert!(ir.channel_count() > 0);
97/// ```
98pub fn load_impulse_response_from_file(
99    path: impl AsRef<Path>,
100) -> Result<ImpulseResponse, ImpulseResponseError> {
101    load_impulse_response_from_file_with_tail(path, Some(-60.0))
102}
103
104/// Load an impulse response from disk and optionally trim its tail in dB.
105pub fn load_impulse_response_from_file_with_tail(
106    path: impl AsRef<Path>,
107    tail_db: Option<f32>,
108) -> Result<ImpulseResponse, ImpulseResponseError> {
109    let file = File::open(path)?;
110    decode_impulse_response(BufReader::new(file), tail_db)
111}
112
113/// Load an impulse response from in-memory audio bytes.
114pub fn load_impulse_response_from_bytes(
115    bytes: &[u8],
116) -> Result<ImpulseResponse, ImpulseResponseError> {
117    load_impulse_response_from_bytes_with_tail(bytes, Some(-60.0))
118}
119
120/// Load an impulse response from bytes and optionally trim its tail in dB.
121pub fn load_impulse_response_from_bytes_with_tail(
122    bytes: &[u8],
123    tail_db: Option<f32>,
124) -> Result<ImpulseResponse, ImpulseResponseError> {
125    decode_impulse_response(BufReader::new(Cursor::new(bytes.to_vec())), tail_db)
126}
127
128/// Load an impulse response attachment from a `.prot`/`.mka` container.
129pub fn load_impulse_response_from_prot_attachment(
130    prot_path: impl AsRef<Path>,
131    attachment_name: &str,
132) -> Result<ImpulseResponse, ImpulseResponseError> {
133    load_impulse_response_from_prot_attachment_with_tail(prot_path, attachment_name, Some(-60.0))
134}
135
136/// Load a container attachment and optionally trim its tail in dB.
137pub fn load_impulse_response_from_prot_attachment_with_tail(
138    prot_path: impl AsRef<Path>,
139    attachment_name: &str,
140    tail_db: Option<f32>,
141) -> Result<ImpulseResponse, ImpulseResponseError> {
142    let file = File::open(prot_path)?;
143    let mka: Matroska = Matroska::open(file)?;
144
145    let attachment = mka
146        .attachments
147        .iter()
148        .find(|attachment| attachment.name.trim_matches('"') == attachment_name)
149        .ok_or_else(|| ImpulseResponseError::AttachmentNotFound(attachment_name.to_string()))?;
150
151    info!("Loading impulse bytes response from {}", attachment.name);
152
153    load_impulse_response_from_bytes_with_tail(&attachment.data, tail_db)
154}
155
156fn decode_impulse_response<R>(
157    reader: R,
158    tail_db: Option<f32>,
159) -> Result<ImpulseResponse, ImpulseResponseError>
160where
161    R: Read + Seek + Send + Sync + 'static,
162{
163    let source = Decoder::new(reader)?;
164    let channels = source.channels() as usize;
165    if channels == 0 {
166        return Err(ImpulseResponseError::InvalidChannels);
167    }
168
169    let sample_rate = source.sample_rate();
170    let mut channel_samples = vec![Vec::new(); channels];
171
172    for (index, sample) in source.enumerate() {
173        channel_samples[index % channels].push(sample as f32);
174    }
175
176    normalize_impulse_response_channels(&mut channel_samples, tail_db);
177
178    if channel_samples.iter().any(|channel| channel.is_empty()) {
179        warn!("Impulse response includes empty channels; results may be silent.");
180    }
181
182    Ok(ImpulseResponse {
183        sample_rate,
184        channels: channel_samples,
185    })
186}
187
188/// Normalize and optionally trim channels in-place.
189///
190/// Normalization is a two-step process:
191/// 1. Peak normalization to avoid clipping.
192/// 2. Energy normalization (attenuation-only) to keep long IRs controlled.
193pub fn normalize_impulse_response_channels(channel_samples: &mut [Vec<f32>], tail_db: Option<f32>) {
194    let mut max_abs = 0.0_f32;
195    for channel in channel_samples.iter() {
196        for sample in channel {
197            let abs = sample.abs();
198            if abs > max_abs {
199                max_abs = abs;
200            }
201        }
202    }
203
204    if max_abs > 0.0 {
205        let scale = 1.0 / max_abs;
206        for channel in channel_samples.iter_mut() {
207            for sample in channel {
208                *sample *= scale;
209            }
210        }
211    }
212
213    if let Some(tail_db) = tail_db {
214        if tail_db.is_finite() {
215            trim_impulse_response_tail(channel_samples, tail_db);
216        }
217    }
218
219    // Energy-normalize (attenuate only) so long IRs don't explode the wet gain.
220    let mut max_energy = 0.0_f32;
221    for channel in channel_samples.iter() {
222        let mut sum_sq = 0.0_f32;
223        for sample in channel {
224            sum_sq += sample * sample;
225        }
226        if sum_sq > max_energy {
227            max_energy = sum_sq;
228        }
229    }
230    if max_energy > 0.0 {
231        let mut scale = 1.0_f32 / max_energy.sqrt();
232        if scale > 1.0 {
233            scale = 1.0;
234        }
235        if scale < 1.0 {
236            for channel in channel_samples.iter_mut() {
237                for sample in channel {
238                    *sample *= scale;
239                }
240            }
241        }
242    }
243}
244
245fn trim_impulse_response_tail(channels: &mut [Vec<f32>], tail_db: f32) {
246    if channels.is_empty() {
247        return;
248    }
249
250    let threshold = 10.0_f32.powf(tail_db / 20.0).abs();
251    if threshold <= 0.0 {
252        return;
253    }
254
255    let mut last_index = 0usize;
256    for (channel_index, channel) in channels.iter().enumerate() {
257        if channel.is_empty() {
258            continue;
259        }
260        let mut channel_last = None;
261        for (index, sample) in channel.iter().enumerate() {
262            if sample.abs() >= threshold {
263                channel_last = Some(index);
264            }
265        }
266        if let Some(channel_last) = channel_last {
267            if channel_index == 0 || channel_last > last_index {
268                last_index = channel_last;
269            }
270        }
271    }
272
273    let keep_len = (last_index + 1).max(1);
274    for channel in channels.iter_mut() {
275        if channel.len() > keep_len {
276            channel.truncate(keep_len);
277        }
278    }
279}