audio_loudness_batch_normalize/
save.rs

1use std::{fs::File, num::NonZero, path::Path};
2
3use log::warn;
4use symphonia::core::{audio::SignalSpec, errors::Error as SymphoniaError};
5use vorbis_rs::VorbisEncoderBuilder;
6
7use crate::{
8    convert_buffer_to_planar_f32,
9    error::{Error, WritingError},
10};
11
12/// Helper for `decode_apply_gain_and_save`: collects all processed samples and saves them to a WAV file.
13pub fn stream_to_wav_writer(
14    format: &mut dyn symphonia::core::formats::FormatReader,
15    decoder: &mut dyn symphonia::core::codecs::Decoder,
16    final_linear_gain: f64,
17    output_path: &Path,
18    spec: SignalSpec,
19) -> Result<(), Error> {
20    // First, collect all the processed samples into a buffer.
21    let samples = collect_gained_samples(format, decoder, final_linear_gain, output_path)?;
22
23    // Define the specification for the output WAV file.
24    let hound_spec = hound::WavSpec {
25        channels: spec.channels.count() as u16,
26        sample_rate: spec.rate,
27        bits_per_sample: 32,
28        sample_format: hound::SampleFormat::Float,
29    };
30
31    // Write the collected samples to the WAV file in one go.
32    save_as_wav(output_path, hound_spec, &samples).map_err(|e| Error::Writing {
33        path: output_path.to_path_buf(),
34        source: e,
35    })
36}
37
38/// Saves audio data as a WAV file.
39///
40/// # Arguments
41/// * `path` - Output file path
42/// * `spec` - The WAV specification (channels, sample rate, etc.)
43/// * `samples` - Interleaved audio samples in 32-bit float format
44///
45/// # Returns
46/// Result indicating success or a WritingError
47pub fn save_as_wav(path: &Path, spec: hound::WavSpec, samples: &[f32]) -> Result<(), WritingError> {
48    let mut writer = hound::WavWriter::create(path, spec).map_err(WritingError::Wav)?;
49    for &sample in samples {
50        writer.write_sample(sample).map_err(WritingError::Wav)?;
51    }
52    writer.finalize().map_err(WritingError::Wav)
53}
54
55pub fn stream_to_ogg_writer(
56    format: &mut dyn symphonia::core::formats::FormatReader,
57    decoder: &mut dyn symphonia::core::codecs::Decoder,
58    final_linear_gain: f64,
59    output_path: &Path,
60    spec: SignalSpec,
61) -> Result<(), Error> {
62    // OGG saving still requires collecting all samples due to vorbis-rs API.
63    let samples =
64        collect_gained_samples(&mut *format, &mut *decoder, final_linear_gain, output_path)?;
65    save_as_ogg(output_path, spec.channels.count(), spec.rate, &samples).map_err(|e| {
66        Error::Writing {
67            path: output_path.to_path_buf(),
68            source: e,
69        }
70    })?;
71    Ok(())
72}
73
74/// Saves audio data as an Ogg Vorbis file
75///
76/// # Arguments
77/// * `path` - Output file path
78/// * `channels` - Number of audio channels
79/// * `sample_rate` - Sample rate in Hz
80/// * `samples` - Interleaved audio samples in 32-bit float format
81///
82/// # Returns
83/// Result indicating success or a WritingError
84///
85/// # Panics
86/// Panics if the number of channels is zero
87pub fn save_as_ogg(
88    path: &Path,
89    channels: usize,
90    sample_rate: u32,
91    samples: &[f32],
92) -> Result<(), WritingError> {
93    assert!(channels > 0, "channels could not be zero");
94    // Open the output file
95    let output_file = File::create(path)?;
96
97    // Initialize the Vorbis encoder
98    let mut encoder = VorbisEncoderBuilder::new(
99        NonZero::new(sample_rate).unwrap(),
100        NonZero::new(channels as u8).unwrap(),
101        output_file,
102    )?
103    .build()?;
104
105    // Convert interleaved samples to planar format
106    let mut planar_samples: Vec<Vec<f32>> = vec![Vec::new(); channels];
107    for (i, &sample) in samples.iter().enumerate() {
108        let channel = i % channels;
109        planar_samples[channel].push(sample);
110        if channel * sample_rate as usize * 2 == i + 1 {
111            encoder.encode_audio_block(&planar_samples)?;
112            planar_samples = vec![Vec::new(); channels];
113        }
114    }
115    encoder.encode_audio_block(&planar_samples)?;
116    encoder.finish()?;
117    Ok(())
118}
119
120/// Helper for `decode_apply_gain_and_save`: collects all processed samples into
121/// a Vec. Used for formats that don't support streaming writes with the current
122/// libraries (e.g., Ogg).
123fn collect_gained_samples(
124    format: &mut dyn symphonia::core::formats::FormatReader,
125    decoder: &mut dyn symphonia::core::codecs::Decoder,
126    final_linear_gain: f64,
127    output_path: &Path, // for error context
128) -> Result<Vec<f32>, Error> {
129    let mut all_samples_interleaved = Vec::new();
130    loop {
131        match format.next_packet() {
132            Ok(packet) => match decoder.decode(&packet) {
133                Ok(decoded) => {
134                    let planar =
135                        convert_buffer_to_planar_f32(&decoded).map_err(|e| Error::Processing {
136                            path: output_path.to_path_buf(),
137                            source: e,
138                        })?;
139
140                    if planar.is_empty() || planar[0].is_empty() {
141                        continue;
142                    }
143
144                    let num_frames = planar[0].len();
145                    let num_channels = planar.len();
146
147                    for frame_idx in 0..num_frames {
148                        for channel_idx in 0..num_channels {
149                            let sample = planar[channel_idx][frame_idx];
150                            let processed_sample = sample * final_linear_gain as f32;
151                            all_samples_interleaved.push(processed_sample);
152                        }
153                    }
154                }
155                Err(SymphoniaError::DecodeError(e)) => {
156                    warn!("Decode error during sample collection: {e}")
157                }
158                Err(e) => {
159                    return Err(Error::Processing {
160                        path: output_path.to_path_buf(),
161                        source: e.into(),
162                    });
163                }
164            },
165            Err(SymphoniaError::IoError(ref e))
166                if e.kind() == std::io::ErrorKind::UnexpectedEof =>
167            {
168                break;
169            }
170            Err(e) => {
171                return Err(Error::Processing {
172                    path: output_path.to_path_buf(),
173                    source: e.into(),
174                });
175            }
176        }
177    }
178    Ok(all_samples_interleaved)
179}