1use anyhow::{Result, anyhow};
2use audio_codec::{PcmBuf, samples_to_bytes};
3use futures::StreamExt;
4use serde::{Deserialize, Serialize};
5use std::{
6 collections::HashMap,
7 path::Path,
8 sync::{
9 Mutex,
10 atomic::{AtomicUsize, Ordering},
11 },
12 time::Duration,
13 u32,
14};
15use tokio::{
16 fs::File,
17 io::{AsyncSeekExt, AsyncWriteExt},
18 select,
19 sync::mpsc::UnboundedReceiver,
20};
21use tokio_stream::wrappers::IntervalStream;
22use tokio_util::sync::CancellationToken;
23use tracing::{info, warn};
24
25use crate::media::{AudioFrame, Samples};
26
27#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
28#[serde(rename_all = "lowercase")]
29pub enum RecorderFormat {
30 Wav,
31 Pcm,
32 Pcmu,
33 Pcma,
34 G722,
35}
36
37impl RecorderFormat {
38 pub fn extension(&self) -> &'static str {
39 "wav"
40 }
41
42 pub fn is_supported(&self) -> bool {
43 true
44 }
45
46 pub fn effective(&self) -> RecorderFormat {
47 *self
48 }
49}
50
51impl Default for RecorderFormat {
52 fn default() -> Self {
53 RecorderFormat::Wav
54 }
55}
56
57#[derive(Debug, Deserialize, Serialize, Clone)]
58#[serde(rename_all = "camelCase")]
59#[serde(default)]
60pub struct RecorderOption {
61 #[serde(default)]
62 pub recorder_file: String,
63 #[serde(default)]
64 pub samplerate: u32,
65 #[serde(default)]
66 pub ptime: u32,
67 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub format: Option<RecorderFormat>,
69}
70
71impl RecorderOption {
72 pub fn new(recorder_file: String) -> Self {
73 Self {
74 recorder_file,
75 ..Default::default()
76 }
77 }
78
79 pub fn resolved_format(&self, default: RecorderFormat) -> RecorderFormat {
80 self.format.unwrap_or(default).effective()
81 }
82
83 pub fn ensure_path_extension(&mut self, fallback_format: RecorderFormat) {
84 let effective_format = self.format.unwrap_or(fallback_format).effective();
85 self.format = Some(effective_format);
86
87 if self.recorder_file.is_empty() {
88 return;
89 }
90
91 let extension = effective_format.extension();
92 if !self
93 .recorder_file
94 .to_lowercase()
95 .ends_with(&format!(".{}", extension.to_lowercase()))
96 {
97 self.recorder_file = format!("{}.{}", self.recorder_file, extension);
98 }
99 }
100}
101
102impl Default for RecorderOption {
103 fn default() -> Self {
104 Self {
105 recorder_file: "".to_string(),
106 samplerate: 16000,
107 ptime: 200,
108 format: None,
109 }
110 }
111}
112
113pub struct Recorder {
114 session_id: String,
115 option: RecorderOption,
116 samples_written: AtomicUsize,
117 cancel_token: CancellationToken,
118 channel_idx: AtomicUsize,
119 channels: Mutex<HashMap<String, usize>>,
120 stereo_buf: Mutex<PcmBuf>,
121 mono_buf: Mutex<PcmBuf>,
122}
123
124impl Recorder {
125 pub fn new(
126 cancel_token: CancellationToken,
127 session_id: String,
128 option: RecorderOption,
129 ) -> Self {
130 Self {
131 session_id,
132 option,
133 samples_written: AtomicUsize::new(0),
134 cancel_token,
135 channel_idx: AtomicUsize::new(0),
136 channels: Mutex::new(HashMap::new()),
137 stereo_buf: Mutex::new(Vec::new()),
138 mono_buf: Mutex::new(Vec::new()),
139 }
140 }
141
142 async fn update_wav_header(&self, file: &mut File, payload_type: Option<u8>) -> Result<()> {
143 let total = self.samples_written.load(Ordering::SeqCst);
144
145 let (format_tag, sample_rate, channels, bits_per_sample, data_size): (
146 u16,
147 u32,
148 u16,
149 u16,
150 usize,
151 ) = match payload_type {
152 Some(pt) => {
153 let (tag, rate, chan): (u16, u32, u16) = match pt {
154 0 => (0x0007, 8000, 1), 8 => (0x0006, 8000, 1), 9 => (0x0064, 16000, 1), 10 => (0x0001, 44100, 2), 11 => (0x0001, 44100, 1), _ => (0x0001, 16000, 1), };
161 let bits: u16 = match pt {
162 9 => 4,
163 0 | 8 => 8,
164 _ => 16,
165 };
166 (tag, rate, chan, bits, total)
167 }
168 None => (0x0001, self.option.samplerate, 2, 16, total),
169 };
170
171 let mut header_buf = Vec::new();
172 header_buf.extend_from_slice(b"RIFF");
173 let file_size = data_size + 36;
174 header_buf.extend_from_slice(&(file_size as u32).to_le_bytes());
175 header_buf.extend_from_slice(b"WAVE");
176
177 header_buf.extend_from_slice(b"fmt ");
178 header_buf.extend_from_slice(&16u32.to_le_bytes());
179 header_buf.extend_from_slice(&format_tag.to_le_bytes());
180 header_buf.extend_from_slice(&(channels as u16).to_le_bytes());
181 header_buf.extend_from_slice(&sample_rate.to_le_bytes());
182
183 let bytes_per_sec: u32 = match format_tag {
184 0x0064 => 8000, _ => sample_rate * (channels as u32) * (bits_per_sample as u32 / 8),
186 };
187 header_buf.extend_from_slice(&bytes_per_sec.to_le_bytes());
188
189 let block_align: u16 = match format_tag {
190 0x0064 | 0x0007 | 0x0006 => 1 * channels,
191 _ => (bits_per_sample / 8) * channels,
192 };
193 header_buf.extend_from_slice(&block_align.to_le_bytes());
194 header_buf.extend_from_slice(&bits_per_sample.to_le_bytes());
195
196 header_buf.extend_from_slice(b"data");
197 header_buf.extend_from_slice(&(data_size as u32).to_le_bytes());
198
199 file.seek(std::io::SeekFrom::Start(0)).await?;
200 file.write_all(&header_buf).await?;
201 file.seek(std::io::SeekFrom::End(0)).await?;
202
203 Ok(())
204 }
205
206 pub async fn process_recording(
207 &self,
208 file_path: &Path,
209 mut receiver: UnboundedReceiver<AudioFrame>,
210 ) -> Result<()> {
211 let first_frame = match receiver.recv().await {
212 Some(f) => f,
213 None => return Ok(()),
214 };
215
216 if let Samples::RTP { .. } = first_frame.samples {
217 return self
218 .process_recording_rtp(file_path, receiver, first_frame)
219 .await;
220 }
221
222 let _requested_format = self.option.format.unwrap_or(RecorderFormat::Wav);
223
224 self.process_recording_wav(file_path, receiver, first_frame)
225 .await
226 }
227
228 fn ensure_parent_dir(&self, file_path: &Path) -> Result<()> {
229 if let Some(parent) = file_path.parent() {
230 if !parent.exists() {
231 if let Err(e) = std::fs::create_dir_all(parent) {
232 warn!(
233 "Failed to create recording file parent directory: {} {}",
234 e,
235 file_path.display()
236 );
237 return Err(anyhow!("Failed to create recording file parent directory"));
238 }
239 }
240 }
241 Ok(())
242 }
243
244 async fn create_output_file(&self, file_path: &Path) -> Result<File> {
245 self.ensure_parent_dir(file_path)?;
246 match File::create(file_path).await {
247 Ok(file) => {
248 info!(
249 session_id = self.session_id,
250 "recorder: created recording file: {}",
251 file_path.display()
252 );
253 Ok(file)
254 }
255 Err(e) => {
256 warn!(
257 "Failed to create recording file: {} {}",
258 e,
259 file_path.display()
260 );
261 Err(anyhow!("Failed to create recording file"))
262 }
263 }
264 }
265
266 async fn process_recording_rtp(
267 &self,
268 file_path: &Path,
269 mut receiver: UnboundedReceiver<AudioFrame>,
270 first_frame: AudioFrame,
271 ) -> Result<()> {
272 let (payload_type, mut file) =
273 if let Samples::RTP { payload_type, .. } = &first_frame.samples {
274 let file = self.create_output_file(file_path).await?;
275 (*payload_type, file)
276 } else {
277 return Err(anyhow!("Invalid frame type for RTP recording"));
278 };
279
280 self.update_wav_header(&mut file, Some(payload_type))
281 .await?;
282
283 if let Samples::RTP { payload, .. } = first_frame.samples {
284 file.write_all(&payload).await?;
285 self.samples_written
286 .fetch_add(payload.len(), Ordering::SeqCst);
287 }
288
289 loop {
290 match receiver.recv().await {
291 Some(frame) => {
292 if let Samples::RTP { payload, .. } = frame.samples {
293 file.write_all(&payload).await?;
294 self.samples_written
295 .fetch_add(payload.len(), Ordering::SeqCst);
296 }
297 }
298 None => break,
299 }
300 }
301
302 self.update_wav_header(&mut file, Some(payload_type))
303 .await?;
304
305 file.sync_all().await?;
306
307 Ok(())
308 }
309
310 async fn process_recording_wav(
311 &self,
312 file_path: &Path,
313 mut receiver: UnboundedReceiver<AudioFrame>,
314 first_frame: AudioFrame,
315 ) -> Result<()> {
316 let mut file = self.create_output_file(file_path).await?;
317 self.update_wav_header(&mut file, None).await?;
318
319 self.append_frame(first_frame).await.ok();
320
321 let chunk_size = (self.option.samplerate / 1000 * self.option.ptime) as usize;
322 info!(
323 session_id = self.session_id,
324 format = "wav",
325 "Recording to {} ptime: {}ms chunk_size: {}",
326 file_path.display(),
327 self.option.ptime,
328 chunk_size
329 );
330
331 let mut interval = IntervalStream::new(tokio::time::interval(Duration::from_millis(
332 self.option.ptime as u64,
333 )));
334 loop {
335 select! {
336 Some(frame) = receiver.recv() => {
337 self.append_frame(frame).await.ok();
338 }
339 _ = interval.next() => {
340 let (mono_buf, stereo_buf) = self.pop(chunk_size).await;
341 self.process_buffers(&mut file, mono_buf, stereo_buf).await?;
342 self.update_wav_header(&mut file, None).await?;
343 }
344 _ = self.cancel_token.cancelled() => {
345 self.flush_buffers(&mut file).await?;
346 self.update_wav_header(&mut file, None).await?;
347 return Ok(());
348 }
349 }
350 }
351 }
352
353 fn get_channel_index(&self, track_id: &str) -> usize {
354 let mut channels = self.channels.lock().unwrap();
355 if let Some(&channel_idx) = channels.get(track_id) {
356 channel_idx % 2
357 } else {
358 let new_idx = self.channel_idx.fetch_add(1, Ordering::SeqCst);
359 channels.insert(track_id.to_string(), new_idx);
360 info!(
361 session_id = self.session_id,
362 "Assigned channel {} to track: {}",
363 new_idx % 2,
364 track_id
365 );
366 new_idx % 2
367 }
368 }
369
370 async fn append_frame(&self, frame: AudioFrame) -> Result<()> {
371 let buffer = match frame.samples {
372 Samples::PCM { samples } => samples,
373 _ => return Ok(()), };
375
376 if buffer.is_empty() {
377 return Ok(());
378 }
379
380 let channel_idx = self.get_channel_index(&frame.track_id);
381 match channel_idx {
382 0 => {
383 let mut mono_buf = self.mono_buf.lock().unwrap();
384 mono_buf.extend(buffer.iter());
385 }
386 1 => {
387 let mut stereo_buf = self.stereo_buf.lock().unwrap();
388 stereo_buf.extend(buffer.iter());
389 }
390 _ => {}
391 }
392
393 Ok(())
394 }
395
396 pub(crate) fn extract_samples(buffer: &mut PcmBuf, extract_size: usize) -> PcmBuf {
397 if extract_size > 0 && !buffer.is_empty() {
398 let take_size = extract_size.min(buffer.len());
399 buffer.drain(..take_size).collect()
400 } else {
401 Vec::new()
402 }
403 }
404
405 async fn pop(&self, chunk_size: usize) -> (PcmBuf, PcmBuf) {
406 let mut mono_buf = self.mono_buf.lock().unwrap();
407 let mut stereo_buf = self.stereo_buf.lock().unwrap();
408
409 let safe_chunk_size = chunk_size.min(16000 * 10);
410
411 let mono_result = if mono_buf.len() >= safe_chunk_size {
412 Self::extract_samples(&mut mono_buf, safe_chunk_size)
413 } else if !mono_buf.is_empty() {
414 let available_len = mono_buf.len();
415 let mut result = Self::extract_samples(&mut mono_buf, available_len);
416 if chunk_size != usize::MAX {
417 result.resize(safe_chunk_size, 0);
418 }
419 result
420 } else {
421 if chunk_size != usize::MAX {
422 vec![0; safe_chunk_size]
423 } else {
424 Vec::new()
425 }
426 };
427
428 let stereo_result = if stereo_buf.len() >= safe_chunk_size {
429 Self::extract_samples(&mut stereo_buf, safe_chunk_size)
430 } else if !stereo_buf.is_empty() {
431 let available_len = stereo_buf.len();
432 let mut result = Self::extract_samples(&mut stereo_buf, available_len);
433 if chunk_size != usize::MAX {
434 result.resize(safe_chunk_size, 0);
435 }
436 result
437 } else {
438 if chunk_size != usize::MAX {
439 vec![0; safe_chunk_size]
440 } else {
441 Vec::new()
442 }
443 };
444
445 if chunk_size == usize::MAX {
446 let max_len = mono_result.len().max(stereo_result.len());
447 let mut mono_final = mono_result;
448 let mut stereo_final = stereo_result;
449 mono_final.resize(max_len, 0);
450 stereo_final.resize(max_len, 0);
451 (mono_final, stereo_final)
452 } else {
453 (mono_result, stereo_result)
454 }
455 }
456
457 pub fn stop_recording(&self) -> Result<()> {
458 self.cancel_token.cancel();
459 Ok(())
460 }
461
462 pub(crate) fn mix_buffers(mono_buf: &PcmBuf, stereo_buf: &PcmBuf) -> Vec<i16> {
463 assert_eq!(
464 mono_buf.len(),
465 stereo_buf.len(),
466 "Buffer lengths must be equal after pop()"
467 );
468
469 let len = mono_buf.len();
470 let mut mix_buff = Vec::with_capacity(len * 2);
471
472 for i in 0..len {
473 mix_buff.push(mono_buf[i]);
474 mix_buff.push(stereo_buf[i]);
475 }
476
477 mix_buff
478 }
479
480 async fn write_audio_data(
481 &self,
482 file: &mut File,
483 mono_buf: &PcmBuf,
484 stereo_buf: &PcmBuf,
485 ) -> Result<usize> {
486 let max_len = mono_buf.len().max(stereo_buf.len());
487 if max_len == 0 {
488 return Ok(0);
489 }
490
491 let mix_buff = Self::mix_buffers(mono_buf, stereo_buf);
492
493 file.seek(std::io::SeekFrom::End(0)).await?;
494 file.write_all(&samples_to_bytes(&mix_buff)).await?;
495
496 Ok(max_len)
497 }
498
499 async fn process_buffers(
500 &self,
501 file: &mut File,
502 mono_buf: PcmBuf,
503 stereo_buf: PcmBuf,
504 ) -> Result<()> {
505 if mono_buf.is_empty() && stereo_buf.is_empty() {
506 return Ok(());
507 }
508 let samples_written = self.write_audio_data(file, &mono_buf, &stereo_buf).await?;
509 if samples_written > 0 {
510 self.samples_written
511 .fetch_add(samples_written * 4, Ordering::SeqCst);
512 }
513 Ok(())
514 }
515
516 async fn flush_buffers(&self, file: &mut File) -> Result<()> {
517 loop {
518 let (mono_buf, stereo_buf) = self.pop(usize::MAX).await;
519
520 if mono_buf.is_empty() && stereo_buf.is_empty() {
521 break;
522 }
523
524 let samples_written = self.write_audio_data(file, &mono_buf, &stereo_buf).await?;
525 if samples_written > 0 {
526 self.samples_written
527 .fetch_add(samples_written * 4, Ordering::SeqCst);
528 }
529 }
530
531 Ok(())
532 }
533}