1use crate::core::AudioData;
2use thiserror::Error;
3use rayon::prelude::*;
4
5#[derive(Error, Debug)]
10pub enum SignalOpError {
11 #[error("Sample length mismatch: {0} vs {0}")]
13 LengthMismatch(usize, usize),
14
15 #[error("Division by zero at sample index {0}")]
17 DivisionByZero(usize),
18
19 #[error("Invalid input parameter: {0}")]
21 InvalidInput(String),
22
23 #[error("Computation failed: {0}")]
25 ComputationFailed(String),
26}
27
28pub fn mix_signals(signals: &[&AudioData]) -> Result<AudioData, SignalOpError> {
41 if signals.is_empty() {
42 return Err(SignalOpError::InvalidInput("Signal array is empty".to_string()));
43 }
44
45 let length = signals[0].samples.len();
46 let sample_rate = signals[0].sample_rate;
47 let channels = signals[0].channels;
48
49 for &signal in signals.iter().skip(1) {
50 if signal.samples.len() != length {
51 return Err(SignalOpError::LengthMismatch(length, signal.samples.len()));
52 }
53 if signal.sample_rate != sample_rate || signal.channels != channels {
54 return Err(SignalOpError::InvalidInput(
55 format!(
56 "Metadata mismatch: expected {} Hz, {} channels; got {} Hz, {} channels",
57 sample_rate, channels, signal.sample_rate, signal.channels
58 )
59 ));
60 }
61 }
62
63 let mixed_samples: Vec<f32> = (0..length)
64 .into_par_iter()
65 .map(|i| {
66 let sum: f32 = signals.iter().map(|s| s.samples[i]).sum();
67 sum / signals.len() as f32
68 })
69 .collect();
70
71 Ok(AudioData::new(mixed_samples, sample_rate, channels))
72}
73
74pub fn subtract_signals(signal1: &AudioData, signal2: &AudioData) -> Result<AudioData, SignalOpError> {
87 if signal1.samples.len() != signal2.samples.len() {
88 return Err(SignalOpError::LengthMismatch(signal1.samples.len(), signal2.samples.len()));
89 }
90 if signal1.sample_rate != signal2.sample_rate || signal1.channels != signal2.channels {
91 return Err(SignalOpError::InvalidInput(
92 format!(
93 "Metadata mismatch: expected {} Hz, {} channels; got {} Hz, {} channels",
94 signal1.sample_rate, signal1.channels, signal2.sample_rate, signal2.channels
95 )
96 ));
97 }
98
99 let samples: Vec<f32> = signal1.samples
100 .par_iter()
101 .zip(&signal2.samples)
102 .map(|(&s1, &s2)| s1 - s2)
103 .collect();
104
105 Ok(AudioData::new(samples, signal1.sample_rate, signal1.channels))
106}
107
108pub fn multiply_signals(signal1: &AudioData, signal2: &AudioData) -> Result<AudioData, SignalOpError> {
122 if signal1.samples.len() != signal2.samples.len() {
123 return Err(SignalOpError::LengthMismatch(signal1.samples.len(), signal2.samples.len()));
124 }
125 if signal1.sample_rate != signal2.sample_rate || signal1.channels != signal2.channels {
126 return Err(SignalOpError::InvalidInput(
127 format!(
128 "Metadata mismatch: expected {} Hz, {} channels; got {} Hz, {} channels",
129 signal1.sample_rate, signal1.channels, signal2.sample_rate, signal2.channels
130 )
131 ));
132 }
133
134 let samples: Vec<f32> = signal1.samples
135 .par_iter()
136 .zip(&signal2.samples)
137 .map(|(&s1, &s2)| s1 * s2)
138 .collect();
139
140 if samples.iter().any(|&s| !s.is_finite()) {
141 return Err(SignalOpError::ComputationFailed("Non-finite result detected".to_string()));
142 }
143
144 Ok(AudioData::new(samples, signal1.sample_rate, signal1.channels))
145}
146
147pub fn divide_signals(signal1: &AudioData, signal2: &AudioData) -> Result<AudioData, SignalOpError> {
161 if signal1.samples.len() != signal2.samples.len() {
162 return Err(SignalOpError::LengthMismatch(signal1.samples.len(), signal2.samples.len()));
163 }
164 if signal1.sample_rate != signal2.sample_rate || signal1.channels != signal2.channels {
165 return Err(SignalOpError::InvalidInput(
166 format!(
167 "Metadata mismatch: expected {} Hz, {} channels; got {} Hz, {} channels",
168 signal1.sample_rate, signal1.channels, signal2.sample_rate, signal2.channels
169 )
170 ));
171 }
172
173 let samples: Vec<f32> = signal1.samples
174 .par_iter()
175 .zip(&signal2.samples)
176 .enumerate()
177 .map(|(i, (&s1, &s2))| {
178 if s2 == 0.0 {
179 eprintln!("Warning: Division by zero at index {}, clamping to 0.0", i);
180 0.0
181 } else {
182 s1 / s2
183 }
184 })
185 .collect();
186
187 if samples.iter().any(|&s| !s.is_finite()) {
188 return Err(SignalOpError::ComputationFailed("Non-finite result detected".to_string()));
189 }
190
191 Ok(AudioData::new(samples, signal1.sample_rate, signal1.channels))
192}
193
194pub fn scalar_operation(signal: &AudioData, scalar: f32, op: &str) -> Result<AudioData, SignalOpError> {
209 let samples: Vec<f32> = match op.to_lowercase().as_str() {
210 "add" => signal.samples.par_iter().map(|&s| s + scalar).collect(),
211 "subtract" => signal.samples.par_iter().map(|&s| s - scalar).collect(),
212 "multiply" => signal.samples.par_iter().map(|&s| s * scalar).collect(),
213 "divide" => {
214 if scalar == 0.0 {
215 return Err(SignalOpError::DivisionByZero(0));
216 }
217 signal.samples.par_iter().map(|&s| s / scalar).collect()
218 }
219 _ => return Err(SignalOpError::InvalidInput(format!("Unsupported operation: {}", op))),
220 };
221
222 if samples.iter().any(|&s| !s.is_finite()) {
223 return Err(SignalOpError::ComputationFailed("Non-finite result detected".to_string()));
224 }
225
226 Ok(AudioData::new(samples, signal.sample_rate, signal.channels))
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 fn test_signal(samples: Vec<f32>) -> AudioData {
234 AudioData::new(samples, 44100, 1)
235 }
236
237 #[test]
238 fn test_mix_signals_basic() {
239 let s1 = test_signal(vec![1.0, 2.0, 3.0]);
240 let s2 = test_signal(vec![2.0, 4.0, 6.0]);
241 let mixed = mix_signals(&[&s1, &s2]).unwrap();
242 assert_eq!(mixed.samples, vec![1.5, 3.0, 4.5]);
243 assert_eq!(mixed.sample_rate, 44100);
244 assert_eq!(mixed.channels, 1);
245 }
246
247 #[test]
248 fn test_mix_signals_empty() {
249 let result = mix_signals(&[]);
250 assert!(matches!(result, Err(SignalOpError::InvalidInput(_))));
251 }
252
253 #[test]
254 fn test_mix_signals_length_mismatch() {
255 let s1 = test_signal(vec![1.0, 2.0]);
256 let s2 = test_signal(vec![2.0, 4.0, 6.0]);
257 let result = mix_signals(&[&s1, &s2]);
258 assert!(matches!(result, Err(SignalOpError::LengthMismatch(2, 3))));
259 }
260
261 #[test]
262 fn test_subtract_signals() {
263 let s1 = test_signal(vec![2.0, 4.0, 6.0]);
264 let s2 = test_signal(vec![1.0, 2.0, 3.0]);
265 let result = subtract_signals(&s1, &s2).unwrap();
266 assert_eq!(result.samples, vec![1.0, 2.0, 3.0]);
267 }
268
269 #[test]
270 fn test_subtract_signals_mismatch() {
271 let s1 = test_signal(vec![2.0, 4.0]);
272 let s2 = test_signal(vec![1.0, 2.0, 3.0]);
273 let result = subtract_signals(&s1, &s2);
274 assert!(matches!(result, Err(SignalOpError::LengthMismatch(2, 3))));
275 }
276
277 #[test]
278 fn test_multiply_signals() {
279 let s1 = test_signal(vec![1.0, 2.0, 3.0]);
280 let s2 = test_signal(vec![2.0, 2.0, 2.0]);
281 let result = multiply_signals(&s1, &s2).unwrap();
282 assert_eq!(result.samples, vec![2.0, 4.0, 6.0]);
283 }
284
285 #[test]
286 fn test_multiply_signals_overflow() {
287 let s1 = test_signal(vec![f32::MAX, 2.0]);
288 let s2 = test_signal(vec![2.0, 2.0]);
289 let result = multiply_signals(&s1, &s2);
290 assert!(matches!(result, Err(SignalOpError::ComputationFailed(_))));
291 }
292
293 #[test]
294 fn test_divide_signals() {
295 let s1 = test_signal(vec![4.0, 6.0, 8.0]);
296 let s2 = test_signal(vec![2.0, 0.0, 4.0]);
297 let result = divide_signals(&s1, &s2).unwrap();
298 assert_eq!(result.samples, vec![2.0, 0.0, 2.0]);
299 }
300
301 #[test]
302 fn test_divide_signals_infinity() {
303 let s1 = test_signal(vec![f32::MAX, 1.0]);
304 let s2 = test_signal(vec![0.001, 1.0]);
305 let result = divide_signals(&s1, &s2);
306 assert!(matches!(result, Err(SignalOpError::ComputationFailed(_))));
307 }
308
309 #[test]
310 fn test_scalar_operation_add() {
311 let s = test_signal(vec![1.0, 2.0, 3.0]);
312 let result = scalar_operation(&s, 1.0, "add").unwrap();
313 assert_eq!(result.samples, vec![2.0, 3.0, 4.0]);
314 }
315
316 #[test]
317 fn test_scalar_operation_multiply() {
318 let s = test_signal(vec![1.0, 2.0, 3.0]);
319 let result = scalar_operation(&s, 2.0, "multiply").unwrap();
320 assert_eq!(result.samples, vec![2.0, 4.0, 6.0]);
321 }
322
323 #[test]
324 fn test_scalar_operation_divide_by_zero() {
325 let s = test_signal(vec![1.0, 2.0]);
326 let result = scalar_operation(&s, 0.0, "divide");
327 assert!(matches!(result, Err(SignalOpError::DivisionByZero(0))));
328 }
329
330 #[test]
331 fn test_scalar_operation_invalid_op() {
332 let s = test_signal(vec![1.0, 2.0]);
333 let result = scalar_operation(&s, 1.0, "invalid");
334 assert!(matches!(result, Err(SignalOpError::InvalidInput(_))));
335 }
336}