1use base64::{Engine as _, engine::general_purpose};
2use client::models::BiliMessage;
3use client::scheduler::{EventHandler, EventContext};
4use log::{debug, error, info, warn};
5use rodio::{Decoder, OutputStream, Sink};
6use serde::{Deserialize, Serialize};
7use std::io::Cursor;
8use std::process::Command;
9use std::sync::mpsc::{self, Sender};
10use std::thread;
11use std::thread::JoinHandle;
12
13#[derive(Serialize, Debug)]
14struct TtsRequest {
15 text: String,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 voice: Option<String>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 backend: Option<String>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 quality: Option<String>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 format: Option<String>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 sample_rate: Option<u32>,
26}
27
28#[derive(Deserialize, Debug)]
29struct TtsResponse {
30 audio_data: String,
31 metadata: TtsMetadata,
32 #[allow(dead_code)]
33 cached: bool,
34}
35
36#[derive(Deserialize, Debug)]
37struct TtsMetadata {
38 #[allow(dead_code)]
39 backend: String,
40 #[allow(dead_code)]
41 #[serde(skip_serializing_if = "Option::is_none")]
42 voice: Option<String>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 duration: Option<f64>,
45 #[allow(dead_code)]
46 #[serde(skip_serializing_if = "Option::is_none")]
47 sample_rate: Option<u32>,
48 #[allow(dead_code)]
49 #[serde(skip_serializing_if = "Option::is_none")]
50 format: Option<String>,
51 #[allow(dead_code)]
52 #[serde(skip_serializing_if = "Option::is_none")]
53 size_bytes: Option<u64>,
54}
55
56#[derive(Debug, Clone)]
58pub enum TtsMode {
59 RestApi {
61 server_url: String,
63 voice: Option<String>,
65 backend: Option<String>,
67 quality: Option<String>,
69 format: Option<String>,
71 sample_rate: Option<u32>,
73 volume: Option<f32>,
75 },
76 Command {
78 tts_command: String,
80 tts_args: Vec<String>,
82 },
83}
84
85pub struct TtsHandler {
94 #[allow(dead_code)]
96 mode: TtsMode,
97 sender: Sender<String>,
99 _worker_handle: JoinHandle<()>,
101}
102
103impl TtsHandler {
104 pub fn new(mode: TtsMode) -> Self {
106 let (sender, receiver) = mpsc::channel::<String>();
107
108 let mode_clone = mode.clone();
110
111 let worker_handle = thread::spawn(move || match &mode_clone {
113 TtsMode::RestApi { .. } => {
114 Self::run_rest_api_worker(receiver, mode_clone);
115 }
116 TtsMode::Command { .. } => {
117 Self::run_command_worker(receiver, mode_clone);
118 }
119 });
120
121 TtsHandler {
122 mode,
123 sender,
124 _worker_handle: worker_handle,
125 }
126 }
127
128 pub fn new_rest_api_default(server_url: String) -> Self {
130 Self::new_rest_api_default_with_volume(server_url, 1.0)
131 }
132
133 pub fn new_rest_api_default_with_volume(server_url: String, volume: f32) -> Self {
135 let mode = TtsMode::RestApi {
136 server_url,
137 voice: Some("zh-CN-XiaoxiaoNeural".to_string()),
138 backend: Some("edge".to_string()),
139 quality: Some("medium".to_string()),
140 format: Some("wav".to_string()),
141 sample_rate: Some(22050),
142 volume: Some(volume),
143 };
144 Self::new(mode)
145 }
146
147 pub fn new_rest_api(
149 server_url: String,
150 voice: Option<String>,
151 backend: Option<String>,
152 quality: Option<String>,
153 format: Option<String>,
154 sample_rate: Option<u32>,
155 ) -> Self {
156 Self::new_rest_api_with_volume(
157 server_url,
158 voice,
159 backend,
160 quality,
161 format,
162 sample_rate,
163 None,
164 )
165 }
166
167 pub fn new_rest_api_with_volume(
169 server_url: String,
170 voice: Option<String>,
171 backend: Option<String>,
172 quality: Option<String>,
173 format: Option<String>,
174 sample_rate: Option<u32>,
175 volume: Option<f32>,
176 ) -> Self {
177 let mode = TtsMode::RestApi {
178 server_url,
179 voice,
180 backend,
181 quality,
182 format,
183 sample_rate,
184 volume,
185 };
186 Self::new(mode)
187 }
188
189 pub fn new_command(tts_command: String, tts_args: Vec<String>) -> Self {
191 let mode = TtsMode::Command {
192 tts_command,
193 tts_args,
194 };
195 Self::new(mode)
196 }
197
198 fn run_rest_api_worker(receiver: std::sync::mpsc::Receiver<String>, mode: TtsMode) {
200 if let TtsMode::RestApi {
201 server_url,
202 voice,
203 backend,
204 quality,
205 format,
206 sample_rate,
207 volume,
208 } = mode
209 {
210 let rt = tokio::runtime::Runtime::new().unwrap();
212 let client = reqwest::Client::new();
213
214 let (_stream, stream_handle) = OutputStream::try_default().unwrap();
216
217 while let Ok(message) = receiver.recv() {
218 let request = TtsRequest {
219 text: message,
220 voice: voice.clone(),
221 backend: backend.clone(),
222 quality: quality.clone(),
223 format: format.clone(),
224 sample_rate,
225 };
226
227 rt.block_on(async {
229 match client
230 .post(&format!("{}/tts", server_url))
231 .header("Content-Type", "application/json")
232 .json(&request)
233 .send()
234 .await
235 {
236 Ok(response) => {
237 if response.status().is_success() {
238 match response.json::<TtsResponse>().await {
239 Ok(tts_response) => {
240 info!("TTS generated successfully");
241
242 match general_purpose::STANDARD
244 .decode(&tts_response.audio_data)
245 {
246 Ok(audio_bytes) => {
247 let cursor = Cursor::new(audio_bytes);
249
250 match Decoder::new(cursor) {
252 Ok(source) => {
253 let sink =
255 Sink::try_new(&stream_handle).unwrap();
256
257 let audio_volume = volume.unwrap_or(1.0);
259 sink.set_volume(audio_volume);
260
261 sink.append(source);
263
264 sink.sleep_until_end();
266
267 debug!("Audio playback completed");
268 }
269 Err(e) => error!(
270 "Failed to decode audio format: {}",
271 e
272 ),
273 }
274 }
275 Err(e) => {
276 error!("Failed to decode base64 audio data: {}", e)
277 }
278 }
279 }
280 Err(e) => error!("Failed to parse TTS response: {}", e),
281 }
282 } else {
283 warn!("TTS request failed with status: {}", response.status());
284 }
285 }
286 Err(e) => error!("Failed to send TTS request: {}", e),
287 }
288 });
289 }
290 }
291 }
292
293 fn run_command_worker(receiver: std::sync::mpsc::Receiver<String>, mode: TtsMode) {
295 if let TtsMode::Command {
296 tts_command,
297 tts_args,
298 } = mode
299 {
300 while let Ok(message) = receiver.recv() {
301 let mut command = Command::new(&tts_command);
302 for arg in &tts_args {
303 command.arg(arg);
304 }
305
306 match command.arg(&message).status() {
308 Ok(status) => {
309 if status.success() {
310 debug!("TTS command completed successfully");
311 } else {
312 warn!("TTS command failed with status: {}", status);
313 }
314 }
315 Err(e) => error!("Failed to execute TTS command: {}", e),
316 }
317 }
318 }
319 }
320
321 #[deprecated(note = "Use new_rest_api_default instead")]
323 pub fn new_default(server_url: String) -> Self {
324 Self::new_rest_api_default(server_url)
325 }
326}
327
328impl EventHandler for TtsHandler {
329 fn handle(&self, msg: &BiliMessage, _context: &EventContext) {
330 if let BiliMessage::Danmu { user, text } = msg {
331 let message = format!("{}说:{}", user, text);
332 let _ = self.sender.send(message);
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use client::models::BiliMessage;
342 use client::scheduler::EventHandler;
343
344 #[test]
345 fn test_tts_handler_danmu() {
346 let handler = TtsHandler::new_rest_api_default("http://localhost:8000".to_string());
348
349 let text = "您好,欢迎来到直播间。".to_string();
350 let msg = BiliMessage::Danmu {
351 user: "测试用户".to_string(),
352 text: text.clone(),
353 };
354 let context = EventContext { cookies: None, room_id: 12345 };
355 handler.handle(&msg, &context);
356 }
357
358 #[test]
359 fn test_tts_handler_custom_config() {
360 let handler = TtsHandler::new_rest_api(
361 "http://localhost:8000".to_string(),
362 Some("zh-CN-XiaoxiaoNeural".to_string()),
363 Some("edge".to_string()),
364 Some("high".to_string()),
365 Some("wav".to_string()),
366 Some(44100),
367 );
368
369 let msg = BiliMessage::Danmu {
370 user: "test_user".to_string(),
371 text: "hello world".to_string(),
372 };
373 let context = EventContext { cookies: None, room_id: 12345 };
374 handler.handle(&msg, &context);
375 }
376
377 #[test]
378 fn test_tts_handler_sequential_processing() {
379 use std::time::Duration;
380
381 let handler = TtsHandler::new_rest_api_default("http://localhost:8000".to_string());
383
384 let messages = vec![
386 ("User1", "First message"),
387 ("User2", "Second message"),
388 ("User3", "Third message"),
389 ];
390
391 for (user, text) in messages {
392 let msg = BiliMessage::Danmu {
393 user: user.to_string(),
394 text: text.to_string(),
395 };
396 let context = EventContext { cookies: None, room_id: 12345 };
397 handler.handle(&msg, &context);
398 }
399
400 std::thread::sleep(Duration::from_millis(100));
402
403 }
406
407 #[test]
408 fn test_tts_handler_command_mode() {
409 let handler = TtsHandler::new_command("echo".to_string(), vec![]);
411
412 let msg = BiliMessage::Danmu {
413 user: "test_user".to_string(),
414 text: "test message".to_string(),
415 };
416 let context = EventContext { cookies: None, room_id: 12345 };
417 handler.handle(&msg, &context);
418
419 std::thread::sleep(std::time::Duration::from_millis(50));
421 }
422
423 #[cfg(target_os = "macos")]
424 #[test]
425 fn test_tts_handler_macos_voice() {
426 let handler = TtsHandler::new_command(
427 "say".to_string(),
428 vec!["-v".to_string(), "Mei-Jia".to_string()],
429 );
430
431 let msg = BiliMessage::Danmu {
432 user: "用户".to_string(),
433 text: "你好".to_string(),
434 };
435 let context = EventContext { cookies: None, room_id: 12345 };
436 handler.handle(&msg, &context);
437 }
438
439 #[cfg(target_os = "linux")]
440 #[test]
441 fn test_tts_handler_linux_voice() {
442 let handler = TtsHandler::new_command(
443 "espeak-ng".to_string(),
444 vec!["-v".to_string(), "cmn".to_string()],
445 );
446
447 let msg = BiliMessage::Danmu {
448 user: "用户".to_string(),
449 text: "你好".to_string(),
450 };
451 let context = EventContext { cookies: None, room_id: 12345 };
452 handler.handle(&msg, &context);
453 }
454
455 #[test]
456 fn test_tts_request_serialization() {
457 let request = TtsRequest {
458 text: "Hello world".to_string(),
459 voice: Some("zh-CN-XiaoxiaoNeural".to_string()),
460 backend: Some("edge".to_string()),
461 quality: Some("medium".to_string()),
462 format: Some("wav".to_string()),
463 sample_rate: Some(22050),
464 };
465
466 let json = serde_json::to_string(&request).unwrap();
467 assert!(json.contains("Hello world"));
468 assert!(json.contains("zh-CN-XiaoxiaoNeural"));
469 assert!(json.contains("edge"));
470 }
471
472 #[test]
473 fn test_tts_handler_with_volume() {
474 let handler =
476 TtsHandler::new_rest_api_default_with_volume("http://localhost:8000".to_string(), 0.5);
477
478 let msg = BiliMessage::Danmu {
479 user: "test_user".to_string(),
480 text: "volume test".to_string(),
481 };
482 let context = EventContext { cookies: None, room_id: 12345 };
483 handler.handle(&msg, &context);
484
485 let handler_custom = TtsHandler::new_rest_api_with_volume(
487 "http://localhost:8000".to_string(),
488 Some("zh-CN-XiaoxiaoNeural".to_string()),
489 Some("edge".to_string()),
490 Some("high".to_string()),
491 Some("wav".to_string()),
492 Some(44100),
493 Some(0.8),
494 );
495 let context = EventContext { cookies: None, room_id: 12345 };
496 handler_custom.handle(&msg, &context);
497 }
498}