Skip to main content

firewheel_nodes/
fast_rms.rs

1use bevy_platform::sync::atomic::{AtomicU32, Ordering};
2use firewheel_core::{
3    atomic_float::AtomicF32,
4    channel_config::{ChannelConfig, ChannelCount},
5    collector::ArcGc,
6    diff::{Diff, Patch},
7    dsp::volume::amp_to_db,
8    event::ProcEvents,
9    node::{
10        AudioNode, AudioNodeInfo, AudioNodeProcessor, ConstructProcessorContext, EmptyConfig,
11        ProcBuffers, ProcExtra, ProcInfo, ProcStreamCtx, ProcessStatus,
12    },
13    StreamInfo,
14};
15
16#[cfg(not(feature = "std"))]
17use num_traits::Float;
18
19/// A lightweight node that measures the loudness of a mono signal using a rough RMS
20/// (root mean square) estimate.
21///
22/// Note this node doesn't calculate the true RMS (That requires a much more expensive
23/// algorithm using a sliding window.) But it should be good enough for games that
24/// simply wish to react to player audio.
25#[derive(Debug, Diff, Patch, Clone, Copy, PartialEq)]
26#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
27#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29pub struct FastRmsNode {
30    /// Whether or not this node is enabled.
31    pub enabled: bool,
32    /// The size of the window used for measuring the RMS value.
33    ///
34    /// Smaller values are better at detecting short bursts of loudness (transients),
35    /// while larger values are better for measuring loudness on a broader time scale.
36    ///
37    /// By default this is set to `0.05` (50ms).
38    pub window_size_secs: f32,
39}
40
41impl Default for FastRmsNode {
42    fn default() -> Self {
43        Self {
44            enabled: true,
45            window_size_secs: 50.0 / 1_000.0,
46        }
47    }
48}
49
50/// The state of a [`FastRmsNode`]. This contains the calculated RMS values.
51#[derive(Clone)]
52pub struct FastRmsState {
53    shared_state: ArcGc<SharedState>,
54}
55
56impl FastRmsState {
57    fn new() -> Self {
58        Self {
59            shared_state: ArcGc::new(SharedState {
60                rms_value: AtomicF32::new(0.0),
61                read_count: AtomicU32::new(1),
62            }),
63        }
64    }
65
66    /// Get the estimated RMS value in decibels.
67    ///
68    /// * `db_epsilon` - If the RMS value is less than or equal to this value, then it
69    /// will be clamped to `f32::NEG_INFINITY` (silence). (You can use
70    /// [firewheel_core::dsp::volume::DEFAULT_DB_EPSILON].)
71    ///
72    /// If the node is currently disabled, then this will return a value
73    /// of `f32::NEG_INFINITY` (silence).
74    ///
75    /// Note this node doesn't calculate the true RMS (That requires a much more expensive
76    /// algorithm using a sliding window.) But it should be good enough for games that
77    /// simply wish to react to player audio.
78    pub fn rms_db(&self, db_epsilon: f32) -> f32 {
79        let rms = amp_to_db(self.shared_state.rms_value.load(Ordering::Relaxed));
80        self.shared_state.read_count.fetch_add(1, Ordering::Relaxed);
81
82        if rms <= db_epsilon {
83            f32::NEG_INFINITY
84        } else {
85            rms
86        }
87    }
88}
89
90impl AudioNode for FastRmsNode {
91    type Configuration = EmptyConfig;
92
93    fn info(&self, _config: &Self::Configuration) -> AudioNodeInfo {
94        AudioNodeInfo::new()
95            .debug_name("fast_rms")
96            .channel_config(ChannelConfig {
97                num_inputs: ChannelCount::MONO,
98                num_outputs: ChannelCount::ZERO,
99            })
100            .custom_state(FastRmsState::new())
101    }
102
103    fn construct_processor(
104        &self,
105        _config: &Self::Configuration,
106        cx: ConstructProcessorContext,
107    ) -> impl AudioNodeProcessor {
108        let window_frames =
109            (self.window_size_secs * cx.stream_info.sample_rate.get() as f32).round() as usize;
110
111        let custom_state = cx.custom_state::<FastRmsState>().unwrap();
112
113        Processor {
114            params: self.clone(),
115            shared_state: ArcGc::clone(&custom_state.shared_state),
116            squares: 0.0,
117            num_squared_values: 0,
118            window_frames,
119            last_read_count: 0,
120        }
121    }
122}
123
124struct Processor {
125    params: FastRmsNode,
126    shared_state: ArcGc<SharedState>,
127    squares: f32,
128    num_squared_values: usize,
129    window_frames: usize,
130    last_read_count: u32,
131}
132
133impl AudioNodeProcessor for Processor {
134    fn process(
135        &mut self,
136        info: &ProcInfo,
137        buffers: ProcBuffers,
138        events: &mut ProcEvents,
139        _extra: &mut ProcExtra,
140    ) -> ProcessStatus {
141        for patch in events.drain_patches::<FastRmsNode>() {
142            match patch {
143                FastRmsNodePatch::WindowSizeSecs(window_size_secs) => {
144                    let window_frames =
145                        (window_size_secs * info.sample_rate.get() as f32).round() as usize;
146
147                    if self.window_frames != window_frames {
148                        self.window_frames = window_frames;
149
150                        self.squares = 0.0;
151                        self.num_squared_values = 0;
152                    }
153                }
154                _ => {}
155            }
156
157            self.params.apply(patch);
158        }
159
160        if !self.params.enabled {
161            self.shared_state.rms_value.store(0.0, Ordering::Relaxed);
162
163            self.squares = 0.0;
164            self.num_squared_values = 0;
165
166            return ProcessStatus::Bypass;
167        }
168
169        let mut frames_processed = 0;
170        while frames_processed < info.frames {
171            let process_frames =
172                (info.frames - frames_processed).min(self.window_frames - self.num_squared_values);
173
174            if !info.in_silence_mask.is_channel_silent(0) {
175                for &s in
176                    buffers.inputs[0][frames_processed..frames_processed + process_frames].iter()
177                {
178                    self.squares += s * s;
179                }
180            }
181
182            self.num_squared_values += process_frames;
183            frames_processed += process_frames;
184
185            if self.num_squared_values == self.window_frames {
186                let mean = self.squares / self.window_frames as f32;
187                let rms = mean.sqrt();
188
189                let latest_read_count = self.shared_state.read_count.load(Ordering::Relaxed);
190                let previous_rms = self.shared_state.rms_value.load(Ordering::Relaxed);
191
192                if latest_read_count != self.last_read_count || rms > previous_rms {
193                    self.shared_state.rms_value.store(rms, Ordering::Relaxed);
194                }
195
196                self.squares = 0.0;
197                self.num_squared_values = 0;
198                self.last_read_count = latest_read_count;
199            }
200        }
201
202        // There are no outputs in this node.
203        ProcessStatus::Bypass
204    }
205
206    fn new_stream(&mut self, stream_info: &StreamInfo, _context: &mut ProcStreamCtx) {
207        self.window_frames =
208            (self.params.window_size_secs * stream_info.sample_rate.get() as f32).round() as usize;
209
210        self.squares = 0.0;
211        self.num_squared_values = 0;
212    }
213}
214
215#[derive(Debug)]
216struct SharedState {
217    rms_value: AtomicF32,
218    // A simple counter used to keep track of when the processor should update
219    // the RMS value.
220    read_count: AtomicU32,
221}