use std::path::{Path, PathBuf};
use crate::{TranscriptionEngine, TranscriptionResult};
use super::model::MoonshineModel;
const SAMPLE_RATE: u32 = 16000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelVariant {
Tiny,
TinyAr,
TinyZh,
TinyJa,
TinyKo,
TinyUk,
TinyVi,
Base,
BaseEs,
}
impl ModelVariant {
pub fn num_layers(&self) -> usize {
match self {
ModelVariant::Tiny
| ModelVariant::TinyAr
| ModelVariant::TinyZh
| ModelVariant::TinyJa
| ModelVariant::TinyKo
| ModelVariant::TinyUk
| ModelVariant::TinyVi => 6,
ModelVariant::Base | ModelVariant::BaseEs => 8,
}
}
pub fn num_key_value_heads(&self) -> usize {
8
}
pub fn head_dim(&self) -> usize {
match self {
ModelVariant::Tiny
| ModelVariant::TinyAr
| ModelVariant::TinyZh
| ModelVariant::TinyJa
| ModelVariant::TinyKo
| ModelVariant::TinyUk
| ModelVariant::TinyVi => 36,
ModelVariant::Base | ModelVariant::BaseEs => 52,
}
}
pub fn token_rate(&self) -> usize {
match self {
ModelVariant::Tiny | ModelVariant::Base | ModelVariant::BaseEs => 6,
ModelVariant::TinyUk => 8,
ModelVariant::TinyAr
| ModelVariant::TinyZh
| ModelVariant::TinyJa
| ModelVariant::TinyKo
| ModelVariant::TinyVi => 13,
}
}
}
impl Default for ModelVariant {
fn default() -> Self {
ModelVariant::Tiny
}
}
#[derive(Debug, Clone, Default)]
pub struct MoonshineModelParams {
pub variant: ModelVariant,
}
impl MoonshineModelParams {
pub fn tiny() -> Self {
Self {
variant: ModelVariant::Tiny,
}
}
pub fn base() -> Self {
Self {
variant: ModelVariant::Base,
}
}
pub fn variant(variant: ModelVariant) -> Self {
Self { variant }
}
}
#[derive(Debug, Clone, Default)]
pub struct MoonshineInferenceParams {
pub max_length: Option<usize>,
}
pub struct MoonshineEngine {
loaded_model_path: Option<PathBuf>,
model: Option<MoonshineModel>,
variant: ModelVariant,
}
impl MoonshineEngine {
pub fn new() -> Self {
Self {
loaded_model_path: None,
model: None,
variant: ModelVariant::default(),
}
}
}
impl Default for MoonshineEngine {
fn default() -> Self {
Self::new()
}
}
impl Drop for MoonshineEngine {
fn drop(&mut self) {
self.unload_model();
}
}
impl TranscriptionEngine for MoonshineEngine {
type InferenceParams = MoonshineInferenceParams;
type ModelParams = MoonshineModelParams;
fn load_model_with_params(
&mut self,
model_path: &Path,
params: Self::ModelParams,
) -> Result<(), Box<dyn std::error::Error>> {
self.unload_model();
self.variant = params.variant;
self.model = Some(MoonshineModel::new(model_path, params.variant)?);
self.loaded_model_path = Some(model_path.to_path_buf());
log::info!(
"Loaded Moonshine {:?} model from {:?}",
params.variant,
model_path
);
Ok(())
}
fn unload_model(&mut self) {
if self.model.is_some() {
log::debug!("Unloading Moonshine model");
self.model = None;
self.loaded_model_path = None;
}
}
fn transcribe_samples(
&mut self,
samples: Vec<f32>,
params: Option<Self::InferenceParams>,
) -> Result<TranscriptionResult, Box<dyn std::error::Error>> {
let model = self
.model
.as_mut()
.ok_or_else(|| super::model::MoonshineError::ModelNotLoaded)?;
let params = params.unwrap_or_default();
let max_length = params.max_length.unwrap_or_else(|| {
let audio_duration_sec = samples.len() as f32 / SAMPLE_RATE as f32;
(audio_duration_sec * self.variant.token_rate() as f32).ceil() as usize
});
log::debug!(
"Transcribing {} samples ({:.2}s), max_length={}",
samples.len(),
samples.len() as f32 / SAMPLE_RATE as f32,
max_length
);
let tokens = model.generate(&samples, max_length)?;
let text = model.decode_tokens(&tokens)?;
Ok(TranscriptionResult {
text,
segments: None, })
}
}