proteus_lib/dsp/
impulse_response.rs1use std::fmt;
4use std::fs::File;
5use std::io::{BufReader, Cursor, Read, Seek};
6use std::path::Path;
7
8use log::warn;
9use matroska::Matroska;
10use rodio::{Decoder, Source};
11
12#[derive(Debug, Clone)]
16pub struct ImpulseResponse {
17 pub sample_rate: u32,
18 pub channels: Vec<Vec<f32>>,
19}
20
21impl ImpulseResponse {
22 pub fn channel_count(&self) -> usize {
24 self.channels.len()
25 }
26
27 pub fn channel_for_output(&self, index: usize) -> &[f32] {
32 if self.channels.is_empty() {
33 return &[];
34 }
35
36 if self.channels.len() == 1 {
37 return &self.channels[0];
38 }
39
40 let channel_index = index % self.channels.len();
41 &self.channels[channel_index]
42 }
43}
44
45#[derive(Debug)]
47pub enum ImpulseResponseError {
48 Io(std::io::Error),
49 Matroska(matroska::Error),
50 Decode(rodio::decoder::DecoderError),
51 AttachmentNotFound(String),
52 InvalidChannels,
53}
54
55impl fmt::Display for ImpulseResponseError {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 match self {
58 Self::Io(err) => write!(f, "failed to read impulse response: {}", err),
59 Self::Matroska(err) => write!(f, "failed to read prot container: {}", err),
60 Self::Decode(err) => write!(f, "failed to decode impulse response: {}", err),
61 Self::AttachmentNotFound(name) => {
62 write!(f, "impulse response attachment not found: {}", name)
63 }
64 Self::InvalidChannels => write!(f, "impulse response has invalid channel count"),
65 }
66 }
67}
68
69impl std::error::Error for ImpulseResponseError {}
70
71impl From<std::io::Error> for ImpulseResponseError {
72 fn from(err: std::io::Error) -> Self {
73 Self::Io(err)
74 }
75}
76
77impl From<rodio::decoder::DecoderError> for ImpulseResponseError {
78 fn from(err: rodio::decoder::DecoderError) -> Self {
79 Self::Decode(err)
80 }
81}
82
83impl From<matroska::Error> for ImpulseResponseError {
84 fn from(err: matroska::Error) -> Self {
85 Self::Matroska(err)
86 }
87}
88
89pub fn load_impulse_response_from_file(
99 path: impl AsRef<Path>,
100) -> Result<ImpulseResponse, ImpulseResponseError> {
101 load_impulse_response_from_file_with_tail(path, Some(-60.0))
102}
103
104pub fn load_impulse_response_from_file_with_tail(
106 path: impl AsRef<Path>,
107 tail_db: Option<f32>,
108) -> Result<ImpulseResponse, ImpulseResponseError> {
109 let file = File::open(path)?;
110 decode_impulse_response(BufReader::new(file), tail_db)
111}
112
113pub fn load_impulse_response_from_bytes(
115 bytes: &[u8],
116) -> Result<ImpulseResponse, ImpulseResponseError> {
117 load_impulse_response_from_bytes_with_tail(bytes, Some(-60.0))
118}
119
120pub fn load_impulse_response_from_bytes_with_tail(
122 bytes: &[u8],
123 tail_db: Option<f32>,
124) -> Result<ImpulseResponse, ImpulseResponseError> {
125 decode_impulse_response(BufReader::new(Cursor::new(bytes.to_vec())), tail_db)
126}
127
128pub fn load_impulse_response_from_prot_attachment(
130 prot_path: impl AsRef<Path>,
131 attachment_name: &str,
132) -> Result<ImpulseResponse, ImpulseResponseError> {
133 load_impulse_response_from_prot_attachment_with_tail(prot_path, attachment_name, Some(-60.0))
134}
135
136pub fn load_impulse_response_from_prot_attachment_with_tail(
138 prot_path: impl AsRef<Path>,
139 attachment_name: &str,
140 tail_db: Option<f32>,
141) -> Result<ImpulseResponse, ImpulseResponseError> {
142 let file = File::open(prot_path)?;
143 let mka: Matroska = Matroska::open(file)?;
144
145 let attachment = mka
146 .attachments
147 .iter()
148 .find(|attachment| attachment.name == attachment_name)
149 .ok_or_else(|| ImpulseResponseError::AttachmentNotFound(attachment_name.to_string()))?;
150
151 load_impulse_response_from_bytes_with_tail(&attachment.data, tail_db)
152}
153
154fn decode_impulse_response<R>(
155 reader: R,
156 tail_db: Option<f32>,
157) -> Result<ImpulseResponse, ImpulseResponseError>
158where
159 R: Read + Seek + Send + Sync + 'static,
160{
161 let source = Decoder::new(reader)?;
162 let channels = source.channels() as usize;
163 if channels == 0 {
164 return Err(ImpulseResponseError::InvalidChannels);
165 }
166
167 let sample_rate = source.sample_rate();
168 let mut channel_samples = vec![Vec::new(); channels];
169
170 for (index, sample) in source.enumerate() {
171 channel_samples[index % channels].push(sample as f32);
172 }
173
174 normalize_impulse_response_channels(&mut channel_samples, tail_db);
175
176 if channel_samples.iter().any(|channel| channel.is_empty()) {
177 warn!("Impulse response includes empty channels; results may be silent.");
178 }
179
180 Ok(ImpulseResponse {
181 sample_rate,
182 channels: channel_samples,
183 })
184}
185
186pub fn normalize_impulse_response_channels(
192 channel_samples: &mut [Vec<f32>],
193 tail_db: Option<f32>,
194) {
195 let mut max_abs = 0.0_f32;
196 for channel in channel_samples.iter() {
197 for sample in channel {
198 let abs = sample.abs();
199 if abs > max_abs {
200 max_abs = abs;
201 }
202 }
203 }
204
205 if max_abs > 0.0 {
206 let scale = 1.0 / max_abs;
207 for channel in channel_samples.iter_mut() {
208 for sample in channel {
209 *sample *= scale;
210 }
211 }
212 }
213
214 if let Some(tail_db) = tail_db {
215 if tail_db.is_finite() {
216 trim_impulse_response_tail(channel_samples, tail_db);
217 }
218 }
219
220 let mut max_energy = 0.0_f32;
222 for channel in channel_samples.iter() {
223 let mut sum_sq = 0.0_f32;
224 for sample in channel {
225 sum_sq += sample * sample;
226 }
227 if sum_sq > max_energy {
228 max_energy = sum_sq;
229 }
230 }
231 if max_energy > 0.0 {
232 let mut scale = 1.0_f32 / max_energy.sqrt();
233 if scale > 1.0 {
234 scale = 1.0;
235 }
236 if scale < 1.0 {
237 for channel in channel_samples.iter_mut() {
238 for sample in channel {
239 *sample *= scale;
240 }
241 }
242 }
243 }
244}
245
246fn trim_impulse_response_tail(channels: &mut [Vec<f32>], tail_db: f32) {
247 if channels.is_empty() {
248 return;
249 }
250
251 let threshold = 10.0_f32.powf(tail_db / 20.0).abs();
252 if threshold <= 0.0 {
253 return;
254 }
255
256 let mut last_index = 0usize;
257 for (channel_index, channel) in channels.iter().enumerate() {
258 if channel.is_empty() {
259 continue;
260 }
261 let mut channel_last = None;
262 for (index, sample) in channel.iter().enumerate() {
263 if sample.abs() >= threshold {
264 channel_last = Some(index);
265 }
266 }
267 if let Some(channel_last) = channel_last {
268 if channel_index == 0 || channel_last > last_index {
269 last_index = channel_last;
270 }
271 }
272 }
273
274 let keep_len = (last_index + 1).max(1);
275 for channel in channels.iter_mut() {
276 if channel.len() > keep_len {
277 channel.truncate(keep_len);
278 }
279 }
280}