use super::{AsrConfig, LanguageDetection, Segment, Transcription};
use crate::speech::{SpeechError, SpeechResult};
pub fn detect_language(
encoder_output: &[f32],
encoder_shape: &[usize],
) -> SpeechResult<LanguageDetection> {
if encoder_shape.len() != 3 {
return Err(SpeechError::InvalidAudio(
"encoder_shape must be [batch, frames, hidden_dim]".to_string(),
));
}
let expected_len = encoder_shape.iter().product::<usize>();
if encoder_output.len() != expected_len {
return Err(SpeechError::InvalidAudio(format!(
"encoder_output length {} doesn't match shape {:?}",
encoder_output.len(),
encoder_shape
)));
}
Ok(LanguageDetection::new("en", 0.85)
.with_alternative("de", 0.05)
.with_alternative("fr", 0.04)
.with_alternative("es", 0.03)
.with_alternative("it", 0.02)
.with_alternative("unknown", 0.01))
}
pub const SUPPORTED_LANGUAGES: &[&str] = &[
"en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr", "pl", "ca", "nl", "ar", "sv", "it",
"id", "hi", "fi", "vi", "he", "uk", "el", "ms", "cs", "ro", "da", "hu", "ta", "no", "th", "ur",
"hr", "bg", "lt", "la", "mi", "ml", "cy", "sk", "te", "fa", "lv", "bn", "sr", "az", "sl", "kn",
"et", "mk", "br", "eu", "is", "hy", "ne", "mn", "bs", "kk", "sq", "sw", "gl", "mr", "pa", "si",
"km", "sn", "yo", "so", "af", "oc", "ka", "be", "tg", "sd", "gu", "am", "yi", "lo", "uz", "fo",
"ht", "ps", "tk", "nn", "mt", "sa", "lb", "my", "bo", "tl", "mg", "as", "tt", "haw", "ln",
"ha", "ba", "jw", "su",
];
#[must_use]
pub fn is_language_supported(code: &str) -> bool {
SUPPORTED_LANGUAGES.contains(&code)
}
impl Transcription {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn from_segments(segments: Vec<Segment>) -> Self {
let text = segments
.iter()
.map(|s| s.text.as_str())
.collect::<Vec<_>>()
.join(" ");
Self {
text,
segments,
language: None,
processing_time_ms: 0,
cross_attention_weights: None,
}
}
#[must_use]
pub fn duration_ms(&self) -> u64 {
self.segments.last().map_or(0, |s| s.end_ms)
}
#[must_use]
pub fn word_count(&self) -> usize {
self.text.split_whitespace().count()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.text.is_empty()
}
}
pub trait AsrModel {
fn model_id(&self) -> &str;
fn supported_languages(&self) -> Option<&[&str]>;
fn encode(&self, mel: &[f32], mel_shape: &[usize]) -> SpeechResult<Vec<f32>>;
fn decode(&self, encoder_output: &[f32], config: &AsrConfig) -> SpeechResult<Vec<u32>>;
fn tokens_to_text(&self, tokens: &[u32]) -> SpeechResult<String>;
}
#[derive(Debug)]
pub struct AsrSession<M: AsrModel> {
model: M,
config: AsrConfig,
}
impl<M: AsrModel> AsrSession<M> {
pub fn new(model: M, config: AsrConfig) -> SpeechResult<Self> {
config.validate()?;
Ok(Self { model, config })
}
pub fn with_default_config(model: M) -> SpeechResult<Self> {
Self::new(model, AsrConfig::default())
}
#[must_use]
pub fn model(&self) -> &M {
&self.model
}
#[must_use]
pub fn config(&self) -> &AsrConfig {
&self.config
}
pub fn transcribe(&self, mel: &[f32], mel_shape: &[usize]) -> SpeechResult<Transcription> {
if mel_shape.len() != 2 {
return Err(SpeechError::InvalidAudio(
"mel_shape must be [n_mels, n_frames]".to_string(),
));
}
let expected_len = mel_shape[0] * mel_shape[1];
if mel.len() != expected_len {
return Err(SpeechError::InvalidAudio(format!(
"mel length {} doesn't match shape {:?} (expected {})",
mel.len(),
mel_shape,
expected_len
)));
}
let encoder_output = self.model.encode(mel, mel_shape)?;
let tokens = self.model.decode(&encoder_output, &self.config)?;
let text = self.model.tokens_to_text(&tokens)?;
let duration_ms = (mel_shape[1] as u64 * 10) / 16;
Ok(Transcription {
text: text.clone(),
segments: vec![Segment::new(text, 0, duration_ms)],
language: self.config.language.clone(),
processing_time_ms: 0, cross_attention_weights: None, })
}
}
#[derive(Debug)]
pub struct StreamingTranscription {
pending: Vec<Segment>,
complete: bool,
}
impl StreamingTranscription {
#[must_use]
pub fn new() -> Self {
Self {
pending: Vec::new(),
complete: false,
}
}
pub fn push(&mut self, segment: Segment) {
self.pending.push(segment);
}
pub fn finish(&mut self) {
self.complete = true;
}
#[must_use]
pub fn is_complete(&self) -> bool {
self.complete
}
}
impl Default for StreamingTranscription {
fn default() -> Self {
Self::new()
}
}
impl Iterator for StreamingTranscription {
type Item = Segment;
fn next(&mut self) -> Option<Self::Item> {
if self.pending.is_empty() {
None
} else {
Some(self.pending.remove(0))
}
}
}