dasp_rs/signal_processing/
amplitude.rs

1use crate::core::io::AudioData;
2use thiserror::Error;
3
4/// Custom error types for amplitude processing operations.
5///
6/// This enum defines errors specific to amplitude manipulation functions like
7/// amplification, attenuation, and normalization.
8#[derive(Error, Debug)]
9pub enum AmplitudeError {
10    /// Error when an invalid gain factor is provided (e.g., negative or zero where not allowed).
11    #[error("Invalid gain factor: {0}")]
12    InvalidGain(String),
13
14    /// Error when the input signal is invalid (e.g., empty or zero amplitude).
15    #[error("Invalid signal: {0}")]
16    InvalidSignal(String),
17}
18
19/// Amplifies an audio signal by a specified gain factor.
20///
21/// This function increases the amplitude of the signal by multiplying each sample
22/// by the gain factor. A gain > 1.0 increases amplitude; values <= 0 are invalid.
23///
24/// # Arguments
25/// * `signal` - The input audio signal.
26/// * `gain` - The amplification factor (must be positive).
27///
28/// # Returns
29/// Returns `Result<AudioData, AmplitudeError>` containing the amplified signal or an error.
30///
31/// # Examples
32/// ```
33/// let signal = AudioData { samples: vec![0.5, 1.0, 0.5], sample_rate: 44100, channels: 1 };
34/// let amplified = amplify(&signal, 2.0)?;
35/// assert_eq!(amplified.samples, vec![1.0, 2.0, 1.0]);
36/// ```
37pub fn amplify(signal: &AudioData, gain: f32) -> Result<AudioData, AmplitudeError> {
38    if gain <= 0.0 {
39        return Err(AmplitudeError::InvalidGain(
40            "Gain must be positive".to_string(),
41        ));
42    }
43
44    let samples = signal.samples.iter().map(|&s| s * gain).collect();
45    Ok(AudioData {
46        samples,
47        sample_rate: signal.sample_rate,
48        channels: signal.channels,
49    })
50}
51
52/// Attenuates an audio signal by a specified gain factor.
53///
54/// This function decreases the amplitude of the signal by multiplying each sample
55/// by the gain factor. A gain between 0.0 and 1.0 reduces amplitude; values < 0 are invalid.
56///
57/// # Arguments
58/// * `signal` - The input audio signal.
59/// * `gain` - The attenuation factor (must be non-negative, typically < 1.0).
60///
61/// # Returns
62/// Returns `Result<AudioData, AmplitudeError>` containing the attenuated signal or an error.
63///
64/// # Examples
65/// ```
66/// let signal = AudioData { samples: vec![1.0, 2.0, 1.0], sample_rate: 44100, channels: 1 };
67/// let attenuated = attenuate(&signal, 0.5)?;
68/// assert_eq!(attenuated.samples, vec![0.5, 1.0, 0.5]);
69/// ```
70pub fn attenuate(signal: &AudioData, gain: f32) -> Result<AudioData, AmplitudeError> {
71    if gain < 0.0 {
72        return Err(AmplitudeError::InvalidGain(
73            "Gain must be non-negative".to_string(),
74        ));
75    }
76
77    let samples = signal.samples.iter().map(|&s| s * gain).collect();
78    Ok(AudioData {
79        samples,
80        sample_rate: signal.sample_rate,
81        channels: signal.channels,
82    })
83}
84
85/// Normalizes an audio signal to a target peak or RMS level.
86///
87/// This function scales the signal so its peak amplitude or RMS (root mean square)
88/// level matches the target value. Useful for ensuring consistent loudness.
89///
90/// # Arguments
91/// * `signal` - The input audio signal.
92/// * `target` - The target level (e.g., 1.0 for full scale).
93/// * `mode` - "peak" for peak normalization, "rms" for RMS normalization.
94///
95/// # Returns
96/// Returns `Result<AudioData, AmplitudeError>` containing the normalized signal or an error.
97///
98/// # Examples
99/// ```
100/// let signal = AudioData { samples: vec![0.5, 1.0, 0.5], sample_rate: 44100, channels: 1 };
101/// let normalized = normalize(&signal, 1.0, "peak")?;
102/// assert_eq!(normalized.samples, vec![0.5, 1.0, 0.5]); // Already at peak 1.0
103///
104/// let signal = AudioData { samples: vec![0.2, 0.4, 0.2], sample_rate: 44100, channels: 1 };
105/// let normalized = normalize(&signal, 1.0, "peak")?;
106/// assert_eq!(normalized.samples, vec![0.5, 1.0, 0.5]); // Scaled to peak 1.0
107/// ```
108pub fn normalize(signal: &AudioData, target: f32, mode: &str) -> Result<AudioData, AmplitudeError> {
109    if target <= 0.0 {
110        return Err(AmplitudeError::InvalidGain(
111            "Target level must be positive".to_string(),
112        ));
113    }
114    if signal.samples.is_empty() {
115        return Err(AmplitudeError::InvalidSignal(
116            "Signal cannot be empty".to_string(),
117        ));
118    }
119
120    let gain = match mode.to_lowercase().as_str() {
121        "peak" => {
122            let max_amplitude = signal
123                .samples
124                .iter()
125                .map(|s| s.abs())
126                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
127                .unwrap_or(0.0);
128            if max_amplitude == 0.0 {
129                return Err(AmplitudeError::InvalidSignal(
130                    "Signal has no amplitude to normalize".to_string(),
131                ));
132            }
133            target / max_amplitude
134        }
135        "rms" => {
136            let rms = (signal.samples.iter().map(|&s| s * s).sum::<f32>() / signal.samples.len() as f32)
137                .sqrt();
138            if rms == 0.0 {
139                return Err(AmplitudeError::InvalidSignal(
140                    "Signal has no RMS level to normalize".to_string(),
141                ));
142            }
143            target / rms
144        }
145        _ => {
146            return Err(AmplitudeError::InvalidGain(format!(
147                "Unknown normalization mode: {}",
148                mode
149            )))
150        }
151    };
152
153    let samples = signal.samples.iter().map(|&s| s * gain).collect();
154    Ok(AudioData {
155        samples,
156        sample_rate: signal.sample_rate,
157        channels: signal.channels,
158    })
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_amplify() {
167        let signal = AudioData {
168            samples: vec![0.5, 1.0, 0.5],
169            sample_rate: 44100,
170            channels: 1,
171        };
172        let amplified = amplify(&signal, 2.0).unwrap();
173        assert_eq!(amplified.samples, vec![1.0, 2.0, 1.0]);
174        
175        let result = amplify(&signal, 0.0);
176        assert!(matches!(result, Err(AmplitudeError::InvalidGain(_))));
177    }
178
179    #[test]
180    fn test_attenuate() {
181        let signal = AudioData {
182            samples: vec![1.0, 2.0, 1.0],
183            sample_rate: 44100,
184            channels: 1,
185        };
186        let attenuated = attenuate(&signal, 0.5).unwrap();
187        assert_eq!(attenuated.samples, vec![0.5, 1.0, 0.5]);
188        
189        let result = attenuate(&signal, -1.0);
190        assert!(matches!(result, Err(AmplitudeError::InvalidGain(_))));
191    }
192
193    #[test]
194    fn test_normalize_peak() {
195        let signal = AudioData {
196            samples: vec![0.2, 0.4, 0.2],
197            sample_rate: 44100,
198            channels: 1,
199        };
200        let normalized = normalize(&signal, 1.0, "peak").unwrap();
201        assert_eq!(normalized.samples, vec![0.5, 1.0, 0.5]);
202        
203        let silent = AudioData {
204            samples: vec![0.0, 0.0],
205            sample_rate: 44100,
206            channels: 1,
207        };
208        let result = normalize(&silent, 1.0, "peak");
209        assert!(matches!(result, Err(AmplitudeError::InvalidSignal(_))));
210    }
211
212    #[test]
213    fn test_normalize_rms() {
214        let signal = AudioData {
215            samples: vec![1.0, 1.0],
216            sample_rate: 44100,
217            channels: 1,
218        };
219        let normalized = normalize(&signal, 0.5, "rms").unwrap();
220        assert_eq!(normalized.samples, vec![0.5, 0.5]);
221        
222        let empty = AudioData {
223            samples: vec![],
224            sample_rate: 44100,
225            channels: 1,
226        };
227        let result = normalize(&empty, 1.0, "rms");
228        assert!(matches!(result, Err(AmplitudeError::InvalidSignal(_))));
229    }
230}