use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use crate::error::{AudioError, AudioResult};
use crate::frame::AudioFrame;
use crate::registry::LocalModelRegistry;
use crate::traits::{SttOptions, SttProvider, Transcript};
use super::config::MlxSttConfig;
pub struct MlxSttProvider {
config: MlxSttConfig,
#[allow(dead_code)]
model_path: std::path::PathBuf,
#[allow(dead_code)] tokenizer: tokenizers::Tokenizer,
}
impl MlxSttProvider {
#[doc(hidden)]
pub fn with_dummy() -> Self {
Self {
config: MlxSttConfig::default(),
model_path: std::path::PathBuf::from("/tmp/model"),
tokenizer: tokenizers::Tokenizer::new(tokenizers::models::bpe::BPE::default()),
}
}
pub async fn new(config: MlxSttConfig, registry: &LocalModelRegistry) -> AudioResult<Self> {
let model_path = registry.get_or_download(&config.model_id).await?;
let tokenizer = Self::load_tokenizer(&model_path)?;
Ok(Self { config, model_path, tokenizer })
}
pub async fn default_whisper() -> AudioResult<Self> {
let registry = LocalModelRegistry::default();
Self::new(MlxSttConfig::default(), ®istry).await
}
fn load_tokenizer(model_path: &std::path::Path) -> AudioResult<tokenizers::Tokenizer> {
let tokenizer_path = model_path.join("tokenizer.json");
tokenizers::Tokenizer::from_file(&tokenizer_path).map_err(|e| AudioError::Stt {
provider: "MLX".into(),
message: format!("failed to load tokenizer: {e}"),
})
}
}
#[async_trait]
impl SttProvider for MlxSttProvider {
async fn transcribe(&self, audio: &AudioFrame, _opts: &SttOptions) -> AudioResult<Transcript> {
let samples: Vec<f32> = audio.samples().iter().map(|&s| s as f32 / 32768.0).collect();
if samples.is_empty() {
return Err(AudioError::Stt {
provider: "MLX".into(),
message: "empty audio input".into(),
});
}
let mel = super::mel::compute_log_mel_spectrogram(&samples, self.config.sample_rate)?;
Err(AudioError::Stt {
provider: "MLX".into(),
message: format!(
"MLX Whisper inference is not yet implemented — use a cloud STT provider instead. \
mel spectrogram {}×{} frames. Model at: {}",
mel.n_frames,
mel.n_mels,
self.model_path.display()
),
})
}
async fn transcribe_stream(
&self,
_audio: Pin<Box<dyn Stream<Item = AudioFrame> + Send>>,
_opts: &SttOptions,
) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<Transcript>> + Send>>> {
Err(AudioError::Stt {
provider: "MLX".into(),
message: "streaming transcription not yet implemented".into(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn transcribe_stream_returns_explicit_unimplemented_error() {
let provider = MlxSttProvider {
config: MlxSttConfig::default(),
model_path: std::path::PathBuf::from("/tmp/model"),
tokenizer: tokenizers::Tokenizer::new(tokenizers::models::bpe::BPE::default()),
};
let result = provider
.transcribe_stream(Box::pin(futures::stream::empty()), &SttOptions::default())
.await;
match result {
Err(AudioError::Stt { provider, message }) => {
assert_eq!(provider, "MLX");
assert!(message.contains("not yet implemented"));
}
Err(err) => panic!("unexpected audio error: {err}"),
Ok(_) => panic!("expected explicit STT error"),
}
}
}