Skip to main content

mofa_plugins/
tts.rs

1//! Text-to-Speech (TTS) Plugin Module
2//!
3//! Provides TTS capabilities using a generic TTS engine interface.
4//! This module is designed to work with multiple TTS backends.
5//!
6//! Note: The Kokoro TTS integration is currently prepared but commented out
7//! due to compatibility issues with the ort (ONNX Runtime) dependency.
8//! The plugin structure is ready for use with any TTS backend.
9
10// Model cache and download modules
11pub mod cache;
12pub mod model_downloader;
13
14// Kokoro TTS wrapper (available with kokoro feature)
15#[cfg(feature = "kokoro")]
16pub mod kokoro_wrapper;
17
18use crate::{AgentPlugin, PluginContext, PluginMetadata, PluginResult, PluginState, PluginType};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use tracing::{debug, info, warn};
23
24// ============================================================================
25// Voice Information
26// ============================================================================
27
28/// Voice metadata information
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct VoiceInfo {
31    /// Voice identifier
32    pub id: String,
33    /// Human-readable voice name
34    pub name: String,
35    /// Language code (e.g., "en-US", "zh-CN")
36    pub language: String,
37}
38
39impl VoiceInfo {
40    pub fn new(id: &str, name: &str, language: &str) -> Self {
41        Self {
42            id: id.to_string(),
43            name: name.to_string(),
44            language: language.to_string(),
45        }
46    }
47}
48
49// ============================================================================
50// TTS Plugin Configuration
51// ============================================================================
52
53/// TTS plugin configuration
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct TTSPluginConfig {
56    /// Default voice to use for synthesis
57    pub default_voice: String,
58    /// Model version to use ("v1.0" or "v1.1" for Kokoro")
59    pub model_version: String,
60    /// Streaming chunk size in bytes
61    pub streaming_chunk_size: usize,
62    /// Hugging Face model URL (e.g., "hexgrad/Kokoro-82M")
63    #[serde(default = "default_model_url")]
64    pub model_url: String,
65    /// Custom cache directory path (defaults to ~/.mofa/models/tts/)
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub cache_dir: Option<String>,
68    /// Enable automatic model download if not found in cache
69    #[serde(default = "default_auto_download")]
70    pub auto_download: bool,
71    /// Expected model checksum for validation (MD5 hex string)
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub model_checksum: Option<String>,
74    /// Download timeout in seconds
75    #[serde(default = "default_download_timeout")]
76    pub download_timeout: u64,
77}
78
79fn default_model_url() -> String {
80    "hexgrad/Kokoro-82M".to_string()
81}
82
83fn default_auto_download() -> bool {
84    true
85}
86
87fn default_download_timeout() -> u64 {
88    600 // 10 minutes default
89}
90
91impl Default for TTSPluginConfig {
92    fn default() -> Self {
93        Self {
94            default_voice: "default".to_string(),
95            model_version: "v1.1".to_string(),
96            streaming_chunk_size: 4096,
97            model_url: default_model_url(),
98            cache_dir: None,
99            auto_download: true,
100            model_checksum: None,
101            download_timeout: 600,
102        }
103    }
104}
105
106// ============================================================================
107// TTS Engine Trait
108// ============================================================================
109
110/// Abstract TTS engine trait for extensibility
111#[async_trait::async_trait]
112pub trait TTSEngine: Send + Sync {
113    /// Synthesize text to audio data
114    async fn synthesize(&self, text: &str, voice: &str) -> PluginResult<Vec<u8>>;
115
116    /// Synthesize with streaming callback for long texts
117    async fn synthesize_stream(
118        &self,
119        text: &str,
120        voice: &str,
121        callback: Box<dyn Fn(Vec<u8>) + Send + Sync>,
122    ) -> PluginResult<()>;
123
124    /// List available voices
125    async fn list_voices(&self) -> PluginResult<Vec<VoiceInfo>>;
126
127    /// Get engine name
128    fn name(&self) -> &str;
129
130    /// Get as Any for downcasting to engine-specific types
131    ///
132    /// This allows accessing engine-specific methods like stream_receiver()
133    /// on KokoroTTS after downcasting.
134    fn as_any(&self) -> &dyn std::any::Any;
135}
136
137// ============================================================================
138// Mock TTS Engine (Placeholder)
139// ============================================================================
140
141/// A mock TTS engine for testing and development.
142///
143/// This engine generates placeholder WAV audio data. It's used when
144/// a real TTS engine is not available or for testing purposes.
145pub struct MockTTSEngine {
146    config: TTSPluginConfig,
147    voices: Vec<VoiceInfo>,
148}
149
150impl MockTTSEngine {
151    /// Create a new mock TTS engine
152    pub fn new(config: TTSPluginConfig) -> Self {
153        let voices = vec![
154            VoiceInfo::new("default", "Default Voice", "en-US"),
155            VoiceInfo::new("af_heart", "Heart (Female)", "en-US"),
156            VoiceInfo::new("am_michael", "Michael (Male)", "en-US"),
157            VoiceInfo::new("bf_emma", "Emma (Female)", "en-US"),
158            VoiceInfo::new("bm_george", "George (Male)", "en-US"),
159            VoiceInfo::new("zh_female", "Chinese Female", "zh-CN"),
160        ];
161
162        Self { config, voices }
163    }
164}
165
166#[async_trait::async_trait]
167impl TTSEngine for MockTTSEngine {
168    async fn synthesize(&self, text: &str, voice: &str) -> PluginResult<Vec<u8>> {
169        debug!(
170            "[MockTTS] Synthesizing text with voice '{}': {}",
171            voice, text
172        );
173
174        // Generate a placeholder WAV file
175        let sample_rate = 24000u32;
176        let duration_sec = (text.len() as f32 / 15.0).ceil() as u32; // Rough estimate
177        let num_samples = sample_rate * duration_sec;
178        let data_size = num_samples * 2; // 16-bit samples
179
180        // Generate WAV file header
181        let mut wav_data = Vec::new();
182
183        // RIFF header
184        wav_data.extend_from_slice(b"RIFF");
185        wav_data.extend_from_slice(&(36 + data_size).to_le_bytes());
186        wav_data.extend_from_slice(b"WAVE");
187
188        // fmt chunk
189        wav_data.extend_from_slice(b"fmt ");
190        wav_data.extend_from_slice(&16u32.to_le_bytes()); // chunk size
191        wav_data.extend_from_slice(&1u16.to_le_bytes()); // audio format (PCM)
192        wav_data.extend_from_slice(&1u16.to_le_bytes()); // num channels (mono)
193        wav_data.extend_from_slice(&sample_rate.to_le_bytes());
194        wav_data.extend_from_slice(&(sample_rate * 2).to_le_bytes()); // byte rate
195        wav_data.extend_from_slice(&2u16.to_le_bytes()); // block align
196        wav_data.extend_from_slice(&16u16.to_le_bytes()); // bits per sample
197
198        // data chunk
199        wav_data.extend_from_slice(b"data");
200        wav_data.extend_from_slice(&data_size.to_le_bytes());
201
202        // Add silence (zeros) for audio data
203        wav_data.resize(wav_data.len() + data_size as usize, 0);
204
205        Ok(wav_data)
206    }
207
208    async fn synthesize_stream(
209        &self,
210        text: &str,
211        voice: &str,
212        callback: Box<dyn Fn(Vec<u8>) + Send + Sync>,
213    ) -> PluginResult<()> {
214        debug!(
215            "[MockTTS] Stream synthesizing text with voice '{}': {}",
216            voice, text
217        );
218
219        // Split text into chunks for streaming
220        let chunk_size = self.config.streaming_chunk_size;
221        let chunks: Vec<&str> = text
222            .as_bytes()
223            .chunks(chunk_size)
224            .map(|c| std::str::from_utf8(c).unwrap_or(""))
225            .collect();
226
227        for chunk in chunks {
228            if chunk.is_empty() {
229                continue;
230            }
231            let audio = self.synthesize(chunk, voice).await?;
232            callback(audio);
233            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
234        }
235
236        Ok(())
237    }
238
239    async fn list_voices(&self) -> PluginResult<Vec<VoiceInfo>> {
240        Ok(self.voices.clone())
241    }
242
243    fn name(&self) -> &str {
244        "MockTTS"
245    }
246
247    fn as_any(&self) -> &dyn std::any::Any {
248        self
249    }
250}
251
252// ============================================================================
253// Kokoro TTS Engine (Prepared but Disabled)
254// ============================================================================
255//
256// The following code structure is prepared for Kokoro TTS integration.
257// It's currently disabled due to compatibility issues with the ort
258// (ONNX Runtime) dependency version used by kokoro-tts.
259//
260// To enable when dependencies are fixed:
261// 1. Uncomment kokoro-tts dependency in Cargo.toml
262// 2. Uncomment the KokoroEngine implementation below
263// 3. Update TtsPlugin::init_plugin to use KokoroEngine
264
265/*
266use std::io::Cursor;
267
268/// Kokoro TTS engine implementation
269///
270/// This is a placeholder showing how the Kokoro TTS engine would be integrated.
271/// The actual implementation would depend on the kokoro-tts crate API.
272pub struct KokoroEngine {
273    config: TTSPluginConfig,
274    voices: Vec<VoiceInfo>,
275    // kokoro: kokoro_tts::Kokoro,  // Would be initialized when dependencies are fixed
276}
277
278impl KokoroEngine {
279    /// Create a new Kokoro TTS engine
280    pub async fn new(config: TTSPluginConfig) -> PluginResult<Self> {
281        info!(
282            "Initializing Kokoro TTS engine with model version {}",
283            config.model_version
284        );
285
286        // TODO: Initialize actual Kokoro TTS instance
287        // let kokoro = kokoro_tts::Kokoro::new()
288        //     .await
289        //     .map_err(|e| anyhow::anyhow!("Failed to initialize Kokoro: {}", e))?;
290
291        let voices = vec![
292            VoiceInfo::new("default", "Default Voice", "en-US"),
293            VoiceInfo::new("af_heart", "Heart (Female)", "en-US"),
294            VoiceInfo::new("am_michael", "Michael (Male)", "en-US"),
295            VoiceInfo::new("bf_emma", "Emma (Female)", "en-US"),
296            VoiceInfo::new("bm_george", "George (Male)", "en-US"),
297            // Add more Kokoro voices as available
298        ];
299
300        Ok(Self {
301            config,
302            voices,
303            // kokoro,
304        })
305    }
306
307    /// Get the default voice
308    pub fn default_voice(&self) -> &str {
309        &self.config.default_voice
310    }
311}
312
313#[async_trait::async_trait]
314impl TTSEngine for KokoroEngine {
315    async fn synthesize(&self, text: &str, voice: &str) -> PluginResult<Vec<u8>> {
316        debug!("Synthesizing text with voice '{}': {}", voice, text);
317
318        // TODO: Replace with actual Kokoro TTS synthesis
319        // let audio = self.kokoro
320        //     .synthesize(text, voice)
321        //     .await
322        //     .map_err(|e| anyhow::anyhow!("Synthesis failed: {}", e))?;
323
324        // For now, return placeholder
325        Ok(vec![])
326    }
327
328    async fn synthesize_stream(
329        &self,
330        text: &str,
331        voice: &str,
332        callback: Box<dyn Fn(Vec<u8>) + Send + Sync>,
333    ) -> PluginResult<()> {
334        debug!("Stream synthesizing text with voice '{}': {}", voice, text);
335
336        // TODO: Replace with actual Kokoro streaming synthesis
337        // let mut stream = self.kokoro
338        //     .synthesize_stream(text, voice)
339        //     .await
340        //     .map_err(|e| anyhow::anyhow!("Stream synthesis failed: {}", e))?;
341
342        // while let Some(chunk) = stream.next().await {
343        //     callback(chunk?);
344        // }
345
346        Ok(())
347    }
348
349    async fn list_voices(&self) -> PluginResult<Vec<VoiceInfo>> {
350        Ok(self.voices.clone())
351    }
352
353    fn name(&self) -> &str {
354        "Kokoro"
355    }
356}
357*/
358
359// ============================================================================
360// Audio Playback Helper (Optional - Requires rodio feature)
361// ============================================================================
362
363#[cfg(feature = "rodio")]
364use rodio::{Decoder, OutputStream, Sink};
365#[cfg(feature = "rodio")]
366use std::io::Cursor;
367
368/// Audio playback configuration
369#[derive(Debug, Clone)]
370pub struct AudioPlaybackConfig {
371    /// Whether to enable audio playback
372    pub enabled: bool,
373    /// Volume level (0.0 to 1.0)
374    pub volume: f32,
375}
376
377impl Default for AudioPlaybackConfig {
378    fn default() -> Self {
379        Self {
380            enabled: true,
381            volume: 0.8,
382        }
383    }
384}
385
386/// Play audio data synchronously using rodio when feature is enabled
387///
388/// This function decodes and plays WAV audio data through the default audio output.
389/// It blocks until playback completes.
390///
391/// # Arguments
392///
393/// * `audio_data` - WAV format audio data as bytes
394///
395/// # Returns
396///
397/// Returns `Ok(())` if playback succeeds, or an error if audio initialization or playback fails.
398///
399/// # Platform Support
400///
401/// - **macOS**: Works out of the box (Core Audio)
402/// - **Linux**: Requires `libasound2-dev` (ALSA)
403/// - **Windows**: Works out of the box (WasAPI)
404#[cfg(feature = "rodio")]
405pub fn play_audio(audio_data: Vec<u8>) -> PluginResult<()> {
406    info!("Playing {} bytes of audio using rodio", audio_data.len());
407
408    let cursor = Cursor::new(audio_data);
409    let (_stream, stream_handle) = OutputStream::try_default()
410        .map_err(|e| anyhow::anyhow!("Failed to get audio output: {}", e))?;
411    let sink = Sink::try_new(&stream_handle)
412        .map_err(|e| anyhow::anyhow!("Failed to create sink: {}", e))?;
413
414    let source =
415        Decoder::new(cursor).map_err(|e| anyhow::anyhow!("Failed to decode audio: {}", e))?;
416    sink.append(source);
417    sink.sleep_until_end();
418
419    Ok(())
420}
421
422/// Play audio data synchronously (fallback when rodio feature is not enabled)
423///
424/// When rodio is not enabled, this simulates playback with a delay.
425/// Enable the rodio feature for actual audio playback.
426#[cfg(not(feature = "rodio"))]
427pub fn play_audio(audio_data: Vec<u8>) -> PluginResult<()> {
428    debug!(
429        "Playing {} bytes of audio (placeholder - rodio not enabled)",
430        audio_data.len()
431    );
432
433    warn!(
434        "Audio playback is simulated. Enable the 'rodio' feature in Cargo.toml \
435         for actual audio playback support."
436    );
437
438    // Simulate playback delay based on audio size
439    let delay_ms = std::cmp::min(500, audio_data.len() as u64 / 100);
440    std::thread::sleep(std::time::Duration::from_millis(delay_ms));
441
442    Ok(())
443}
444
445/// Play audio asynchronously
446pub async fn play_audio_async(audio_data: Vec<u8>) -> PluginResult<()> {
447    tokio::task::spawn_blocking(move || play_audio(audio_data))
448        .await
449        .map_err(|e| anyhow::anyhow!("Playback task failed: {}", e))?
450}
451
452// ============================================================================
453// TTS Plugin
454// ============================================================================
455
456/// TTS Plugin implementing AgentPlugin
457pub struct TTSPlugin {
458    metadata: PluginMetadata,
459    state: PluginState,
460    config: TTSPluginConfig,
461    engine: Option<Arc<dyn TTSEngine>>,
462    synthesis_count: u64,
463    total_chars_synthesized: u64,
464    last_audio_data: Vec<u8>,
465    /// Model cache manager
466    model_cache: Option<cache::ModelCache>,
467    /// Hugging Face download client
468    hf_client: Option<model_downloader::HFHubClient>,
469}
470
471impl TTSPlugin {
472    /// Create a new TTS plugin
473    pub fn new(plugin_id: &str) -> Self {
474        let metadata = PluginMetadata::new(plugin_id, "TTS Plugin", PluginType::Tool)
475            .with_description("Text-to-Speech plugin with support for multiple TTS engines")
476            .with_capability("text_to_speech")
477            .with_capability("audio_synthesis")
478            .with_capability("streaming_synthesis")
479            .with_capability("model_download");
480
481        Self {
482            metadata,
483            state: PluginState::Unloaded,
484            config: TTSPluginConfig::default(),
485            engine: None,
486            synthesis_count: 0,
487            total_chars_synthesized: 0,
488            last_audio_data: Vec::new(),
489            model_cache: None,
490            hf_client: None,
491        }
492    }
493
494    /// Create a TTS plugin with engine and voice (便捷方法)
495    ///
496    /// # 参数
497    ///
498    /// - `plugin_id`: 插件ID
499    /// - `engine`: TTS 引擎
500    /// - `default_voice`: 默认音色,如 `"zf_090"`,默认为 `"default"`
501    ///
502    /// # 示例
503    ///
504    /// ```rust,ignore
505    /// use mofa_plugins::TTSPlugin;
506    ///
507    /// // 使用默认音色
508    /// let plugin = TTSPlugin::with_engine("tts", kokoro_engine, None);
509    ///
510    /// // 指定音色
511    /// let plugin = TTSPlugin::with_engine("tts", kokoro_engine, Some("zf_090"));
512    /// ```
513    pub fn with_engine<E: TTSEngine + 'static>(
514        plugin_id: &str,
515        engine: E,
516        default_voice: Option<&str>,
517    ) -> Self {
518        let mut plugin = Self::new(plugin_id);
519        plugin.engine = Some(Arc::new(engine));
520        if let Some(voice) = default_voice {
521            plugin.config.default_voice = voice.to_string();
522        }
523        plugin
524    }
525
526    /// Set the plugin configuration
527    pub fn with_config(mut self, config: TTSPluginConfig) -> Self {
528        self.config = config;
529        self
530    }
531
532    /// Set a custom TTS engine (链式调用版本,用于已有实例)
533    pub fn with_engine_ref<E: TTSEngine + 'static>(mut self, engine: E) -> Self {
534        self.engine = Some(Arc::new(engine));
535        self
536    }
537
538    /// Set the default voice
539    pub fn with_voice(mut self, voice: &str) -> Self {
540        self.config.default_voice = voice.to_string();
541        self
542    }
543
544    /// Get the TTS engine
545    pub fn engine(&self) -> Option<&Arc<dyn TTSEngine>> {
546        self.engine.as_ref()
547    }
548
549    /// Synthesize text to audio and play it
550    pub async fn synthesize_and_play(&mut self, text: &str) -> PluginResult<()> {
551        let engine = self
552            .engine
553            .as_ref()
554            .ok_or_else(|| anyhow::anyhow!("TTS engine not initialized"))?;
555
556        self.synthesis_count += 1;
557        self.total_chars_synthesized += text.len() as u64;
558
559        let voice = self.config.default_voice.as_str();
560        let audio = engine.synthesize(text, voice).await?;
561        play_audio_async(audio).await?;
562        Ok(())
563    }
564
565    /// Synthesize text to audio data (no playback)
566    pub async fn synthesize_to_audio(&mut self, text: &str) -> PluginResult<Vec<u8>> {
567        let engine = self
568            .engine
569            .as_ref()
570            .ok_or_else(|| anyhow::anyhow!("TTS engine not initialized"))?;
571
572        self.synthesis_count += 1;
573        self.total_chars_synthesized += text.len() as u64;
574
575        let voice = self.config.default_voice.as_str();
576        engine.synthesize(text, voice).await
577    }
578
579    /// Stream synthesize text with callback
580    pub async fn synthesize_streaming(
581        &mut self,
582        text: &str,
583        callback: Box<dyn Fn(Vec<u8>) + Send + Sync>,
584    ) -> PluginResult<()> {
585        let engine = self
586            .engine
587            .as_ref()
588            .ok_or_else(|| anyhow::anyhow!("TTS engine not initialized"))?;
589
590        self.synthesis_count += 1;
591        self.total_chars_synthesized += text.len() as u64;
592
593        let voice = self.config.default_voice.as_str();
594        engine.synthesize_stream(text, voice, callback).await
595    }
596
597    /// Stream synthesize text with f32 callback (native format)
598    ///
599    /// This method is more efficient for KokoroTTS as it uses the native f32 format
600    /// without the overhead of f32 -> i16 -> u8 conversion.
601    ///
602    /// # Arguments
603    /// - `text`: The text to synthesize
604    /// - `callback`: Function to call with each audio chunk (Vec<f32>)
605    ///
606    /// # Example
607    /// ```rust,ignore
608    /// plugin.synthesize_streaming_f32("Hello", Box::new(|audio_f32| {
609    ///     // audio_f32 is Vec<f32> with values in [-1.0, 1.0]
610    ///     sink.append(SamplesBuffer::new(1, 24000, audio_f32));
611    /// })).await?;
612    /// ```
613    pub async fn synthesize_streaming_f32(
614        &mut self,
615        text: &str,
616        callback: Box<dyn Fn(Vec<f32>) + Send + Sync>,
617    ) -> PluginResult<()> {
618        let engine = self
619            .engine
620            .as_ref()
621            .ok_or_else(|| anyhow::anyhow!("TTS engine not initialized"))?;
622
623        self.synthesis_count += 1;
624        self.total_chars_synthesized += text.len() as u64;
625
626        let voice = self.config.default_voice.as_str();
627
628        #[cfg(feature = "kokoro")]
629        {
630            // Try to downcast to KokoroTTS for native f32 streaming
631            if let Some(kokoro) = engine.as_any().downcast_ref::<kokoro_wrapper::KokoroTTS>() {
632                return kokoro
633                    .synthesize_stream_f32(text, voice, callback)
634                    .await
635                    .map_err(|e| anyhow::anyhow!("F32 streaming failed: {}", e));
636            }
637        }
638
639        // Fallback: synthesize to bytes first, then convert to f32
640        // We can't directly pass callback to synthesize_stream because of ownership issues
641        let audio_bytes = engine.synthesize(text, voice).await?;
642
643        // Convert bytes to f32 and call callback
644        let audio_i16: Vec<i16> = audio_bytes
645            .chunks_exact(2)
646            .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
647            .collect();
648        let audio_f32: Vec<f32> = audio_i16
649            .iter()
650            .map(|&s| s as f32 / i16::MAX as f32)
651            .collect();
652        callback(audio_f32);
653
654        Ok(())
655    }
656
657    /// List available voices
658    pub async fn list_voices(&self) -> PluginResult<Vec<VoiceInfo>> {
659        let engine = self
660            .engine
661            .as_ref()
662            .ok_or_else(|| anyhow::anyhow!("TTS engine not initialized"))?;
663        engine.list_voices().await
664    }
665
666    /// Set the default voice
667    pub fn set_default_voice(&mut self, voice: &str) {
668        self.config.default_voice = voice.to_string();
669    }
670
671    /// Get synthesis statistics
672    pub fn stats(&self) -> HashMap<String, serde_json::Value> {
673        let mut stats = HashMap::new();
674        stats.insert(
675            "synthesis_count".to_string(),
676            serde_json::json!(self.synthesis_count),
677        );
678        stats.insert(
679            "total_chars".to_string(),
680            serde_json::json!(self.total_chars_synthesized),
681        );
682        stats.insert(
683            "default_voice".to_string(),
684            serde_json::json!(self.config.default_voice),
685        );
686        stats.insert(
687            "model_version".to_string(),
688            serde_json::json!(self.config.model_version),
689        );
690        stats.insert(
691            "model_url".to_string(),
692            serde_json::json!(self.config.model_url),
693        );
694        stats.insert(
695            "auto_download".to_string(),
696            serde_json::json!(self.config.auto_download),
697        );
698        // Add cache info if available
699        if let Some(cache) = &self.model_cache
700            && let Some(cache_dir) = cache.cache_dir().to_str()
701        {
702            stats.insert("cache_dir".to_string(), serde_json::json!(cache_dir));
703        }
704        if let Some(engine) = &self.engine {
705            stats.insert("engine".to_string(), serde_json::json!(engine.name()));
706        }
707        stats
708    }
709
710    /// Get the last synthesized audio data
711    pub fn last_audio(&self) -> Vec<u8> {
712        self.last_audio_data.clone()
713    }
714}
715
716#[async_trait::async_trait]
717impl AgentPlugin for TTSPlugin {
718    fn metadata(&self) -> &PluginMetadata {
719        &self.metadata
720    }
721
722    fn state(&self) -> PluginState {
723        self.state.clone()
724    }
725
726    async fn load(&mut self, ctx: &PluginContext) -> PluginResult<()> {
727        self.state = PluginState::Loading;
728        info!("Loading TTS plugin: {}", self.metadata.id);
729
730        // Load configuration from context
731        if let Some(default_voice) = ctx.config.get_string("default_voice") {
732            self.config.default_voice = default_voice;
733        }
734        if let Some(model_version) = ctx.config.get_string("model_version") {
735            self.config.model_version = model_version;
736        }
737        if let Some(model_url) = ctx.config.get_string("model_url") {
738            self.config.model_url = model_url;
739        }
740        if let Some(cache_dir) = ctx.config.get_string("cache_dir") {
741            self.config.cache_dir = Some(cache_dir);
742        }
743        if let Some(auto_download) = ctx.config.get_bool("auto_download") {
744            self.config.auto_download = auto_download;
745        }
746        if let Some(checksum) = ctx.config.get_string("model_checksum") {
747            self.config.model_checksum = Some(checksum);
748        }
749
750        // Initialize model cache
751        let cache_dir = self.config.cache_dir.as_ref().map(std::path::PathBuf::from);
752        self.model_cache = Some(
753            cache::ModelCache::new(cache_dir)
754                .map_err(|e| anyhow::anyhow!("Failed to initialize model cache: {}", e))?,
755        );
756
757        // Initialize Hugging Face client
758        self.hf_client = Some(model_downloader::HFHubClient::new());
759
760        self.state = PluginState::Loaded;
761        Ok(())
762    }
763
764    async fn init_plugin(&mut self) -> PluginResult<()> {
765        info!("Initializing TTS plugin: {}", self.metadata.id);
766
767        // Check if we need to download the model
768        let cache = self
769            .model_cache
770            .as_ref()
771            .ok_or_else(|| anyhow::anyhow!("Model cache not initialized"))?;
772
773        let hf_client = self
774            .hf_client
775            .as_ref()
776            .ok_or_else(|| anyhow::anyhow!("HF client not initialized"))?;
777
778        // Check if model exists in cache
779        let model_exists = cache.exists(&self.config.model_url).await;
780
781        if !model_exists && self.config.auto_download {
782            info!(
783                "Model not found in cache, initiating download: {}",
784                self.config.model_url
785            );
786
787            let download_config = model_downloader::DownloadConfig {
788                model_id: self.config.model_url.clone(),
789                filename: "kokoro-v0_19.onnx".to_string(),
790                checksum: self.config.model_checksum.clone(),
791                timeout_secs: self.config.download_timeout,
792                max_retries: 3,
793                progress_callback: Some(Box::new(|downloaded, total| {
794                    let progress = if total > 0 {
795                        format!("{:.1}%", (downloaded as f64 / total as f64) * 100.0)
796                    } else {
797                        format!("{} bytes", downloaded)
798                    };
799                    info!("Download progress: {}", progress);
800                })),
801            };
802
803            hf_client
804                .download_model(download_config, cache)
805                .await
806                .map_err(|e| anyhow::anyhow!("Failed to download model: {}", e))?;
807        } else if !model_exists {
808            // Auto-download disabled, fail with clear error
809            return Err(anyhow::anyhow!(
810                "Model '{}' not found in cache and auto_download is disabled. \
811                Please enable auto_download or manually download the model to: {:?}",
812                self.config.model_url,
813                cache.model_path(&self.config.model_url)
814            ));
815        }
816
817        // Validate cached model
818        if let Some(expected_checksum) = &self.config.model_checksum
819            && !cache
820                .validate(&self.config.model_url, Some(expected_checksum))
821                .await
822                .map_err(|e| anyhow::anyhow!("Failed to validate model: {}", e))?
823        {
824            return Err(anyhow::anyhow!(
825                "Model validation failed. The cached model may be corrupted. \
826                    Try deleting the cache and re-downloading."
827            ));
828        }
829
830        // Initialize engine with downloaded/cached model
831        if self.engine.is_none() {
832            #[cfg(feature = "kokoro")]
833            {
834                // Try to initialize Kokoro engine
835                let model_path = cache.model_path(&self.config.model_url);
836                let voice_path_buf = model_path
837                    .parent()
838                    .map(|p| p.join("voices-v1.1-zh.bin"))
839                    .unwrap_or_else(|| std::path::PathBuf::from("voices-v1.1-zh.bin"));
840                let voice_path = voice_path_buf.to_str().unwrap_or("voices-v1.1-zh.bin");
841
842                let model_path_str = model_path
843                    .to_str()
844                    .ok_or_else(|| anyhow::anyhow!("Invalid model path"))?;
845
846                info!(
847                    "Initializing Kokoro TTS engine with model: {}, voices: {}",
848                    model_path_str, voice_path
849                );
850
851                match kokoro_wrapper::KokoroTTS::new(model_path_str, voice_path).await {
852                    Ok(engine) => {
853                        self.engine = Some(Arc::new(engine));
854                        info!("Kokoro TTS engine initialized successfully");
855                    }
856                    Err(e) => {
857                        warn!(
858                            "Failed to initialize Kokoro engine: {}, falling back to mock engine",
859                            e
860                        );
861                        let engine = MockTTSEngine::new(self.config.clone());
862                        self.engine = Some(Arc::new(engine));
863                    }
864                }
865            }
866
867            #[cfg(not(feature = "kokoro"))]
868            {
869                // Fallback to mock engine when feature not enabled
870                warn!("Kokoro feature not enabled, using mock engine");
871                let engine = MockTTSEngine::new(self.config.clone());
872                self.engine = Some(Arc::new(engine));
873            }
874        }
875
876        Ok(())
877    }
878
879    async fn start(&mut self) -> PluginResult<()> {
880        self.state = PluginState::Running;
881        info!("TTS plugin {} started", self.metadata.id);
882        Ok(())
883    }
884
885    async fn stop(&mut self) -> PluginResult<()> {
886        self.state = PluginState::Paused;
887        info!("TTS plugin {} stopped", self.metadata.id);
888        Ok(())
889    }
890
891    async fn unload(&mut self) -> PluginResult<()> {
892        self.engine = None;
893        self.state = PluginState::Unloaded;
894        info!("TTS plugin {} unloaded", self.metadata.id);
895        Ok(())
896    }
897
898    async fn execute(&mut self, input: String) -> PluginResult<String> {
899        // Parse input as JSON command
900        let command: TTSCommand = serde_json::from_str(&input)
901            .map_err(|e| anyhow::anyhow!("Invalid TTS command format: {}", e))?;
902
903        match command.action.as_str() {
904            "speak" | "synthesize" => {
905                let text = command
906                    .text
907                    .ok_or_else(|| anyhow::anyhow!("Missing 'text' parameter"))?;
908
909                if command.play.unwrap_or(true) {
910                    self.synthesize_and_play(&text).await?;
911                    Ok(format!("Played: {}", text))
912                } else {
913                    let audio = self.synthesize_to_audio(&text).await?;
914                    // Store the audio data for later retrieval
915                    self.last_audio_data = audio.clone();
916                    Ok(format!("Generated {} bytes of audio", audio.len()))
917                }
918            }
919            "list_voices" => {
920                let voices = self.list_voices().await?;
921                let json = serde_json::to_string(&voices)?;
922                Ok(json)
923            }
924            "set_voice" => {
925                let voice = command
926                    .voice
927                    .ok_or_else(|| anyhow::anyhow!("Missing 'voice' parameter"))?;
928                self.set_default_voice(&voice);
929                Ok(format!("Default voice set to: {}", voice))
930            }
931            "stats" => {
932                let stats = self.stats();
933                let json = serde_json::to_string(&stats)?;
934                Ok(json)
935            }
936            _ => Err(anyhow::anyhow!("Unknown action: {}", command.action)),
937        }
938    }
939
940    fn stats(&self) -> HashMap<String, serde_json::Value> {
941        self.stats()
942    }
943
944    fn as_any(&self) -> &dyn std::any::Any {
945        self
946    }
947
948    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
949        self
950    }
951
952    fn into_any(self: Box<Self>) -> Box<dyn std::any::Any> {
953        self
954    }
955}
956
957// ============================================================================
958// TTS Command Types
959// ============================================================================
960
961/// TTS command structure for execute()
962#[derive(Debug, Clone, Serialize, Deserialize)]
963pub struct TTSCommand {
964    /// Action to perform: "speak", "list_voices", "set_voice", "stats"
965    pub action: String,
966    /// Text to synthesize (for "speak" action)
967    pub text: Option<String>,
968    /// Voice to use (for "speak" or "set_voice" action)
969    pub voice: Option<String>,
970    /// Whether to play audio (for "speak" action)
971    pub play: Option<bool>,
972}
973
974// ============================================================================
975// Tool Executor for TTS
976// ============================================================================
977
978use crate::ToolDefinition;
979use crate::ToolExecutor;
980
981/// Text-to-Speech tool executor
982pub struct TextToSpeechTool {
983    plugin_id: String,
984    definition: ToolDefinition,
985}
986
987impl TextToSpeechTool {
988    pub fn new(plugin_id: &str) -> Self {
989        Self {
990            plugin_id: plugin_id.to_string(),
991            definition: ToolDefinition {
992                name: "text_to_speech".to_string(),
993                description: "Convert text to speech using the TTS plugin engine".to_string(),
994                parameters: serde_json::json!({
995                    "type": "object",
996                    "properties": {
997                        "text": {
998                            "type": "string",
999                            "description": "The text to synthesize to speech"
1000                        },
1001                        "voice": {
1002                            "type": "string",
1003                            "description": "Voice ID to use (optional, uses default if not specified)",
1004                            "default": "default"
1005                        },
1006                        "play": {
1007                            "type": "boolean",
1008                            "description": "Whether to play the audio (true) or return audio data (false)",
1009                            "default": true
1010                        }
1011                    },
1012                    "required": ["text"]
1013                }),
1014                requires_confirmation: false,
1015            },
1016        }
1017    }
1018}
1019
1020#[async_trait::async_trait]
1021impl ToolExecutor for TextToSpeechTool {
1022    fn definition(&self) -> &ToolDefinition {
1023        &self.definition
1024    }
1025
1026    async fn execute(&self, arguments: serde_json::Value) -> PluginResult<serde_json::Value> {
1027        let text = arguments
1028            .get("text")
1029            .and_then(|v| v.as_str())
1030            .ok_or_else(|| anyhow::anyhow!("Missing 'text' parameter"))?;
1031
1032        let voice = arguments.get("voice").and_then(|v| v.as_str());
1033        let play = arguments
1034            .get("play")
1035            .and_then(|v| v.as_bool())
1036            .unwrap_or(true);
1037
1038        let command = if let Some(voice) = voice {
1039            TTSCommand {
1040                action: "speak".to_string(),
1041                text: Some(text.to_string()),
1042                voice: Some(voice.to_string()),
1043                play: Some(play),
1044            }
1045        } else {
1046            TTSCommand {
1047                action: "speak".to_string(),
1048                text: Some(text.to_string()),
1049                voice: None,
1050                play: Some(play),
1051            }
1052        };
1053
1054        let input = serde_json::to_string(&command)?;
1055        Ok(serde_json::json!({
1056            "success": true,
1057            "message": format!("TTS command prepared for: {}", text),
1058            "command": input
1059        }))
1060    }
1061
1062    fn validate(&self, arguments: &serde_json::Value) -> PluginResult<()> {
1063        if !arguments.is_object() {
1064            return Err(anyhow::anyhow!("Arguments must be an object"));
1065        }
1066        if arguments.get("text").and_then(|v| v.as_str()).is_none() {
1067            return Err(anyhow::anyhow!("Missing required parameter: text"));
1068        }
1069        Ok(())
1070    }
1071}
1072
1073// ============================================================================
1074// Tests
1075// ============================================================================
1076
1077#[cfg(test)]
1078mod tests {
1079    use super::*;
1080
1081    #[tokio::test]
1082    async fn test_mock_tts_engine_creation() {
1083        let config = TTSPluginConfig::default();
1084        let engine = MockTTSEngine::new(config);
1085        assert_eq!(engine.name(), "MockTTS");
1086    }
1087
1088    #[tokio::test]
1089    async fn test_mock_tts_list_voices() {
1090        let config = TTSPluginConfig::default();
1091        let engine = MockTTSEngine::new(config);
1092        let voices = engine.list_voices().await.unwrap();
1093
1094        assert!(!voices.is_empty());
1095        assert!(voices.iter().any(|v| v.id == "default"));
1096    }
1097
1098    #[tokio::test]
1099    async fn test_tts_plugin_creation() {
1100        let plugin = TTSPlugin::new("test_tts");
1101        assert_eq!(plugin.plugin_id(), "test_tts");
1102        assert_eq!(plugin.state(), PluginState::Unloaded);
1103    }
1104
1105    #[tokio::test]
1106    async fn test_tts_plugin_lifecycle() {
1107        let mut plugin = TTSPlugin::new("test_tts");
1108        let ctx = PluginContext::new("test_agent");
1109
1110        plugin.load(&ctx).await.unwrap();
1111        assert_eq!(plugin.state(), PluginState::Loaded);
1112
1113        // Use mock engine to avoid model download in tests
1114        let mock_engine = MockTTSEngine::new(TTSPluginConfig::default());
1115        plugin.engine = Some(Arc::new(mock_engine));
1116
1117        plugin.start().await.unwrap();
1118        assert_eq!(plugin.state(), PluginState::Running);
1119
1120        plugin.stop().await.unwrap();
1121        assert_eq!(plugin.state(), PluginState::Paused);
1122
1123        plugin.unload().await.unwrap();
1124        assert_eq!(plugin.state(), PluginState::Unloaded);
1125    }
1126
1127    #[tokio::test]
1128    async fn test_tts_execute_speak_command() {
1129        let mut plugin = TTSPlugin::new("test_tts");
1130        let ctx = PluginContext::new("test_agent");
1131
1132        plugin.load(&ctx).await.unwrap();
1133
1134        // Use mock engine to avoid model download in tests
1135        let mock_engine = MockTTSEngine::new(TTSPluginConfig::default());
1136        plugin.engine = Some(Arc::new(mock_engine));
1137
1138        plugin.start().await.unwrap();
1139
1140        let command = TTSCommand {
1141            action: "speak".to_string(),
1142            text: Some("Hello, world!".to_string()),
1143            voice: None,
1144            play: Some(false), // Don't actually play audio in tests
1145        };
1146
1147        let input = serde_json::to_string(&command).unwrap();
1148        let result = plugin.execute(input).await;
1149
1150        // Should succeed with placeholder implementation
1151        assert!(result.is_ok());
1152    }
1153
1154    #[tokio::test]
1155    async fn test_tts_execute_list_voices() {
1156        let mut plugin = TTSPlugin::new("test_tts");
1157        let ctx = PluginContext::new("test_agent");
1158
1159        plugin.load(&ctx).await.unwrap();
1160
1161        // Use mock engine to avoid model download in tests
1162        let mock_engine = MockTTSEngine::new(TTSPluginConfig::default());
1163        plugin.engine = Some(Arc::new(mock_engine));
1164
1165        plugin.start().await.unwrap();
1166
1167        let command = TTSCommand {
1168            action: "list_voices".to_string(),
1169            text: None,
1170            voice: None,
1171            play: None,
1172        };
1173
1174        let input = serde_json::to_string(&command).unwrap();
1175        let result = plugin.execute(input).await.unwrap();
1176
1177        let voices: Vec<VoiceInfo> = serde_json::from_str(&result).unwrap();
1178        assert!(!voices.is_empty());
1179    }
1180
1181    #[tokio::test]
1182    async fn test_tts_stats() {
1183        let plugin = TTSPlugin::new("test_tts");
1184        let stats = plugin.stats();
1185
1186        assert_eq!(stats.get("synthesis_count"), Some(&serde_json::json!(0)));
1187        assert_eq!(stats.get("total_chars"), Some(&serde_json::json!(0)));
1188    }
1189
1190    #[test]
1191    fn test_voice_info_creation() {
1192        let voice = VoiceInfo::new("test", "Test Voice", "en-US");
1193        assert_eq!(voice.id, "test");
1194        assert_eq!(voice.name, "Test Voice");
1195        assert_eq!(voice.language, "en-US");
1196    }
1197
1198    #[test]
1199    fn test_tts_command_serialization() {
1200        let command = TTSCommand {
1201            action: "speak".to_string(),
1202            text: Some("Hello".to_string()),
1203            voice: Some("default".to_string()),
1204            play: Some(true),
1205        };
1206
1207        let json = serde_json::to_string(&command).unwrap();
1208        let parsed: TTSCommand = serde_json::from_str(&json).unwrap();
1209
1210        assert_eq!(parsed.action, "speak");
1211        assert_eq!(parsed.text, Some("Hello".to_string()));
1212        assert_eq!(parsed.voice, Some("default".to_string()));
1213        assert_eq!(parsed.play, Some(true));
1214    }
1215}