Skip to main content

hematite/ui/
voice.rs

1use crate::agent::inference::InferenceEvent;
2#[cfg(feature = "embedded-voice-assets")]
3use kokoros::tts::koko::TTSKoko;
4#[cfg(feature = "embedded-voice-assets")]
5use rodio::OutputStream;
6use rodio::Sink;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::mpsc;
9use std::sync::Arc;
10use tokio::sync::mpsc as tokio_mpsc;
11
12/// Manages the local Text-to-Speech pipeline.
13/// Uses the all-Rust `kokoros` engine for streaming synthesis.
14pub struct VoiceManager {
15    sender: mpsc::SyncSender<String>,
16    enabled: Arc<AtomicBool>,
17    available: Arc<AtomicBool>,
18    cancelled: Arc<AtomicBool>, // Immediate abort flag
19    sink: Arc<tokio::sync::Mutex<Option<Sink>>>,
20    /// Currently active voice ID — updated live by /voice command.
21    current_voice: Arc<std::sync::Mutex<String>>,
22    /// Speech speed multiplier (0.5–2.0). Read at synthesis time.
23    current_speed: Arc<std::sync::Mutex<f32>>,
24    /// Output volume (0.0–3.0). Applied to the rodio Sink.
25    current_volume: Arc<std::sync::Mutex<f32>>,
26}
27
28impl VoiceManager {
29    pub fn new(event_tx: tokio_mpsc::Sender<InferenceEvent>) -> Self {
30        let cfg = crate::agent::config::load_config();
31        let initial_voice = crate::agent::config::effective_voice(&cfg);
32        let initial_speed = crate::agent::config::effective_voice_speed(&cfg);
33        let initial_volume = crate::agent::config::effective_voice_volume(&cfg);
34        // Large buffer so tokens arriving during model load (~30-60s) aren't dropped.
35        let (tx, rx) = mpsc::sync_channel::<String>(1024);
36        let enabled = Arc::new(AtomicBool::new(true));
37        let available = Arc::new(AtomicBool::new(cfg!(feature = "embedded-voice-assets")));
38        let cancelled = Arc::new(AtomicBool::new(false));
39        let enabled_ctx = enabled.clone();
40        #[cfg(not(feature = "embedded-voice-assets"))]
41        let available_ctx = available.clone();
42        let _cancelled_ctx = cancelled.clone();
43        let sink_shared = Arc::new(tokio::sync::Mutex::new(None));
44        let current_voice = Arc::new(std::sync::Mutex::new(initial_voice));
45        let current_speed = Arc::new(std::sync::Mutex::new(initial_speed));
46        let current_volume = Arc::new(std::sync::Mutex::new(initial_volume));
47        let _voice_synth = Arc::clone(&current_voice);
48        let _speed_synth = Arc::clone(&current_speed);
49        let _volume_synth = Arc::clone(&current_volume);
50        let sink_manager_clone = Arc::clone(&sink_shared);
51
52        // Dedicated thread for voice synthesis and playback
53        // This solves the 'rodio::OutputStream is not Send' issue.
54        let _ = std::thread::Builder::new()
55            .name("VoiceManager".into())
56            .stack_size(32 * 1024 * 1024) // 32MB Stack for deep ONNX graph optimization
57            .spawn(move || {
58                #[cfg(not(feature = "embedded-voice-assets"))]
59                {
60                    enabled_ctx.store(false, Ordering::SeqCst);
61                    available_ctx.store(false, Ordering::SeqCst);
62                    let _ = event_tx.blocking_send(InferenceEvent::VoiceStatus(
63                        "Voice Engine: Disabled in crates.io/source build (use packaged releases for baked-in voice).".into(),
64                    ));
65                    while rx.recv().is_ok() {}
66                    return;
67                }
68
69                #[cfg(feature = "embedded-voice-assets")]
70                {
71                let mut _stream: Option<OutputStream> = None;
72
73                let _ = event_tx.blocking_send(InferenceEvent::VoiceStatus(
74                    "Voice Engine: Initializing Audio Pipeline...".into(),
75                ));
76                let _ = event_tx.blocking_send(InferenceEvent::VoiceStatus(
77                    "Voice Engine: Activating Baked-In Weights...".into(),
78                ));
79
80                // --- STATIC BAKE: Include weights in binary ---
81                const MODEL_BYTES: &[u8] =
82                    include_bytes!("../../.hematite/assets/voice/kokoro-v1.0.onnx");
83                const VOICES_BYTES: &[u8] =
84                    include_bytes!("../../.hematite/assets/voice/voices.bin");
85
86                let _ = event_tx.blocking_send(InferenceEvent::VoiceStatus(
87                    "Voice Engine: Loading voice model...".into(),
88                ));
89
90                // Catch panics from ONNX Runtime init (e.g. API version mismatch with system DLL)
91                let tts_result = std::panic::catch_unwind(|| {
92                    TTSKoko::new_from_memory(MODEL_BYTES, VOICES_BYTES)
93                });
94
95                let tts = match tts_result {
96                    Ok(Ok(engine)) => {
97                        enabled_ctx.store(true, Ordering::SeqCst);
98                        if let Ok((s, handle)) = OutputStream::try_default() {
99                            _stream = Some(s);
100                            if let Ok(new_sink) = Sink::try_new(&handle) {
101                                let mut lock = sink_shared.blocking_lock();
102                                *lock = Some(new_sink);
103                            }
104                            let _ = event_tx.blocking_send(InferenceEvent::VoiceStatus(
105                                "Voice Engine: Vibrant & Ready ✅".into(),
106                            ));
107                        } else {
108                            let _ = event_tx.blocking_send(InferenceEvent::VoiceStatus(
109                                "Voice Engine: ERROR - No audio device found ❌".into(),
110                            ));
111                        }
112                        Some(engine)
113                    }
114                    Ok(Err(e)) => {
115                        let _ = event_tx.blocking_send(InferenceEvent::VoiceStatus(format!(
116                            "Voice Engine: ERROR - {} ❌",
117                            e
118                        )));
119                        None
120                    }
121                    Err(panic_val) => {
122                        let msg = panic_val
123                            .downcast_ref::<String>()
124                            .map(|s| s.as_str())
125                            .or_else(|| panic_val.downcast_ref::<&str>().copied())
126                            .unwrap_or("unknown panic");
127                        let _ = event_tx.blocking_send(InferenceEvent::VoiceStatus(format!(
128                            "Voice Engine: CRASH - {} ❌",
129                            msg
130                        )));
131                        None
132                    }
133                };
134
135                // Stage 2: Background Synthesizer
136                let (synth_tx, mut synth_rx) = tokio_mpsc::channel::<String>(64);
137                let tts_shared = Arc::new(tokio::sync::Mutex::new(tts));
138                let tts_synth_clone = Arc::clone(&tts_shared);
139                let sink_synth_clone = Arc::clone(&sink_shared);
140                let event_tx_synth = event_tx.clone();
141
142                std::thread::spawn(move || {
143                    let rt = tokio::runtime::Builder::new_current_thread()
144                        .enable_all()
145                        .build()
146                        .unwrap();
147
148                    rt.block_on(async {
149                        while let Some(to_speak) = synth_rx.recv().await {
150                            let mut engine_opt = tts_synth_clone.lock().await;
151                            if let Some(ref mut engine) = *engine_opt {
152                                let voice_id = _voice_synth
153                                    .lock()
154                                    .map(|v| v.clone())
155                                    .unwrap_or_else(|_| "af_sky".to_string());
156                                let speed = _speed_synth.lock().map(|v| *v).unwrap_or(1.0);
157                                let volume = _volume_synth.lock().map(|v| *v).unwrap_or(1.0);
158                                let res = engine.tts_raw_audio_streaming(
159                                    &to_speak,
160                                    "en-us",
161                                    &voice_id,
162                                    speed,
163                                    None,
164                                    None,
165                                    None,
166                                    None,
167                                    |chunk| {
168                                        if _cancelled_ctx.load(Ordering::SeqCst) {
169                                            return Err(Box::new(std::io::Error::new(
170                                                std::io::ErrorKind::Interrupted,
171                                                "Silenced",
172                                            )));
173                                        }
174                                        if !chunk.is_empty() {
175                                            if let Ok(mut snk_opt) = sink_synth_clone.try_lock() {
176                                                if let Some(ref mut snk) = *snk_opt {
177                                                    snk.set_volume(volume);
178                                                    let source = rodio::buffer::SamplesBuffer::new(
179                                                        1, 24000, chunk,
180                                                    );
181                                                    snk.append(source);
182                                                    snk.play();
183                                                }
184                                            }
185                                        }
186                                        Ok(())
187                                    },
188                                );
189                                if let Err(e) = res {
190                                    let err_str = e.to_string();
191                                    if err_str != "Silenced" && !err_str.contains("Expand node") && !err_str.contains("invalid expand shape") {
192                                        let _ = event_tx_synth
193                                            .send(InferenceEvent::VoiceStatus(format!(
194                                                "Audio Pipeline: Synthesis Error - {}",
195                                                err_str
196                                            )))
197                                            .await;
198                                    }
199                                }
200                            }
201                            drop(engine_opt);
202                        }
203                    });
204                });
205
206                // Stage 1: Token Collector — builds tokens into sentences, then forwards to Stage 2.
207                // Runs after model load. Tokens that arrived during load are buffered in the 1024-cap channel.
208                let mut sentence_buffer = String::new();
209                let mut last_activity = std::time::Instant::now();
210
211                loop {
212                    let timeout = std::time::Duration::from_millis(150);
213                    let result = rx.recv_timeout(timeout);
214
215                    let token = match result {
216                        Ok(t) => {
217                            last_activity = std::time::Instant::now();
218                            Some(t)
219                        }
220                        Err(mpsc::RecvTimeoutError::Timeout) => {
221                            if !sentence_buffer.is_empty() && last_activity.elapsed() > timeout {
222                                None
223                            } else {
224                                continue;
225                            }
226                        }
227                        Err(mpsc::RecvTimeoutError::Disconnected) => break,
228                    };
229
230                    if let Some(ref text) = token {
231                        if !enabled_ctx.load(Ordering::Relaxed) || text == "\x03" {
232                            sentence_buffer.clear();
233                            continue;
234                        }
235                        if text == "\x04" {
236                            if !sentence_buffer.is_empty() {
237                                let to_speak = sentence_buffer.trim().to_string();
238                                sentence_buffer.clear();
239                                let _ = synth_tx.blocking_send(to_speak);
240                            }
241                            continue;
242                        }
243                        sentence_buffer.push_str(text);
244                    }
245
246                    let to_speak = sentence_buffer.trim().to_string();
247                    let has_punctuation = to_speak.ends_with('.')
248                        || to_speak.ends_with('!')
249                        || to_speak.ends_with('?')
250                        || to_speak.ends_with(':')
251                        || to_speak.ends_with('\n');
252
253                    let is_word_boundary = token
254                        .as_ref()
255                        .map(|t| t.starts_with(' ') || t.starts_with('\n') || t.starts_with('\t'))
256                        .unwrap_or(true);
257
258                    let is_done = token.is_none();
259
260                    if (!to_speak.is_empty() && has_punctuation && is_word_boundary)
261                        || (is_done && !to_speak.is_empty())
262                    {
263                        sentence_buffer.clear();
264                        let _ = synth_tx.blocking_send(to_speak);
265                    }
266                }
267                }
268            });
269
270        Self {
271            sender: tx,
272            enabled,
273            available,
274            cancelled,
275            sink: sink_manager_clone,
276            current_voice,
277            current_speed,
278            current_volume,
279        }
280    }
281
282    pub fn speak(&self, text: String) {
283        if self.enabled.load(Ordering::Relaxed) {
284            // New utterance: reset cancellation
285            self.cancelled.store(false, Ordering::SeqCst);
286            let _ = self.sender.try_send(text);
287        }
288    }
289
290    /// Forces a flush of the current sentence buffer.
291    pub fn stop(&self) {
292        self.cancelled.store(true, Ordering::SeqCst);
293        let _ = self.sender.try_send("\x03".to_string());
294        if let Ok(mut lock) = self.sink.try_lock() {
295            if let Some(sink) = lock.as_mut() {
296                sink.stop();
297                sink.pause();
298                sink.play();
299            }
300        }
301    }
302
303    pub fn flush(&self) {
304        if self.enabled.load(Ordering::Relaxed) {
305            let _ = self.sender.try_send("\x04".to_string());
306        }
307    }
308
309    pub fn toggle(&self) -> bool {
310        if !self.available.load(Ordering::Relaxed) {
311            self.enabled.store(false, Ordering::Relaxed);
312            return false;
313        }
314        let current = self.enabled.load(Ordering::Relaxed);
315        let next = !current;
316        self.enabled.store(next, Ordering::Relaxed);
317        next
318    }
319
320    pub fn is_enabled(&self) -> bool {
321        self.available.load(Ordering::Relaxed) && self.enabled.load(Ordering::Relaxed)
322    }
323
324    pub fn is_available(&self) -> bool {
325        self.available.load(Ordering::Relaxed)
326    }
327
328    /// Change the active voice. Takes effect on the next spoken sentence.
329    pub fn set_voice(&self, voice_id: &str) {
330        if let Ok(mut v) = self.current_voice.lock() {
331            *v = voice_id.to_string();
332        }
333    }
334
335    pub fn current_voice_id(&self) -> String {
336        self.current_voice
337            .lock()
338            .map(|v| v.clone())
339            .unwrap_or_else(|_| "af_sky".to_string())
340    }
341
342    pub fn set_speed(&self, speed: f32) {
343        if let Ok(mut v) = self.current_speed.lock() {
344            *v = speed.clamp(0.5, 2.0);
345        }
346    }
347
348    pub fn set_volume(&self, volume: f32) {
349        if let Ok(mut v) = self.current_volume.lock() {
350            *v = volume.clamp(0.0, 3.0);
351        }
352    }
353}
354
355/// All voices baked into voices.bin, grouped for display.
356pub const VOICE_LIST: &[(&str, &str)] = &[
357    ("af_alloy", "American Female — Alloy"),
358    ("af_aoede", "American Female — Aoede"),
359    ("af_bella", "American Female — Bella ⭐"),
360    ("af_heart", "American Female — Heart ⭐"),
361    ("af_jessica", "American Female — Jessica"),
362    ("af_kore", "American Female — Kore"),
363    ("af_nicole", "American Female — Nicole"),
364    ("af_nova", "American Female — Nova"),
365    ("af_river", "American Female — River"),
366    ("af_sarah", "American Female — Sarah"),
367    ("af_sky", "American Female — Sky (default)"),
368    ("am_adam", "American Male   — Adam"),
369    ("am_echo", "American Male   — Echo"),
370    ("am_eric", "American Male   — Eric"),
371    ("am_fenrir", "American Male   — Fenrir"),
372    ("am_liam", "American Male   — Liam"),
373    ("am_michael", "American Male   — Michael ⭐"),
374    ("am_onyx", "American Male   — Onyx"),
375    ("am_puck", "American Male   — Puck"),
376    ("bf_alice", "British Female  — Alice"),
377    ("bf_emma", "British Female  — Emma ⭐"),
378    ("bf_isabella", "British Female  — Isabella"),
379    ("bf_lily", "British Female  — Lily"),
380    ("bm_daniel", "British Male    — Daniel"),
381    ("bm_fable", "British Male    — Fable ⭐"),
382    ("bm_george", "British Male    — George ⭐"),
383    ("bm_lewis", "British Male    — Lewis"),
384    ("ef_dora", "Spanish Female  — Dora"),
385    ("em_alex", "Spanish Male    — Alex"),
386    ("ff_siwis", "French Female   — Siwis"),
387    ("hf_alpha", "Hindi Female    — Alpha"),
388    ("hf_beta", "Hindi Female    — Beta"),
389    ("hm_omega", "Hindi Male      — Omega"),
390    ("hm_psi", "Hindi Male      — Psi"),
391    ("if_sara", "Italian Female  — Sara"),
392    ("im_nicola", "Italian Male    — Nicola"),
393    ("jf_alpha", "Japanese Female — Alpha"),
394    ("jf_gongitsune", "Japanese Female — Gongitsune"),
395    ("jf_nezumi", "Japanese Female — Nezumi"),
396    ("jf_tebukuro", "Japanese Female — Tebukuro"),
397    ("jm_kumo", "Japanese Male   — Kumo"),
398    ("zf_xiaobei", "Chinese Female  — Xiaobei"),
399    ("zf_xiaoni", "Chinese Female  — Xiaoni"),
400    ("zf_xiaoxiao", "Chinese Female  — Xiaoxiao"),
401    ("zf_xiaoyi", "Chinese Female  — Xiaoyi"),
402    ("zm_yunjian", "Chinese Male    — Yunjian"),
403    ("zm_yunxi", "Chinese Male    — Yunxi"),
404    ("zm_yunxia", "Chinese Male    — Yunxia"),
405    ("zm_yunyang", "Chinese Male    — Yunyang"),
406];