libflo_audio/lossless/
decoder.rs

1use crate::core::audio_constants::i32_to_f32;
2use crate::core::types::{ChannelData, FloFile};
3use crate::{core::rice, FloResult, Reader};
4
5/// audio decoder for flo format
6pub struct Decoder;
7
8impl Decoder {
9    pub fn new() -> Self {
10        Decoder
11    }
12
13    /// decode flo file to samples
14    pub fn decode(&self, data: &[u8]) -> FloResult<Vec<f32>> {
15        let reader = Reader::new();
16        let file = reader.read(data)?;
17        self.decode_file(&file)
18    }
19
20    /// decode from parsed file
21    pub fn decode_file(&self, file: &FloFile) -> FloResult<Vec<f32>> {
22        let channels = file.header.channels as usize;
23        let mut all_samples: Vec<Vec<i32>> = vec![vec![]; channels];
24
25        for frame in &file.frames {
26            let use_mid_side = channels == 2 && (frame.flags & 0x01) != 0;
27
28            let mut frame_channels: Vec<Vec<i32>> = Vec::with_capacity(channels);
29
30            for ch_data in &frame.channels {
31                let samples = self.decode_channel_int(ch_data, frame.frame_samples as usize)?;
32                frame_channels.push(samples);
33            }
34
35            // mid-side to left-right
36            if use_mid_side && frame_channels.len() == 2 {
37                let (left, right) = self.decode_mid_side(&frame_channels[0], &frame_channels[1]);
38                all_samples[0].extend(left);
39                all_samples[1].extend(right);
40            } else {
41                for (ch_idx, samples) in frame_channels.into_iter().enumerate() {
42                    if ch_idx < channels {
43                        all_samples[ch_idx].extend(samples);
44                    }
45                }
46            }
47        }
48
49        // interleave and convert to f32
50        let max_len = all_samples.iter().map(|v| v.len()).max().unwrap_or(0);
51        let mut interleaved = Vec::with_capacity(max_len * channels);
52
53        // Fast path for stereo (most common case)
54        if channels == 2 && all_samples[0].len() == all_samples[1].len() {
55            let left = &all_samples[0];
56            let right = &all_samples[1];
57            for i in 0..left.len() {
58                interleaved.push(i32_to_f32(left[i]));
59                interleaved.push(i32_to_f32(right[i]));
60            }
61        } else {
62            // General case for mono or mismatched lengths
63            for i in 0..max_len {
64                for ch in 0..channels {
65                    let sample = all_samples[ch].get(i).copied().unwrap_or(0);
66                    interleaved.push(i32_to_f32(sample));
67                }
68            }
69        }
70
71        Ok(interleaved)
72    }
73
74    /// Convert mid-side back to left-right
75    fn decode_mid_side(&self, mid: &[i32], side: &[i32]) -> (Vec<i32>, Vec<i32>) {
76        // FLAC-style: mid = L + R, side = L - R
77        // So: L = (mid + side) / 2, R = (mid - side) / 2
78        let left: Vec<i32> = mid
79            .iter()
80            .zip(side.iter())
81            .map(|(&m, &s)| (m + s) / 2)
82            .collect();
83        let right: Vec<i32> = mid
84            .iter()
85            .zip(side.iter())
86            .map(|(&m, &s)| (m - s) / 2)
87            .collect();
88        (left, right)
89    }
90
91    /// Decode a single channel to integers
92    fn decode_channel_int(
93        &self,
94        ch_data: &ChannelData,
95        frame_samples: usize,
96    ) -> FloResult<Vec<i32>> {
97        let has_coeffs = !ch_data.predictor_coeffs.is_empty();
98        let has_residuals = !ch_data.residuals.is_empty();
99        let shift_bits = ch_data.shift_bits;
100
101        // Check for fixed predictor marker: shift_bits >= 128 means fixed order (128 + order)
102        let is_fixed_predictor = !has_coeffs && has_residuals && shift_bits >= 128;
103
104        if is_fixed_predictor {
105            // Fixed predictor: order stored as (128 + order)
106            let fixed_order = (shift_bits - 128) as usize;
107
108            let residuals =
109                rice::decode_i32(&ch_data.residuals, ch_data.rice_parameter, frame_samples);
110
111            return Ok(self.reconstruct_fixed(fixed_order, &residuals, frame_samples));
112        }
113
114        if has_coeffs {
115            // LPC decoding with stored coefficients
116            let residuals =
117                rice::decode_i32(&ch_data.residuals, ch_data.rice_parameter, frame_samples);
118
119            let order = ch_data.predictor_coeffs.len();
120
121            let samples = self.reconstruct_lpc_int(
122                &ch_data.predictor_coeffs,
123                &residuals,
124                shift_bits,
125                order,
126                frame_samples,
127            );
128
129            return Ok(samples);
130        }
131
132        if has_residuals {
133            // Raw PCM
134            let mut samples = Vec::with_capacity(frame_samples);
135            for chunk in ch_data.residuals.chunks(2) {
136                if chunk.len() == 2 {
137                    samples.push(i16::from_le_bytes([chunk[0], chunk[1]]) as i32);
138                }
139            }
140            while samples.len() < frame_samples {
141                samples.push(0);
142            }
143            return Ok(samples);
144        }
145
146        // Silence
147        Ok(vec![0; frame_samples])
148    }
149
150    /// Reconstruct from LPC prediction
151    /// Optimized version with branch-free inner loop
152    #[inline]
153    fn reconstruct_lpc_int(
154        &self,
155        coeffs: &[i32],
156        residuals: &[i32],
157        shift: u8,
158        order: usize,
159        target_len: usize,
160    ) -> Vec<i32> {
161        let mut samples = Vec::with_capacity(target_len);
162        let actual_len = target_len.min(residuals.len());
163
164        // Warmup samples from residuals (no prediction needed)
165        let warmup_len = order.min(actual_len);
166        samples.extend_from_slice(&residuals[..warmup_len]);
167
168        // Reconstruct remaining samples using LPC prediction
169        // This is the hot loop - keep it simple and predictable
170        for i in order..actual_len {
171            let mut prediction: i64 = 0;
172
173            // Unrolled inner loop for common orders
174            // Access pattern: samples[i-1], samples[i-2], ..., samples[i-order]
175            for j in 0..order {
176                prediction += (coeffs[j] as i64) * (samples[i - j - 1] as i64);
177            }
178
179            samples.push((prediction >> shift) as i32 + residuals[i]);
180        }
181
182        // Pad if needed
183        samples.resize(target_len, 0);
184        samples
185    }
186
187    /// Reconstruct from fixed predictor
188    fn reconstruct_fixed(&self, order: usize, residuals: &[i32], target_len: usize) -> Vec<i32> {
189        let mut samples = Vec::with_capacity(target_len);
190
191        if residuals.is_empty() {
192            return vec![0; target_len];
193        }
194
195        match order {
196            0 => {
197                // No prediction - residuals are samples
198                samples.extend_from_slice(residuals);
199            }
200            1 => {
201                // s[i] = r[i] + s[i-1]
202                samples.push(residuals[0]);
203                for i in 1..residuals.len().min(target_len) {
204                    samples.push(residuals[i].wrapping_add(samples[i - 1]));
205                }
206            }
207            2 => {
208                // s[i] = r[i] + 2*s[i-1] - s[i-2]
209                if !residuals.is_empty() {
210                    samples.push(residuals[0]);
211                }
212                if residuals.len() > 1 {
213                    samples.push(residuals[1].wrapping_add(samples[0]));
214                }
215                for i in 2..residuals.len().min(target_len) {
216                    let pred = (2i64 * samples[i - 1] as i64 - samples[i - 2] as i64) as i32;
217                    samples.push(residuals[i].wrapping_add(pred));
218                }
219            }
220            3 => {
221                // s[i] = r[i] + 3*s[i-1] - 3*s[i-2] + s[i-3]
222                if !residuals.is_empty() {
223                    samples.push(residuals[0]);
224                }
225                if residuals.len() > 1 {
226                    samples.push(residuals[1].wrapping_add(samples[0]));
227                }
228                if residuals.len() > 2 {
229                    let pred = (2i64 * samples[1] as i64 - samples[0] as i64) as i32;
230                    samples.push(residuals[2].wrapping_add(pred));
231                }
232                for i in 3..residuals.len().min(target_len) {
233                    let pred = (3i64 * samples[i - 1] as i64 - 3i64 * samples[i - 2] as i64
234                        + samples[i - 3] as i64) as i32;
235                    samples.push(residuals[i].wrapping_add(pred));
236                }
237            }
238            4 => {
239                // s[i] = r[i] + 4*s[i-1] - 6*s[i-2] + 4*s[i-3] - s[i-4]
240                if !residuals.is_empty() {
241                    samples.push(residuals[0]);
242                }
243                if residuals.len() > 1 {
244                    samples.push(residuals[1].wrapping_add(samples[0]));
245                }
246                if residuals.len() > 2 {
247                    let pred = (2i64 * samples[1] as i64 - samples[0] as i64) as i32;
248                    samples.push(residuals[2].wrapping_add(pred));
249                }
250                if residuals.len() > 3 {
251                    let pred = (3i64 * samples[2] as i64 - 3i64 * samples[1] as i64
252                        + samples[0] as i64) as i32;
253                    samples.push(residuals[3].wrapping_add(pred));
254                }
255                for i in 4..residuals.len().min(target_len) {
256                    let pred = (4i64 * samples[i - 1] as i64 - 6i64 * samples[i - 2] as i64
257                        + 4i64 * samples[i - 3] as i64
258                        - samples[i - 4] as i64) as i32;
259                    samples.push(residuals[i].wrapping_add(pred));
260                }
261            }
262            _ => {
263                // Unknown order, just use residuals
264                samples.extend_from_slice(residuals);
265            }
266        }
267
268        // Pad if needed
269        while samples.len() < target_len {
270            samples.push(0);
271        }
272
273        samples
274    }
275}
276
277impl Default for Decoder {
278    fn default() -> Self {
279        Self::new()
280    }
281}