proteus_lib/dsp/effects/convolution_reverb/
impulse_response.rs1use std::fmt;
4use std::fs::File;
5use std::io::{BufReader, Cursor, Read, Seek};
6use std::path::Path;
7
8use log::{info, warn};
9use matroska::Matroska;
10use rodio::{Decoder, Source};
11
12#[derive(Debug, Clone)]
16pub struct ImpulseResponse {
17 pub sample_rate: u32,
18 pub channels: Vec<Vec<f32>>,
19}
20
21impl ImpulseResponse {
22 pub fn channel_count(&self) -> usize {
24 self.channels.len()
25 }
26
27 pub fn channel_for_output(&self, index: usize) -> &[f32] {
32 if self.channels.is_empty() {
33 return &[];
34 }
35
36 if self.channels.len() == 1 {
37 return &self.channels[0];
38 }
39
40 let channel_index = index % self.channels.len();
41 &self.channels[channel_index]
42 }
43}
44
45#[derive(Debug)]
47pub enum ImpulseResponseError {
48 Io(std::io::Error),
49 Matroska(matroska::Error),
50 Decode(rodio::decoder::DecoderError),
51 AttachmentNotFound(String),
52 InvalidChannels,
53}
54
55impl fmt::Display for ImpulseResponseError {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 match self {
58 Self::Io(err) => write!(f, "failed to read impulse response: {}", err),
59 Self::Matroska(err) => write!(f, "failed to read prot container: {}", err),
60 Self::Decode(err) => write!(f, "failed to decode impulse response: {}", err),
61 Self::AttachmentNotFound(name) => {
62 write!(f, "impulse response attachment not found: {}", name)
63 }
64 Self::InvalidChannels => write!(f, "impulse response has invalid channel count"),
65 }
66 }
67}
68
69impl std::error::Error for ImpulseResponseError {}
70
71impl From<std::io::Error> for ImpulseResponseError {
72 fn from(err: std::io::Error) -> Self {
73 Self::Io(err)
74 }
75}
76
77impl From<rodio::decoder::DecoderError> for ImpulseResponseError {
78 fn from(err: rodio::decoder::DecoderError) -> Self {
79 Self::Decode(err)
80 }
81}
82
83impl From<matroska::Error> for ImpulseResponseError {
84 fn from(err: matroska::Error) -> Self {
85 Self::Matroska(err)
86 }
87}
88
89pub fn load_impulse_response_from_file(
99 path: impl AsRef<Path>,
100) -> Result<ImpulseResponse, ImpulseResponseError> {
101 load_impulse_response_from_file_with_tail(path, Some(-60.0))
102}
103
104pub fn load_impulse_response_from_file_with_tail(
106 path: impl AsRef<Path>,
107 tail_db: Option<f32>,
108) -> Result<ImpulseResponse, ImpulseResponseError> {
109 let file = File::open(path)?;
110 decode_impulse_response(BufReader::new(file), tail_db)
111}
112
113pub fn load_impulse_response_from_bytes(
115 bytes: &[u8],
116) -> Result<ImpulseResponse, ImpulseResponseError> {
117 load_impulse_response_from_bytes_with_tail(bytes, Some(-60.0))
118}
119
120pub fn load_impulse_response_from_bytes_with_tail(
122 bytes: &[u8],
123 tail_db: Option<f32>,
124) -> Result<ImpulseResponse, ImpulseResponseError> {
125 decode_impulse_response(BufReader::new(Cursor::new(bytes.to_vec())), tail_db)
126}
127
128pub fn load_impulse_response_from_prot_attachment(
130 prot_path: impl AsRef<Path>,
131 attachment_name: &str,
132) -> Result<ImpulseResponse, ImpulseResponseError> {
133 load_impulse_response_from_prot_attachment_with_tail(prot_path, attachment_name, Some(-60.0))
134}
135
136pub fn load_impulse_response_from_prot_attachment_with_tail(
138 prot_path: impl AsRef<Path>,
139 attachment_name: &str,
140 tail_db: Option<f32>,
141) -> Result<ImpulseResponse, ImpulseResponseError> {
142 let file = File::open(prot_path)?;
143 let mka: Matroska = Matroska::open(file)?;
144
145 let attachment = mka
146 .attachments
147 .iter()
148 .find(|attachment| attachment.name.trim_matches('"') == attachment_name)
149 .ok_or_else(|| ImpulseResponseError::AttachmentNotFound(attachment_name.to_string()))?;
150
151 info!("Loading impulse bytes response from {}", attachment.name);
152
153 load_impulse_response_from_bytes_with_tail(&attachment.data, tail_db)
154}
155
156fn decode_impulse_response<R>(
157 reader: R,
158 tail_db: Option<f32>,
159) -> Result<ImpulseResponse, ImpulseResponseError>
160where
161 R: Read + Seek + Send + Sync + 'static,
162{
163 let source = Decoder::new(reader)?;
164 let channels = source.channels() as usize;
165 if channels == 0 {
166 return Err(ImpulseResponseError::InvalidChannels);
167 }
168
169 let sample_rate = source.sample_rate();
170 let mut channel_samples = vec![Vec::new(); channels];
171
172 for (index, sample) in source.enumerate() {
173 channel_samples[index % channels].push(sample as f32);
174 }
175
176 normalize_impulse_response_channels(&mut channel_samples, tail_db);
177
178 if channel_samples.iter().any(|channel| channel.is_empty()) {
179 warn!("Impulse response includes empty channels; results may be silent.");
180 }
181
182 Ok(ImpulseResponse {
183 sample_rate,
184 channels: channel_samples,
185 })
186}
187
188pub fn normalize_impulse_response_channels(channel_samples: &mut [Vec<f32>], tail_db: Option<f32>) {
194 let mut max_abs = 0.0_f32;
195 for channel in channel_samples.iter() {
196 for sample in channel {
197 let abs = sample.abs();
198 if abs > max_abs {
199 max_abs = abs;
200 }
201 }
202 }
203
204 if max_abs > 0.0 {
205 let scale = 1.0 / max_abs;
206 for channel in channel_samples.iter_mut() {
207 for sample in channel {
208 *sample *= scale;
209 }
210 }
211 }
212
213 if let Some(tail_db) = tail_db {
214 if tail_db.is_finite() {
215 trim_impulse_response_tail(channel_samples, tail_db);
216 }
217 }
218
219 let mut max_energy = 0.0_f32;
221 for channel in channel_samples.iter() {
222 let mut sum_sq = 0.0_f32;
223 for sample in channel {
224 sum_sq += sample * sample;
225 }
226 if sum_sq > max_energy {
227 max_energy = sum_sq;
228 }
229 }
230 if max_energy > 0.0 {
231 let mut scale = 1.0_f32 / max_energy.sqrt();
232 if scale > 1.0 {
233 scale = 1.0;
234 }
235 if scale < 1.0 {
236 for channel in channel_samples.iter_mut() {
237 for sample in channel {
238 *sample *= scale;
239 }
240 }
241 }
242 }
243}
244
245fn trim_impulse_response_tail(channels: &mut [Vec<f32>], tail_db: f32) {
246 if channels.is_empty() {
247 return;
248 }
249
250 let threshold = 10.0_f32.powf(tail_db / 20.0).abs();
251 if threshold <= 0.0 {
252 return;
253 }
254
255 let mut last_index = 0usize;
256 for (channel_index, channel) in channels.iter().enumerate() {
257 if channel.is_empty() {
258 continue;
259 }
260 let mut channel_last = None;
261 for (index, sample) in channel.iter().enumerate() {
262 if sample.abs() >= threshold {
263 channel_last = Some(index);
264 }
265 }
266 if let Some(channel_last) = channel_last {
267 if channel_index == 0 || channel_last > last_index {
268 last_index = channel_last;
269 }
270 }
271 }
272
273 let keep_len = (last_index + 1).max(1);
274 for channel in channels.iter_mut() {
275 if channel.len() > keep_len {
276 channel.truncate(keep_len);
277 }
278 }
279}