use std::path::{Path, PathBuf};
use crate::{TranscriptionEngine, TranscriptionResult};
use super::streaming_model::StreamingModel;
const SAMPLE_RATE: u32 = 16000;
#[derive(Debug, Clone)]
pub struct StreamingModelParams {
pub max_tokens_per_second: f32,
pub num_threads: usize,
}
impl Default for StreamingModelParams {
fn default() -> Self {
Self {
max_tokens_per_second: 6.5,
num_threads: 0,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct StreamingInferenceParams {
pub max_length: Option<usize>,
}
pub struct MoonshineStreamingEngine {
model: Option<StreamingModel>,
loaded_model_path: Option<PathBuf>,
max_tokens_per_second: f32,
}
impl MoonshineStreamingEngine {
pub fn new() -> Self {
Self {
model: None,
loaded_model_path: None,
max_tokens_per_second: StreamingModelParams::default().max_tokens_per_second,
}
}
}
impl Default for MoonshineStreamingEngine {
fn default() -> Self {
Self::new()
}
}
impl Drop for MoonshineStreamingEngine {
fn drop(&mut self) {
self.unload_model();
}
}
impl TranscriptionEngine for MoonshineStreamingEngine {
type InferenceParams = StreamingInferenceParams;
type ModelParams = StreamingModelParams;
fn load_model_with_params(
&mut self,
model_path: &Path,
params: Self::ModelParams,
) -> Result<(), Box<dyn std::error::Error>> {
self.unload_model();
self.max_tokens_per_second = params.max_tokens_per_second;
self.model = Some(StreamingModel::new(model_path, params.num_threads)?);
self.loaded_model_path = Some(model_path.to_path_buf());
log::info!(
"Loaded Moonshine streaming model from {:?}",
model_path
);
Ok(())
}
fn unload_model(&mut self) {
if self.model.is_some() {
log::debug!("Unloading Moonshine streaming model");
self.model = None;
self.loaded_model_path = None;
}
}
fn transcribe_samples(
&mut self,
samples: Vec<f32>,
params: Option<Self::InferenceParams>,
) -> Result<TranscriptionResult, Box<dyn std::error::Error>> {
let model = self
.model
.as_mut()
.ok_or("Streaming model not loaded")?;
let max_tokens_override = params.and_then(|p| p.max_length);
log::debug!(
"Transcribing {} samples ({:.2}s) with streaming model",
samples.len(),
samples.len() as f32 / SAMPLE_RATE as f32,
);
let tokens = model.generate(&samples, self.max_tokens_per_second, max_tokens_override)?;
let text = model.decode_tokens(&tokens)?;
Ok(TranscriptionResult {
text,
segments: None,
})
}
}