Skip to main content

vona_mlx/
lib.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use serde_json::json;
4use std::{path::PathBuf, sync::Arc};
5use thiserror::Error;
6use vona_core::{
7    AudioInputFrame, AudioOutputFrame, AudioProcessingError, AudioStreamingTranscriber,
8    AudioSynthesisConfig, AudioSynthesizer, AudioTranscriber, BackendCapabilities, BackendError,
9    BackendStep, ControlEvent, ExternalContextEvent, SessionConfig, SpeechToSpeechBackend,
10    StreamingTranscriptKind, StreamingTranscriptUpdate, StreamingTranscriptionConfig,
11    StreamingTranscriptionSession,
12};
13
14#[cfg(feature = "native-mlx")]
15pub type MlxArray = mlx_rs::Array;
16
17#[cfg(not(feature = "native-mlx"))]
18#[derive(Debug, Clone)]
19pub struct MlxArray {
20    samples: Vec<f32>,
21    shape: Vec<i32>,
22}
23
24#[cfg(not(feature = "native-mlx"))]
25impl MlxArray {
26    fn from_samples(samples: &[f32]) -> Result<Self, MlxAudioError> {
27        let len = i32::try_from(samples.len()).map_err(|_| {
28            MlxAudioError::InvalidInput("audio frame is too large for mlx shape".to_string())
29        })?;
30        Ok(Self {
31            samples: samples.to_vec(),
32            shape: vec![len],
33        })
34    }
35
36    pub fn eval(&self) -> Result<(), MlxAudioError> {
37        Ok(())
38    }
39
40    pub fn as_slice<T>(&self) -> &[f32] {
41        let _ = std::marker::PhantomData::<T>;
42        &self.samples
43    }
44
45    pub fn shape(&self) -> &[i32] {
46        &self.shape
47    }
48}
49
50pub const DEFAULT_STT_MODEL_ID: &str = "distil-whisper/distil-large-v3";
51pub const DEFAULT_TTS_MODEL_ID: &str = "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16";
52pub const DEFAULT_SAMPLE_RATE_HZ: u32 = 24_000;
53
54#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
55pub struct MlxAudioConfig {
56    pub stt_model_id: String,
57    pub tts_model_id: String,
58    pub output_sample_rate_hz: u32,
59}
60
61impl Default for MlxAudioConfig {
62    fn default() -> Self {
63        Self {
64            stt_model_id: DEFAULT_STT_MODEL_ID.to_string(),
65            tts_model_id: DEFAULT_TTS_MODEL_ID.to_string(),
66            output_sample_rate_hz: DEFAULT_SAMPLE_RATE_HZ,
67        }
68    }
69}
70
71impl MlxAudioConfig {
72    pub fn from_env() -> Self {
73        Self {
74            stt_model_id: std::env::var("VONA_MLX_STT_MODEL")
75                .unwrap_or_else(|_| DEFAULT_STT_MODEL_ID.to_string()),
76            tts_model_id: std::env::var("VONA_MLX_TTS_MODEL")
77                .unwrap_or_else(|_| DEFAULT_TTS_MODEL_ID.to_string()),
78            output_sample_rate_hz: std::env::var("VONA_MLX_OUTPUT_SAMPLE_RATE")
79                .ok()
80                .and_then(|value| value.parse().ok())
81                .unwrap_or(DEFAULT_SAMPLE_RATE_HZ),
82        }
83    }
84}
85
86pub trait MlxSpeechModel: Send + Sync {
87    fn transcribe(&self, audio: &MlxArray, sample_rate_hz: u32) -> Result<String, MlxAudioError>;
88    fn synthesize(&self, text: &str, sample_rate_hz: u32) -> Result<MlxArray, MlxAudioError>;
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
92#[serde(rename_all = "snake_case")]
93pub enum MlxModelKind {
94    Speech,
95    WhisperSpeech,
96    Qwen3TtsSpeech,
97    TransformerText,
98    Qwen3NextText,
99}
100
101#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
102pub struct MlxModelLoadRequest {
103    pub model_id: String,
104    pub local_path: Option<PathBuf>,
105    pub kind: MlxModelKind,
106}
107
108impl MlxModelLoadRequest {
109    pub fn local(
110        kind: MlxModelKind,
111        model_id: impl Into<String>,
112        local_path: impl Into<PathBuf>,
113    ) -> Self {
114        Self {
115            model_id: model_id.into(),
116            local_path: Some(local_path.into()),
117            kind,
118        }
119    }
120}
121
122pub enum LoadedMlxModel {
123    Speech(Arc<dyn MlxSpeechModel>),
124    #[cfg(feature = "mlx-models-loader")]
125    TransformerText {
126        model: mlx_models::Model,
127        tokenizer: tokenizers::Tokenizer,
128    },
129    #[cfg(feature = "mlx-models-loader")]
130    Qwen3NextText {
131        model: mlx_models::Qwen3NextCausalLM,
132        tokenizer: tokenizers::Tokenizer,
133    },
134}
135
136pub trait MlxModelLoader: Send + Sync {
137    fn load_model(&self, request: MlxModelLoadRequest) -> Result<LoadedMlxModel, MlxAudioError>;
138}
139
140#[derive(Debug, Clone, Copy, Default)]
141pub struct MlxModelsLoader;
142
143impl MlxModelLoader for MlxModelsLoader {
144    fn load_model(&self, request: MlxModelLoadRequest) -> Result<LoadedMlxModel, MlxAudioError> {
145        load_with_mlx_models(request)
146    }
147}
148
149#[derive(Clone)]
150pub struct MlxAudioEngine {
151    config: MlxAudioConfig,
152    device_label: String,
153    model: Arc<dyn MlxSpeechModel>,
154}
155
156impl MlxAudioEngine {
157    pub fn init() -> Result<Self, MlxAudioError> {
158        Self::with_model(MlxAudioConfig::default(), Arc::new(UnloadedMlxSpeechModel))
159    }
160
161    pub fn from_env() -> Result<Self, MlxAudioError> {
162        Self::with_model(MlxAudioConfig::from_env(), Arc::new(UnloadedMlxSpeechModel))
163    }
164
165    pub fn with_model(
166        config: MlxAudioConfig,
167        model: Arc<dyn MlxSpeechModel>,
168    ) -> Result<Self, MlxAudioError> {
169        let device_label = assert_mlx_gpu_available()?;
170        Ok(Self {
171            config,
172            device_label,
173            model,
174        })
175    }
176
177    pub fn with_loader(
178        config: MlxAudioConfig,
179        loader: &dyn MlxModelLoader,
180        request: MlxModelLoadRequest,
181    ) -> Result<Self, MlxAudioError> {
182        match loader.load_model(request)? {
183            LoadedMlxModel::Speech(model) => Self::with_model(config, model),
184            #[cfg(feature = "mlx-models-loader")]
185            LoadedMlxModel::TransformerText { .. } | LoadedMlxModel::Qwen3NextText { .. } => {
186                Err(MlxAudioError::ModelUnavailable(
187                    "loaded mlx-models text model cannot satisfy Vona speech traits".to_string(),
188                ))
189            }
190        }
191    }
192
193    pub fn config(&self) -> &MlxAudioConfig {
194        &self.config
195    }
196
197    pub fn device_label(&self) -> &str {
198        &self.device_label
199    }
200
201    pub fn audio_array_from_frame(frame: &AudioInputFrame) -> Result<MlxArray, MlxAudioError> {
202        #[cfg(feature = "native-mlx")]
203        {
204            let len = i32::try_from(frame.samples.len()).map_err(|_| {
205                MlxAudioError::InvalidInput("audio frame is too large for mlx shape".to_string())
206            })?;
207            return Ok(MlxArray::from_slice(&frame.samples, &[len]));
208        }
209
210        #[cfg(not(feature = "native-mlx"))]
211        {
212            MlxArray::from_samples(&frame.samples)
213        }
214    }
215}
216
217#[cfg(feature = "mlx-models-loader")]
218fn load_with_mlx_models(request: MlxModelLoadRequest) -> Result<LoadedMlxModel, MlxAudioError> {
219    let local_path = request.local_path.ok_or_else(|| {
220        MlxAudioError::InvalidInput(
221            "mlx-models loader currently requires a local model directory".to_string(),
222        )
223    })?;
224
225    match request.kind {
226        MlxModelKind::Speech | MlxModelKind::WhisperSpeech | MlxModelKind::Qwen3TtsSpeech => {
227            Err(MlxAudioError::ModelUnavailable(
228                "mlx-models 0.1.x does not expose speech model loaders; use a Vona speech-model crate"
229                    .to_string(),
230            ))
231        }
232        MlxModelKind::TransformerText => {
233            let model = mlx_models::transformer::load_model(&local_path)
234                .map_err(|error| MlxAudioError::Inference(error.to_string()))?;
235            let tokenizer = mlx_models::load_tokenizer(&local_path)
236                .map_err(|error| MlxAudioError::Inference(error.to_string()))?;
237            Ok(LoadedMlxModel::TransformerText { model, tokenizer })
238        }
239        MlxModelKind::Qwen3NextText => {
240            let model = mlx_models::qwen3_next::load_qwen3_next_model(&local_path)
241                .map_err(|error| MlxAudioError::Inference(error.to_string()))?;
242            let tokenizer = mlx_models::load_tokenizer(&local_path)
243                .map_err(|error| MlxAudioError::Inference(error.to_string()))?;
244            Ok(LoadedMlxModel::Qwen3NextText { model, tokenizer })
245        }
246    }
247}
248
249#[cfg(not(feature = "mlx-models-loader"))]
250fn load_with_mlx_models(_request: MlxModelLoadRequest) -> Result<LoadedMlxModel, MlxAudioError> {
251    Err(MlxAudioError::Runtime(
252        "enable the mlx-models-loader feature to use mlx-models loading".to_string(),
253    ))
254}
255
256#[cfg(feature = "native-mlx")]
257fn assert_mlx_gpu_available() -> Result<String, MlxAudioError> {
258    use mlx_rs::{Array, Device};
259    use std::panic::AssertUnwindSafe;
260
261    let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
262        let device = Device::gpu();
263        Device::set_default(&device);
264        let probe = Array::from_slice(&[0.0_f32], &[1]);
265        probe.eval()?;
266        mlx_rs::error::Result::Ok(format!("{device}"))
267    }))
268    .map_err(|_| MlxAudioError::Runtime("MLX GPU initialization panicked".to_string()))?;
269
270    result.map_err(|error| MlxAudioError::Runtime(error.to_string()))
271}
272
273#[cfg(not(feature = "native-mlx"))]
274fn assert_mlx_gpu_available() -> Result<String, MlxAudioError> {
275    Ok("native-mlx feature disabled".to_string())
276}
277
278#[cfg(feature = "native-mlx")]
279fn eval_array(array: &MlxArray) -> Result<(), MlxAudioError> {
280    array
281        .eval()
282        .map_err(|error| MlxAudioError::Inference(error.to_string()))
283}
284
285#[cfg(not(feature = "native-mlx"))]
286fn eval_array(array: &MlxArray) -> Result<(), MlxAudioError> {
287    array.eval()
288}
289
290fn samples_from_array(array: &MlxArray) -> Vec<f32> {
291    array.as_slice::<f32>().to_vec()
292}
293
294#[cfg(test)]
295fn test_array_from_samples(samples: &[f32]) -> MlxArray {
296    #[cfg(feature = "native-mlx")]
297    {
298        let len = i32::try_from(samples.len()).unwrap();
299        return MlxArray::from_slice(samples, &[len]);
300    }
301
302    #[cfg(not(feature = "native-mlx"))]
303    {
304        MlxArray::from_samples(samples).unwrap()
305    }
306}
307
308#[derive(Debug, Clone)]
309pub struct MlxAudioSession {
310    pub config: SessionConfig,
311    pub pending_events: Vec<ExternalContextEvent>,
312}
313
314pub struct MlxStreamingTranscriptionSession {
315    engine: MlxAudioEngine,
316    config: StreamingTranscriptionConfig,
317    pcm_buffer: Vec<f32>,
318    last_inference_samples: usize,
319    recent_hypotheses: Vec<String>,
320    committed_prefix: String,
321    latest_update: Option<StreamingTranscriptUpdate>,
322    pending_decode: Option<tokio::task::JoinHandle<Result<String, AudioProcessingError>>>,
323    pending_decode_samples: usize,
324}
325
326#[derive(Debug, Error)]
327pub enum MlxAudioError {
328    #[error("MLX runtime is unavailable: {0}")]
329    Runtime(String),
330    #[error("MLX model is not loaded: {0}")]
331    ModelUnavailable(String),
332    #[error("MLX input is invalid: {0}")]
333    InvalidInput(String),
334    #[error("MLX inference failed: {0}")]
335    Inference(String),
336}
337
338impl From<MlxAudioError> for BackendError {
339    fn from(value: MlxAudioError) -> Self {
340        match value {
341            MlxAudioError::Runtime(message) | MlxAudioError::ModelUnavailable(message) => {
342                BackendError::Start(message)
343            }
344            MlxAudioError::InvalidInput(message) | MlxAudioError::Inference(message) => {
345                BackendError::Step(message)
346            }
347        }
348    }
349}
350
351impl From<MlxAudioError> for AudioProcessingError {
352    fn from(value: MlxAudioError) -> Self {
353        match value {
354            MlxAudioError::Runtime(message) => AudioProcessingError::Runtime(message),
355            MlxAudioError::ModelUnavailable(message) => {
356                AudioProcessingError::ModelUnavailable(message)
357            }
358            MlxAudioError::InvalidInput(message) => AudioProcessingError::InvalidInput(message),
359            MlxAudioError::Inference(message) => AudioProcessingError::Inference(message),
360        }
361    }
362}
363
364struct UnloadedMlxSpeechModel;
365
366impl MlxSpeechModel for UnloadedMlxSpeechModel {
367    fn transcribe(&self, _audio: &MlxArray, _sample_rate_hz: u32) -> Result<String, MlxAudioError> {
368        Err(MlxAudioError::ModelUnavailable(
369            "Distil-Whisper MLX graph loader is not implemented in mlx-models 0.1.x".to_string(),
370        ))
371    }
372
373    fn synthesize(&self, _text: &str, _sample_rate_hz: u32) -> Result<MlxArray, MlxAudioError> {
374        Err(MlxAudioError::ModelUnavailable(
375            "Qwen3-TTS MLX graph loader is not implemented in mlx-models 0.1.x".to_string(),
376        ))
377    }
378}
379
380fn event_text(events: &[ExternalContextEvent]) -> Option<String> {
381    events.iter().find_map(|event| match event.source.as_str() {
382        "vona.plan_result" | "vona.precomputed_reply" => event
383            .spoken_summary
384            .as_ref()
385            .map(|value| value.trim().to_string())
386            .filter(|value| !value.is_empty()),
387        "vona.tts_text" => event
388            .payload
389            .as_str()
390            .map(|value| value.trim().to_string())
391            .filter(|value| !value.is_empty()),
392        _ => None,
393    })
394}
395
396#[async_trait]
397impl AudioTranscriber for MlxAudioEngine {
398    async fn transcribe_audio(
399        &self,
400        input: AudioInputFrame,
401    ) -> Result<String, AudioProcessingError> {
402        let audio = Self::audio_array_from_frame(&input).map_err(AudioProcessingError::from)?;
403        self.model
404            .transcribe(&audio, input.sample_rate_hz)
405            .map_err(AudioProcessingError::from)
406    }
407}
408
409#[async_trait]
410impl AudioStreamingTranscriber for MlxAudioEngine {
411    async fn start_streaming_transcription(
412        &self,
413        config: StreamingTranscriptionConfig,
414    ) -> Result<Box<dyn StreamingTranscriptionSession>, AudioProcessingError> {
415        if config.sample_rate_hz == 0 {
416            return Err(AudioProcessingError::InvalidInput(
417                "streaming transcription sample rate must be non-zero".to_string(),
418            ));
419        }
420        if config.channels == 0 {
421            return Err(AudioProcessingError::InvalidInput(
422                "streaming transcription channel count must be non-zero".to_string(),
423            ));
424        }
425
426        Ok(Box::new(MlxStreamingTranscriptionSession {
427            engine: self.clone(),
428            config,
429            pcm_buffer: Vec::new(),
430            last_inference_samples: 0,
431            recent_hypotheses: Vec::new(),
432            committed_prefix: String::new(),
433            latest_update: None,
434            pending_decode: None,
435            pending_decode_samples: 0,
436        }))
437    }
438}
439
440#[async_trait]
441impl StreamingTranscriptionSession for MlxStreamingTranscriptionSession {
442    async fn push_audio(
443        &mut self,
444        input: AudioInputFrame,
445    ) -> Result<Option<StreamingTranscriptUpdate>, AudioProcessingError> {
446        if input.sample_rate_hz != self.config.sample_rate_hz {
447            return Err(AudioProcessingError::InvalidInput(format!(
448                "streaming transcription sample rate changed from {} to {}",
449                self.config.sample_rate_hz, input.sample_rate_hz
450            )));
451        }
452        if input.channels != self.config.channels {
453            return Err(AudioProcessingError::InvalidInput(format!(
454                "streaming transcription channel count changed from {} to {}",
455                self.config.channels, input.channels
456            )));
457        }
458        if input.samples.is_empty() {
459            return Ok(None);
460        }
461
462        self.pcm_buffer.extend(input.samples);
463        self.enforce_buffer_limit();
464
465        if let Some(update) = self.collect_finished_decode(false).await? {
466            return Ok(Some(update));
467        }
468
469        let min_samples = samples_for_ms(self.config.sample_rate_hz, self.config.min_buffer_ms);
470        if self.pcm_buffer.len() < min_samples {
471            return Ok(None);
472        }
473
474        let step_samples = samples_for_ms(self.config.sample_rate_hz, self.config.step_ms);
475        if self
476            .pcm_buffer
477            .len()
478            .saturating_sub(self.last_inference_samples)
479            < step_samples
480        {
481            return Ok(None);
482        }
483
484        if self.pending_decode.is_none() {
485            self.last_inference_samples = self.pcm_buffer.len();
486            self.pending_decode_samples = self.pcm_buffer.len();
487            self.pending_decode = Some(spawn_mlx_streaming_decode(
488                self.engine.clone(),
489                self.config.sample_rate_hz,
490                self.config.channels,
491                self.pcm_buffer.clone(),
492            ));
493        }
494        Ok(None)
495    }
496
497    async fn finish(&mut self) -> Result<Option<StreamingTranscriptUpdate>, AudioProcessingError> {
498        if self.pcm_buffer.is_empty() {
499            return Ok(None);
500        }
501        if let Some(update) = self.collect_finished_decode(false).await?
502            && !update.text.trim().is_empty()
503        {
504            self.latest_update = Some(update);
505        }
506
507        if self.pending_decode.is_some() {
508            if self.pending_decode_samples == self.pcm_buffer.len() {
509                if let Some(update) = self.collect_finished_decode(true).await? {
510                    return Ok(Some(StreamingTranscriptUpdate {
511                        kind: StreamingTranscriptKind::Final,
512                        text: update.text,
513                        stability_passes: update.stability_passes,
514                        total_audio_ms: self.total_audio_ms(),
515                    }));
516                }
517            } else {
518                self.pending_decode = None;
519                self.pending_decode_samples = 0;
520            }
521        }
522        self.transcribe_current(true).await
523    }
524}
525
526impl MlxStreamingTranscriptionSession {
527    async fn collect_finished_decode(
528        &mut self,
529        wait: bool,
530    ) -> Result<Option<StreamingTranscriptUpdate>, AudioProcessingError> {
531        let should_collect = self
532            .pending_decode
533            .as_ref()
534            .is_some_and(|handle| wait || handle.is_finished());
535        if !should_collect {
536            return Ok(None);
537        }
538
539        let handle = self.pending_decode.take().expect("checked pending decode");
540        self.pending_decode_samples = 0;
541        let transcript = handle.await.map_err(|err| {
542            AudioProcessingError::Runtime(format!("MLX streaming STT task join failed: {err}"))
543        })??;
544        self.accept_transcript(transcript, false)
545    }
546
547    fn enforce_buffer_limit(&mut self) {
548        let max_samples = samples_for_ms(self.config.sample_rate_hz, self.config.max_buffer_ms);
549        if self.pcm_buffer.len() <= max_samples {
550            return;
551        }
552
553        let overflow = self.pcm_buffer.len().saturating_sub(max_samples);
554        self.pcm_buffer.drain(0..overflow);
555        self.last_inference_samples = self.last_inference_samples.saturating_sub(overflow);
556        self.pending_decode_samples = self.pending_decode_samples.saturating_sub(overflow);
557    }
558
559    async fn transcribe_current(
560        &mut self,
561        is_final: bool,
562    ) -> Result<Option<StreamingTranscriptUpdate>, AudioProcessingError> {
563        let transcript = self.decode_current_buffer().await?;
564        self.last_inference_samples = self.pcm_buffer.len();
565        self.accept_transcript(transcript, is_final)
566    }
567
568    async fn decode_current_buffer(&self) -> Result<String, AudioProcessingError> {
569        self.engine
570            .transcribe_audio(AudioInputFrame {
571                sequence: 0,
572                sample_rate_hz: self.config.sample_rate_hz,
573                channels: self.config.channels,
574                samples: self.pcm_buffer.clone(),
575            })
576            .await
577            .map(|value| value.split_whitespace().collect::<Vec<_>>().join(" "))
578    }
579
580    fn accept_transcript(
581        &mut self,
582        transcript: String,
583        is_final: bool,
584    ) -> Result<Option<StreamingTranscriptUpdate>, AudioProcessingError> {
585        self.recent_hypotheses.push(transcript.clone());
586        let max_hypotheses = self.config.stability_passes.max(1) as usize;
587        if self.recent_hypotheses.len() > max_hypotheses {
588            self.recent_hypotheses.remove(0);
589        }
590
591        if is_final {
592            self.committed_prefix = transcript.clone();
593            let update = StreamingTranscriptUpdate {
594                kind: StreamingTranscriptKind::Final,
595                text: transcript,
596                stability_passes: self.recent_hypotheses.len() as u32,
597                total_audio_ms: self.total_audio_ms(),
598            };
599            self.latest_update = Some(update.clone());
600            return Ok(Some(update));
601        }
602
603        if self.recent_hypotheses.len() < max_hypotheses {
604            return Ok(None);
605        }
606
607        let committed_candidate = longest_common_word_prefix(&self.recent_hypotheses);
608        if committed_candidate.is_empty() || committed_candidate == self.committed_prefix {
609            return Ok(None);
610        }
611        if !self.committed_prefix.is_empty()
612            && !committed_candidate.starts_with(&self.committed_prefix)
613        {
614            return Ok(None);
615        }
616
617        self.committed_prefix = committed_candidate.clone();
618        let update = StreamingTranscriptUpdate {
619            kind: StreamingTranscriptKind::Partial,
620            text: committed_candidate,
621            stability_passes: self.recent_hypotheses.len() as u32,
622            total_audio_ms: self.total_audio_ms(),
623        };
624        self.latest_update = Some(update.clone());
625        Ok(Some(update))
626    }
627
628    fn total_audio_ms(&self) -> u64 {
629        ((self.pcm_buffer.len() as u64) * 1000) / self.config.sample_rate_hz as u64
630    }
631}
632
633fn spawn_mlx_streaming_decode(
634    engine: MlxAudioEngine,
635    sample_rate_hz: u32,
636    channels: u16,
637    samples: Vec<f32>,
638) -> tokio::task::JoinHandle<Result<String, AudioProcessingError>> {
639    tokio::spawn(async move {
640        engine
641            .transcribe_audio(AudioInputFrame {
642                sequence: 0,
643                sample_rate_hz,
644                channels,
645                samples,
646            })
647            .await
648            .map(|value| value.split_whitespace().collect::<Vec<_>>().join(" "))
649    })
650}
651
652fn samples_for_ms(sample_rate_hz: u32, ms: u32) -> usize {
653    ((sample_rate_hz as u64 * ms as u64) / 1000) as usize
654}
655
656fn longest_common_word_prefix(hypotheses: &[String]) -> String {
657    let Some(first) = hypotheses.first() else {
658        return String::new();
659    };
660    let mut prefix: Vec<&str> = first.split_whitespace().collect();
661    for hypothesis in hypotheses.iter().skip(1) {
662        let words: Vec<&str> = hypothesis.split_whitespace().collect();
663        let common_len = prefix
664            .iter()
665            .zip(words.iter())
666            .take_while(|(left, right)| left.eq_ignore_ascii_case(right))
667            .count();
668        prefix.truncate(common_len);
669        if prefix.is_empty() {
670            break;
671        }
672    }
673    prefix.join(" ")
674}
675
676#[async_trait]
677impl AudioSynthesizer for MlxAudioEngine {
678    async fn synthesize_audio(
679        &self,
680        text: String,
681        config: AudioSynthesisConfig,
682    ) -> Result<AudioOutputFrame, AudioProcessingError> {
683        let speech = self
684            .model
685            .synthesize(&text, config.sample_rate_hz)
686            .map_err(AudioProcessingError::from)?;
687        eval_array(&speech).map_err(AudioProcessingError::from)?;
688
689        Ok(AudioOutputFrame {
690            sequence: config.sequence,
691            sample_rate_hz: config.sample_rate_hz,
692            channels: config.channels,
693            samples: samples_from_array(&speech),
694            is_filler: false,
695        })
696    }
697}
698
699#[async_trait]
700impl SpeechToSpeechBackend for MlxAudioEngine {
701    type Session = MlxAudioSession;
702
703    fn capabilities(&self) -> BackendCapabilities {
704        BackendCapabilities {
705            supports_context_injection: true,
706            supports_style_conditioning: false,
707            supports_word_timestamps: false,
708            ..BackendCapabilities::default()
709        }
710    }
711
712    async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, BackendError> {
713        Ok(MlxAudioSession {
714            config,
715            pending_events: Vec::new(),
716        })
717    }
718
719    async fn step(
720        &self,
721        session: &mut Self::Session,
722        input: AudioInputFrame,
723    ) -> Result<BackendStep, BackendError> {
724        let sequence = input.sequence;
725        let transcript = self
726            .transcribe_audio(input)
727            .await
728            .map_err(|error| BackendError::Step(error.to_string()))?;
729
730        let pending_events = std::mem::take(&mut session.pending_events);
731        let reply_text = event_text(&pending_events).unwrap_or_else(|| transcript.clone());
732        let output_audio = self
733            .synthesize_audio(
734                reply_text,
735                AudioSynthesisConfig {
736                    sequence,
737                    sample_rate_hz: self.config.output_sample_rate_hz,
738                    channels: session.config.channels,
739                },
740            )
741            .await
742            .map_err(|error| BackendError::Step(error.to_string()))?;
743
744        Ok(BackendStep {
745            output_audio: vec![output_audio],
746            control_events: vec![ControlEvent::Diagnostic {
747                message: format!("vona-mlx device {}", self.device_label),
748            }],
749            transcript: Some(transcript),
750            finished: false,
751            debug_payload: Some(json!({
752                "stt_model_id": self.config.stt_model_id,
753                "tts_model_id": self.config.tts_model_id,
754            })),
755        })
756    }
757
758    async fn inject_event(
759        &self,
760        session: &mut Self::Session,
761        event: ExternalContextEvent,
762    ) -> Result<(), BackendError> {
763        session.pending_events.push(event);
764        Ok(())
765    }
766
767    async fn end_session(&self, _session: Self::Session) -> Result<(), BackendError> {
768        Ok(())
769    }
770}
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775
776    struct EchoModel;
777
778    impl MlxSpeechModel for EchoModel {
779        fn transcribe(
780            &self,
781            _audio: &MlxArray,
782            _sample_rate_hz: u32,
783        ) -> Result<String, MlxAudioError> {
784            Ok("hello".to_string())
785        }
786
787        fn synthesize(&self, _text: &str, _sample_rate_hz: u32) -> Result<MlxArray, MlxAudioError> {
788            Ok(test_array_from_samples(&[0.0_f32, 0.25, -0.25]))
789        }
790    }
791
792    struct LengthAwareModel;
793
794    impl MlxSpeechModel for LengthAwareModel {
795        fn transcribe(
796            &self,
797            audio: &MlxArray,
798            _sample_rate_hz: u32,
799        ) -> Result<String, MlxAudioError> {
800            if audio.shape()[0] < 3_000 {
801                Ok("stale partial".to_string())
802            } else {
803                Ok("final transcript".to_string())
804            }
805        }
806
807        fn synthesize(&self, _text: &str, _sample_rate_hz: u32) -> Result<MlxArray, MlxAudioError> {
808            Ok(test_array_from_samples(&[0.0]))
809        }
810    }
811
812    #[test]
813    fn builds_mlx_audio_array_from_vona_frame() {
814        let frame = AudioInputFrame {
815            sequence: 0,
816            sample_rate_hz: 16_000,
817            channels: 1,
818            samples: vec![0.1, 0.2],
819        };
820
821        let array = MlxAudioEngine::audio_array_from_frame(&frame).unwrap();
822        assert_eq!(array.shape(), &[2]);
823    }
824
825    #[test]
826    fn extracts_tts_text_from_events() {
827        let events = vec![ExternalContextEvent {
828            source: "vona.tts_text".to_string(),
829            spoken_summary: None,
830            payload: json!("speak this"),
831        }];
832
833        assert_eq!(event_text(&events), Some("speak this".to_string()));
834    }
835
836    #[test]
837    fn mlx_models_loader_is_feature_gated() {
838        #[cfg(not(feature = "mlx-models-loader"))]
839        {
840            let request = MlxModelLoadRequest::local(
841                MlxModelKind::TransformerText,
842                "local-test-model",
843                "/tmp/model",
844            );
845            let result = MlxModelsLoader.load_model(request);
846            assert!(matches!(result, Err(MlxAudioError::Runtime(_))));
847        }
848    }
849
850    #[test]
851    fn streaming_common_prefix_is_word_stable() {
852        let hypotheses = vec![
853            "focus on the first task today".to_string(),
854            "focus on the first useful task".to_string(),
855        ];
856
857        assert_eq!(
858            longest_common_word_prefix(&hypotheses),
859            "focus on the first"
860        );
861    }
862
863    #[tokio::test]
864    async fn injected_model_runs_backend_step() {
865        let engine = MlxAudioEngine {
866            config: MlxAudioConfig::default(),
867            device_label: "test".to_string(),
868            model: Arc::new(EchoModel),
869        };
870        let mut session = engine
871            .start_session(SessionConfig::default())
872            .await
873            .unwrap();
874        let step = engine
875            .step(
876                &mut session,
877                AudioInputFrame {
878                    sequence: 7,
879                    sample_rate_hz: 16_000,
880                    channels: 1,
881                    samples: vec![0.0, 1.0],
882                },
883            )
884            .await
885            .unwrap();
886
887        assert_eq!(step.transcript, Some("hello".to_string()));
888        assert_eq!(step.output_audio[0].samples, vec![0.0, 0.25, -0.25]);
889    }
890
891    #[tokio::test]
892    async fn streaming_finish_decodes_current_buffer_instead_of_stale_partial() {
893        let engine = MlxAudioEngine {
894            config: MlxAudioConfig::default(),
895            device_label: "test".to_string(),
896            model: Arc::new(LengthAwareModel),
897        };
898        let mut session = engine
899            .start_streaming_transcription(StreamingTranscriptionConfig {
900                sample_rate_hz: 16_000,
901                channels: 1,
902                step_ms: 50,
903                min_buffer_ms: 50,
904                max_buffer_ms: 30_000,
905                stability_passes: 1,
906            })
907            .await
908            .unwrap();
909
910        let _ = session
911            .push_audio(AudioInputFrame {
912                sequence: 1,
913                sample_rate_hz: 16_000,
914                channels: 1,
915                samples: vec![0.0; 1_000],
916            })
917            .await
918            .unwrap();
919        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
920        let _ = session
921            .push_audio(AudioInputFrame {
922                sequence: 2,
923                sample_rate_hz: 16_000,
924                channels: 1,
925                samples: vec![0.0; 3_000],
926            })
927            .await
928            .unwrap();
929
930        let final_update = session.finish().await.unwrap().unwrap();
931        assert_eq!(final_update.kind, StreamingTranscriptKind::Final);
932        assert_eq!(final_update.text, "final transcript");
933    }
934}