kalosm_sound/transform/
voice_audio_detector_ext.rs

1use futures_core::ready;
2use rodio::buffer::SamplesBuffer;
3use std::{collections::VecDeque, task::Poll, time::Duration};
4
5/// The output of a [`crate::VoiceActivityDetectorStream`]
6pub struct VoiceActivityDetectorOutput {
7    /// The probability of voice activity (between 0 and 1)
8    pub probability: f32,
9    /// The audio sample associated with the voice activity probability
10    pub samples: rodio::buffer::SamplesBuffer<f32>,
11}
12
13/// An extension trait for audio streams with voice activity detection information
14pub trait VoiceActivityStreamExt: futures_core::Stream<Item = VoiceActivityDetectorOutput> {
15    /// Only keep audio chunks that have a probability of voice activity above the given threshold
16    fn filter_voice_activity(self, threshold: f32) -> VoiceActivityFilterStream<Self>
17    where
18        Self: Sized + Unpin,
19    {
20        VoiceActivityFilterStream::new(self, threshold)
21    }
22
23    /// Rechunk the audio into chunks of audio with a rolling average over the given duration more than the given threshold
24    fn rechunk_voice_activity(self) -> VoiceActivityRechunkerStream<Self>
25    where
26        Self: Sized + Unpin,
27    {
28        let start_threshold = 0.6;
29        let start_window = Duration::from_millis(100);
30        let end_threshold = 0.2;
31        let end_window = Duration::from_millis(2000);
32        let time_before_speech = Duration::from_millis(500);
33        VoiceActivityRechunkerStream::new(
34            self,
35            start_threshold,
36            start_window,
37            end_threshold,
38            end_window,
39            time_before_speech,
40        )
41    }
42}
43
44impl<S: futures_core::Stream<Item = VoiceActivityDetectorOutput>> VoiceActivityStreamExt for S {}
45
46/// A stream of audio chunks that have a voice activity probability above a given threshold
47pub struct VoiceActivityFilterStream<S> {
48    source: S,
49    threshold: f32,
50}
51
52impl<S> VoiceActivityFilterStream<S> {
53    fn new(source: S, threshold: f32) -> Self {
54        Self { source, threshold }
55    }
56}
57
58impl<S: futures_core::Stream<Item = VoiceActivityDetectorOutput> + Unpin> futures_core::Stream
59    for VoiceActivityFilterStream<S>
60{
61    type Item = SamplesBuffer<f32>;
62
63    fn poll_next(
64        self: std::pin::Pin<&mut Self>,
65        cx: &mut std::task::Context<'_>,
66    ) -> std::task::Poll<Option<Self::Item>> {
67        let this = self.get_mut();
68        let mut source = std::pin::pin!(&mut this.source);
69        loop {
70            let next = ready!(source.as_mut().poll_next(cx));
71            if let Some(next) = next {
72                if next.probability > this.threshold {
73                    return Poll::Ready(Some(next.samples));
74                }
75            } else {
76                return Poll::Ready(None);
77            }
78        }
79    }
80}
81
82/// A stream of audio chunks with a voice activity probability rolling average above a given threshold
83pub struct VoiceActivityRechunkerStream<S> {
84    source: S,
85    start_threshold: f32,
86    start_window: Duration,
87    end_threshold: f32,
88    end_window: Duration,
89    include_duration_before: Duration,
90    duration_before_window: Duration,
91    in_voice_run: bool,
92    buffer: VecDeque<SamplesBuffer<f32>>,
93    channels: u16,
94    sample_rate: u32,
95    voice_probabilities_window: VecDeque<(f32, Duration)>,
96    duration_in_window: Duration,
97    sum: f32,
98}
99
100impl<S> VoiceActivityRechunkerStream<S> {
101    /// Set the threshold for the start of a voice activity run
102    pub fn with_start_threshold(mut self, start_threshold: f32) -> Self {
103        self.start_threshold = start_threshold;
104        self
105    }
106
107    /// Set the window for the start of a voice activity run
108    pub fn with_start_window(mut self, start_window: Duration) -> Self {
109        self.start_window = start_window;
110        self
111    }
112
113    /// Set the threshold for the end of a voice activity run
114    pub fn with_end_threshold(mut self, end_threshold: f32) -> Self {
115        self.end_threshold = end_threshold;
116        self
117    }
118
119    /// Set the window for the end of a voice activity run
120    pub fn with_end_window(mut self, end_window: Duration) -> Self {
121        self.end_window = end_window;
122        self
123    }
124
125    /// Set the time before the speech run starts to include in the output
126    pub fn with_time_before_speech(mut self, time_before_speech: Duration) -> Self {
127        self.include_duration_before = time_before_speech;
128        self
129    }
130}
131
132impl<S> VoiceActivityRechunkerStream<S> {
133    fn new(
134        source: S,
135        start_threshold: f32,
136        start_window: Duration,
137        end_threshold: f32,
138        end_window: Duration,
139        include_duration_before: Duration,
140    ) -> Self {
141        Self {
142            source,
143            start_threshold,
144            start_window,
145            end_threshold,
146            end_window,
147            include_duration_before,
148            duration_before_window: Duration::ZERO,
149            in_voice_run: false,
150            buffer: VecDeque::new(),
151            channels: 1,
152            sample_rate: 0,
153            voice_probabilities_window: VecDeque::new(),
154            duration_in_window: Duration::ZERO,
155            sum: 0.0,
156        }
157    }
158
159    fn add_sample(&mut self, probability: f32, len: Duration, window: Duration) {
160        // Add the samples to the rolling average
161        self.voice_probabilities_window
162            .push_front((probability, len));
163        self.sum += probability;
164        self.duration_in_window += len;
165        // If the buffer is full, remove the first probability from the rolling average
166        while self.duration_in_window > window {
167            self.pop_last_sample();
168        }
169    }
170
171    fn pop_last_sample(&mut self) {
172        let (probability, len) = self.voice_probabilities_window.pop_back().unwrap();
173        self.sum -= probability;
174        self.duration_in_window -= len;
175    }
176
177    fn rolling_average(&self) -> f32 {
178        self.sum / self.voice_probabilities_window.len() as f32
179    }
180
181    fn finish_voice_run(&mut self) -> SamplesBuffer<f32> {
182        let samples = SamplesBuffer::new(
183            self.channels,
184            self.sample_rate,
185            std::mem::take(&mut self.buffer)
186                .into_iter()
187                .flatten()
188                .collect::<Vec<_>>(),
189        );
190        self.sum = 0.0;
191        self.duration_in_window = Duration::ZERO;
192        self.voice_probabilities_window.clear();
193        self.in_voice_run = false;
194        self.duration_before_window = Duration::ZERO;
195        self.buffer.clear();
196        samples
197    }
198}
199
200impl<S: futures_core::Stream<Item = VoiceActivityDetectorOutput> + Unpin> futures_core::Stream
201    for VoiceActivityRechunkerStream<S>
202{
203    type Item = SamplesBuffer<f32>;
204
205    fn poll_next(
206        self: std::pin::Pin<&mut Self>,
207        cx: &mut std::task::Context<'_>,
208    ) -> std::task::Poll<Option<Self::Item>> {
209        let this = self.get_mut();
210        loop {
211            let source = std::pin::pin!(&mut this.source);
212            let next = ready!(source.poll_next(cx));
213            if let Some(next) = next {
214                // Set the sample rate from the stream
215                this.sample_rate = rodio::Source::sample_rate(&next.samples);
216                let sample_duration = rodio::Source::total_duration(&next.samples)
217                    .expect("samples must have a duration");
218                let window = if this.in_voice_run {
219                    this.end_window
220                } else {
221                    this.start_window
222                };
223                this.add_sample(next.probability, sample_duration, window);
224                // If we are inside a chunk that looks like voice, set the in voice run flag
225                if this.rolling_average() > this.start_threshold {
226                    this.in_voice_run = true;
227                }
228                // Add the samples to the buffer
229                this.buffer.push_back(next.samples);
230                // If this is inside a voice run, add the sample to the buffer
231                if this.in_voice_run {
232                    // Otherwise, if we just left a chunk that looks like voice, add the buffer to the output
233                    if this.rolling_average() < this.end_threshold {
234                        let samples = this.finish_voice_run();
235                        return Poll::Ready(Some(samples));
236                    }
237                } else {
238                    // Otherwise, add it to the pre-voice buffer
239                    this.duration_before_window += sample_duration;
240                    // If the pre-voice buffer is full, remove the first sample from it
241                    while this.duration_before_window >= this.include_duration_before {
242                        let sample = this.buffer.pop_front().unwrap();
243                        this.duration_before_window -= rodio::Source::total_duration(&sample)
244                            .expect("samples must have a duration");
245                    }
246                }
247            } else {
248                // Finish off the current voice run if there is one
249                if this.in_voice_run {
250                    let samples = this.finish_voice_run();
251                    return Poll::Ready(Some(samples));
252                }
253                // Otherwise, return None and finish the stream
254                return Poll::Ready(None);
255            }
256        }
257    }
258}