autoagents-speech 0.3.7

Speech (TTS/STT) provider abstractions for AutoAgents
Documentation
use super::VadEngine;
use super::config::VadSttConfig;
use super::error::VadError;
use super::segmenter::{SpeechSegment, VadSegmenter};
use crate::{STTError, STTSpeechProvider, TranscriptionRequest, TranscriptionResponse};
use futures::future::try_join_all;

#[derive(Debug, Clone)]
pub struct SegmentTranscription {
    pub segment: SpeechSegment,
    pub transcription: TranscriptionResponse,
}

#[derive(Debug, thiserror::Error)]
pub enum VadPipelineError {
    #[error("VAD error: {0}")]
    Vad(#[from] VadError),
    #[error("STT error: {0}")]
    Stt(#[from] STTError),
}

pub struct VadSttPipeline<V: VadEngine, S: STTSpeechProvider> {
    segmenter: VadSegmenter<V>,
    stt: S,
    config: VadSttConfig,
}

impl<V: VadEngine, S: STTSpeechProvider> VadSttPipeline<V, S> {
    pub fn new(segmenter: VadSegmenter<V>, stt: S, config: VadSttConfig) -> Self {
        Self {
            segmenter,
            stt,
            config,
        }
    }

    pub fn window_samples(&self) -> usize {
        self.segmenter.window_samples()
    }

    pub fn window_ms(&self) -> u32 {
        self.segmenter.window_ms()
    }

    pub async fn process_audio(
        &mut self,
        audio: &crate::AudioData,
    ) -> Result<Vec<SegmentTranscription>, VadPipelineError> {
        let segments = self.segmenter.process_audio(audio)?;
        self.transcribe_segments(segments).await
    }

    pub async fn finalize(&mut self) -> Result<Option<SegmentTranscription>, VadPipelineError> {
        let segment = match self.segmenter.finalize()? {
            Some(segment) => segment,
            None => return Ok(None),
        };

        let transcription = self.transcribe_segment(&segment).await?;
        Ok(Some(SegmentTranscription {
            segment,
            transcription,
        }))
    }

    async fn transcribe_segments(
        &self,
        segments: Vec<SpeechSegment>,
    ) -> Result<Vec<SegmentTranscription>, VadPipelineError> {
        // Run all segment transcriptions concurrently. For a single-backend STT provider,
        // the calls will serialise on the provider's internal mutex; for multi-instance
        // providers they may run in parallel.
        let futures: Vec<_> = segments
            .iter()
            .map(|segment| self.transcribe_segment(segment))
            .collect();

        let transcriptions = try_join_all(futures).await?;

        Ok(segments
            .into_iter()
            .zip(transcriptions)
            .map(|(segment, transcription)| SegmentTranscription {
                segment,
                transcription,
            })
            .collect())
    }

    async fn transcribe_segment(
        &self,
        segment: &SpeechSegment,
    ) -> Result<TranscriptionResponse, VadPipelineError> {
        let request = TranscriptionRequest {
            audio: segment.audio.clone(), // cheap Arc clone — no audio data is copied
            language: self.config.language.clone(),
            include_timestamps: self.config.include_timestamps,
        };
        let response = self.stt.transcribe(request).await?;
        Ok(response)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::vad::VadOutput;
    use crate::vad::config::SegmenterConfig;
    use crate::vad::error::VadResult;
    use crate::vad::{VadEngine, VadSegmenter};
    use crate::{AudioData, STTResult};
    use async_trait::async_trait;

    struct MockVad;

    impl VadEngine for MockVad {
        fn sample_rate(&self) -> u32 {
            16_000
        }

        fn reset(&mut self) {}

        fn compute(&mut self, samples: &[f32]) -> VadResult<VadOutput> {
            let avg = samples.iter().map(|v| v.abs()).sum::<f32>() / samples.len() as f32;
            let prob = if avg > 0.2 { 0.9 } else { 0.1 };
            Ok(VadOutput { probability: prob })
        }
    }

    struct MockStt;

    #[async_trait]
    impl STTSpeechProvider for MockStt {
        async fn transcribe(
            &self,
            request: TranscriptionRequest,
        ) -> STTResult<TranscriptionResponse> {
            Ok(TranscriptionResponse {
                text: format!("{} samples", request.audio.samples.len()),
                timestamps: None,
                duration_ms: 1,
            })
        }
    }

    #[tokio::test]
    async fn pipeline_emits_transcriptions() {
        let config = SegmenterConfig::default()
            .with_window_ms(100)
            .with_min_speech_ms(100)
            .with_min_silence_ms(100);

        let segmenter = VadSegmenter::new(MockVad, config).unwrap();
        let mut pipeline = VadSttPipeline::new(segmenter, MockStt, VadSttConfig::default());
        let samples = vec![0.8; pipeline.window_samples() * 2];
        let audio = AudioData {
            samples,
            sample_rate: 16_000,
            channels: 1,
        };

        let results = pipeline.process_audio(&audio).await.unwrap();
        assert_eq!(results.len(), 0);

        let silence = vec![0.0; pipeline.window_samples() * 2];
        let silence_audio = AudioData {
            samples: silence,
            sample_rate: 16_000,
            channels: 1,
        };
        let results = pipeline.process_audio(&silence_audio).await.unwrap();
        assert_eq!(results.len(), 1);
        assert!(results[0].transcription.text.contains("samples"));
    }
}