blivedm_plugins/
tts.rs

1use base64::{Engine as _, engine::general_purpose};
2use client::models::BiliMessage;
3use client::scheduler::{EventHandler, EventContext};
4use log::{debug, error, info, warn};
5use rodio::{Decoder, OutputStream, Sink};
6use serde::{Deserialize, Serialize};
7use std::io::Cursor;
8use std::process::Command;
9use std::sync::mpsc::{self, Sender};
10use std::thread;
11use std::thread::JoinHandle;
12
13#[derive(Serialize, Debug)]
14struct TtsRequest {
15    text: String,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    voice: Option<String>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    backend: Option<String>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    quality: Option<String>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    format: Option<String>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    sample_rate: Option<u32>,
26}
27
28#[derive(Deserialize, Debug)]
29struct TtsResponse {
30    audio_data: String,
31    metadata: TtsMetadata,
32    #[allow(dead_code)]
33    cached: bool,
34}
35
36#[derive(Deserialize, Debug)]
37struct TtsMetadata {
38    #[allow(dead_code)]
39    backend: String,
40    #[allow(dead_code)]
41    #[serde(skip_serializing_if = "Option::is_none")]
42    voice: Option<String>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    duration: Option<f64>,
45    #[allow(dead_code)]
46    #[serde(skip_serializing_if = "Option::is_none")]
47    sample_rate: Option<u32>,
48    #[allow(dead_code)]
49    #[serde(skip_serializing_if = "Option::is_none")]
50    format: Option<String>,
51    #[allow(dead_code)]
52    #[serde(skip_serializing_if = "Option::is_none")]
53    size_bytes: Option<u64>,
54}
55
56/// TTS backend configuration
57#[derive(Debug, Clone)]
58pub enum TtsMode {
59    /// Use REST API for TTS with advanced neural voices
60    RestApi {
61        /// The base URL of the TTS server (e.g., "http://localhost:8000")
62        server_url: String,
63        /// Voice ID to use for TTS (e.g., "zh-CN-XiaoxiaoNeural")
64        voice: Option<String>,
65        /// TTS backend to use (e.g., "edge", "xtts", "piper")
66        backend: Option<String>,
67        /// Audio quality ("low", "medium", "high")
68        quality: Option<String>,
69        /// Audio format (e.g., "wav")
70        format: Option<String>,
71        /// Sample rate for audio
72        sample_rate: Option<u32>,
73        /// Audio volume (0.0 to 1.0, default is 1.0)
74        volume: Option<f32>,
75    },
76    /// Use local command-line TTS programs
77    Command {
78        /// The TTS command to use (e.g., "say" on macOS, "espeak-ng" on Linux)
79        tts_command: String,
80        /// Optional extra arguments for the TTS command (e.g., ["-v", "SinJi"])
81        tts_args: Vec<String>,
82    },
83}
84
85/// A plugin that sends Danmaku text to a TTS service and plays the audio sequentially.
86///
87/// This handler supports two modes:
88/// 1. REST API mode: Sends text to a TTS REST API server, receives base64-encoded audio data,
89///    decodes it and plays through the system's audio output
90/// 2. Command mode: Uses local command-line TTS programs (like `say` on macOS or `espeak-ng` on Linux)
91///
92/// Messages are processed sequentially to avoid overlapping audio.
93pub struct TtsHandler {
94    /// TTS configuration (either REST API or command-based)
95    #[allow(dead_code)]
96    mode: TtsMode,
97    /// Channel sender for queuing TTS messages
98    sender: Sender<String>,
99    /// Background thread handle for TTS processing
100    _worker_handle: JoinHandle<()>,
101}
102
103impl TtsHandler {
104    /// Create a new TTS handler with the specified mode
105    pub fn new(mode: TtsMode) -> Self {
106        let (sender, receiver) = mpsc::channel::<String>();
107
108        // Clone the mode for the worker thread
109        let mode_clone = mode.clone();
110
111        // Spawn worker thread to process TTS queue sequentially
112        let worker_handle = thread::spawn(move || match &mode_clone {
113            TtsMode::RestApi { .. } => {
114                Self::run_rest_api_worker(receiver, mode_clone);
115            }
116            TtsMode::Command { .. } => {
117                Self::run_command_worker(receiver, mode_clone);
118            }
119        });
120
121        TtsHandler {
122            mode,
123            sender,
124            _worker_handle: worker_handle,
125        }
126    }
127
128    /// Create a new TTS handler with REST API using default Chinese voice settings
129    pub fn new_rest_api_default(server_url: String) -> Self {
130        Self::new_rest_api_default_with_volume(server_url, 1.0)
131    }
132
133    /// Create a new TTS handler with REST API using default Chinese voice settings and custom volume
134    pub fn new_rest_api_default_with_volume(server_url: String, volume: f32) -> Self {
135        let mode = TtsMode::RestApi {
136            server_url,
137            voice: Some("zh-CN-XiaoxiaoNeural".to_string()),
138            backend: Some("edge".to_string()),
139            quality: Some("medium".to_string()),
140            format: Some("wav".to_string()),
141            sample_rate: Some(22050),
142            volume: Some(volume),
143        };
144        Self::new(mode)
145    }
146
147    /// Create a new TTS handler with REST API and custom configuration
148    pub fn new_rest_api(
149        server_url: String,
150        voice: Option<String>,
151        backend: Option<String>,
152        quality: Option<String>,
153        format: Option<String>,
154        sample_rate: Option<u32>,
155    ) -> Self {
156        Self::new_rest_api_with_volume(
157            server_url,
158            voice,
159            backend,
160            quality,
161            format,
162            sample_rate,
163            None,
164        )
165    }
166
167    /// Create a new TTS handler with REST API and custom configuration including volume
168    pub fn new_rest_api_with_volume(
169        server_url: String,
170        voice: Option<String>,
171        backend: Option<String>,
172        quality: Option<String>,
173        format: Option<String>,
174        sample_rate: Option<u32>,
175        volume: Option<f32>,
176    ) -> Self {
177        let mode = TtsMode::RestApi {
178            server_url,
179            voice,
180            backend,
181            quality,
182            format,
183            sample_rate,
184            volume,
185        };
186        Self::new(mode)
187    }
188
189    /// Create a new TTS handler with command-line TTS
190    pub fn new_command(tts_command: String, tts_args: Vec<String>) -> Self {
191        let mode = TtsMode::Command {
192            tts_command,
193            tts_args,
194        };
195        Self::new(mode)
196    }
197
198    /// Worker thread for REST API TTS processing
199    fn run_rest_api_worker(receiver: std::sync::mpsc::Receiver<String>, mode: TtsMode) {
200        if let TtsMode::RestApi {
201            server_url,
202            voice,
203            backend,
204            quality,
205            format,
206            sample_rate,
207            volume,
208        } = mode
209        {
210            // Create a tokio runtime for HTTP requests
211            let rt = tokio::runtime::Runtime::new().unwrap();
212            let client = reqwest::Client::new();
213
214            // Initialize audio output stream (this will be reused for all audio playback)
215            let (_stream, stream_handle) = OutputStream::try_default().unwrap();
216
217            while let Ok(message) = receiver.recv() {
218                let request = TtsRequest {
219                    text: message,
220                    voice: voice.clone(),
221                    backend: backend.clone(),
222                    quality: quality.clone(),
223                    format: format.clone(),
224                    sample_rate,
225                };
226
227                // Make HTTP request to TTS service
228                rt.block_on(async {
229                    match client
230                        .post(&format!("{}/tts", server_url))
231                        .header("Content-Type", "application/json")
232                        .json(&request)
233                        .send()
234                        .await
235                    {
236                        Ok(response) => {
237                            if response.status().is_success() {
238                                match response.json::<TtsResponse>().await {
239                                    Ok(tts_response) => {
240                                        info!("TTS generated successfully");
241
242                                        // Decode base64 audio data and play it
243                                        match general_purpose::STANDARD
244                                            .decode(&tts_response.audio_data)
245                                        {
246                                            Ok(audio_bytes) => {
247                                                // Create a cursor from the audio bytes
248                                                let cursor = Cursor::new(audio_bytes);
249
250                                                // Create a decoder for the audio format
251                                                match Decoder::new(cursor) {
252                                                    Ok(source) => {
253                                                        // Create a new sink for this audio
254                                                        let sink =
255                                                            Sink::try_new(&stream_handle).unwrap();
256
257                                                        // Set volume if specified (default to 1.0 if not set)
258                                                        let audio_volume = volume.unwrap_or(1.0);
259                                                        sink.set_volume(audio_volume);
260
261                                                        // Append the audio source to the sink
262                                                        sink.append(source);
263
264                                                        // Wait for the audio to finish playing
265                                                        sink.sleep_until_end();
266
267                                                        debug!("Audio playback completed");
268                                                    }
269                                                    Err(e) => error!(
270                                                        "Failed to decode audio format: {}",
271                                                        e
272                                                    ),
273                                                }
274                                            }
275                                            Err(e) => {
276                                                error!("Failed to decode base64 audio data: {}", e)
277                                            }
278                                        }
279                                    }
280                                    Err(e) => error!("Failed to parse TTS response: {}", e),
281                                }
282                            } else {
283                                warn!("TTS request failed with status: {}", response.status());
284                            }
285                        }
286                        Err(e) => error!("Failed to send TTS request: {}", e),
287                    }
288                });
289            }
290        }
291    }
292
293    /// Worker thread for command-line TTS processing
294    fn run_command_worker(receiver: std::sync::mpsc::Receiver<String>, mode: TtsMode) {
295        if let TtsMode::Command {
296            tts_command,
297            tts_args,
298        } = mode
299        {
300            while let Ok(message) = receiver.recv() {
301                let mut command = Command::new(&tts_command);
302                for arg in &tts_args {
303                    command.arg(arg);
304                }
305
306                // Execute TTS command and wait for it to complete
307                match command.arg(&message).status() {
308                    Ok(status) => {
309                        if status.success() {
310                            debug!("TTS command completed successfully");
311                        } else {
312                            warn!("TTS command failed with status: {}", status);
313                        }
314                    }
315                    Err(e) => error!("Failed to execute TTS command: {}", e),
316                }
317            }
318        }
319    }
320
321    /// Legacy method - kept for backward compatibility
322    #[deprecated(note = "Use new_rest_api_default instead")]
323    pub fn new_default(server_url: String) -> Self {
324        Self::new_rest_api_default(server_url)
325    }
326}
327
328impl EventHandler for TtsHandler {
329    fn handle(&self, msg: &BiliMessage, _context: &EventContext) {
330        if let BiliMessage::Danmu { user, text } = msg {
331            let message = format!("{}说:{}", user, text);
332            // Send message to the queue for sequential processing
333            let _ = self.sender.send(message);
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use client::models::BiliMessage;
342    use client::scheduler::EventHandler;
343
344    #[test]
345    fn test_tts_handler_danmu() {
346        // Test with a mock server URL (won't actually make requests in this test)
347        let handler = TtsHandler::new_rest_api_default("http://localhost:8000".to_string());
348
349        let text = "您好,欢迎来到直播间。".to_string();
350        let msg = BiliMessage::Danmu {
351            user: "测试用户".to_string(),
352            text: text.clone(),
353        };
354        let context = EventContext { cookies: None, room_id: 12345 };
355        handler.handle(&msg, &context);
356    }
357
358    #[test]
359    fn test_tts_handler_custom_config() {
360        let handler = TtsHandler::new_rest_api(
361            "http://localhost:8000".to_string(),
362            Some("zh-CN-XiaoxiaoNeural".to_string()),
363            Some("edge".to_string()),
364            Some("high".to_string()),
365            Some("wav".to_string()),
366            Some(44100),
367        );
368
369        let msg = BiliMessage::Danmu {
370            user: "test_user".to_string(),
371            text: "hello world".to_string(),
372        };
373        let context = EventContext { cookies: None, room_id: 12345 };
374        handler.handle(&msg, &context);
375    }
376
377    #[test]
378    fn test_tts_handler_sequential_processing() {
379        use std::time::Duration;
380
381        // Use default configuration for testing
382        let handler = TtsHandler::new_rest_api_default("http://localhost:8000".to_string());
383
384        // Send multiple messages quickly
385        let messages = vec![
386            ("User1", "First message"),
387            ("User2", "Second message"),
388            ("User3", "Third message"),
389        ];
390
391        for (user, text) in messages {
392            let msg = BiliMessage::Danmu {
393                user: user.to_string(),
394                text: text.to_string(),
395            };
396            let context = EventContext { cookies: None, room_id: 12345 };
397        handler.handle(&msg, &context);
398        }
399
400        // Give the worker thread some time to process the queue
401        std::thread::sleep(Duration::from_millis(100));
402
403        // The test passes if no panic occurs - the sequential processing
404        // is ensured by the worker thread design
405    }
406
407    #[test]
408    fn test_tts_handler_command_mode() {
409        // Test command-based TTS (cross-platform using echo)
410        let handler = TtsHandler::new_command("echo".to_string(), vec![]);
411
412        let msg = BiliMessage::Danmu {
413            user: "test_user".to_string(),
414            text: "test message".to_string(),
415        };
416        let context = EventContext { cookies: None, room_id: 12345 };
417        handler.handle(&msg, &context);
418
419        // Give the worker thread some time to process the message
420        std::thread::sleep(std::time::Duration::from_millis(50));
421    }
422
423    #[cfg(target_os = "macos")]
424    #[test]
425    fn test_tts_handler_macos_voice() {
426        let handler = TtsHandler::new_command(
427            "say".to_string(),
428            vec!["-v".to_string(), "Mei-Jia".to_string()],
429        );
430
431        let msg = BiliMessage::Danmu {
432            user: "用户".to_string(),
433            text: "你好".to_string(),
434        };
435        let context = EventContext { cookies: None, room_id: 12345 };
436        handler.handle(&msg, &context);
437    }
438
439    #[cfg(target_os = "linux")]
440    #[test]
441    fn test_tts_handler_linux_voice() {
442        let handler = TtsHandler::new_command(
443            "espeak-ng".to_string(),
444            vec!["-v".to_string(), "cmn".to_string()],
445        );
446
447        let msg = BiliMessage::Danmu {
448            user: "用户".to_string(),
449            text: "你好".to_string(),
450        };
451        let context = EventContext { cookies: None, room_id: 12345 };
452        handler.handle(&msg, &context);
453    }
454
455    #[test]
456    fn test_tts_request_serialization() {
457        let request = TtsRequest {
458            text: "Hello world".to_string(),
459            voice: Some("zh-CN-XiaoxiaoNeural".to_string()),
460            backend: Some("edge".to_string()),
461            quality: Some("medium".to_string()),
462            format: Some("wav".to_string()),
463            sample_rate: Some(22050),
464        };
465
466        let json = serde_json::to_string(&request).unwrap();
467        assert!(json.contains("Hello world"));
468        assert!(json.contains("zh-CN-XiaoxiaoNeural"));
469        assert!(json.contains("edge"));
470    }
471
472    #[test]
473    fn test_tts_handler_with_volume() {
474        // Test with custom volume setting
475        let handler =
476            TtsHandler::new_rest_api_default_with_volume("http://localhost:8000".to_string(), 0.5);
477
478        let msg = BiliMessage::Danmu {
479            user: "test_user".to_string(),
480            text: "volume test".to_string(),
481        };
482        let context = EventContext { cookies: None, room_id: 12345 };
483        handler.handle(&msg, &context);
484
485        // Test with custom configuration including volume
486        let handler_custom = TtsHandler::new_rest_api_with_volume(
487            "http://localhost:8000".to_string(),
488            Some("zh-CN-XiaoxiaoNeural".to_string()),
489            Some("edge".to_string()),
490            Some("high".to_string()),
491            Some("wav".to_string()),
492            Some(44100),
493            Some(0.8),
494        );
495        let context = EventContext { cookies: None, room_id: 12345 };
496        handler_custom.handle(&msg, &context);
497    }
498}