Skip to main content

proteus_lib/dsp/effects/
distortion.rs

1//! Distortion effect based on rodio's distortion source.
2
3use serde::{Deserialize, Serialize};
4
5use super::level::deserialize_linear_gain;
6use super::EffectContext;
7
8const DEFAULT_GAIN: f32 = 1.0;
9const DEFAULT_THRESHOLD: f32 = 1.0;
10
11/// Serialized configuration for distortion parameters.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(default)]
14pub struct DistortionSettings {
15    #[serde(deserialize_with = "deserialize_linear_gain")]
16    pub gain: f32,
17    #[serde(deserialize_with = "deserialize_linear_gain")]
18    pub threshold: f32,
19}
20
21impl DistortionSettings {
22    /// Create a distortion settings payload.
23    pub fn new(gain: f32, threshold: f32) -> Self {
24        Self { gain, threshold }
25    }
26}
27
28impl Default for DistortionSettings {
29    fn default() -> Self {
30        Self {
31            gain: DEFAULT_GAIN,
32            threshold: DEFAULT_THRESHOLD,
33        }
34    }
35}
36
37/// Configured distortion effect.
38#[derive(Clone, Serialize, Deserialize)]
39#[serde(default)]
40pub struct DistortionEffect {
41    pub enabled: bool,
42    #[serde(flatten)]
43    pub settings: DistortionSettings,
44}
45
46impl std::fmt::Debug for DistortionEffect {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("DistortionEffect")
49            .field("enabled", &self.enabled)
50            .field("settings", &self.settings)
51            .finish()
52    }
53}
54
55impl Default for DistortionEffect {
56    fn default() -> Self {
57        Self {
58            enabled: false,
59            settings: DistortionSettings::default(),
60        }
61    }
62}
63
64impl DistortionEffect {
65    /// Process interleaved samples through the distortion effect.
66    ///
67    /// # Arguments
68    /// - `samples`: Interleaved input samples.
69    /// - `context`: Environment details (unused for this effect).
70    /// - `drain`: Unused for this effect.
71    ///
72    /// # Returns
73    /// Processed interleaved samples.
74    pub fn process(&mut self, samples: &[f32], _context: &EffectContext, _drain: bool) -> Vec<f32> {
75        if !self.enabled {
76            return samples.to_vec();
77        }
78
79        let gain = sanitize_gain(self.settings.gain);
80        let threshold = sanitize_threshold(self.settings.threshold);
81        if samples.is_empty() {
82            return Vec::new();
83        }
84
85        let mut out = Vec::with_capacity(samples.len());
86        for &sample in samples {
87            let v = sample * gain;
88            out.push(v.clamp(-threshold, threshold));
89        }
90
91        out
92    }
93
94    /// Reset any internal state (none for distortion).
95    pub fn reset_state(&mut self) {}
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    fn context() -> EffectContext {
103        EffectContext {
104            sample_rate: 44_100,
105            channels: 1,
106            container_path: None,
107            impulse_response_spec: None,
108            impulse_response_tail_db: -60.0,
109        }
110    }
111
112    #[test]
113    fn distortion_disabled_passthrough() {
114        let mut effect = DistortionEffect::default();
115        let samples = vec![0.25_f32, -0.25, 0.5, -0.5];
116        let output = effect.process(&samples, &context(), false);
117        assert_eq!(output, samples);
118    }
119
120    #[test]
121    fn distortion_clamps_output() {
122        let mut effect = DistortionEffect::default();
123        effect.enabled = true;
124        effect.settings.gain = 2.0;
125        effect.settings.threshold = 0.5;
126        let samples = vec![0.4_f32, -0.4, 0.6, -0.6];
127        let output = effect.process(&samples, &context(), false);
128        assert_eq!(output.len(), samples.len());
129        assert_eq!(output[0], 0.5);
130        assert_eq!(output[1], -0.5);
131        assert_eq!(output[2], 0.5);
132        assert_eq!(output[3], -0.5);
133    }
134
135    #[test]
136    fn distortion_deserializes_db_gain() {
137        let json = r#"{"enabled":true,"gain":"6db","threshold":"-6db"}"#;
138        let effect: DistortionEffect = serde_json::from_str(json).expect("deserialize distortion");
139        assert!(effect.settings.gain > 1.0);
140        assert!(effect.settings.threshold > 0.0);
141    }
142}
143
144fn sanitize_gain(gain: f32) -> f32 {
145    if gain.is_finite() {
146        gain
147    } else {
148        DEFAULT_GAIN
149    }
150}
151
152fn sanitize_threshold(threshold: f32) -> f32 {
153    if !threshold.is_finite() {
154        return DEFAULT_THRESHOLD;
155    }
156    let t = threshold.abs();
157    if t <= f32::EPSILON {
158        DEFAULT_THRESHOLD
159    } else {
160        t
161    }
162}