Skip to main content

j_cli/command/
voice.rs

1use crate::config::YamlConfig;
2use crate::constants::voice as vc;
3use crate::{error, info};
4use colored::Colorize;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, Ordering};
8
9/// 获取模型期望的最小文件大小(MB),用于完整性校验
10fn expected_min_size_mb(model_size: &str) -> u64 {
11    match model_size {
12        "tiny" => 70,
13        "base" => 130,
14        "small" => 450,
15        "medium" => 1400,
16        "large" => 2900,
17        _ => 50,
18    }
19}
20
21/// 语音转文字命令入口
22///
23/// - action 为空:录音 → Whisper 转写 → 输出文字
24/// - action 为 "download":下载指定模型
25/// - copy: 转写结果复制到剪贴板
26/// - model_size: 指定模型大小 (tiny/base/small/medium/large)
27pub fn handle_voice(action: &str, copy: bool, model_size: Option<&str>, _config: &YamlConfig) {
28    let model = model_size.unwrap_or(vc::DEFAULT_MODEL);
29
30    // 验证模型大小
31    if !vc::MODEL_SIZES.contains(&model) {
32        error!(
33            "不支持的模型大小: {},可选: {}",
34            model,
35            vc::MODEL_SIZES.join(", ")
36        );
37        return;
38    }
39
40    if action == vc::ACTION_DOWNLOAD {
41        // 下载模型
42        download_model(model);
43        return;
44    }
45
46    if !action.is_empty() {
47        error!("未知操作: {},可用操作: download", action);
48        crate::usage!("voice [-c] [-m <model>] 或 voice download [-m <model>]");
49        return;
50    }
51
52    // 检查模型是否存在
53    let model_path = get_model_path(model);
54    if !model_path.exists() {
55        error!("模型文件不存在: {}", model_path.display());
56        info!(
57            "💡 请先下载模型: {} 或 {}",
58            format!("j voice download -m {}", model).cyan(),
59            format!("j voice download").cyan()
60        );
61        info!(
62            "💡 也可以手动下载模型放到: {}",
63            model_path.display().to_string().cyan()
64        );
65        return;
66    }
67
68    // 检查模型文件完整性(文件大小是否达到期望最小值)
69    let file_size_mb = std::fs::metadata(&model_path)
70        .map(|m| m.len() / 1024 / 1024)
71        .unwrap_or(0);
72    let min_size = expected_min_size_mb(model);
73    if file_size_mb < min_size {
74        error!(
75            "模型文件不完整: {} ({} MB,期望至少 {} MB)",
76            model_path.display(),
77            file_size_mb,
78            min_size
79        );
80        info!(
81            "💡 请删除后重新下载: {} && {}",
82            format!("rm {}", model_path.display()).cyan(),
83            format!("j voice download -m {}", model).cyan()
84        );
85        return;
86    }
87
88    // 开始录音
89    info!("🎙️  按 {} 开始录音...", "回车".green().bold());
90    wait_for_enter();
91
92    info!("🔴 录音中... 按 {} 结束录音", "回车".red().bold());
93
94    let recording_path = get_recording_path();
95    match record_audio(&recording_path) {
96        Ok(()) => {
97            info!("✅ 录音完成,开始转写...");
98        }
99        Err(e) => {
100            error!("[handle_voice] 录音失败: {}", e);
101            return;
102        }
103    }
104
105    // Whisper 转写
106    match transcribe(&model_path, &recording_path) {
107        Ok(text) => {
108            let text = text.trim().to_string();
109            if text.is_empty() {
110                info!("⚠️  未识别到语音内容");
111            } else {
112                println!();
113                info!("📝 转写结果:");
114                println!("{}", text);
115
116                if copy {
117                    copy_to_clipboard(&text);
118                }
119            }
120        }
121        Err(e) => {
122            error!("[handle_voice] 转写失败: {}", e);
123        }
124    }
125
126    // 清理临时录音文件
127    let _ = std::fs::remove_file(&recording_path);
128}
129
130/// 获取模型文件路径: ~/.jdata/voice/model/ggml-<size>.bin
131fn get_model_path(model_size: &str) -> PathBuf {
132    let model_file = vc::MODEL_FILE_TEMPLATE.replace("{}", model_size);
133    let voice_dir = YamlConfig::data_dir()
134        .join(vc::VOICE_DIR)
135        .join(vc::MODEL_DIR);
136    let _ = std::fs::create_dir_all(&voice_dir);
137    voice_dir.join(model_file)
138}
139
140/// 获取临时录音文件路径: ~/.jdata/voice/recording.wav
141fn get_recording_path() -> PathBuf {
142    let voice_dir = YamlConfig::data_dir().join(vc::VOICE_DIR);
143    let _ = std::fs::create_dir_all(&voice_dir);
144    voice_dir.join(vc::RECORDING_FILE)
145}
146
147/// 等待用户按回车
148fn wait_for_enter() {
149    let mut input = String::new();
150    let _ = std::io::stdin().read_line(&mut input);
151}
152
153/// 录音:使用 cpal 捕获麦克风音频,保存为 WAV 文件
154/// 使用设备默认配置录音,然后重采样到 16kHz 单声道(Whisper 要求)
155/// 用户按回车结束录音
156fn record_audio(output_path: &PathBuf) -> Result<(), String> {
157    use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
158
159    let host = cpal::default_host();
160    let device = host
161        .default_input_device()
162        .ok_or_else(|| "未找到麦克风设备,请检查音频输入设备".to_string())?;
163
164    // 获取设备支持的默认输入配置
165    let supported_config = device
166        .default_input_config()
167        .map_err(|e| format!("获取设备默认输入配置失败: {}", e))?;
168
169    let device_sample_rate = supported_config.sample_rate();
170    let device_channels = supported_config.channels();
171
172    let config = cpal::StreamConfig {
173        channels: device_channels,
174        sample_rate: supported_config.sample_rate(),
175        buffer_size: cpal::BufferSize::Default,
176    };
177
178    // 用于在录音线程和主线程之间共享数据
179    let recording = Arc::new(AtomicBool::new(true));
180    let recording_clone = recording.clone();
181
182    // 收集原始 f32 音频采样数据(设备原始采样率和声道数)
183    let raw_samples: Arc<std::sync::Mutex<Vec<f32>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
184    let raw_samples_clone = raw_samples.clone();
185
186    let err_flag: Arc<std::sync::Mutex<Option<String>>> = Arc::new(std::sync::Mutex::new(None));
187    let err_flag_clone = err_flag.clone();
188
189    // 创建音频输入流
190    let stream = device
191        .build_input_stream(
192            &config,
193            move |data: &[f32], _: &cpal::InputCallbackInfo| {
194                if !recording_clone.load(Ordering::Relaxed) {
195                    return;
196                }
197                let mut buf = raw_samples_clone.lock().unwrap();
198                buf.extend_from_slice(data);
199            },
200            move |err| {
201                let mut flag = err_flag_clone.lock().unwrap();
202                *flag = Some(format!("录音流错误: {}", err));
203            },
204            None,
205        )
206        .map_err(|e| format!("创建录音流失败: {}", e))?;
207
208    stream.play().map_err(|e| format!("启动录音失败: {}", e))?;
209
210    // 等待用户按回车结束录音
211    wait_for_enter();
212
213    // 停止录音
214    recording.store(false, Ordering::Relaxed);
215    // 给录音流一点时间完成最后的数据收集
216    std::thread::sleep(std::time::Duration::from_millis(100));
217    drop(stream);
218
219    // 检查是否有错误
220    if let Some(err) = err_flag.lock().unwrap().take() {
221        return Err(err);
222    }
223
224    let raw_data = raw_samples.lock().unwrap();
225    if raw_data.is_empty() {
226        return Err("未录到任何音频数据".to_string());
227    }
228
229    // 步骤 1: 多声道转单声道(取各声道均值)
230    let mono_samples: Vec<f32> = if device_channels > 1 {
231        raw_data
232            .chunks(device_channels as usize)
233            .map(|frame| frame.iter().sum::<f32>() / device_channels as f32)
234            .collect()
235    } else {
236        raw_data.clone()
237    };
238
239    // 步骤 2: 重采样到 16kHz(如果设备采样率不是 16kHz)
240    let target_rate = vc::SAMPLE_RATE;
241    let resampled: Vec<f32> = if device_sample_rate != target_rate {
242        resample(&mono_samples, device_sample_rate, target_rate)
243    } else {
244        mono_samples
245    };
246
247    // 步骤 3: 转换为 i16 并写入 WAV
248    let i16_samples: Vec<i16> = resampled
249        .iter()
250        .map(|&s| {
251            let clamped = s.clamp(-1.0, 1.0);
252            (clamped * i16::MAX as f32) as i16
253        })
254        .collect();
255
256    if i16_samples.is_empty() {
257        return Err("重采样后无音频数据".to_string());
258    }
259
260    let duration_secs = i16_samples.len() as f64 / target_rate as f64;
261    info!(
262        "📊 录音时长: {:.1}s (设备: {}Hz {}ch → 重采样到 {}Hz 单声道)",
263        duration_secs, device_sample_rate, device_channels, target_rate
264    );
265
266    let spec = hound::WavSpec {
267        channels: vc::CHANNELS,
268        sample_rate: target_rate,
269        bits_per_sample: vc::BITS_PER_SAMPLE,
270        sample_format: hound::SampleFormat::Int,
271    };
272
273    let mut writer = hound::WavWriter::create(output_path, spec)
274        .map_err(|e| format!("创建 WAV 文件失败: {}", e))?;
275
276    for &sample in i16_samples.iter() {
277        writer
278            .write_sample(sample)
279            .map_err(|e| format!("写入音频数据失败: {}", e))?;
280    }
281
282    writer
283        .finalize()
284        .map_err(|e| format!("完成 WAV 文件写入失败: {}", e))?;
285
286    Ok(())
287}
288
289/// 线性插值重采样
290/// 将 source_rate 的音频数据重采样到 target_rate
291fn resample(samples: &[f32], source_rate: u32, target_rate: u32) -> Vec<f32> {
292    if samples.is_empty() || source_rate == target_rate {
293        return samples.to_vec();
294    }
295
296    let ratio = source_rate as f64 / target_rate as f64;
297    let output_len = (samples.len() as f64 / ratio) as usize;
298    let mut output = Vec::with_capacity(output_len);
299
300    for i in 0..output_len {
301        let src_idx = i as f64 * ratio;
302        let idx_floor = src_idx as usize;
303        let frac = (src_idx - idx_floor as f64) as f32;
304
305        let sample = if idx_floor + 1 < samples.len() {
306            samples[idx_floor] * (1.0 - frac) + samples[idx_floor + 1] * frac
307        } else if idx_floor < samples.len() {
308            samples[idx_floor]
309        } else {
310            0.0
311        };
312
313        output.push(sample);
314    }
315
316    output
317}
318
319/// 使用 Whisper 模型转写音频文件
320fn transcribe(model_path: &PathBuf, audio_path: &PathBuf) -> Result<String, String> {
321    use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
322
323    // 临时抑制 whisper.cpp C 库的 stderr 调试输出
324    let _stderr_guard = suppress_stderr();
325
326    // 加载模型
327    let ctx = WhisperContext::new_with_params(
328        model_path.to_str().unwrap_or(""),
329        WhisperContextParameters::default(),
330    )
331    .map_err(|e| format!("加载 Whisper 模型失败: {}", e))?;
332
333    let mut state = ctx
334        .create_state()
335        .map_err(|e| format!("创建 Whisper 状态失败: {}", e))?;
336
337    // 读取 WAV 文件并转换为 f32 采样
338    let reader =
339        hound::WavReader::open(audio_path).map_err(|e| format!("读取 WAV 文件失败: {}", e))?;
340
341    let samples: Vec<f32> = reader
342        .into_samples::<i16>()
343        .filter_map(|s| s.ok())
344        .map(|s| s as f32 / i16::MAX as f32)
345        .collect();
346
347    if samples.is_empty() {
348        return Err("音频文件为空".to_string());
349    }
350
351    // 配置 Whisper 转写参数
352    let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
353
354    // 设置语言为中文
355    params.set_language(Some("zh"));
356    // 不打印进度
357    params.set_print_progress(false);
358    // 不打印特殊 token
359    params.set_print_special(false);
360    // 不打印实时结果
361    params.set_print_realtime(false);
362    // 单段模式(适合短音频)
363    params.set_single_segment(false);
364    // 线程数
365    params.set_n_threads(4);
366
367    // 执行转写
368    state
369        .full(params, &samples)
370        .map_err(|e| format!("Whisper 转写失败: {}", e))?;
371
372    // 提取转写结果
373    let num_segments = state.full_n_segments();
374    let mut result = String::new();
375
376    for i in 0..num_segments {
377        if let Some(segment) = state.get_segment(i) {
378            if let Ok(text) = segment.to_str_lossy() {
379                result.push_str(&text);
380            }
381        }
382    }
383
384    Ok(result)
385}
386
387/// 下载 Whisper 模型
388fn download_model(model_size: &str) {
389    let model_path = get_model_path(model_size);
390
391    if model_path.exists() {
392        let file_size = std::fs::metadata(&model_path).map(|m| m.len()).unwrap_or(0);
393        let file_size_mb = file_size / 1024 / 1024;
394        let min_size = expected_min_size_mb(model_size);
395
396        if file_size_mb < min_size {
397            // 文件存在但不完整(可能是之前下载中断)
398            info!(
399                "⚠️  模型文件不完整: {} ({} MB,期望至少 {} MB)",
400                model_path.display(),
401                file_size_mb,
402                min_size
403            );
404            info!("🔄 删除不完整文件,重新下载...");
405            let _ = std::fs::remove_file(&model_path);
406        } else {
407            info!(
408                "✅ 模型已存在: {} ({:.1} MB)",
409                model_path.display(),
410                file_size as f64 / 1024.0 / 1024.0
411            );
412            info!("💡 如需重新下载,请先删除模型文件");
413            return;
414        }
415    }
416
417    let url = vc::MODEL_URL_TEMPLATE.replace("{}", model_size);
418
419    info!("📥 下载 Whisper {} 模型...", model_size.cyan().bold());
420    info!("   URL: {}", url.dimmed());
421    info!("   保存到: {}", model_path.display().to_string().dimmed());
422    println!();
423
424    // 使用 curl 下载(避免引入额外的 HTTP 依赖)
425    let status = std::process::Command::new("curl")
426        .args([
427            "-L",             // 跟随重定向
428            "--progress-bar", // 进度条
429            "-o",
430            model_path.to_str().unwrap_or(""),
431            &url,
432        ])
433        .stdin(std::process::Stdio::inherit())
434        .stdout(std::process::Stdio::inherit())
435        .stderr(std::process::Stdio::inherit())
436        .status();
437
438    match status {
439        Ok(s) if s.success() => {
440            let file_size = std::fs::metadata(&model_path).map(|m| m.len()).unwrap_or(0);
441            let file_size_mb = file_size / 1024 / 1024;
442            let min_size = expected_min_size_mb(model_size);
443            if file_size_mb < min_size {
444                error!(
445                    "下载的文件不完整 ({} MB,期望至少 {} MB)",
446                    file_size_mb, min_size
447                );
448                error!(
449                    "请检查网络连接,或手动下载模型文件到: {}",
450                    model_path.display()
451                );
452                error!(
453                    "手动下载链接: {}",
454                    vc::MODEL_URL_TEMPLATE.replace("{}", model_size)
455                );
456                let _ = std::fs::remove_file(&model_path);
457                return;
458            }
459            println!();
460            info!(
461                "✅ 模型下载完成: {} ({:.1} MB)",
462                model_size.green().bold(),
463                file_size as f64 / 1024.0 / 1024.0
464            );
465        }
466        Ok(_) => {
467            error!("模型下载失败,请检查网络连接");
468            let _ = std::fs::remove_file(&model_path);
469        }
470        Err(e) => {
471            error!(
472                "[download_model] 执行 curl 失败: {},请确保系统安装了 curl",
473                e
474            );
475        }
476    }
477}
478
479/// 复制文字到系统剪贴板 (macOS: pbcopy)
480fn copy_to_clipboard(text: &str) {
481    use std::io::Write;
482
483    let mut child = match std::process::Command::new("pbcopy")
484        .stdin(std::process::Stdio::piped())
485        .spawn()
486    {
487        Ok(c) => c,
488        Err(e) => {
489            error!("[copy_to_clipboard] 无法调用 pbcopy: {}", e);
490            return;
491        }
492    };
493
494    if let Some(mut stdin) = child.stdin.take() {
495        let _ = stdin.write_all(text.as_bytes());
496    }
497
498    match child.wait() {
499        Ok(_) => info!("📋 已复制到剪贴板"),
500        Err(e) => error!("[copy_to_clipboard] pbcopy 执行失败: {}", e),
501    }
502}
503
504/// 临时抑制 stderr 输出(用于屏蔽 whisper.cpp C 库的调试日志)
505/// 返回一个 guard,drop 时自动恢复 stderr
506fn suppress_stderr() -> StderrGuard {
507    use std::os::unix::io::AsRawFd;
508
509    let stderr_fd = std::io::stderr().as_raw_fd();
510    // 备份原始 stderr fd
511    let saved_fd = unsafe { libc::dup(stderr_fd) };
512    // 打开 /dev/null
513    let devnull = std::fs::OpenOptions::new()
514        .write(true)
515        .open("/dev/null")
516        .ok();
517    if let Some(ref devnull_file) = devnull {
518        unsafe {
519            libc::dup2(devnull_file.as_raw_fd(), stderr_fd);
520        }
521    }
522
523    StderrGuard {
524        saved_fd,
525        stderr_fd,
526        _devnull: devnull,
527    }
528}
529
530/// stderr 重定向 guard,drop 时恢复原始 stderr
531struct StderrGuard {
532    saved_fd: i32,
533    stderr_fd: i32,
534    _devnull: Option<std::fs::File>,
535}
536
537impl Drop for StderrGuard {
538    fn drop(&mut self) {
539        if self.saved_fd >= 0 {
540            unsafe {
541                libc::dup2(self.saved_fd, self.stderr_fd);
542                libc::close(self.saved_fd);
543            }
544        }
545    }
546}