use super::{SpeechError, SpeechResult};
#[derive(Debug, Clone)]
pub struct AsrConfig {
pub language: Option<String>,
pub beam_size: usize,
pub temperature: f32,
pub word_timestamps: bool,
pub max_segment_length: f32,
pub suppress_blank: bool,
pub condition_on_previous: bool,
}
impl Default for AsrConfig {
fn default() -> Self {
Self {
language: None,
beam_size: 5,
temperature: 0.0,
word_timestamps: false,
max_segment_length: 30.0,
suppress_blank: true,
condition_on_previous: true,
}
}
}
impl AsrConfig {
#[must_use]
pub fn with_language(mut self, lang: impl Into<String>) -> Self {
self.language = Some(lang.into());
self
}
#[must_use]
pub fn with_word_timestamps(mut self) -> Self {
self.word_timestamps = true;
self
}
#[must_use]
pub fn with_beam_size(mut self, size: usize) -> Self {
self.beam_size = size.max(1);
self
}
pub fn validate(&self) -> SpeechResult<()> {
if self.beam_size == 0 {
return Err(SpeechError::InvalidConfig(
"beam_size must be >= 1".to_string(),
));
}
if self.temperature < 0.0 {
return Err(SpeechError::InvalidConfig(
"temperature must be >= 0.0".to_string(),
));
}
if self.max_segment_length <= 0.0 {
return Err(SpeechError::InvalidConfig(
"max_segment_length must be > 0".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Segment {
pub text: String,
pub start_ms: u64,
pub end_ms: u64,
pub confidence: f32,
pub words: Vec<WordTiming>,
}
impl Segment {
#[must_use]
pub fn new(text: impl Into<String>, start_ms: u64, end_ms: u64) -> Self {
Self {
text: text.into(),
start_ms,
end_ms,
confidence: 1.0,
words: Vec::new(),
}
}
#[must_use]
pub fn duration_ms(&self) -> u64 {
self.end_ms.saturating_sub(self.start_ms)
}
#[must_use]
pub fn duration_secs(&self) -> f32 {
self.duration_ms() as f32 / 1000.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct WordTiming {
pub word: String,
pub start_ms: u64,
pub end_ms: u64,
pub confidence: f32,
}
#[derive(Debug, Clone, Default)]
pub struct Transcription {
pub text: String,
pub segments: Vec<Segment>,
pub language: Option<String>,
pub processing_time_ms: u64,
pub cross_attention_weights: Option<CrossAttentionWeights>,
}
#[derive(Debug, Clone)]
pub struct CrossAttentionWeights {
weights: Vec<f32>,
n_layers: usize,
n_tokens: usize,
n_frames: usize,
}
impl CrossAttentionWeights {
pub fn new(
weights: Vec<f32>,
n_layers: usize,
n_tokens: usize,
n_frames: usize,
) -> Result<Self, SpeechError> {
let expected_len = n_layers * n_tokens * n_frames;
if weights.len() != expected_len {
return Err(SpeechError::InvalidAudio(format!(
"Cross-attention weight length {} doesn't match shape [{}, {}, {}] (expected {})",
weights.len(),
n_layers,
n_tokens,
n_frames,
expected_len
)));
}
Ok(Self {
weights,
n_layers,
n_tokens,
n_frames,
})
}
#[must_use]
pub fn zeros(n_layers: usize, n_tokens: usize, n_frames: usize) -> Self {
Self {
weights: vec![0.0; n_layers * n_tokens * n_frames],
n_layers,
n_tokens,
n_frames,
}
}
#[must_use]
pub fn shape(&self) -> (usize, usize, usize) {
(self.n_layers, self.n_tokens, self.n_frames)
}
#[must_use]
pub fn get_attention(&self, layer: usize, token: usize) -> Option<&[f32]> {
if layer >= self.n_layers || token >= self.n_tokens {
return None;
}
let start = (layer * self.n_tokens + token) * self.n_frames;
let end = start + self.n_frames;
Some(&self.weights[start..end])
}
#[must_use]
pub fn peak_frame(&self, token: usize) -> Option<usize> {
if token >= self.n_tokens || self.n_frames == 0 {
return None;
}
let mut avg_attention = vec![0.0f32; self.n_frames];
for layer in 0..self.n_layers {
if let Some(attn) = self.get_attention(layer, token) {
for (i, &w) in attn.iter().enumerate() {
avg_attention[i] += w;
}
}
}
avg_attention
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
}
#[must_use]
pub fn attention_entropy(&self, layer: usize, token: usize) -> Option<f32> {
let attn = self.get_attention(layer, token)?;
let mut entropy = 0.0f32;
for &p in attn {
if p > 1e-10 {
entropy -= p * p.ln();
}
}
Some(entropy)
}
#[must_use]
pub fn is_healthy(&self) -> bool {
if self.weights.is_empty() {
return true;
}
let n = self.weights.len() as f32;
let mean: f32 = self.weights.iter().sum::<f32>() / n;
let variance: f32 = self.weights.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
let std = variance.sqrt();
std > 0.01
}
#[must_use]
pub fn as_slice(&self) -> &[f32] {
&self.weights
}
}
#[derive(Debug, Clone)]
pub struct LanguageDetection {
language: String,
confidence: f32,
alternatives: Vec<(String, f32)>,
}
impl LanguageDetection {
#[must_use]
pub fn new(language: impl Into<String>, confidence: f32) -> Self {
Self {
language: language.into(),
confidence: confidence.clamp(0.0, 1.0),
alternatives: Vec::new(),
}
}
#[must_use]
pub fn with_alternative(mut self, language: impl Into<String>, confidence: f32) -> Self {
self.alternatives
.push((language.into(), confidence.clamp(0.0, 1.0)));
self
}
#[must_use]
pub fn language(&self) -> &str {
&self.language
}
#[must_use]
pub fn confidence(&self) -> f32 {
self.confidence
}
#[must_use]
pub fn alternatives(&self) -> &[(String, f32)] {
&self.alternatives
}
#[must_use]
pub fn is_confident(&self, threshold: f32) -> bool {
self.confidence >= threshold
}
#[must_use]
pub fn top_languages(&self, n: usize) -> Vec<(&str, f32)> {
let mut all: Vec<(&str, f32)> = vec![(&self.language, self.confidence)];
all.extend(self.alternatives.iter().map(|(l, c)| (l.as_str(), *c)));
all.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
all.truncate(n);
all
}
}
impl Default for LanguageDetection {
fn default() -> Self {
Self::new("en", 1.0) }
}
mod language_detect;
pub use language_detect::*;
#[cfg(test)]
mod tests;