Skip to main content

st/
vad_marine.rs

1// VAD with Marine Algorithm - "Semper Fi to voice detection!" 🎖️
2// Voice Activity Detection using MEM8's marine salience algorithm
3// "Standing watch at the boundaries of speech!" - Hue
4
5use anyhow::Result;
6use std::collections::VecDeque;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11/// Voice Activity Detector using Marine algorithm
12/// Detects when someone is speaking vs silence
13pub struct MarineVAD {
14    /// Marine detector state
15    detector: Arc<RwLock<MarineDetectorState>>,
16
17    /// Audio input monitoring
18    audio_monitor: Arc<RwLock<AudioMonitor>>,
19
20    /// VAD state
21    is_voice_active: Arc<RwLock<bool>>,
22
23    /// Callback for voice state changes
24    state_callback: StateCallback,
25}
26
27type StateCallback = Arc<RwLock<Option<Box<dyn Fn(bool) + Send + Sync>>>>;
28
29/// Marine detector state for VAD
30struct MarineDetectorState {
31    /// Clip threshold for voice detection (dB)
32    voice_threshold: f64,
33
34    /// Grid tick rate (Hz) - how often we evaluate
35    tick_rate: f64,
36
37    /// Peak history for voice pattern analysis
38    peak_history: VecDeque<PeakEvent>,
39
40    /// Period tracking for speech patterns
41    period_ema: ExponentialMovingAverage,
42
43    /// Amplitude tracking for voice energy
44    amplitude_ema: ExponentialMovingAverage,
45
46    /// Speech pattern detector
47    speech_detector: SpeechPatternDetector,
48
49    /// Current salience score (0.0 to 1.0)
50    voice_salience: f64,
51
52    /// Last evaluation time
53    last_tick: Instant,
54
55    /// Voice onset time
56    voice_onset: Option<Instant>,
57
58    /// Voice offset time
59    voice_offset: Option<Instant>,
60}
61
62/// Peak event in audio signal
63#[derive(Clone, Debug)]
64struct PeakEvent {
65    timestamp: Instant,
66    amplitude: f64,
67    frequency: f64,  // Estimated frequency
68    is_voiced: bool, // Voiced vs unvoiced
69}
70
71/// Exponential moving average for smoothing
72struct ExponentialMovingAverage {
73    value: f64,
74    alpha: f64, // Smoothing factor
75}
76
77impl ExponentialMovingAverage {
78    fn new(alpha: f64) -> Self {
79        Self { value: 0.0, alpha }
80    }
81
82    fn update(&mut self, sample: f64) -> f64 {
83        self.value = self.alpha * sample + (1.0 - self.alpha) * self.value;
84        self.value
85    }
86
87    fn jitter(&self, sample: f64) -> f64 {
88        (sample - self.value).abs()
89    }
90}
91
92/// Speech pattern detector
93struct SpeechPatternDetector {
94    /// Typical speech fundamental frequency range (Hz)
95    f0_min: f64, // ~80 Hz for deep male voice
96    f0_max: f64, // ~400 Hz for high female/child voice
97
98    /// Formant tracking
99    formant_tracker: FormantTracker,
100
101    /// Syllable rate detector (2-7 Hz typical)
102    syllable_detector: SyllableRateDetector,
103
104    /// Voice quality metrics
105    voice_quality: VoiceQuality,
106}
107
108/// Formant tracker for vowel detection
109struct FormantTracker {
110    f1_range: (f64, f64), // First formant range (200-1000 Hz)
111    f2_range: (f64, f64), // Second formant range (500-2500 Hz)
112    f3_range: (f64, f64), // Third formant range (1500-3500 Hz)
113}
114
115/// Syllable rate detector
116struct SyllableRateDetector {
117    energy_envelope: VecDeque<f64>,
118    peak_times: VecDeque<Instant>,
119    min_syllable_gap: Duration, // ~100ms minimum
120    max_syllable_gap: Duration, // ~500ms maximum
121}
122
123/// Voice quality metrics
124struct VoiceQuality {
125    harmonicity: f64,        // Harmonic-to-noise ratio
126    spectral_tilt: f64,      // High vs low frequency energy
127    zero_crossing_rate: f64, // Voiced vs unvoiced
128    energy_variance: f64,    // Speech dynamics
129}
130
131/// Audio input monitor
132struct AudioMonitor {
133    /// Current audio level (RMS)
134    current_level: f64,
135
136    /// Peak level in window
137    peak_level: f64,
138
139    /// Noise floor estimate
140    noise_floor: f64,
141
142    /// Signal-to-noise ratio
143    snr: f64,
144
145    /// Audio source (mic, line-in, etc)
146    source: AudioSource,
147}
148
149#[derive(Clone, Debug)]
150enum AudioSource {
151    Microphone,
152    LineIn,
153    Virtual, // For testing
154}
155
156impl MarineVAD {
157    /// Create new VAD with marine algorithm
158    pub fn new() -> Result<Self> {
159        Ok(Self {
160            detector: Arc::new(RwLock::new(MarineDetectorState::new())),
161            audio_monitor: Arc::new(RwLock::new(AudioMonitor::new())),
162            is_voice_active: Arc::new(RwLock::new(false)),
163            state_callback: Arc::new(RwLock::new(None)),
164        })
165    }
166
167    /// Process audio samples
168    pub async fn process_audio(&self, samples: &[f32], sample_rate: u32) -> Result<bool> {
169        let mut detector = self.detector.write().await;
170        let mut monitor = self.audio_monitor.write().await;
171
172        // Update audio monitor
173        monitor.update_levels(samples);
174
175        // Check if we should evaluate (based on tick rate)
176        let now = Instant::now();
177        let tick_duration = Duration::from_secs_f64(1.0 / detector.tick_rate);
178
179        if now.duration_since(detector.last_tick) < tick_duration {
180            return Ok(*self.is_voice_active.read().await);
181        }
182
183        detector.last_tick = now;
184
185        // Marine algorithm evaluation
186        let voice_detected = detector.evaluate_voice(samples, sample_rate, monitor.snr);
187
188        // Update state if changed
189        let mut is_active = self.is_voice_active.write().await;
190        if voice_detected != *is_active {
191            *is_active = voice_detected;
192
193            // Call state change callback
194            if let Some(callback) = &*self.state_callback.read().await {
195                callback(voice_detected);
196            }
197
198            // Log state change
199            if voice_detected {
200                println!("🎤 Voice detected - switching to minimal output mode");
201                detector.voice_onset = Some(now);
202            } else {
203                println!("🔇 Voice ended - returning to normal output mode");
204                detector.voice_offset = Some(now);
205            }
206        }
207
208        Ok(voice_detected)
209    }
210
211    /// Set callback for voice state changes
212    pub async fn set_state_callback<F>(&self, callback: F)
213    where
214        F: Fn(bool) + Send + Sync + 'static,
215    {
216        let mut cb = self.state_callback.write().await;
217        *cb = Some(Box::new(callback));
218    }
219
220    /// Get current voice activity state
221    pub async fn is_voice_active(&self) -> bool {
222        *self.is_voice_active.read().await
223    }
224
225    /// Get voice salience score (0.0 to 1.0)
226    pub async fn get_salience(&self) -> f64 {
227        self.detector.read().await.voice_salience
228    }
229
230    /// Get voice quality metrics
231    pub async fn get_voice_quality(&self) -> VoiceQualityReport {
232        let detector = self.detector.read().await;
233        VoiceQualityReport {
234            salience: detector.voice_salience,
235            harmonicity: detector.speech_detector.voice_quality.harmonicity,
236            spectral_tilt: detector.speech_detector.voice_quality.spectral_tilt,
237            zero_crossing_rate: detector.speech_detector.voice_quality.zero_crossing_rate,
238            energy_variance: detector.speech_detector.voice_quality.energy_variance,
239        }
240    }
241}
242
243impl MarineDetectorState {
244    fn new() -> Self {
245        Self {
246            voice_threshold: -40.0, // -40 dB threshold
247            tick_rate: 100.0,       // 100 Hz evaluation rate
248            peak_history: VecDeque::with_capacity(100),
249            period_ema: ExponentialMovingAverage::new(0.1),
250            amplitude_ema: ExponentialMovingAverage::new(0.05),
251            speech_detector: SpeechPatternDetector::new(),
252            voice_salience: 0.0,
253            last_tick: Instant::now(),
254            voice_onset: None,
255            voice_offset: None,
256        }
257    }
258
259    /// Evaluate voice presence using marine algorithm
260    fn evaluate_voice(&mut self, samples: &[f32], sample_rate: u32, snr: f64) -> bool {
261        // Calculate RMS energy
262        let energy: f64 =
263            samples.iter().map(|&s| (s as f64).powi(2)).sum::<f64>() / samples.len() as f64;
264        let rms = energy.sqrt();
265        let db = 20.0 * rms.log10();
266
267        // Update amplitude tracking
268        self.amplitude_ema.update(rms);
269
270        // Check against threshold
271        if db < self.voice_threshold {
272            self.voice_salience *= 0.9; // Decay salience
273            return false;
274        }
275
276        // Analyze for speech patterns
277        let has_speech_pattern = self.speech_detector.analyze(samples, sample_rate);
278
279        // Calculate salience score
280        let mut salience = 0.0;
281
282        // Energy contribution (30%)
283        let energy_score = ((db - self.voice_threshold) / 20.0).clamp(0.0, 1.0);
284        salience += energy_score * 0.3;
285
286        // SNR contribution (20%)
287        let snr_score = (snr / 20.0).clamp(0.0, 1.0);
288        salience += snr_score * 0.2;
289
290        // Speech pattern contribution (50%)
291        if has_speech_pattern {
292            salience += 0.5;
293        }
294
295        // Update salience with smoothing
296        self.voice_salience = 0.7 * salience + 0.3 * self.voice_salience;
297
298        // Voice detected if salience > 0.5
299        self.voice_salience > 0.5
300    }
301}
302
303impl SpeechPatternDetector {
304    fn new() -> Self {
305        Self {
306            f0_min: 80.0,
307            f0_max: 400.0,
308            formant_tracker: FormantTracker {
309                f1_range: (200.0, 1000.0),
310                f2_range: (500.0, 2500.0),
311                f3_range: (1500.0, 3500.0),
312            },
313            syllable_detector: SyllableRateDetector {
314                energy_envelope: VecDeque::with_capacity(100),
315                peak_times: VecDeque::with_capacity(20),
316                min_syllable_gap: Duration::from_millis(100),
317                max_syllable_gap: Duration::from_millis(500),
318            },
319            voice_quality: VoiceQuality {
320                harmonicity: 0.0,
321                spectral_tilt: 0.0,
322                zero_crossing_rate: 0.0,
323                energy_variance: 0.0,
324            },
325        }
326    }
327
328    fn analyze(&mut self, samples: &[f32], sample_rate: u32) -> bool {
329        // Simple zero-crossing rate for voiced/unvoiced detection
330        let mut zero_crossings = 0;
331        for i in 1..samples.len() {
332            if samples[i - 1] * samples[i] < 0.0 {
333                zero_crossings += 1;
334            }
335        }
336
337        let zcr = zero_crossings as f64 / samples.len() as f64;
338        self.voice_quality.zero_crossing_rate = zcr;
339
340        // Voiced speech has lower ZCR (< 0.3), unvoiced has higher
341        let is_voiced = zcr < 0.3;
342
343        // Check if in speech frequency range
344        let estimated_freq = zcr * sample_rate as f64 / 2.0;
345        let in_speech_range = estimated_freq >= self.f0_min && estimated_freq <= self.f0_max * 10.0;
346
347        is_voiced && in_speech_range
348    }
349}
350
351impl AudioMonitor {
352    fn new() -> Self {
353        Self {
354            current_level: 0.0,
355            peak_level: 0.0,
356            noise_floor: -60.0, // Start with -60 dB assumption
357            snr: 0.0,
358            source: AudioSource::Microphone,
359        }
360    }
361
362    fn update_levels(&mut self, samples: &[f32]) {
363        // Calculate RMS
364        let sum_squares: f32 = samples.iter().map(|&s| s * s).sum();
365        let rms = (sum_squares / samples.len() as f32).sqrt();
366        self.current_level = rms as f64;
367
368        // Find peak
369        let peak = samples.iter().map(|&s| s.abs()).fold(0.0f32, f32::max) as f64;
370        self.peak_level = peak;
371
372        // Update noise floor estimate (slow adaptation)
373        if rms as f64 > 0.0 {
374            let db = 20.0 * (rms as f64).log10();
375            self.noise_floor = 0.99 * self.noise_floor + 0.01 * db;
376            self.snr = db - self.noise_floor;
377        }
378    }
379}
380
381/// Voice quality report
382#[derive(Debug, Clone)]
383pub struct VoiceQualityReport {
384    pub salience: f64,
385    pub harmonicity: f64,
386    pub spectral_tilt: f64,
387    pub zero_crossing_rate: f64,
388    pub energy_variance: f64,
389}
390
391/// Integration with rust_shell
392impl super::rust_shell::RustShell {
393    /// Enable VAD with marine algorithm
394    pub async fn enable_marine_vad(&self) -> Result<()> {
395        println!("🎖️ Enabling Marine VAD - Semper Fi to voice detection!");
396
397        let vad = MarineVAD::new()?;
398
399        // Set callback to adjust verbosity
400        let output_mode = self.output_mode.clone();
401        vad.set_state_callback(move |is_voice| {
402            // This would be called when voice state changes
403            let mode = output_mode.clone();
404            tokio::spawn(async move {
405                let mut m = mode.write().await;
406                if is_voice {
407                    m.verbosity = super::rust_shell::VerbosityLevel::Minimal;
408                    m.format = super::rust_shell::OutputFormat::Voice;
409                } else {
410                    m.verbosity = super::rust_shell::VerbosityLevel::Normal;
411                    m.format = super::rust_shell::OutputFormat::Text;
412                }
413            });
414        })
415        .await;
416
417        // Store VAD instance (would need to add field to RustShell)
418        // self.vad = Some(vad);
419
420        Ok(())
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[tokio::test]
429    async fn test_marine_vad_creation() {
430        let vad = MarineVAD::new();
431        assert!(vad.is_ok());
432    }
433
434    #[tokio::test]
435    async fn test_voice_detection() {
436        let vad = MarineVAD::new().unwrap();
437
438        // Create test signal (sine wave at 200 Hz - typical voice F0)
439        let sample_rate = 16000;
440        let frequency = 200.0;
441        let duration = 0.1; // 100ms
442        let num_samples = (sample_rate as f64 * duration) as usize;
443
444        let mut samples = vec![0.0f32; num_samples];
445        for (i, sample) in samples.iter_mut().enumerate().take(num_samples) {
446            let t = i as f64 / sample_rate as f64;
447            *sample = (2.0 * std::f64::consts::PI * frequency * t).sin() as f32 * 0.5;
448        }
449
450        // Process audio
451        let _is_voice = vad.process_audio(&samples, sample_rate).await.unwrap();
452
453        // Should detect voice-like signal
454        // (In real implementation would need proper training)
455    }
456}