1use serde::{Deserialize, Serialize};
141use std::path::{Path, PathBuf};
142use thiserror::Error;
143
144
145use ffmpeg_sidecar::{command::FfmpegCommand};
146use hound::WavReader;
147use rubato::{
148 Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
149};
150
151#[derive(Debug, Error)]
152pub enum AudioError {
153 #[error("I/O error: {0}")]
154 Io(#[from] std::io::Error),
155
156 #[error("FFmpeg not available: {0}")]
158 FfmpegNotAvailable(String),
159 #[error("FFmpeg execution failed: {0}")]
160 FfmpegExecution(String),
161 #[error("FFmpeg configuration error: {0}")]
162 FfmpegConfig(String),
163
164 #[error("Format not supported: {format}, supported formats: {supported}")]
166 FormatNotSupported { format: String, supported: String },
167 #[error("Decode failed: {reason}")]
168 DecodeError { reason: String },
169 #[error("Encode failed: {reason}")]
170 EncodeError { reason: String },
171 #[error("Audio file corrupted or malformed: {0}")]
172 CorruptedFile(String),
173
174 #[error("Sample rate mismatch: expected {expected}, got {actual}")]
176 SampleRateMismatch { expected: u32, actual: u32 },
177 #[error("Channel count mismatch: expected {expected}, got {actual}")]
178 ChannelMismatch { expected: u16, actual: u16 },
179 #[error("Invalid sample rate: {rate}, must be between {min}-{max}")]
180 InvalidSampleRate { rate: u32, min: u32, max: u32 },
181 #[error("Invalid channel count: {channels}, must be between {min}-{max}")]
182 InvalidChannelCount { channels: u16, min: u16, max: u16 },
183 #[error("Invalid parameter: {0}")]
184 InvalidParameter(String),
185 #[error("Invalid buffer size: {size}, must be greater than {min}")]
186 InvalidBufferSize { size: usize, min: usize },
187
188 #[error("File not found: {0}")]
190 FileNotFound(String),
191 #[error("Path is not a file: {0}")]
192 NotAFile(String),
193 #[error("Permission denied: {0}")]
194 PermissionDenied(String),
195 #[error("Insufficient disk space: {0}")]
196 InsufficientSpace(String),
197
198 #[error("Resampling failed: {0}")]
200 ResampleError(String),
201 #[error("Audio processing failed: {0}")]
202 ProcessingError(String),
203 #[error("Out of memory: {0}")]
204 OutOfMemory(String),
205 #[error("Operation timeout: {0}")]
206 Timeout(String),
207
208 #[error("Unknown error: {0}")]
210 Other(String),
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
236pub enum AudioFormat {
237 Wav,
242
243 Mp3,
248
249 Flac,
254
255 M4a,
260
261 Ogg,
266}
267
268impl AudioFormat {
269 pub fn from_extension(ext: &str) -> Option<Self> {
271 match ext.to_lowercase().as_str() {
272 "wav" => Some(AudioFormat::Wav),
273 "mp3" => Some(AudioFormat::Mp3),
274 "flac" => Some(AudioFormat::Flac),
275 "m4a" => Some(AudioFormat::M4a),
276 "ogg" => Some(AudioFormat::Ogg),
277 _ => None,
278 }
279 }
280
281 pub fn extension(&self) -> &'static str {
283 match self {
284 AudioFormat::Wav => "wav",
285 AudioFormat::Mp3 => "mp3",
286 AudioFormat::Flac => "flac",
287 AudioFormat::M4a => "m4a",
288 AudioFormat::Ogg => "ogg",
289 }
290 }
291
292 pub fn is_whisper_native(&self) -> bool {
294 matches!(self, AudioFormat::Wav)
295 }
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct AudioConfig {
320 pub sample_rate: u32,
329
330 pub channels: u16,
337
338 pub bit_depth: u16,
346}
347
348impl Default for AudioConfig {
349 fn default() -> Self {
350 Self {
351 sample_rate: 16000, channels: 1, bit_depth: 16, }
355 }
356}
357
358impl AudioConfig {
359 pub fn new(sample_rate: u32, channels: u16, bit_depth: u16) -> Self {
361 Self {
362 sample_rate,
363 channels,
364 bit_depth,
365 }
366 }
367
368 pub fn whisper_optimized() -> Self {
370 Self::default()
371 }
372
373 pub fn is_whisper_compatible(&self) -> bool {
375 self.sample_rate == 16000 && self.channels == 1
376 }
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct AudioMeta {
381 pub sample_rate: u32,
383 pub channels: u16,
385 pub duration_ms: Option<u64>,
387 pub format: Option<String>,
389}
390
391#[derive(Debug, Clone)]
392pub struct CompatibleWav {
393 pub path: std::path::PathBuf,
395}
396
397#[derive(Debug, Clone)]
398pub struct Resampled {
399 pub samples: Vec<f32>,
401 pub sample_rate: u32,
403}
404
405pub fn probe<P: AsRef<std::path::Path>>(input: P) -> Result<AudioMeta, AudioError> {
450 let path = input.as_ref();
451 if !path.exists() {
452 return Err(AudioError::FileNotFound(format!("{}", path.display())));
453 }
454 if path.is_dir() {
455 return Err(AudioError::NotAFile(format!("{}", path.display())));
456 }
457
458 let ext = path
460 .extension()
461 .and_then(|e| e.to_str())
462 .unwrap_or("")
463 .to_lowercase();
464 if ext == "wav" {
465 let reader = WavReader::open(path).map_err(|e| AudioError::DecodeError {
466 reason: format!("打开 WAV 失败: {e}"),
467 })?;
468 let spec = reader.spec();
469 let total_samples = reader.duration();
471 let frames = if spec.channels > 0 {
472 total_samples as u64 / spec.channels as u64
473 } else {
474 0
475 };
476 let duration_ms = if spec.sample_rate > 0 {
477 Some(frames * 1000 / spec.sample_rate as u64)
478 } else {
479 None
480 };
481 return Ok(AudioMeta {
482 sample_rate: spec.sample_rate,
483 channels: spec.channels,
484 duration_ms,
485 format: Some("wav".into()),
486 });
487 } else if !ext.is_empty() {
488 return Err(AudioError::FormatNotSupported {
489 format: ext,
490 supported: "wav".to_string(),
491 });
492 }
493
494 Err(AudioError::FormatNotSupported {
496 format: "unknown".to_string(),
497 supported: "wav, mp3, flac, m4a".to_string(),
498 })
499}
500
501
502pub fn ensure_whisper_compatible<P: AsRef<Path>>(
566 input: P,
567 output: Option<PathBuf>,
568) -> Result<CompatibleWav, AudioError> {
569 let in_path = input.as_ref();
570
571 if !in_path.exists() {
573 return Err(AudioError::FileNotFound(format!("{}", in_path.display())));
574 }
575 if in_path.is_dir() {
576 return Err(AudioError::NotAFile(format!("{}", in_path.display())));
577 }
578
579 let out_path = if let Some(p) = output {
581 p
582 } else {
583 let mut temp = std::env::temp_dir();
584 let file_stem = in_path
585 .file_stem()
586 .and_then(|s| s.to_str())
587 .unwrap_or("audio");
588 temp.push(format!("{file_stem}_mono16k.wav"));
589 temp
590 };
591
592 let filter = "aformat=sample_fmts=s16:channel_layouts=mono:sample_rates=16000";
594
595 let status = FfmpegCommand::new()
596 .input(in_path.to_string_lossy())
597 .args(["-filter:a", filter])
598 .overwrite()
599 .output(out_path.to_string_lossy())
600 .spawn()?
601 .wait()?;
602
603 if !status.success() {
604 return Err(AudioError::FfmpegExecution(
605 "FFmpeg conversion failed".to_string(),
606 ));
607 }
608
609 let reader = WavReader::open(&out_path).map_err(|e| AudioError::DecodeError {
611 reason: format!("Failed to verify output WAV: {e}"),
612 })?;
613 let spec = reader.spec();
614
615 if spec.sample_rate != 16000 {
616 return Err(AudioError::SampleRateMismatch {
617 expected: 16000,
618 actual: spec.sample_rate,
619 });
620 }
621
622 if spec.channels != 1 {
623 return Err(AudioError::ChannelMismatch {
624 expected: 1,
625 actual: spec.channels,
626 });
627 }
628
629 if spec.bits_per_sample != 16 {
630 return Err(AudioError::FormatNotSupported {
631 format: format!("{} bit PCM", spec.bits_per_sample),
632 supported: "16 bit PCM".to_string(),
633 });
634 }
635
636 Ok(CompatibleWav { path: out_path })
637}
638
639
640pub fn resample(samples: &[f32], from_rate: u32, to_rate: u32) -> Result<Resampled, AudioError> {
641 if from_rate == 0 {
642 return Err(AudioError::InvalidSampleRate {
643 rate: from_rate,
644 min: 1,
645 max: 192000,
646 });
647 }
648 if to_rate == 0 {
649 return Err(AudioError::InvalidSampleRate {
650 rate: to_rate,
651 min: 1,
652 max: 192000,
653 });
654 }
655 if samples.is_empty() || from_rate == to_rate {
656 return Ok(Resampled {
657 samples: samples.to_vec(),
658 sample_rate: to_rate,
659 });
660 }
661
662 let ratio = to_rate as f64 / from_rate as f64;
664
665 let params = SincInterpolationParameters {
667 sinc_len: 256,
668 f_cutoff: 0.95,
669 interpolation: SincInterpolationType::Linear,
670 oversampling_factor: 256,
671 window: WindowFunction::BlackmanHarris2,
672 };
673
674 let mut resampler = SincFixedIn::<f32>::new(
676 ratio,
677 2.0, params,
679 samples.len(),
680 1, )
682 .map_err(|e| AudioError::ResampleError(format!("创建重采样器失败: {e}")))?;
683
684 let input_data = vec![samples.to_vec()];
686
687 let output_data = resampler
689 .process(&input_data, None)
690 .map_err(|e| AudioError::ProcessingError(format!("重采样失败: {e}")))?;
691
692 let output_samples = output_data
694 .into_iter()
695 .next()
696 .ok_or_else(|| AudioError::ProcessingError("重采样输出为空".into()))?;
697
698 Ok(Resampled {
699 samples: output_samples,
700 sample_rate: to_rate,
701 })
702}
703
704pub struct StreamingResampler {
707 resampler: Option<SincFixedIn<f32>>,
709 from_rate: u32,
711 to_rate: u32,
713 buffer: Vec<f32>,
715 chunk_size: usize,
717}
718
719impl StreamingResampler {
720 pub fn new(from_rate: u32, to_rate: u32) -> Result<Self, AudioError> {
722 if from_rate == 0 {
723 return Err(AudioError::InvalidSampleRate {
724 rate: from_rate,
725 min: 1,
726 max: 192000,
727 });
728 }
729 if to_rate == 0 {
730 return Err(AudioError::InvalidSampleRate {
731 rate: to_rate,
732 min: 1,
733 max: 192000,
734 });
735 }
736
737 let chunk_size = 1024;
738
739 if from_rate == to_rate {
740 return Ok(Self {
742 resampler: None,
743 from_rate,
744 to_rate,
745 buffer: Vec::new(),
746 chunk_size,
747 });
748 }
749
750 let ratio = to_rate as f64 / from_rate as f64;
751
752 let params = SincInterpolationParameters {
754 sinc_len: 256,
755 f_cutoff: 0.95,
756 interpolation: SincInterpolationType::Linear,
757 oversampling_factor: 256,
758 window: WindowFunction::BlackmanHarris2,
759 };
760
761 let resampler = SincFixedIn::<f32>::new(
763 ratio, 2.0, params, chunk_size, 1, )
767 .map_err(|e| AudioError::ResampleError(format!("创建重采样器失败: {e}")))?;
768
769 Ok(Self {
770 resampler: Some(resampler),
771 from_rate,
772 to_rate,
773 buffer: Vec::new(),
774 chunk_size,
775 })
776 }
777
778 pub fn process_chunk(&mut self, input: &[f32]) -> Result<Vec<f32>, AudioError> {
780 if input.is_empty() {
781 return Ok(Vec::new());
782 }
783
784 if self.from_rate == self.to_rate {
785 return Ok(input.to_vec());
786 }
787
788 let resampler = self
789 .resampler
790 .as_mut()
791 .ok_or_else(|| AudioError::ProcessingError("重采样器未初始化".into()))?;
792
793 self.buffer.extend_from_slice(input);
795
796 let mut output = Vec::new();
797
798 while self.buffer.len() >= self.chunk_size {
800 let chunk: Vec<f32> = self.buffer.drain(0..self.chunk_size).collect();
802
803 let input_data = vec![chunk];
805
806 let output_data = resampler
808 .process(&input_data, None)
809 .map_err(|e| AudioError::ProcessingError(format!("重采样失败: {e}")))?;
810
811 if let Some(channel_output) = output_data.into_iter().next() {
813 output.extend(channel_output);
814 }
815 }
816
817 Ok(output)
818 }
819
820 pub fn finalize(&mut self) -> Result<Vec<f32>, AudioError> {
822 if self.from_rate == self.to_rate {
823 let remaining = self.buffer.clone();
825 self.buffer.clear();
826 return Ok(remaining);
827 }
828
829 if let Some(resampler) = self.resampler.as_mut() {
830 let mut output = Vec::new();
831
832 if !self.buffer.is_empty() {
834 let mut padded_buffer = self.buffer.clone();
836 padded_buffer.resize(self.chunk_size, 0.0);
837
838 let input_data = vec![padded_buffer];
839 let output_data = resampler
840 .process(&input_data, None)
841 .map_err(|e| AudioError::ProcessingError(format!("处理剩余样本失败: {e}")))?;
842
843 if let Some(channel_output) = output_data.into_iter().next() {
844 output.extend(channel_output);
845 }
846
847 self.buffer.clear();
848 }
849
850 let empty_input: Option<&[Vec<f32>]> = None;
852 let final_output = resampler
853 .process_partial(empty_input, None)
854 .map_err(|e| AudioError::ProcessingError(format!("完成流式重采样失败: {e}")))?;
855
856 if let Some(channel_output) = final_output.into_iter().next() {
857 output.extend(channel_output);
858 }
859
860 Ok(output)
861 } else {
862 Ok(Vec::new())
863 }
864 }
865}
866
867#[cfg(test)]
868mod tests {
869 use super::*;
870 use hound::WavReader;
871
872 #[test]
873 fn test_probe_stub() {
874 let err = probe("/tmp/nonexist.wav").expect_err("应返回错误");
875 match err {
876 AudioError::FileNotFound(_) => {}
877 _ => panic!("应为 FileNotFound 错误"),
878 }
879 }
880
881 #[test]
882 fn test_resample_ratio() {
883 let input: Vec<f32> = (0..160).map(|i| (i as f32).sin()).collect();
884 let out = resample(&input, 16000, 8000).unwrap();
885 assert_eq!(out.sample_rate, 8000);
886 assert!(!out.samples.is_empty(), "Resampled output should not be empty");
888 let ratio = 8000.0 / 16000.0; let expected_min = (input.len() as f64 * ratio * 0.1) as usize; let expected_max = (input.len() as f64 * ratio * 2.0) as usize;
892 assert!(out.samples.len() >= expected_min && out.samples.len() <= expected_max,
893 "Output length {} not in expected range [{}, {}]", out.samples.len(), expected_min, expected_max);
894 }
895
896 #[test]
897 fn test_resample_quality() {
898 let sample_rate = 16000;
900 let freq = 440.0; let duration = 1.0; let num_samples = (sample_rate as f64 * duration) as usize;
903 let input: Vec<f32> = (0..num_samples)
904 .map(|i| (2.0 * std::f32::consts::PI * i as f32 * freq / sample_rate as f32).sin())
905 .collect();
906
907 let out = resample(&input, sample_rate as u32, 8000).unwrap();
909 assert_eq!(out.sample_rate, 8000);
910
911 let mut zero_crossings = 0;
913 for i in 1..out.samples.len() {
914 if out.samples[i - 1] * out.samples[i] <= 0.0 {
915 zero_crossings += 1;
916 }
917 }
918
919 log::debug!("Zero crossings: {zero_crossings}, expected around 440");
922 assert!((zero_crossings as f64 - 440.0).abs() < 500.0);
923 }
924
925 #[test]
926 fn test_ensure_whisper_compatible_on_fixture() {
927 let crate_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
929 let root_dir = crate_dir.parent().expect("audio crate has parent");
930 let input = root_dir.join("fixtures/audio/jfk.wav");
931 if !input.exists() {
932 log::warn!("Skipping: missing test audio {}", input.display());
933 return;
934 }
935
936 let out = ensure_whisper_compatible(&input, None).expect("Conversion should succeed");
937 assert!(out.path.exists(), "Output file should exist");
938
939 let reader = WavReader::open(&out.path).expect("Should be able to open output WAV");
941 let spec = reader.spec();
942 assert_eq!(spec.sample_rate, 16000);
943 assert_eq!(spec.channels, 1);
944 assert_eq!(spec.bits_per_sample, 16);
945
946 let _ = std::fs::remove_file(&out.path);
948 }
949
950 #[test]
951 fn test_probe_wav_on_fixture() {
952 let crate_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
953 let root_dir = crate_dir.parent().expect("audio crate has parent");
954 let input = root_dir.join("fixtures/audio/jfk.wav");
955 if !input.exists() {
956 log::warn!("跳过: 缺少测试音频 {}", input.display());
957 return;
958 }
959 let meta = probe(&input).expect("应能探测 WAV 元数据");
960 assert_eq!(meta.format.as_deref(), Some("wav"));
961 assert_eq!(meta.channels, 1);
962 assert!(meta.sample_rate > 0);
963 assert!(meta.duration_ms.unwrap_or(0) > 0);
964 }
965
966 #[test]
967 fn test_ensure_whisper_compatible_errors() {
968 let missing = std::path::PathBuf::from("/tmp/__definitely_missing_audio__.wav");
970 let err = ensure_whisper_compatible(&missing, None).expect_err("Should return error");
971
972 match err {
974 AudioError::FileNotFound(_) | AudioError::FfmpegNotAvailable(_) => {}
975 _ => panic!("Should be FileNotFound or FfmpegNotAvailable error"),
976 }
977
978 let crate_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
980 let err2 = ensure_whisper_compatible(&crate_dir, None).expect_err("Should return error");
981
982 match err2 {
984 AudioError::NotAFile(_) | AudioError::FfmpegNotAvailable(_) => {}
985 _ => panic!("Should be NotAFile or FfmpegNotAvailable error"),
986 }
987 }
988
989 #[test]
990 fn test_resample_invalid_rate() {
991 let input: Vec<f32> = vec![0.0, 1.0, 0.0];
992 let err = resample(&input, 0, 16000).expect_err("应返回错误");
994 match err {
995 AudioError::InvalidSampleRate { .. } => {}
996 _ => panic!("应为 InvalidSampleRate 错误"),
997 }
998 let err2 = resample(&input, 16000, 0).expect_err("应返回错误");
1000 match err2 {
1001 AudioError::InvalidSampleRate { .. } => {}
1002 _ => panic!("应为 InvalidSampleRate 错误"),
1003 }
1004 }
1005
1006 #[test]
1007 fn test_streaming_resampler_upsample_matches_batch() {
1008 let from = 16000u32;
1010 let to = 32000u32;
1011 let input: Vec<f32> = (0..1000).map(|i| i as f32 / 1000.0).collect();
1012
1013 let batch = resample(&input, from, to).unwrap().samples;
1015
1016 let mut sr = StreamingResampler::new(from, to).unwrap();
1018 let mut stream_out = Vec::new();
1019 for chunk in input.chunks(123) {
1020 let y = sr.process_chunk(chunk).unwrap();
1021 stream_out.extend(y);
1022 }
1023 stream_out.extend(sr.finalize().unwrap());
1024
1025 let diff = (batch.len() as isize - stream_out.len() as isize).abs();
1028 log::debug!(
1029 "Length difference: {}, batch: {}, stream: {}",
1030 diff,
1031 batch.len(),
1032 stream_out.len()
1033 );
1034 assert!(diff <= 2500);
1035
1036 let n = batch.len().min(stream_out.len());
1038 let mut mse = 0.0f64;
1039 for i in 0..n {
1040 let d = batch[i] - stream_out[i];
1041 mse += (d as f64).powi(2);
1042 }
1043 mse /= n.max(1) as f64;
1044 assert!(mse < 1e-6, "MSE too large: {mse}");
1045 }
1046
1047 #[test]
1048 fn test_streaming_resampler_downsample_length() {
1049 let from = 16000u32;
1050 let to = 8000u32;
1051 let input: Vec<f32> = (0..4000).map(|i| ((i as f32) * 0.01).sin()).collect();
1052
1053 let batch = resample(&input, from, to).unwrap().samples;
1054
1055 let mut sr = StreamingResampler::new(from, to).unwrap();
1056 let mut stream_out = Vec::new();
1057 for chunk in input.chunks(777) {
1058 stream_out.extend(sr.process_chunk(chunk));
1059 }
1060 stream_out.extend(sr.finalize());
1061
1062 let diff = (batch.len() as isize - stream_out.len() as isize).abs();
1064 log::debug!(
1065 "Length difference: {}, batch: {}, stream: {}",
1066 diff,
1067 batch.len(),
1068 stream_out.len()
1069 );
1070 assert!(diff <= 2000);
1071 }
1072
1073 #[test]
1074 fn test_extreme_sample_rates() {
1075 let input: Vec<f32> = vec![0.0, 1.0, 0.0, -1.0];
1077
1078 let result = resample(&input, 192000, 16000);
1080 assert!(result.is_ok(), "192kHz 到 16kHz 重采样应该成功");
1081
1082 let result = resample(&input, 200000, 16000);
1084 assert!(result.is_ok(), "200kHz 到 16kHz 重采样应该成功(虽然超过文档上限但实际可能工作)");
1085
1086 let result = resample(&input, 8000, 16000);
1088 assert!(result.is_ok(), "8kHz 到 16kHz 重采样应该成功");
1089
1090 let result = resample(&input, 16000, 16000);
1092 assert!(result.is_ok(), "16kHz 到 16kHz 重采样应该成功");
1093 assert_eq!(result.unwrap().samples, input, "相同采样率应该返回原始样本");
1094 }
1095
1096 #[test]
1097 fn test_basic_resampling_functionality() {
1098 let input: Vec<f32> = (0..1000).map(|i| (i as f32 * 0.01).sin()).collect();
1100
1101 let result = resample(&input, 16000, 8000);
1103 assert!(result.is_ok(), "降采样应该成功");
1104 let downsampled = result.unwrap();
1105 assert!(!downsampled.samples.is_empty(), "降采样应该产生非空输出");
1106 assert_eq!(downsampled.sample_rate, 8000, "输出采样率应该正确");
1107
1108 let result = resample(&input, 8000, 16000);
1110 assert!(result.is_ok(), "升采样应该成功");
1111 let upsampled = result.unwrap();
1112 assert!(!upsampled.samples.is_empty(), "升采样应该产生非空输出");
1113 assert_eq!(upsampled.sample_rate, 16000, "输出采样率应该正确");
1114
1115 let result = resample(&input, 16000, 16000);
1117 assert!(result.is_ok(), "相同采样率重采样应该成功");
1118 let same_rate = result.unwrap();
1119 assert_eq!(same_rate.samples, input, "相同采样率应该返回原始样本");
1120 assert_eq!(same_rate.sample_rate, 16000, "输出采样率应该正确");
1121
1122 log::info!("基本重采样功能测试通过 - 降采样: {} -> {} 样本, 升采样: {} -> {} 样本",
1123 input.len(), downsampled.samples.len(), input.len(), upsampled.samples.len());
1124 }
1125}