use crate::error::{AudioError, AudioResult};
use crate::onnx::execution_provider::OnnxExecutionProvider;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SttBackend {
#[default]
Whisper,
DistilWhisper,
Moonshine,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum WhisperModelSize {
Tiny,
#[default]
Base,
Small,
Medium,
LargeV3Turbo,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DistilWhisperVariant {
#[default]
DistilSmallEn,
DistilMediumEn,
DistilLargeV3,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MoonshineVariant {
#[default]
Tiny,
Base,
}
#[derive(Debug, Clone)]
pub struct OnnxSttConfig {
pub backend: SttBackend,
pub whisper_size: WhisperModelSize,
pub distil_variant: DistilWhisperVariant,
pub moonshine_variant: MoonshineVariant,
pub beam_size: u8,
pub temperature: f32,
pub execution_provider: OnnxExecutionProvider,
pub language: Option<String>,
}
impl OnnxSttConfig {
pub fn builder() -> OnnxSttConfigBuilder {
OnnxSttConfigBuilder::default()
}
pub fn model_id(&self) -> &str {
match self.backend {
SttBackend::Whisper => match self.whisper_size {
WhisperModelSize::Tiny => "onnx-community/whisper-tiny",
WhisperModelSize::Base => "onnx-community/whisper-base",
WhisperModelSize::Small => "onnx-community/whisper-small",
WhisperModelSize::Medium => "onnx-community/whisper-medium",
WhisperModelSize::LargeV3Turbo => "onnx-community/whisper-large-v3-turbo",
},
SttBackend::DistilWhisper => match self.distil_variant {
DistilWhisperVariant::DistilSmallEn => "distil-whisper/distil-small.en",
DistilWhisperVariant::DistilMediumEn => "distil-whisper/distil-medium.en",
DistilWhisperVariant::DistilLargeV3 => "distil-whisper/distil-large-v3",
},
SttBackend::Moonshine => match self.moonshine_variant {
MoonshineVariant::Tiny => "usefulsensors/moonshine-tiny-onnx",
MoonshineVariant::Base => "usefulsensors/moonshine-base-onnx",
},
}
}
}
#[derive(Debug, Clone)]
pub struct OnnxSttConfigBuilder {
backend: SttBackend,
whisper_size: WhisperModelSize,
distil_variant: DistilWhisperVariant,
moonshine_variant: MoonshineVariant,
beam_size: u8,
temperature: f32,
execution_provider: OnnxExecutionProvider,
language: Option<String>,
}
impl Default for OnnxSttConfigBuilder {
fn default() -> Self {
Self {
backend: SttBackend::default(),
whisper_size: WhisperModelSize::default(),
distil_variant: DistilWhisperVariant::default(),
moonshine_variant: MoonshineVariant::default(),
beam_size: 1,
temperature: 0.0,
execution_provider: OnnxExecutionProvider::auto_detect(),
language: None,
}
}
}
impl OnnxSttConfigBuilder {
pub fn stt_backend(mut self, backend: SttBackend) -> Self {
self.backend = backend;
self
}
pub fn model_size(mut self, size: WhisperModelSize) -> Self {
self.whisper_size = size;
self
}
pub fn distil_variant(mut self, variant: DistilWhisperVariant) -> Self {
self.distil_variant = variant;
self
}
pub fn moonshine_variant(mut self, variant: MoonshineVariant) -> Self {
self.moonshine_variant = variant;
self
}
pub fn beam_size(mut self, size: u8) -> Self {
self.beam_size = size;
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = temp;
self
}
pub fn execution_provider(mut self, ep: OnnxExecutionProvider) -> Self {
self.execution_provider = ep;
self
}
pub fn language(mut self, lang: impl Into<String>) -> Self {
self.language = Some(lang.into());
self
}
pub fn build(self) -> AudioResult<OnnxSttConfig> {
if !(1..=10).contains(&self.beam_size) {
return Err(AudioError::Stt {
provider: "ONNX".into(),
message: format!("beam_size must be 1..=10, got {}", self.beam_size),
});
}
Ok(OnnxSttConfig {
backend: self.backend,
whisper_size: self.whisper_size,
distil_variant: self.distil_variant,
moonshine_variant: self.moonshine_variant,
beam_size: self.beam_size,
temperature: self.temperature,
execution_provider: self.execution_provider,
language: self.language,
})
}
}