Skip to main content

openai_core/
audio_helpers.rs

1//! 本地音频播放与录制辅助能力。
2
3use std::path::{Path, PathBuf};
4use std::process::Stdio;
5use std::time::Duration;
6
7use bytes::Bytes;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::process::Command;
10
11use crate::error::{Error, Result};
12use crate::files::UploadSource;
13
14const DEFAULT_SAMPLE_RATE: u32 = 24_000;
15const DEFAULT_CHANNELS: u32 = 1;
16
17/// 表示可用于本地播放的音频输入。
18#[derive(Debug, Clone)]
19pub enum AudioPlaybackInput {
20    /// 直接播放本地文件。
21    Path(PathBuf),
22    /// 通过 stdin 管道播放内存中的音频字节。
23    Bytes(Bytes),
24    /// 通过上传源中的字节进行播放。
25    UploadSource(UploadSource),
26}
27
28impl AudioPlaybackInput {
29    /// 从路径创建播放输入。
30    pub fn path(path: impl Into<PathBuf>) -> Self {
31        Self::Path(path.into())
32    }
33
34    /// 从字节创建播放输入。
35    pub fn bytes(bytes: impl Into<Bytes>) -> Self {
36        Self::Bytes(bytes.into())
37    }
38
39    /// 从上传源创建播放输入。
40    pub fn upload(source: UploadSource) -> Self {
41        Self::UploadSource(source)
42    }
43}
44
45impl From<PathBuf> for AudioPlaybackInput {
46    fn from(value: PathBuf) -> Self {
47        Self::Path(value)
48    }
49}
50
51impl From<&Path> for AudioPlaybackInput {
52    fn from(value: &Path) -> Self {
53        Self::Path(value.to_path_buf())
54    }
55}
56
57impl From<Vec<u8>> for AudioPlaybackInput {
58    fn from(value: Vec<u8>) -> Self {
59        Self::Bytes(Bytes::from(value))
60    }
61}
62
63impl From<Bytes> for AudioPlaybackInput {
64    fn from(value: Bytes) -> Self {
65        Self::Bytes(value)
66    }
67}
68
69impl From<UploadSource> for AudioPlaybackInput {
70    fn from(value: UploadSource) -> Self {
71        Self::UploadSource(value)
72    }
73}
74
75/// 表示录音辅助的可调参数。
76#[derive(Debug, Clone)]
77pub struct RecordAudioOptions {
78    /// 指定采集设备编号或名称,默认使用 `0`。
79    pub device: Option<String>,
80    /// 录音超时时长,超时后会主动停止采集。
81    pub timeout: Option<Duration>,
82    /// 录音采样率,默认 `24000`。
83    pub sample_rate: u32,
84    /// 录音声道数,默认 `1`。
85    pub channels: u32,
86    /// 覆盖平台默认输入 provider。
87    pub provider: Option<String>,
88    /// 覆盖输出文件名,默认 `audio.wav`。
89    pub filename: String,
90    /// 覆盖录音程序名,默认 `ffmpeg`。
91    pub program: String,
92}
93
94impl Default for RecordAudioOptions {
95    fn default() -> Self {
96        Self {
97            device: None,
98            timeout: None,
99            sample_rate: DEFAULT_SAMPLE_RATE,
100            channels: DEFAULT_CHANNELS,
101            provider: None,
102            filename: "audio.wav".into(),
103            program: "ffmpeg".into(),
104        }
105    }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
109struct CommandSpec {
110    program: String,
111    args: Vec<String>,
112    stdin: Option<Bytes>,
113}
114
115/// 使用系统中的 `ffplay` 播放音频。
116///
117/// 当输入为字节或上传源时,会通过 `stdin` 管道向播放器传输数据。
118///
119/// # Errors
120///
121/// 当本地播放器不存在、启动失败或退出码非零时返回错误。
122pub async fn play_audio(input: impl Into<AudioPlaybackInput>) -> Result<()> {
123    let spec = build_play_audio_command(input.into(), "ffplay");
124    run_play_command(spec).await
125}
126
127/// 使用系统中的 `ffmpeg` 录制一段音频并返回统一上传源。
128///
129/// # Errors
130///
131/// 当当前平台缺少默认采集 provider、命令执行失败或录制超时时返回错误。
132pub async fn record_audio(options: RecordAudioOptions) -> Result<UploadSource> {
133    let spec = build_record_audio_command(&options, std::env::consts::OS)?;
134    let bytes = run_record_command(spec, options.timeout).await?;
135    Ok(UploadSource::from_bytes(bytes, options.filename).with_mime_type("audio/wav"))
136}
137
138fn build_play_audio_command(input: AudioPlaybackInput, program: &str) -> CommandSpec {
139    match input {
140        AudioPlaybackInput::Path(path) => CommandSpec {
141            program: program.into(),
142            args: vec![
143                "-autoexit".into(),
144                "-nodisp".into(),
145                "-i".into(),
146                path.to_string_lossy().into_owned(),
147            ],
148            stdin: None,
149        },
150        AudioPlaybackInput::Bytes(bytes) => CommandSpec {
151            program: program.into(),
152            args: vec![
153                "-autoexit".into(),
154                "-nodisp".into(),
155                "-i".into(),
156                "pipe:0".into(),
157            ],
158            stdin: Some(bytes),
159        },
160        AudioPlaybackInput::UploadSource(source) => CommandSpec {
161            program: program.into(),
162            args: vec![
163                "-autoexit".into(),
164                "-nodisp".into(),
165                "-i".into(),
166                "pipe:0".into(),
167            ],
168            stdin: Some(source.bytes().clone()),
169        },
170    }
171}
172
173fn build_record_audio_command(options: &RecordAudioOptions, platform: &str) -> Result<CommandSpec> {
174    let provider = if let Some(provider) = &options.provider {
175        provider.clone()
176    } else {
177        default_recording_provider(platform)
178            .ok_or_else(|| {
179                Error::InvalidConfig(format!("当前平台 `{platform}` 不支持默认录音 provider"))
180            })?
181            .into()
182    };
183    let device = options.device.as_deref().unwrap_or("0");
184
185    Ok(CommandSpec {
186        program: options.program.clone(),
187        args: vec![
188            "-f".into(),
189            provider,
190            "-i".into(),
191            format!(":{device}"),
192            "-ar".into(),
193            options.sample_rate.to_string(),
194            "-ac".into(),
195            options.channels.to_string(),
196            "-f".into(),
197            "wav".into(),
198            "pipe:1".into(),
199        ],
200        stdin: None,
201    })
202}
203
204async fn run_play_command(spec: CommandSpec) -> Result<()> {
205    let mut command = Command::new(&spec.program);
206    command.args(&spec.args);
207    command.stdout(Stdio::null()).stderr(Stdio::null());
208    if spec.stdin.is_some() {
209        command.stdin(Stdio::piped());
210    } else {
211        command.stdin(Stdio::null());
212    }
213
214    let mut child = command
215        .spawn()
216        .map_err(|error| Error::InvalidConfig(format!("启动 `{}` 失败: {error}", spec.program)))?;
217
218    if let Some(bytes) = spec.stdin {
219        let mut stdin = child
220            .stdin
221            .take()
222            .ok_or_else(|| Error::InvalidConfig(format!("`{}` 未暴露 stdin 管道", spec.program)))?;
223        stdin.write_all(&bytes).await.map_err(|error| {
224            Error::InvalidConfig(format!("向 `{}` 写入音频失败: {error}", spec.program))
225        })?;
226        stdin.shutdown().await.map_err(|error| {
227            Error::InvalidConfig(format!("关闭 `{}` stdin 失败: {error}", spec.program))
228        })?;
229    }
230
231    let status = child.wait().await.map_err(|error| {
232        Error::InvalidConfig(format!("等待 `{}` 退出失败: {error}", spec.program))
233    })?;
234    if status.success() {
235        Ok(())
236    } else {
237        Err(Error::InvalidConfig(format!(
238            "`{}` 退出失败,状态码: {status}",
239            spec.program
240        )))
241    }
242}
243
244async fn run_record_command(spec: CommandSpec, timeout: Option<Duration>) -> Result<Bytes> {
245    let mut command = Command::new(&spec.program);
246    command.args(&spec.args);
247    command.stdin(Stdio::null());
248    command.stdout(Stdio::piped());
249    command.stderr(Stdio::null());
250
251    let mut child = command
252        .spawn()
253        .map_err(|error| Error::InvalidConfig(format!("启动 `{}` 失败: {error}", spec.program)))?;
254    let mut stdout = child
255        .stdout
256        .take()
257        .ok_or_else(|| Error::InvalidConfig(format!("`{}` 未暴露 stdout 管道", spec.program)))?;
258    let read_stdout = tokio::spawn(async move {
259        let mut buffer = Vec::new();
260        stdout.read_to_end(&mut buffer).await.map(|_| buffer)
261    });
262
263    let status = if let Some(timeout) = timeout {
264        tokio::select! {
265            status = child.wait() => {
266                status.map_err(|error| Error::InvalidConfig(format!("等待 `{}` 退出失败: {error}", spec.program)))?
267            }
268            _ = tokio::time::sleep(timeout) => {
269                let _ = child.start_kill();
270                let _ = child.wait().await;
271                return Err(Error::Timeout);
272            }
273        }
274    } else {
275        child.wait().await.map_err(|error| {
276            Error::InvalidConfig(format!("等待 `{}` 退出失败: {error}", spec.program))
277        })?
278    };
279
280    let bytes = read_stdout
281        .await
282        .map_err(|error| {
283            Error::InvalidConfig(format!("读取 `{}` 输出失败: {error}", spec.program))
284        })?
285        .map_err(|error| {
286            Error::InvalidConfig(format!("读取 `{}` 输出失败: {error}", spec.program))
287        })?;
288
289    if status.success() {
290        Ok(Bytes::from(bytes))
291    } else {
292        Err(Error::InvalidConfig(format!(
293            "`{}` 退出失败,状态码: {status}",
294            spec.program
295        )))
296    }
297}
298
299fn default_recording_provider(platform: &str) -> Option<&'static str> {
300    match platform {
301        "windows" => Some("dshow"),
302        "macos" => Some("avfoundation"),
303        "linux" | "android" | "freebsd" | "haiku" | "netbsd" | "openbsd" => Some("alsa"),
304        _ => None,
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::{
311        AudioPlaybackInput, RecordAudioOptions, build_play_audio_command,
312        build_record_audio_command, default_recording_provider,
313    };
314    use bytes::Bytes;
315
316    #[test]
317    fn test_should_build_play_command_for_path_input() {
318        let spec = build_play_audio_command(AudioPlaybackInput::path("/tmp/sample.wav"), "ffplay");
319        assert_eq!(
320            spec.args,
321            vec!["-autoexit", "-nodisp", "-i", "/tmp/sample.wav"]
322        );
323        assert!(spec.stdin.is_none());
324    }
325
326    #[test]
327    fn test_should_build_play_command_for_bytes_input() {
328        let spec = build_play_audio_command(
329            AudioPlaybackInput::bytes(Bytes::from_static(b"wav")),
330            "ffplay",
331        );
332        assert_eq!(spec.args, vec!["-autoexit", "-nodisp", "-i", "pipe:0"]);
333        assert_eq!(spec.stdin, Some(Bytes::from_static(b"wav")));
334    }
335
336    #[test]
337    fn test_should_build_record_command_with_platform_defaults() {
338        let spec = build_record_audio_command(&RecordAudioOptions::default(), "linux").unwrap();
339        assert_eq!(
340            spec.args,
341            vec![
342                "-f", "alsa", "-i", ":0", "-ar", "24000", "-ac", "1", "-f", "wav", "pipe:1"
343            ]
344        );
345    }
346
347    #[test]
348    fn test_should_fail_when_platform_has_no_default_provider() {
349        let error =
350            build_record_audio_command(&RecordAudioOptions::default(), "dragonfly").unwrap_err();
351        assert!(matches!(error, crate::Error::InvalidConfig(_)));
352    }
353
354    #[test]
355    fn test_should_map_platform_provider() {
356        assert_eq!(default_recording_provider("macos"), Some("avfoundation"));
357        assert_eq!(default_recording_provider("windows"), Some("dshow"));
358        assert_eq!(default_recording_provider("linux"), Some("alsa"));
359    }
360}