1use crate::config::YamlConfig;
2use crate::constants::voice as vc;
3use crate::{error, info};
4use colored::Colorize;
5use std::io::Write;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9
10fn expected_min_size_mb(model_size: &str) -> u64 {
14 match model_size {
15 "tiny" => 70,
16 "base" => 130,
17 "small" => 450,
18 "medium" => 1400,
19 "large" => 2900,
20 _ => 50,
21 }
22}
23
24struct RawModeGuard;
26
27impl RawModeGuard {
28 fn enter() -> Result<Self, String> {
29 crossterm::terminal::enable_raw_mode().map_err(|e| format!("启用 raw mode 失败: {}", e))?;
30 Ok(Self)
31 }
32}
33
34impl Drop for RawModeGuard {
35 fn drop(&mut self) {
36 let _ = crossterm::terminal::disable_raw_mode();
37 }
38}
39
40fn start_recording_stream(
43 recording: Arc<AtomicBool>,
44 raw_samples: Arc<std::sync::Mutex<Vec<f32>>>,
45) -> Result<(cpal::Stream, u32, u16), String> {
46 use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
47
48 let host = cpal::default_host();
49 let device = host
50 .default_input_device()
51 .ok_or_else(|| "未找到麦克风设备,请检查音频输入设备".to_string())?;
52
53 let supported_config = device
54 .default_input_config()
55 .map_err(|e| format!("获取设备默认输入配置失败: {}", e))?;
56
57 let sample_rate = supported_config.sample_rate();
58 let channels = supported_config.channels();
59
60 let config = cpal::StreamConfig {
61 channels,
62 sample_rate,
63 buffer_size: cpal::BufferSize::Default,
64 };
65
66 let stream = device
67 .build_input_stream(
68 &config,
69 move |data: &[f32], _: &cpal::InputCallbackInfo| {
70 if !recording.load(Ordering::Relaxed) {
71 return;
72 }
73 let mut buf = raw_samples.lock().unwrap();
74 buf.extend_from_slice(data);
75 },
76 move |err| {
77 eprintln!("录音流错误: {}", err);
78 },
79 None,
80 )
81 .map_err(|e| format!("创建录音流失败: {}", e))?;
82
83 stream.play().map_err(|e| format!("启动录音失败: {}", e))?;
84
85 Ok((stream, sample_rate, channels))
86}
87
88fn process_raw_audio(raw_data: &[f32], sample_rate: u32, channels: u16) -> Vec<f32> {
90 let mono: Vec<f32> = if channels > 1 {
92 raw_data
93 .chunks(channels as usize)
94 .map(|frame| frame.iter().sum::<f32>() / channels as f32)
95 .collect()
96 } else {
97 raw_data.to_vec()
98 };
99
100 let target_rate = vc::SAMPLE_RATE;
102 if sample_rate != target_rate {
103 resample(&mono, sample_rate, target_rate)
104 } else {
105 mono
106 }
107}
108
109fn transcribe_from_samples(model_path: &PathBuf, samples: &[f32]) -> Result<String, String> {
111 use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
112
113 if samples.is_empty() {
114 return Err("音频数据为空".to_string());
115 }
116
117 let _stderr_guard = suppress_stderr();
118
119 let ctx = WhisperContext::new_with_params(
120 model_path.to_str().unwrap_or(""),
121 WhisperContextParameters::default(),
122 )
123 .map_err(|e| format!("加载 Whisper 模型失败: {}", e))?;
124
125 let mut state = ctx
126 .create_state()
127 .map_err(|e| format!("创建 Whisper 状态失败: {}", e))?;
128
129 let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
130 params.set_language(Some("zh"));
131 params.set_print_progress(false);
132 params.set_print_special(false);
133 params.set_print_realtime(false);
134 params.set_single_segment(false);
135 params.set_n_threads(4);
136
137 state
138 .full(params, samples)
139 .map_err(|e| format!("Whisper 转写失败: {}", e))?;
140
141 let num_segments = state.full_n_segments();
142 let mut result = String::new();
143 for i in 0..num_segments {
144 if let Some(segment) = state.get_segment(i) {
145 if let Ok(text) = segment.to_str_lossy() {
146 result.push_str(&text);
147 }
148 }
149 }
150
151 Ok(result)
152}
153
154fn detect_best_model() -> Option<&'static str> {
156 for &size in vc::MODEL_PRIORITY {
157 let path = get_model_path(size);
158 if path.exists() {
159 let file_size_mb = std::fs::metadata(&path)
160 .map(|m| m.len() / 1024 / 1024)
161 .unwrap_or(0);
162 if file_size_mb >= expected_min_size_mb(size) {
163 return Some(size);
164 }
165 }
166 }
167 None
168}
169
170fn wait_for_stop_key(recording: &AtomicBool) -> bool {
173 use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers};
174
175 loop {
176 if !recording.load(Ordering::Relaxed) {
177 return false;
178 }
179 if event::poll(std::time::Duration::from_millis(100)).unwrap_or(false) {
180 if let Ok(Event::Key(KeyEvent {
181 code, modifiers, ..
182 })) = event::read()
183 {
184 match code {
185 KeyCode::Enter => return true,
186 KeyCode::Char('c') if modifiers.contains(KeyModifiers::CONTROL) => {
187 return true;
188 }
189 _ => {}
190 }
191 }
192 }
193 }
194}
195
196fn wait_for_ctrl_v_stop(recording: &AtomicBool) -> bool {
198 use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers};
199
200 loop {
201 if !recording.load(Ordering::Relaxed) {
202 return false;
203 }
204 if event::poll(std::time::Duration::from_millis(100)).unwrap_or(false) {
205 if let Ok(Event::Key(KeyEvent {
206 code, modifiers, ..
207 })) = event::read()
208 {
209 match code {
210 KeyCode::Char('v') if modifiers.contains(KeyModifiers::CONTROL) => {
211 return true;
212 }
213 KeyCode::Char('c') if modifiers.contains(KeyModifiers::CONTROL) => {
214 return true;
215 }
216 _ => {}
217 }
218 }
219 }
220 }
221}
222
223fn record_and_transcribe_streaming(model_path: &PathBuf) -> Result<String, String> {
228 let recording = Arc::new(AtomicBool::new(true));
229 let raw_samples: Arc<std::sync::Mutex<Vec<f32>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
230
231 let (stream, sample_rate, channels) =
232 start_recording_stream(recording.clone(), raw_samples.clone())?;
233
234 let streaming_recording = recording.clone();
236 let streaming_samples = raw_samples.clone();
237 let streaming_model = model_path.clone();
238 let streaming_sr = sample_rate;
239 let streaming_ch = channels;
240 let displayed_len = Arc::new(std::sync::Mutex::new(0usize));
241 let displayed_len_clone = displayed_len.clone();
242
243 let transcribe_handle = std::thread::spawn(move || {
244 let interval = std::time::Duration::from_secs(vc::STREAMING_INTERVAL_SECS);
245 let min_samples = (vc::MIN_AUDIO_SECS as usize) * (streaming_sr as usize);
246
247 while streaming_recording.load(Ordering::Relaxed) {
248 std::thread::sleep(interval);
249
250 if !streaming_recording.load(Ordering::Relaxed) {
251 break;
252 }
253
254 let raw_data = streaming_samples.lock().unwrap().clone();
255 if raw_data.len() < min_samples * (streaming_ch as usize) {
257 continue;
258 }
259
260 let processed = process_raw_audio(&raw_data, streaming_sr, streaming_ch);
261 if processed.is_empty() {
262 continue;
263 }
264
265 if let Ok(text) = transcribe_from_samples(&streaming_model, &processed) {
266 let text = text.trim().to_string();
267 let mut prev_len = displayed_len_clone.lock().unwrap();
268 if text.len() > *prev_len {
269 let new_part = &text[*prev_len..];
270 print!("{}", new_part);
271 let _ = std::io::stdout().flush();
272 *prev_len = text.len();
273 }
274 }
275 }
276 });
277
278 let _raw_guard = RawModeGuard::enter()?;
280 wait_for_stop_key(&recording);
281 drop(_raw_guard);
282
283 recording.store(false, Ordering::Relaxed);
285 std::thread::sleep(std::time::Duration::from_millis(100));
286 drop(stream);
287
288 let _ = transcribe_handle.join();
289
290 let raw_data = raw_samples.lock().unwrap();
292 if raw_data.is_empty() {
293 return Err("未录到任何音频数据".to_string());
294 }
295
296 let processed = process_raw_audio(&raw_data, sample_rate, channels);
297 let duration_secs = processed.len() as f64 / vc::SAMPLE_RATE as f64;
298
299 println!();
301 info!(
302 "📊 录音时长: {:.1}s (设备: {}Hz {}ch → 16kHz 单声道)",
303 duration_secs, sample_rate, channels
304 );
305
306 if processed.is_empty() || duration_secs < vc::MIN_AUDIO_SECS as f64 {
307 return Err("录音时间过短".to_string());
308 }
309
310 let prev_len = *displayed_len.lock().unwrap();
312 let final_text = transcribe_from_samples(model_path, &processed)?;
313 let final_text = final_text.trim().to_string();
314
315 if final_text.len() != prev_len {
317 }
319
320 Ok(final_text)
321}
322
323pub fn handle_voice(action: &str, copy: bool, model_size: Option<&str>, _config: &YamlConfig) {
332 let model = if let Some(m) = model_size {
334 m.to_string()
335 } else if let Some(best) = detect_best_model() {
336 info!("🔍 自动检测到模型: {}", best.cyan().bold());
337 best.to_string()
338 } else {
339 vc::DEFAULT_MODEL.to_string()
340 };
341
342 if !vc::MODEL_SIZES.contains(&model.as_str()) {
344 error!(
345 "不支持的模型大小: {},可选: {}",
346 model,
347 vc::MODEL_SIZES.join(", ")
348 );
349 return;
350 }
351
352 if action == vc::ACTION_DOWNLOAD {
353 download_model(&model);
354 return;
355 }
356
357 if !action.is_empty() {
358 error!("未知操作: {},可用操作: download", action);
359 crate::usage!("voice [-c] [-m <model>] 或 voice download [-m <model>]");
360 return;
361 }
362
363 let model_path = get_model_path(&model);
365 if !model_path.exists() {
366 error!("模型文件不存在: {}", model_path.display());
367 info!(
368 "💡 请先下载模型: {} 或 {}",
369 format!("j voice download -m {}", model).cyan(),
370 format!("j voice download").cyan()
371 );
372 info!(
373 "💡 也可以手动下载模型放到: {}",
374 model_path.display().to_string().cyan()
375 );
376 return;
377 }
378
379 let file_size_mb = std::fs::metadata(&model_path)
381 .map(|m| m.len() / 1024 / 1024)
382 .unwrap_or(0);
383 let min_size = expected_min_size_mb(&model);
384 if file_size_mb < min_size {
385 error!(
386 "模型文件不完整: {} ({} MB,期望至少 {} MB)",
387 model_path.display(),
388 file_size_mb,
389 min_size
390 );
391 info!(
392 "💡 请删除后重新下载: {} && {}",
393 format!("rm {}", model_path.display()).cyan(),
394 format!("j voice download -m {}", model).cyan()
395 );
396 return;
397 }
398
399 info!(
400 "🎙️ 按 {} 开始录音,录音中按 {} 或 {} 结束",
401 "回车".green().bold(),
402 "回车".red().bold(),
403 "Ctrl+C".red().bold()
404 );
405
406 {
408 let _raw_guard = match RawModeGuard::enter() {
409 Ok(g) => g,
410 Err(e) => {
411 error!("[handle_voice] {}", e);
412 return;
413 }
414 };
415 use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers};
416 loop {
417 if event::poll(std::time::Duration::from_millis(100)).unwrap_or(false) {
418 if let Ok(Event::Key(KeyEvent {
419 code, modifiers, ..
420 })) = event::read()
421 {
422 match code {
423 KeyCode::Enter => break,
424 KeyCode::Char('c') if modifiers.contains(KeyModifiers::CONTROL) => {
425 return;
426 }
427 _ => {}
428 }
429 }
430 }
431 }
432 }
433
434 println!();
435 info!(
436 "🔴 录音中... 按 {} 或 {} 结束录音",
437 "回车".red().bold(),
438 "Ctrl+C".red().bold()
439 );
440
441 match record_and_transcribe_streaming(&model_path) {
442 Ok(text) => {
443 let text = text.trim().to_string();
444 if text.is_empty() {
445 info!("⚠️ 未识别到语音内容");
446 } else {
447 println!();
448 info!("📝 转写结果:");
449 println!("{}", text);
450
451 if copy {
452 copy_to_clipboard(&text);
453 }
454 }
455 }
456 Err(e) => {
457 error!("[handle_voice] {}", e);
458 }
459 }
460}
461
462pub fn do_voice_record_for_interactive() -> String {
467 let model = if let Some(best) = detect_best_model() {
468 info!("🔍 自动检测到模型: {}", best.cyan().bold());
469 best.to_string()
470 } else {
471 vc::DEFAULT_MODEL.to_string()
472 };
473
474 let model_path = get_model_path(&model);
475 if !model_path.exists() {
476 error!("模型文件不存在: {}", model_path.display());
477 info!("💡 请先下载模型: {}", format!("j voice download").cyan());
478 return String::new();
479 }
480
481 let file_size_mb = std::fs::metadata(&model_path)
482 .map(|m| m.len() / 1024 / 1024)
483 .unwrap_or(0);
484 if file_size_mb < expected_min_size_mb(&model) {
485 error!("模型文件不完整,请重新下载");
486 return String::new();
487 }
488
489 info!(
490 "🔴 录音中... 按 {} 或 {} 结束",
491 "Ctrl+V".red().bold(),
492 "Ctrl+C".red().bold()
493 );
494
495 let recording = Arc::new(AtomicBool::new(true));
496 let raw_samples: Arc<std::sync::Mutex<Vec<f32>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
497
498 let (stream, sample_rate, channels) =
499 match start_recording_stream(recording.clone(), raw_samples.clone()) {
500 Ok(r) => r,
501 Err(e) => {
502 error!("[voice] {}", e);
503 return String::new();
504 }
505 };
506
507 let streaming_recording = recording.clone();
509 let streaming_samples = raw_samples.clone();
510 let streaming_model = model_path.clone();
511 let streaming_sr = sample_rate;
512 let streaming_ch = channels;
513 let displayed_len = Arc::new(std::sync::Mutex::new(0usize));
514 let displayed_len_clone = displayed_len.clone();
515
516 let transcribe_handle = std::thread::spawn(move || {
517 let interval = std::time::Duration::from_secs(vc::STREAMING_INTERVAL_SECS);
518 let min_samples = (vc::MIN_AUDIO_SECS as usize) * (streaming_sr as usize);
519
520 while streaming_recording.load(Ordering::Relaxed) {
521 std::thread::sleep(interval);
522 if !streaming_recording.load(Ordering::Relaxed) {
523 break;
524 }
525
526 let raw_data = streaming_samples.lock().unwrap().clone();
527 if raw_data.len() < min_samples * (streaming_ch as usize) {
528 continue;
529 }
530
531 let processed = process_raw_audio(&raw_data, streaming_sr, streaming_ch);
532 if processed.is_empty() {
533 continue;
534 }
535
536 if let Ok(text) = transcribe_from_samples(&streaming_model, &processed) {
537 let text = text.trim().to_string();
538 let mut prev_len = displayed_len_clone.lock().unwrap();
539 if text.len() > *prev_len {
540 let new_part = &text[*prev_len..];
541 print!("{}", new_part);
543 let _ = std::io::stdout().flush();
544 *prev_len = text.len();
545 }
546 }
547 }
548 });
549
550 let raw_result = RawModeGuard::enter();
552 if let Err(e) = &raw_result {
553 error!("[voice] {}", e);
554 recording.store(false, Ordering::Relaxed);
555 let _ = transcribe_handle.join();
556 drop(stream);
557 return String::new();
558 }
559 let _raw_guard = raw_result.unwrap();
560 wait_for_ctrl_v_stop(&recording);
561 drop(_raw_guard);
562
563 recording.store(false, Ordering::Relaxed);
565 std::thread::sleep(std::time::Duration::from_millis(100));
566 drop(stream);
567
568 let _ = transcribe_handle.join();
569
570 let raw_data = raw_samples.lock().unwrap();
572 if raw_data.is_empty() {
573 println!();
574 info!("⚠️ 未录到音频数据");
575 return String::new();
576 }
577
578 let processed = process_raw_audio(&raw_data, sample_rate, channels);
579 let duration_secs = processed.len() as f64 / vc::SAMPLE_RATE as f64;
580
581 println!();
582 info!("📊 录音时长: {:.1}s", duration_secs);
583
584 if processed.is_empty() || duration_secs < vc::MIN_AUDIO_SECS as f64 {
585 info!("⚠️ 录音时间过短");
586 return String::new();
587 }
588
589 info!("✅ 转写中...");
590 match transcribe_from_samples(&model_path, &processed) {
591 Ok(text) => {
592 let text = text.trim().to_string();
593 if text.is_empty() {
594 info!("⚠️ 未识别到语音内容");
595 } else {
596 info!("📝 {}", &text);
597 }
598 text
599 }
600 Err(e) => {
601 error!("[voice] 转写失败: {}", e);
602 String::new()
603 }
604 }
605}
606
607fn get_model_path(model_size: &str) -> PathBuf {
611 let model_file = vc::MODEL_FILE_TEMPLATE.replace("{}", model_size);
612 let voice_dir = YamlConfig::data_dir()
613 .join(vc::VOICE_DIR)
614 .join(vc::MODEL_DIR);
615 let _ = std::fs::create_dir_all(&voice_dir);
616 voice_dir.join(model_file)
617}
618
619fn resample(samples: &[f32], source_rate: u32, target_rate: u32) -> Vec<f32> {
621 if samples.is_empty() || source_rate == target_rate {
622 return samples.to_vec();
623 }
624
625 let ratio = source_rate as f64 / target_rate as f64;
626 let output_len = (samples.len() as f64 / ratio) as usize;
627 let mut output = Vec::with_capacity(output_len);
628
629 for i in 0..output_len {
630 let src_idx = i as f64 * ratio;
631 let idx_floor = src_idx as usize;
632 let frac = (src_idx - idx_floor as f64) as f32;
633
634 let sample = if idx_floor + 1 < samples.len() {
635 samples[idx_floor] * (1.0 - frac) + samples[idx_floor + 1] * frac
636 } else if idx_floor < samples.len() {
637 samples[idx_floor]
638 } else {
639 0.0
640 };
641
642 output.push(sample);
643 }
644
645 output
646}
647
648fn download_model(model_size: &str) {
650 let model_path = get_model_path(model_size);
651
652 if model_path.exists() {
653 let file_size = std::fs::metadata(&model_path).map(|m| m.len()).unwrap_or(0);
654 let file_size_mb = file_size / 1024 / 1024;
655 let min_size = expected_min_size_mb(model_size);
656
657 if file_size_mb < min_size {
658 info!(
659 "⚠️ 模型文件不完整: {} ({} MB,期望至少 {} MB)",
660 model_path.display(),
661 file_size_mb,
662 min_size
663 );
664 info!("🔄 删除不完整文件,重新下载...");
665 let _ = std::fs::remove_file(&model_path);
666 } else {
667 info!(
668 "✅ 模型已存在: {} ({:.1} MB)",
669 model_path.display(),
670 file_size as f64 / 1024.0 / 1024.0
671 );
672 info!("💡 如需重新下载,请先删除模型文件");
673 return;
674 }
675 }
676
677 let url = vc::MODEL_URL_TEMPLATE.replace("{}", model_size);
678
679 info!("📥 下载 Whisper {} 模型...", model_size.cyan().bold());
680 info!(" URL: {}", url.dimmed());
681 info!(" 保存到: {}", model_path.display().to_string().dimmed());
682 println!();
683
684 let status = std::process::Command::new("curl")
685 .args([
686 "-L",
687 "--progress-bar",
688 "-o",
689 model_path.to_str().unwrap_or(""),
690 &url,
691 ])
692 .stdin(std::process::Stdio::inherit())
693 .stdout(std::process::Stdio::inherit())
694 .stderr(std::process::Stdio::inherit())
695 .status();
696
697 match status {
698 Ok(s) if s.success() => {
699 let file_size = std::fs::metadata(&model_path).map(|m| m.len()).unwrap_or(0);
700 let file_size_mb = file_size / 1024 / 1024;
701 let min_size = expected_min_size_mb(model_size);
702 if file_size_mb < min_size {
703 error!(
704 "下载的文件不完整 ({} MB,期望至少 {} MB)",
705 file_size_mb, min_size
706 );
707 error!(
708 "请检查网络连接,或手动下载模型文件到: {}",
709 model_path.display()
710 );
711 error!(
712 "手动下载链接: {}",
713 vc::MODEL_URL_TEMPLATE.replace("{}", model_size)
714 );
715 let _ = std::fs::remove_file(&model_path);
716 return;
717 }
718 println!();
719 info!(
720 "✅ 模型下载完成: {} ({:.1} MB)",
721 model_size.green().bold(),
722 file_size as f64 / 1024.0 / 1024.0
723 );
724 }
725 Ok(_) => {
726 error!("模型下载失败,请检查网络连接");
727 let _ = std::fs::remove_file(&model_path);
728 }
729 Err(e) => {
730 error!(
731 "[download_model] 执行 curl 失败: {},请确保系统安装了 curl",
732 e
733 );
734 }
735 }
736}
737
738fn copy_to_clipboard(text: &str) {
740 let mut child = match std::process::Command::new("pbcopy")
741 .stdin(std::process::Stdio::piped())
742 .spawn()
743 {
744 Ok(c) => c,
745 Err(e) => {
746 error!("[copy_to_clipboard] 无法调用 pbcopy: {}", e);
747 return;
748 }
749 };
750
751 if let Some(mut stdin) = child.stdin.take() {
752 let _ = stdin.write_all(text.as_bytes());
753 }
754
755 match child.wait() {
756 Ok(_) => info!("📋 已复制到剪贴板"),
757 Err(e) => error!("[copy_to_clipboard] pbcopy 执行失败: {}", e),
758 }
759}
760
761fn suppress_stderr() -> StderrGuard {
763 use std::os::unix::io::AsRawFd;
764
765 let stderr_fd = std::io::stderr().as_raw_fd();
766 let saved_fd = unsafe { libc::dup(stderr_fd) };
767 let devnull = std::fs::OpenOptions::new()
768 .write(true)
769 .open("/dev/null")
770 .ok();
771 if let Some(ref devnull_file) = devnull {
772 unsafe {
773 libc::dup2(devnull_file.as_raw_fd(), stderr_fd);
774 }
775 }
776
777 StderrGuard {
778 saved_fd,
779 stderr_fd,
780 _devnull: devnull,
781 }
782}
783
784struct StderrGuard {
785 saved_fd: i32,
786 stderr_fd: i32,
787 _devnull: Option<std::fs::File>,
788}
789
790impl Drop for StderrGuard {
791 fn drop(&mut self) {
792 if self.saved_fd >= 0 {
793 unsafe {
794 libc::dup2(self.saved_fd, self.stderr_fd);
795 libc::close(self.saved_fd);
796 }
797 }
798 }
799}