stt-cli 0.2.1

Speech to text Cli using Groq API and OpenAI API
// --- Transcription Service ---
// This Tower service takes WAV audio requests and returns transcription results.
// It wraps the configured TranscriptionProvider and handles all provider errors.

use std::sync::Arc;
use std::task::{Context, Poll};
use std::pin::Pin;
use std::future::Future;

use tokio::sync::Mutex;
use tower::Service;

use crate::pipeline::types::{WavAudioRequest, AudioResponse, ProcessedData, PipelineError};
use crate::providers::TranscriptionProvider;

/// Tower Service that handles transcription of WAV audio data.
#[derive(Clone)]
pub struct TranscriptionService {
    /// Shared transcription provider implementation.
    provider: Arc<Mutex<Box<dyn TranscriptionProvider + Send + Sync>>>,
}

impl TranscriptionService {
    /// Create a new TranscriptionService with the given provider.
    pub fn new(provider: Arc<Mutex<Box<dyn TranscriptionProvider + Send + Sync>>>) -> Self {
        Self { provider }
    }
}

impl Service<WavAudioRequest> for TranscriptionService {
    type Response = AudioResponse;
    type Error = PipelineError;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        // Always ready to accept new requests.
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, req: WavAudioRequest) -> Self::Future {
        let provider = self.provider.clone();
        let timestamp = req.timestamp;
        let wav_data = req.wav_data;
        Box::pin(async move {
            let guard = provider.lock().await;
            match guard.transcribe(&wav_data).await {
                Ok(text) => Ok(AudioResponse {
                    original_timestamp: timestamp,
                    result_data: ProcessedData::Transcription(text),
                }),
                Err(e) => Err(PipelineError::TranscriptionError(e.to_string())),
            }
        })
    }
}

// #[cfg(test)]
// mod tests {
//     use super::*;
//     use crate::pipeline::types::{WavAudioRequest, AudioResponse, ProcessedData, PipelineError};
//     use crate::providers::TranscriptionProvider;
//     use std::sync::Arc;
//     use tokio::sync::Mutex;
//     use tower::Service;
//     use futures::executor::block_on;

//     struct MockTranscriptionProvider {
//         result: Option<String>,
//         fail: bool,
//     }

//     #[async_trait::async_trait]
//     impl TranscriptionProvider for MockTranscriptionProvider {
//         fn name(&self) -> &'static str {
//             "MockTranscriptionProvider"
//         }
//         fn min_chunk_duration(&self) -> std::time::Duration {
//             std::time::Duration::from_secs(1)
//         }
//         async fn transcribe(&self, _wav_data: &[u8]) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
//             if self.fail {
//                 Err("mock provider error".into())
//             } else {
//                 Ok(self.result.clone().unwrap_or_else(|| "mocked transcription".to_string()))
//             }
//         }
//     }

//     fn make_service(provider: MockTranscriptionProvider) -> TranscriptionService {
//         let boxed: Box<dyn TranscriptionProvider + Send + Sync> = Box::new(provider);
//         TranscriptionService::new(Arc::new(Mutex::new(boxed)))
//     }

//     #[tokio::test]
//     async fn test_transcription_service_success() {
//         // Test normal transcription path with valid WAV data
//         let provider = MockTranscriptionProvider { result: Some("hello world".to_string()), fail: false };
//         let mut service = make_service(provider);
//         let req = WavAudioRequest { wav_data: vec![1, 2, 3], timestamp: std::time::SystemTime::now() };
//         service.poll_ready(&mut std::task::Context::from_waker(futures::task::noop_waker_ref())).unwrap();
//         let resp = service.call(req).await.unwrap();
//         match resp.result_data {
//             ProcessedData::Transcription(text) => assert_eq!(text, "hello world"),
//         }
//     }

//     #[tokio::test]
//     async fn test_transcription_service_empty_audio() {
//         // Test transcription with empty WAV data (should still return mocked transcription)
//         let provider = MockTranscriptionProvider { result: Some("empty input".to_string()), fail: false };
//         let mut service = make_service(provider);
//         let req = WavAudioRequest { wav_data: vec![], timestamp: std::time::SystemTime::now() };
//         service.poll_ready(&mut std::task::Context::from_waker(futures::task::noop_waker_ref())).unwrap();
//         let resp = service.call(req).await.unwrap();
//         match resp.result_data {
//             ProcessedData::Transcription(text) => assert_eq!(text, "empty input"),
//         }
//     }

//     #[tokio::test]
//     async fn test_transcription_service_provider_error() {
//         // Test that provider error is mapped to PipelineError::Transcription
//         let provider = MockTranscriptionProvider { result: None, fail: true };
//         let mut service = make_service(provider);
//         let req = WavAudioRequest { wav_data: vec![1, 2, 3], timestamp: std::time::SystemTime::now() };
//         service.poll_ready(&mut std::task::Context::from_waker(futures::task::noop_waker_ref())).unwrap();
//         let err = service.call(req).await.unwrap_err();
//         match err {
//             PipelineError::Transcription(msg) => assert!(msg.contains("mock provider error")),
//             _ => panic!("Expected PipelineError::Transcription, got {:?}", err),
//         }
//     }

//     // Optionally, test concurrency if desired
//     #[tokio::test]
//     async fn test_transcription_service_concurrent_calls() {
//         // Service should be usable from multiple tasks concurrently
//         let provider = MockTranscriptionProvider { result: Some("concurrent".to_string()), fail: false };
//         let service = Arc::new(Mutex::new(make_service(provider)));
//         let req = WavAudioRequest { wav_data: vec![1, 2, 3], timestamp: std::time::SystemTime::now() };
//         let handles: Vec<_> = (0..4).map(|_| {
//             let service = service.clone();
//             let req = req.clone();
//             tokio::spawn(async move {
//                 let mut svc = service.lock().await;
//                 svc.poll_ready(&mut std::task::Context::from_waker(futures::task::noop_waker_ref())).unwrap();
//                 let resp = svc.call(req).await.unwrap();
//                 match resp.result_data {
//                     ProcessedData::Transcription(text) => assert_eq!(text, "concurrent"),
//                 }
//             })
//         }).collect();
//         for h in handles {
//             h.await.unwrap();
//         }
//     }
// }