Skip to main content

candle_examples/
audio.rs

1use candle::{Result, Tensor};
2
3// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
4pub fn normalize_loudness(
5    wav: &Tensor,
6    sample_rate: u32,
7    loudness_compressor: bool,
8) -> Result<Tensor> {
9    let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
10    if energy < 2e-3 {
11        return Ok(wav.clone());
12    }
13    let wav_array = wav.to_vec1::<f32>()?;
14    let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
15    meter.push(wav_array.into_iter());
16    let power = meter.as_100ms_windows();
17    let loudness = match crate::bs1770::gated_mean(power) {
18        None => return Ok(wav.clone()),
19        Some(gp) => gp.loudness_lkfs() as f64,
20    };
21    let delta_loudness = -14. - loudness;
22    let gain = 10f64.powf(delta_loudness / 20.);
23    let wav = (wav * gain)?;
24    if loudness_compressor {
25        wav.tanh()
26    } else {
27        Ok(wav)
28    }
29}
30
31#[cfg(feature = "symphonia")]
32pub fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
33    use symphonia::core::audio::{AudioBufferRef, Signal};
34    use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
35    use symphonia::core::conv::FromSample;
36
37    fn conv<T>(
38        samples: &mut Vec<f32>,
39        data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>,
40    ) where
41        T: symphonia::core::sample::Sample,
42        f32: symphonia::core::conv::FromSample<T>,
43    {
44        samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
45    }
46
47    // Open the media source.
48    let src = std::fs::File::open(path).map_err(candle::Error::wrap)?;
49
50    // Create the media source stream.
51    let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
52
53    // Create a probe hint using the file's extension. [Optional]
54    let hint = symphonia::core::probe::Hint::new();
55
56    // Use the default options for metadata and format readers.
57    let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
58    let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
59
60    // Probe the media source.
61    let probed = symphonia::default::get_probe()
62        .format(&hint, mss, &fmt_opts, &meta_opts)
63        .map_err(candle::Error::wrap)?;
64    // Get the instantiated format reader.
65    let mut format = probed.format;
66
67    // Find the first audio track with a known (decodable) codec.
68    let track = format
69        .tracks()
70        .iter()
71        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
72        .ok_or_else(|| candle::Error::Msg("no supported audio tracks".to_string()))?;
73
74    // Use the default options for the decoder.
75    let dec_opts: DecoderOptions = Default::default();
76
77    // Create a decoder for the track.
78    let mut decoder = symphonia::default::get_codecs()
79        .make(&track.codec_params, &dec_opts)
80        .map_err(|_| candle::Error::Msg("unsupported codec".to_string()))?;
81    let track_id = track.id;
82    let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
83    let mut pcm_data = Vec::new();
84    // The decode loop.
85    while let Ok(packet) = format.next_packet() {
86        // Consume any new metadata that has been read since the last packet.
87        while !format.metadata().is_latest() {
88            format.metadata().pop();
89        }
90
91        // If the packet does not belong to the selected track, skip over it.
92        if packet.track_id() != track_id {
93            continue;
94        }
95        match decoder.decode(&packet).map_err(candle::Error::wrap)? {
96            AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
97            AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
98            AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
99            AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
100            AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
101            AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
102            AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
103            AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
104            AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
105            AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
106        }
107    }
108    Ok((pcm_data, sample_rate))
109}
110
111#[cfg(feature = "rubato")]
112pub fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
113    use rubato::Resampler;
114
115    let mut pcm_out =
116        Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
117
118    let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)
119        .map_err(candle::Error::wrap)?;
120    let mut output_buffer = resampler.output_buffer_allocate(true);
121    let mut pos_in = 0;
122    while pos_in + resampler.input_frames_next() < pcm_in.len() {
123        let (in_len, out_len) = resampler
124            .process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)
125            .map_err(candle::Error::wrap)?;
126        pos_in += in_len;
127        pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
128    }
129
130    if pos_in < pcm_in.len() {
131        let (_in_len, out_len) = resampler
132            .process_partial_into_buffer(Some(&[&pcm_in[pos_in..]]), &mut output_buffer, None)
133            .map_err(candle::Error::wrap)?;
134        pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
135    }
136
137    Ok(pcm_out)
138}