dasp_rs/core/
ops.rs

1use crate::core::AudioData;
2use thiserror::Error;
3use rayon::prelude::*;
4
5/// Enumerates error conditions for signal operation failures in DSP workflows.
6///
7/// Provides detailed diagnostics for binary and scalar operations on audio signals,
8/// tailored for debugging and error recovery in production-grade audio processing pipelines.
9#[derive(Error, Debug)]
10pub enum SignalOpError {
11    /// Signals have incompatible sample lengths for binary operations.
12    #[error("Sample length mismatch: {0} vs {0}")]
13    LengthMismatch(usize, usize),
14
15    /// Division by zero encountered at a specific sample index.
16    #[error("Division by zero at sample index {0}")]
17    DivisionByZero(usize),
18
19    /// Input validation failure (e.g., empty signal array, mismatched metadata).
20    #[error("Invalid input parameter: {0}")]
21    InvalidInput(String),
22
23    /// Numerical computation failure (e.g., overflow, NaN result).
24    #[error("Computation failed: {0}")]
25    ComputationFailed(String),
26}
27
28/// Mixes multiple audio signals by averaging their samples in parallel.
29///
30/// Computes the sample-wise mean of an array of `AudioData` signals, producing a new
31/// `AudioData` instance. All signals must share identical sample lengths, sample rates,
32/// and channel counts. Parallelized using `rayon` for multi-core efficiency.
33///
34/// # Parameters
35/// - `signals`: Slice of `AudioData` references to mix.
36///
37/// # Returns
38/// - `Ok(AudioData)`: Mixed signal with averaged samples.
39/// - `Err(SignalOpError)`: Failure due to empty input, length mismatch, or metadata inconsistency.
40pub fn mix_signals(signals: &[&AudioData]) -> Result<AudioData, SignalOpError> {
41    if signals.is_empty() {
42        return Err(SignalOpError::InvalidInput("Signal array is empty".to_string()));
43    }
44
45    let length = signals[0].samples.len();
46    let sample_rate = signals[0].sample_rate;
47    let channels = signals[0].channels;
48
49    for &signal in signals.iter().skip(1) {
50        if signal.samples.len() != length {
51            return Err(SignalOpError::LengthMismatch(length, signal.samples.len()));
52        }
53        if signal.sample_rate != sample_rate || signal.channels != channels {
54            return Err(SignalOpError::InvalidInput(
55                format!(
56                    "Metadata mismatch: expected {} Hz, {} channels; got {} Hz, {} channels",
57                    sample_rate, channels, signal.sample_rate, signal.channels
58                )
59            ));
60        }
61    }
62
63    let mixed_samples: Vec<f32> = (0..length)
64        .into_par_iter()
65        .map(|i| {
66            let sum: f32 = signals.iter().map(|s| s.samples[i]).sum();
67            sum / signals.len() as f32
68        })
69        .collect();
70
71    Ok(AudioData::new(mixed_samples, sample_rate, channels))
72}
73
74/// Subtracts one audio signal from another with parallel sample processing.
75///
76/// Performs sample-wise subtraction (`signal1 - signal2`), producing a new `AudioData`.
77/// Signals must have identical sample lengths, sample rates, and channel counts.
78///
79/// # Parameters
80/// - `signal1`: Base signal (minuend).
81/// - `signal2`: Signal to subtract (subtrahend).
82///
83/// # Returns
84/// - `Ok(AudioData)`: Resulting difference signal.
85/// - `Err(SignalOpError)`: Failure due to length or metadata mismatch.
86pub fn subtract_signals(signal1: &AudioData, signal2: &AudioData) -> Result<AudioData, SignalOpError> {
87    if signal1.samples.len() != signal2.samples.len() {
88        return Err(SignalOpError::LengthMismatch(signal1.samples.len(), signal2.samples.len()));
89    }
90    if signal1.sample_rate != signal2.sample_rate || signal1.channels != signal2.channels {
91        return Err(SignalOpError::InvalidInput(
92            format!(
93                "Metadata mismatch: expected {} Hz, {} channels; got {} Hz, {} channels",
94                signal1.sample_rate, signal1.channels, signal2.sample_rate, signal2.channels
95            )
96        ));
97    }
98
99    let samples: Vec<f32> = signal1.samples
100        .par_iter()
101        .zip(&signal2.samples)
102        .map(|(&s1, &s2)| s1 - s2)
103        .collect();
104
105    Ok(AudioData::new(samples, signal1.sample_rate, signal1.channels))
106}
107
108/// Multiplies two audio signals sample-wise in parallel (e.g., amplitude modulation).
109///
110/// Computes the product of corresponding samples from `signal1` and `signal2`, producing
111/// a new `AudioData`. Suitable for modulation effects. Signals must match in length,
112/// sample rate, and channels.
113///
114/// # Parameters
115/// - `signal1`: First signal (carrier or base).
116/// - `signal2`: Second signal (modulator).
117///
118/// # Returns
119/// - `Ok(AudioData)`: Product signal.
120/// - `Err(SignalOpError)`: Failure due to length or metadata mismatch.
121pub fn multiply_signals(signal1: &AudioData, signal2: &AudioData) -> Result<AudioData, SignalOpError> {
122    if signal1.samples.len() != signal2.samples.len() {
123        return Err(SignalOpError::LengthMismatch(signal1.samples.len(), signal2.samples.len()));
124    }
125    if signal1.sample_rate != signal2.sample_rate || signal1.channels != signal2.channels {
126        return Err(SignalOpError::InvalidInput(
127            format!(
128                "Metadata mismatch: expected {} Hz, {} channels; got {} Hz, {} channels",
129                signal1.sample_rate, signal1.channels, signal2.sample_rate, signal2.channels
130            )
131        ));
132    }
133
134    let samples: Vec<f32> = signal1.samples
135        .par_iter()
136        .zip(&signal2.samples)
137        .map(|(&s1, &s2)| s1 * s2)
138        .collect();
139
140    if samples.iter().any(|&s| !s.is_finite()) {
141        return Err(SignalOpError::ComputationFailed("Non-finite result detected".to_string()));
142    }
143
144    Ok(AudioData::new(samples, signal1.sample_rate, signal1.channels))
145}
146
147/// Divides one audio signal by another with parallel processing and zero handling.
148///
149/// Performs sample-wise division (`signal1 / signal2`), producing a new `AudioData`.
150/// Handles division by zero by clamping to 0.0 and logs a warning. Signals must match
151/// in length, sample rate, and channels.
152///
153/// # Parameters
154/// - `signal1`: Numerator signal.
155/// - `signal2`: Denominator signal.
156///
157/// # Returns
158/// - `Ok(AudioData)`: Quotient signal.
159/// - `Err(SignalOpError)`: Failure due to length or metadata mismatch.
160pub fn divide_signals(signal1: &AudioData, signal2: &AudioData) -> Result<AudioData, SignalOpError> {
161    if signal1.samples.len() != signal2.samples.len() {
162        return Err(SignalOpError::LengthMismatch(signal1.samples.len(), signal2.samples.len()));
163    }
164    if signal1.sample_rate != signal2.sample_rate || signal1.channels != signal2.channels {
165        return Err(SignalOpError::InvalidInput(
166            format!(
167                "Metadata mismatch: expected {} Hz, {} channels; got {} Hz, {} channels",
168                signal1.sample_rate, signal1.channels, signal2.sample_rate, signal2.channels
169            )
170        ));
171    }
172
173    let samples: Vec<f32> = signal1.samples
174        .par_iter()
175        .zip(&signal2.samples)
176        .enumerate()
177        .map(|(i, (&s1, &s2))| {
178            if s2 == 0.0 {
179                eprintln!("Warning: Division by zero at index {}, clamping to 0.0", i);
180                0.0
181            } else {
182                s1 / s2
183            }
184        })
185        .collect();
186
187    if samples.iter().any(|&s| !s.is_finite()) {
188        return Err(SignalOpError::ComputationFailed("Non-finite result detected".to_string()));
189    }
190
191    Ok(AudioData::new(samples, signal1.sample_rate, signal1.channels))
192}
193
194/// Applies a scalar operation to an audio signal in parallel.
195///
196/// Performs element-wise addition, subtraction, multiplication, or division between
197/// a signal’s samples and a scalar value, producing a new `AudioData`. Division by zero
198/// is explicitly rejected.
199///
200/// # Parameters
201/// - `signal`: Input signal.
202/// - `scalar`: Scalar value for operation.
203/// - `op`: Operation type: `"add"`, `"subtract"`, `"multiply"`, or `"divide"`.
204///
205/// # Returns
206/// - `Ok(AudioData)`: Resulting signal.
207/// - `Err(SignalOpError)`: Failure due to invalid operation or division by zero.
208pub fn scalar_operation(signal: &AudioData, scalar: f32, op: &str) -> Result<AudioData, SignalOpError> {
209    let samples: Vec<f32> = match op.to_lowercase().as_str() {
210        "add" => signal.samples.par_iter().map(|&s| s + scalar).collect(),
211        "subtract" => signal.samples.par_iter().map(|&s| s - scalar).collect(),
212        "multiply" => signal.samples.par_iter().map(|&s| s * scalar).collect(),
213        "divide" => {
214            if scalar == 0.0 {
215                return Err(SignalOpError::DivisionByZero(0));
216            }
217            signal.samples.par_iter().map(|&s| s / scalar).collect()
218        }
219        _ => return Err(SignalOpError::InvalidInput(format!("Unsupported operation: {}", op))),
220    };
221
222    if samples.iter().any(|&s| !s.is_finite()) {
223        return Err(SignalOpError::ComputationFailed("Non-finite result detected".to_string()));
224    }
225
226    Ok(AudioData::new(samples, signal.sample_rate, signal.channels))
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    fn test_signal(samples: Vec<f32>) -> AudioData {
234        AudioData::new(samples, 44100, 1)
235    }
236
237    #[test]
238    fn test_mix_signals_basic() {
239        let s1 = test_signal(vec![1.0, 2.0, 3.0]);
240        let s2 = test_signal(vec![2.0, 4.0, 6.0]);
241        let mixed = mix_signals(&[&s1, &s2]).unwrap();
242        assert_eq!(mixed.samples, vec![1.5, 3.0, 4.5]);
243        assert_eq!(mixed.sample_rate, 44100);
244        assert_eq!(mixed.channels, 1);
245    }
246
247    #[test]
248    fn test_mix_signals_empty() {
249        let result = mix_signals(&[]);
250        assert!(matches!(result, Err(SignalOpError::InvalidInput(_))));
251    }
252
253    #[test]
254    fn test_mix_signals_length_mismatch() {
255        let s1 = test_signal(vec![1.0, 2.0]);
256        let s2 = test_signal(vec![2.0, 4.0, 6.0]);
257        let result = mix_signals(&[&s1, &s2]);
258        assert!(matches!(result, Err(SignalOpError::LengthMismatch(2, 3))));
259    }
260
261    #[test]
262    fn test_subtract_signals() {
263        let s1 = test_signal(vec![2.0, 4.0, 6.0]);
264        let s2 = test_signal(vec![1.0, 2.0, 3.0]);
265        let result = subtract_signals(&s1, &s2).unwrap();
266        assert_eq!(result.samples, vec![1.0, 2.0, 3.0]);
267    }
268
269    #[test]
270    fn test_subtract_signals_mismatch() {
271        let s1 = test_signal(vec![2.0, 4.0]);
272        let s2 = test_signal(vec![1.0, 2.0, 3.0]);
273        let result = subtract_signals(&s1, &s2);
274        assert!(matches!(result, Err(SignalOpError::LengthMismatch(2, 3))));
275    }
276
277    #[test]
278    fn test_multiply_signals() {
279        let s1 = test_signal(vec![1.0, 2.0, 3.0]);
280        let s2 = test_signal(vec![2.0, 2.0, 2.0]);
281        let result = multiply_signals(&s1, &s2).unwrap();
282        assert_eq!(result.samples, vec![2.0, 4.0, 6.0]);
283    }
284
285    #[test]
286    fn test_multiply_signals_overflow() {
287        let s1 = test_signal(vec![f32::MAX, 2.0]);
288        let s2 = test_signal(vec![2.0, 2.0]);
289        let result = multiply_signals(&s1, &s2);
290        assert!(matches!(result, Err(SignalOpError::ComputationFailed(_))));
291    }
292
293    #[test]
294    fn test_divide_signals() {
295        let s1 = test_signal(vec![4.0, 6.0, 8.0]);
296        let s2 = test_signal(vec![2.0, 0.0, 4.0]);
297        let result = divide_signals(&s1, &s2).unwrap();
298        assert_eq!(result.samples, vec![2.0, 0.0, 2.0]);
299    }
300
301    #[test]
302    fn test_divide_signals_infinity() {
303        let s1 = test_signal(vec![f32::MAX, 1.0]);
304        let s2 = test_signal(vec![0.001, 1.0]);
305        let result = divide_signals(&s1, &s2);
306        assert!(matches!(result, Err(SignalOpError::ComputationFailed(_))));
307    }
308
309    #[test]
310    fn test_scalar_operation_add() {
311        let s = test_signal(vec![1.0, 2.0, 3.0]);
312        let result = scalar_operation(&s, 1.0, "add").unwrap();
313        assert_eq!(result.samples, vec![2.0, 3.0, 4.0]);
314    }
315
316    #[test]
317    fn test_scalar_operation_multiply() {
318        let s = test_signal(vec![1.0, 2.0, 3.0]);
319        let result = scalar_operation(&s, 2.0, "multiply").unwrap();
320        assert_eq!(result.samples, vec![2.0, 4.0, 6.0]);
321    }
322
323    #[test]
324    fn test_scalar_operation_divide_by_zero() {
325        let s = test_signal(vec![1.0, 2.0]);
326        let result = scalar_operation(&s, 0.0, "divide");
327        assert!(matches!(result, Err(SignalOpError::DivisionByZero(0))));
328    }
329
330    #[test]
331    fn test_scalar_operation_invalid_op() {
332        let s = test_signal(vec![1.0, 2.0]);
333        let result = scalar_operation(&s, 1.0, "invalid");
334        assert!(matches!(result, Err(SignalOpError::InvalidInput(_))));
335    }
336}