kalosm_sound/transform/
voice_audio_detector_ext.rs1use futures_core::ready;
2use rodio::buffer::SamplesBuffer;
3use std::{collections::VecDeque, task::Poll, time::Duration};
4
5pub struct VoiceActivityDetectorOutput {
7 pub probability: f32,
9 pub samples: rodio::buffer::SamplesBuffer<f32>,
11}
12
13pub trait VoiceActivityStreamExt: futures_core::Stream<Item = VoiceActivityDetectorOutput> {
15 fn filter_voice_activity(self, threshold: f32) -> VoiceActivityFilterStream<Self>
17 where
18 Self: Sized + Unpin,
19 {
20 VoiceActivityFilterStream::new(self, threshold)
21 }
22
23 fn rechunk_voice_activity(self) -> VoiceActivityRechunkerStream<Self>
25 where
26 Self: Sized + Unpin,
27 {
28 let start_threshold = 0.6;
29 let start_window = Duration::from_millis(100);
30 let end_threshold = 0.2;
31 let end_window = Duration::from_millis(2000);
32 let time_before_speech = Duration::from_millis(500);
33 VoiceActivityRechunkerStream::new(
34 self,
35 start_threshold,
36 start_window,
37 end_threshold,
38 end_window,
39 time_before_speech,
40 )
41 }
42}
43
44impl<S: futures_core::Stream<Item = VoiceActivityDetectorOutput>> VoiceActivityStreamExt for S {}
45
46pub struct VoiceActivityFilterStream<S> {
48 source: S,
49 threshold: f32,
50}
51
52impl<S> VoiceActivityFilterStream<S> {
53 fn new(source: S, threshold: f32) -> Self {
54 Self { source, threshold }
55 }
56}
57
58impl<S: futures_core::Stream<Item = VoiceActivityDetectorOutput> + Unpin> futures_core::Stream
59 for VoiceActivityFilterStream<S>
60{
61 type Item = SamplesBuffer<f32>;
62
63 fn poll_next(
64 self: std::pin::Pin<&mut Self>,
65 cx: &mut std::task::Context<'_>,
66 ) -> std::task::Poll<Option<Self::Item>> {
67 let this = self.get_mut();
68 let mut source = std::pin::pin!(&mut this.source);
69 loop {
70 let next = ready!(source.as_mut().poll_next(cx));
71 if let Some(next) = next {
72 if next.probability > this.threshold {
73 return Poll::Ready(Some(next.samples));
74 }
75 } else {
76 return Poll::Ready(None);
77 }
78 }
79 }
80}
81
82pub struct VoiceActivityRechunkerStream<S> {
84 source: S,
85 start_threshold: f32,
86 start_window: Duration,
87 end_threshold: f32,
88 end_window: Duration,
89 include_duration_before: Duration,
90 duration_before_window: Duration,
91 in_voice_run: bool,
92 buffer: VecDeque<SamplesBuffer<f32>>,
93 channels: u16,
94 sample_rate: u32,
95 voice_probabilities_window: VecDeque<(f32, Duration)>,
96 duration_in_window: Duration,
97 sum: f32,
98}
99
100impl<S> VoiceActivityRechunkerStream<S> {
101 pub fn with_start_threshold(mut self, start_threshold: f32) -> Self {
103 self.start_threshold = start_threshold;
104 self
105 }
106
107 pub fn with_start_window(mut self, start_window: Duration) -> Self {
109 self.start_window = start_window;
110 self
111 }
112
113 pub fn with_end_threshold(mut self, end_threshold: f32) -> Self {
115 self.end_threshold = end_threshold;
116 self
117 }
118
119 pub fn with_end_window(mut self, end_window: Duration) -> Self {
121 self.end_window = end_window;
122 self
123 }
124
125 pub fn with_time_before_speech(mut self, time_before_speech: Duration) -> Self {
127 self.include_duration_before = time_before_speech;
128 self
129 }
130}
131
132impl<S> VoiceActivityRechunkerStream<S> {
133 fn new(
134 source: S,
135 start_threshold: f32,
136 start_window: Duration,
137 end_threshold: f32,
138 end_window: Duration,
139 include_duration_before: Duration,
140 ) -> Self {
141 Self {
142 source,
143 start_threshold,
144 start_window,
145 end_threshold,
146 end_window,
147 include_duration_before,
148 duration_before_window: Duration::ZERO,
149 in_voice_run: false,
150 buffer: VecDeque::new(),
151 channels: 1,
152 sample_rate: 0,
153 voice_probabilities_window: VecDeque::new(),
154 duration_in_window: Duration::ZERO,
155 sum: 0.0,
156 }
157 }
158
159 fn add_sample(&mut self, probability: f32, len: Duration, window: Duration) {
160 self.voice_probabilities_window
162 .push_front((probability, len));
163 self.sum += probability;
164 self.duration_in_window += len;
165 while self.duration_in_window > window {
167 self.pop_last_sample();
168 }
169 }
170
171 fn pop_last_sample(&mut self) {
172 let (probability, len) = self.voice_probabilities_window.pop_back().unwrap();
173 self.sum -= probability;
174 self.duration_in_window -= len;
175 }
176
177 fn rolling_average(&self) -> f32 {
178 self.sum / self.voice_probabilities_window.len() as f32
179 }
180
181 fn finish_voice_run(&mut self) -> SamplesBuffer<f32> {
182 let samples = SamplesBuffer::new(
183 self.channels,
184 self.sample_rate,
185 std::mem::take(&mut self.buffer)
186 .into_iter()
187 .flatten()
188 .collect::<Vec<_>>(),
189 );
190 self.sum = 0.0;
191 self.duration_in_window = Duration::ZERO;
192 self.voice_probabilities_window.clear();
193 self.in_voice_run = false;
194 self.duration_before_window = Duration::ZERO;
195 self.buffer.clear();
196 samples
197 }
198}
199
200impl<S: futures_core::Stream<Item = VoiceActivityDetectorOutput> + Unpin> futures_core::Stream
201 for VoiceActivityRechunkerStream<S>
202{
203 type Item = SamplesBuffer<f32>;
204
205 fn poll_next(
206 self: std::pin::Pin<&mut Self>,
207 cx: &mut std::task::Context<'_>,
208 ) -> std::task::Poll<Option<Self::Item>> {
209 let this = self.get_mut();
210 loop {
211 let source = std::pin::pin!(&mut this.source);
212 let next = ready!(source.poll_next(cx));
213 if let Some(next) = next {
214 this.sample_rate = rodio::Source::sample_rate(&next.samples);
216 let sample_duration = rodio::Source::total_duration(&next.samples)
217 .expect("samples must have a duration");
218 let window = if this.in_voice_run {
219 this.end_window
220 } else {
221 this.start_window
222 };
223 this.add_sample(next.probability, sample_duration, window);
224 if this.rolling_average() > this.start_threshold {
226 this.in_voice_run = true;
227 }
228 this.buffer.push_back(next.samples);
230 if this.in_voice_run {
232 if this.rolling_average() < this.end_threshold {
234 let samples = this.finish_voice_run();
235 return Poll::Ready(Some(samples));
236 }
237 } else {
238 this.duration_before_window += sample_duration;
240 while this.duration_before_window >= this.include_duration_before {
242 let sample = this.buffer.pop_front().unwrap();
243 this.duration_before_window -= rodio::Source::total_duration(&sample)
244 .expect("samples must have a duration");
245 }
246 }
247 } else {
248 if this.in_voice_run {
250 let samples = this.finish_voice_run();
251 return Poll::Ready(Some(samples));
252 }
253 return Poll::Ready(None);
255 }
256 }
257 }
258}