use super::config::ParakeetConfig;
use super::error::Result;
use super::stt::{ParakeetBackend, validate_language};
use crate::{
ModelInfo, STTModelsProvider, STTProvider, STTResult, STTSpeechProvider, TextChunk,
TranscriptionRequest, TranscriptionResponse,
};
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct Parakeet {
config: ParakeetConfig,
backend: Arc<Mutex<ParakeetBackend>>,
}
impl Parakeet {
pub fn new(config: ParakeetConfig) -> Result<Self> {
let backend = ParakeetBackend::new(&config)?;
Ok(Self {
config,
backend: Arc::new(Mutex::new(backend)),
})
}
pub fn config(&self) -> &ParakeetConfig {
&self.config
}
pub async fn reset(&self) {
let mut backend = self.backend.lock().await;
backend.reset();
}
pub async fn process_chunk(&self, audio_chunk: Vec<f32>) -> STTResult<TextChunk> {
let mut backend = self.backend.lock().await;
backend
.transcribe_chunk(audio_chunk)
.await
.map_err(Into::into)
}
}
impl STTProvider for Parakeet {}
#[async_trait]
impl STTSpeechProvider for Parakeet {
async fn transcribe(&self, request: TranscriptionRequest) -> STTResult<TranscriptionResponse> {
validate_language(request.language.as_deref(), &self.config.model_variant)
.map_err(crate::error::STTError::from)?;
let mut backend = self.backend.lock().await;
backend.transcribe(request).await.map_err(Into::into)
}
async fn transcribe_stream<'a>(
&'a self,
request: TranscriptionRequest,
) -> STTResult<Pin<Box<dyn Stream<Item = STTResult<TextChunk>> + Send + 'a>>> {
if !self.config.model_variant.supports_streaming() {
return Err(crate::error::STTError::StreamingNotSupported(format!(
"{} does not support streaming",
self.config.model_variant
)));
}
let audio = request.audio;
let backend = self.backend.clone();
{
let mut b = backend.lock().await;
b.reset();
}
let chunk_size = self.config.model_variant.chunk_size();
let stream = futures::stream::unfold(0usize, move |offset| {
let audio = audio.clone(); let backend = backend.clone();
async move {
let samples = &audio.samples;
if offset >= samples.len() {
return None;
}
let end = (offset + chunk_size).min(samples.len());
let chunk = if end - offset == chunk_size {
samples[offset..end].to_vec()
} else {
let mut padded = Vec::with_capacity(chunk_size);
padded.extend_from_slice(&samples[offset..end]);
padded.resize(chunk_size, 0.0);
padded
};
let next_offset = offset + chunk_size;
let mut b = backend.lock().await;
let result = b.transcribe_chunk(chunk).await.map_err(Into::into);
Some((result, next_offset))
}
});
Ok(Box::pin(stream))
}
fn supports_streaming(&self) -> bool {
self.config.model_variant.supports_streaming()
}
fn supports_timestamps(&self) -> bool {
self.config.model_variant.supports_timestamps()
}
}
#[async_trait]
impl STTModelsProvider for Parakeet {
async fn list_models(&self) -> STTResult<Vec<ModelInfo>> {
Ok(vec![self.get_current_model()])
}
fn get_current_model(&self) -> ModelInfo {
let variant = &self.config.model_variant;
ModelInfo {
id: variant.id().to_string(),
name: variant.to_string(),
description: Some(variant.description().to_string()),
languages: variant.supported_languages(),
}
}
fn supported_languages(&self) -> Vec<String> {
self.config.model_variant.supported_languages()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::parakeet::ModelVariant;
#[test]
fn test_provider_settings() {
let config = ParakeetConfig::new(ModelVariant::TDT, "./models/tdt");
let provider = Parakeet::new(config);
if provider.is_err() {
return;
}
let provider = provider.unwrap();
assert_eq!(provider.supported_sample_rate(), 16000);
assert_eq!(provider.supported_channels(), 1);
assert!(provider.supports_timestamps()); assert!(!provider.supports_streaming()); }
#[test]
fn test_nemotron_streaming_support() {
let config = ParakeetConfig::new(ModelVariant::Nemotron, "./models/nemotron");
let provider = Parakeet::new(config);
if provider.is_err() {
return;
}
let provider = provider.unwrap();
assert!(provider.supports_streaming()); assert!(!provider.supports_timestamps()); }
#[test]
fn test_eou_streaming_and_detection_support() {
let config = ParakeetConfig::new(ModelVariant::EOU, "./models/eou");
let provider = Parakeet::new(config);
if provider.is_err() {
return;
}
let provider = provider.unwrap();
assert!(provider.supports_streaming()); assert!(!provider.supports_timestamps());
let model_variant = &provider.config.model_variant;
assert!(model_variant.supports_eou_detection()); assert_eq!(model_variant.supported_languages(), vec!["en"]); }
}