Skip to main content

proteus_lib/dsp/
impulse_response.rs

1//! Load and normalize impulse responses for convolution reverb.
2
3use std::fmt;
4use std::fs::File;
5use std::io::{BufReader, Cursor, Read, Seek};
6use std::path::Path;
7
8use log::warn;
9use matroska::Matroska;
10use rodio::{Decoder, Source};
11
12/// Decoded impulse response audio data.
13///
14/// Samples are stored per-channel, interleaving is handled by consumers.
15#[derive(Debug, Clone)]
16pub struct ImpulseResponse {
17    pub sample_rate: u32,
18    pub channels: Vec<Vec<f32>>,
19}
20
21impl ImpulseResponse {
22    /// Return the number of channels contained in the impulse response.
23    pub fn channel_count(&self) -> usize {
24        self.channels.len()
25    }
26
27    /// Select a channel to use for the requested output index.
28    ///
29    /// Multi-channel IRs are wrapped (round-robin). Mono IRs are reused for
30    /// all outputs.
31    pub fn channel_for_output(&self, index: usize) -> &[f32] {
32        if self.channels.is_empty() {
33            return &[];
34        }
35
36        if self.channels.len() == 1 {
37            return &self.channels[0];
38        }
39
40        let channel_index = index % self.channels.len();
41        &self.channels[channel_index]
42    }
43}
44
45/// Errors that can occur while loading or decoding impulse responses.
46#[derive(Debug)]
47pub enum ImpulseResponseError {
48    Io(std::io::Error),
49    Matroska(matroska::Error),
50    Decode(rodio::decoder::DecoderError),
51    AttachmentNotFound(String),
52    InvalidChannels,
53}
54
55impl fmt::Display for ImpulseResponseError {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        match self {
58            Self::Io(err) => write!(f, "failed to read impulse response: {}", err),
59            Self::Matroska(err) => write!(f, "failed to read prot container: {}", err),
60            Self::Decode(err) => write!(f, "failed to decode impulse response: {}", err),
61            Self::AttachmentNotFound(name) => {
62                write!(f, "impulse response attachment not found: {}", name)
63            }
64            Self::InvalidChannels => write!(f, "impulse response has invalid channel count"),
65        }
66    }
67}
68
69impl std::error::Error for ImpulseResponseError {}
70
71impl From<std::io::Error> for ImpulseResponseError {
72    fn from(err: std::io::Error) -> Self {
73        Self::Io(err)
74    }
75}
76
77impl From<rodio::decoder::DecoderError> for ImpulseResponseError {
78    fn from(err: rodio::decoder::DecoderError) -> Self {
79        Self::Decode(err)
80    }
81}
82
83impl From<matroska::Error> for ImpulseResponseError {
84    fn from(err: matroska::Error) -> Self {
85        Self::Matroska(err)
86    }
87}
88
89/// Load an impulse response from a file path.
90///
91/// # Example
92/// ```no_run
93/// use proteus_lib::dsp::impulse_response::load_impulse_response_from_file;
94///
95/// let ir = load_impulse_response_from_file("ir.wav").unwrap();
96/// assert!(ir.channel_count() > 0);
97/// ```
98pub fn load_impulse_response_from_file(
99    path: impl AsRef<Path>,
100) -> Result<ImpulseResponse, ImpulseResponseError> {
101    load_impulse_response_from_file_with_tail(path, Some(-60.0))
102}
103
104/// Load an impulse response from disk and optionally trim its tail in dB.
105pub fn load_impulse_response_from_file_with_tail(
106    path: impl AsRef<Path>,
107    tail_db: Option<f32>,
108) -> Result<ImpulseResponse, ImpulseResponseError> {
109    let file = File::open(path)?;
110    decode_impulse_response(BufReader::new(file), tail_db)
111}
112
113/// Load an impulse response from in-memory audio bytes.
114pub fn load_impulse_response_from_bytes(
115    bytes: &[u8],
116) -> Result<ImpulseResponse, ImpulseResponseError> {
117    load_impulse_response_from_bytes_with_tail(bytes, Some(-60.0))
118}
119
120/// Load an impulse response from bytes and optionally trim its tail in dB.
121pub fn load_impulse_response_from_bytes_with_tail(
122    bytes: &[u8],
123    tail_db: Option<f32>,
124) -> Result<ImpulseResponse, ImpulseResponseError> {
125    decode_impulse_response(BufReader::new(Cursor::new(bytes.to_vec())), tail_db)
126}
127
128/// Load an impulse response attachment from a `.prot`/`.mka` container.
129pub fn load_impulse_response_from_prot_attachment(
130    prot_path: impl AsRef<Path>,
131    attachment_name: &str,
132) -> Result<ImpulseResponse, ImpulseResponseError> {
133    load_impulse_response_from_prot_attachment_with_tail(prot_path, attachment_name, Some(-60.0))
134}
135
136/// Load a container attachment and optionally trim its tail in dB.
137pub fn load_impulse_response_from_prot_attachment_with_tail(
138    prot_path: impl AsRef<Path>,
139    attachment_name: &str,
140    tail_db: Option<f32>,
141) -> Result<ImpulseResponse, ImpulseResponseError> {
142    let file = File::open(prot_path)?;
143    let mka: Matroska = Matroska::open(file)?;
144
145    let attachment = mka
146        .attachments
147        .iter()
148        .find(|attachment| attachment.name == attachment_name)
149        .ok_or_else(|| ImpulseResponseError::AttachmentNotFound(attachment_name.to_string()))?;
150
151    load_impulse_response_from_bytes_with_tail(&attachment.data, tail_db)
152}
153
154fn decode_impulse_response<R>(
155    reader: R,
156    tail_db: Option<f32>,
157) -> Result<ImpulseResponse, ImpulseResponseError>
158where
159    R: Read + Seek + Send + Sync + 'static,
160{
161    let source = Decoder::new(reader)?;
162    let channels = source.channels() as usize;
163    if channels == 0 {
164        return Err(ImpulseResponseError::InvalidChannels);
165    }
166
167    let sample_rate = source.sample_rate();
168    let mut channel_samples = vec![Vec::new(); channels];
169
170    for (index, sample) in source.enumerate() {
171        channel_samples[index % channels].push(sample as f32);
172    }
173
174    normalize_impulse_response_channels(&mut channel_samples, tail_db);
175
176    if channel_samples.iter().any(|channel| channel.is_empty()) {
177        warn!("Impulse response includes empty channels; results may be silent.");
178    }
179
180    Ok(ImpulseResponse {
181        sample_rate,
182        channels: channel_samples,
183    })
184}
185
186/// Normalize and optionally trim channels in-place.
187///
188/// Normalization is a two-step process:
189/// 1. Peak normalization to avoid clipping.
190/// 2. Energy normalization (attenuation-only) to keep long IRs controlled.
191pub fn normalize_impulse_response_channels(
192    channel_samples: &mut [Vec<f32>],
193    tail_db: Option<f32>,
194) {
195    let mut max_abs = 0.0_f32;
196    for channel in channel_samples.iter() {
197        for sample in channel {
198            let abs = sample.abs();
199            if abs > max_abs {
200                max_abs = abs;
201            }
202        }
203    }
204
205    if max_abs > 0.0 {
206        let scale = 1.0 / max_abs;
207        for channel in channel_samples.iter_mut() {
208            for sample in channel {
209                *sample *= scale;
210            }
211        }
212    }
213
214    if let Some(tail_db) = tail_db {
215        if tail_db.is_finite() {
216            trim_impulse_response_tail(channel_samples, tail_db);
217        }
218    }
219
220    // Energy-normalize (attenuate only) so long IRs don't explode the wet gain.
221    let mut max_energy = 0.0_f32;
222    for channel in channel_samples.iter() {
223        let mut sum_sq = 0.0_f32;
224        for sample in channel {
225            sum_sq += sample * sample;
226        }
227        if sum_sq > max_energy {
228            max_energy = sum_sq;
229        }
230    }
231    if max_energy > 0.0 {
232        let mut scale = 1.0_f32 / max_energy.sqrt();
233        if scale > 1.0 {
234            scale = 1.0;
235        }
236        if scale < 1.0 {
237            for channel in channel_samples.iter_mut() {
238                for sample in channel {
239                    *sample *= scale;
240                }
241            }
242        }
243    }
244}
245
246fn trim_impulse_response_tail(channels: &mut [Vec<f32>], tail_db: f32) {
247    if channels.is_empty() {
248        return;
249    }
250
251    let threshold = 10.0_f32.powf(tail_db / 20.0).abs();
252    if threshold <= 0.0 {
253        return;
254    }
255
256    let mut last_index = 0usize;
257    for (channel_index, channel) in channels.iter().enumerate() {
258        if channel.is_empty() {
259            continue;
260        }
261        let mut channel_last = None;
262        for (index, sample) in channel.iter().enumerate() {
263            if sample.abs() >= threshold {
264                channel_last = Some(index);
265            }
266        }
267        if let Some(channel_last) = channel_last {
268            if channel_index == 0 || channel_last > last_index {
269                last_index = channel_last;
270            }
271        }
272    }
273
274    let keep_len = (last_index + 1).max(1);
275    for channel in channels.iter_mut() {
276        if channel.len() > keep_len {
277            channel.truncate(keep_len);
278        }
279    }
280}