Skip to main content

proteus_lib/dsp/effects/
distortion.rs

1//! Distortion effect based on rodio's distortion source.
2
3use serde::{Deserialize, Serialize};
4
5use super::EffectContext;
6
7const DEFAULT_GAIN: f32 = 1.0;
8const DEFAULT_THRESHOLD: f32 = 1.0;
9
10/// Serialized configuration for distortion parameters.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(default)]
13pub struct DistortionSettings {
14    pub gain: f32,
15    pub threshold: f32,
16}
17
18impl DistortionSettings {
19    /// Create a distortion settings payload.
20    pub fn new(gain: f32, threshold: f32) -> Self {
21        Self { gain, threshold }
22    }
23}
24
25impl Default for DistortionSettings {
26    fn default() -> Self {
27        Self {
28            gain: DEFAULT_GAIN,
29            threshold: DEFAULT_THRESHOLD,
30        }
31    }
32}
33
34/// Configured distortion effect.
35#[derive(Clone, Serialize, Deserialize)]
36#[serde(default)]
37pub struct DistortionEffect {
38    pub enabled: bool,
39    #[serde(flatten)]
40    pub settings: DistortionSettings,
41}
42
43impl std::fmt::Debug for DistortionEffect {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("DistortionEffect")
46            .field("enabled", &self.enabled)
47            .field("settings", &self.settings)
48            .finish()
49    }
50}
51
52impl Default for DistortionEffect {
53    fn default() -> Self {
54        Self {
55            enabled: false,
56            settings: DistortionSettings::default(),
57        }
58    }
59}
60
61impl DistortionEffect {
62    /// Process interleaved samples through the distortion effect.
63    ///
64    /// # Arguments
65    /// - `samples`: Interleaved input samples.
66    /// - `context`: Environment details (unused for this effect).
67    /// - `drain`: Unused for this effect.
68    ///
69    /// # Returns
70    /// Processed interleaved samples.
71    pub fn process(&mut self, samples: &[f32], _context: &EffectContext, _drain: bool) -> Vec<f32> {
72        if !self.enabled {
73            return samples.to_vec();
74        }
75
76        let gain = sanitize_gain(self.settings.gain);
77        let threshold = sanitize_threshold(self.settings.threshold);
78        if samples.is_empty() {
79            return Vec::new();
80        }
81
82        let mut out = Vec::with_capacity(samples.len());
83        for &sample in samples {
84            let v = sample * gain;
85            out.push(v.clamp(-threshold, threshold));
86        }
87
88        out
89    }
90
91    /// Reset any internal state (none for distortion).
92    pub fn reset_state(&mut self) {}
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    fn context() -> EffectContext {
100        EffectContext {
101            sample_rate: 44_100,
102            channels: 1,
103            container_path: None,
104            impulse_response_spec: None,
105            impulse_response_tail_db: -60.0,
106        }
107    }
108
109    #[test]
110    fn distortion_disabled_passthrough() {
111        let mut effect = DistortionEffect::default();
112        let samples = vec![0.25_f32, -0.25, 0.5, -0.5];
113        let output = effect.process(&samples, &context(), false);
114        assert_eq!(output, samples);
115    }
116
117    #[test]
118    fn distortion_clamps_output() {
119        let mut effect = DistortionEffect::default();
120        effect.enabled = true;
121        effect.settings.gain = 2.0;
122        effect.settings.threshold = 0.5;
123        let samples = vec![0.4_f32, -0.4, 0.6, -0.6];
124        let output = effect.process(&samples, &context(), false);
125        assert_eq!(output.len(), samples.len());
126        assert_eq!(output[0], 0.5);
127        assert_eq!(output[1], -0.5);
128        assert_eq!(output[2], 0.5);
129        assert_eq!(output[3], -0.5);
130    }
131}
132
133fn sanitize_gain(gain: f32) -> f32 {
134    if gain.is_finite() {
135        gain
136    } else {
137        DEFAULT_GAIN
138    }
139}
140
141fn sanitize_threshold(threshold: f32) -> f32 {
142    if !threshold.is_finite() {
143        return DEFAULT_THRESHOLD;
144    }
145    let t = threshold.abs();
146    if t <= f32::EPSILON {
147        DEFAULT_THRESHOLD
148    } else {
149        t
150    }
151}