libflo_audio/lossless/
encoder.rs

1use crate::core::audio_constants::f32_to_i32;
2use crate::core::{ChannelData, Frame, FrameType, ResidualEncoding};
3use crate::{core::rice, FloResult, Writer};
4
5use super::lpc::{
6    autocorr_int, calc_residuals_int, fixed_predictor_residuals, levinson_durbin_int,
7};
8
9pub struct Encoder {
10    sample_rate: u32,
11    channels: u8,
12    bit_depth: u8,
13    compression_level: u8,
14}
15
16impl Encoder {
17    pub fn new(sample_rate: u32, channels: u8, bit_depth: u8) -> Self {
18        Encoder {
19            sample_rate,
20            channels,
21            bit_depth,
22            compression_level: 5,
23        }
24    }
25
26    pub fn with_compression(mut self, level: u8) -> Self {
27        self.compression_level = level.min(9);
28        self
29    }
30
31    /// encode samples to flo format
32    pub fn encode(&self, samples: &[f32], metadata: &[u8]) -> FloResult<Vec<u8>> {
33        let samples_per_frame = self.sample_rate as usize;
34        let frames = self.encode_frames(samples, samples_per_frame);
35
36        let writer = Writer::new();
37        writer.write(
38            self.sample_rate,
39            self.channels,
40            self.bit_depth,
41            self.compression_level,
42            &frames,
43            metadata,
44        )
45    }
46
47    fn encode_frames(&self, samples: &[f32], samples_per_frame: usize) -> Vec<Frame> {
48        let total_samples = samples.len() / self.channels as usize;
49        let num_frames = total_samples.div_ceil(samples_per_frame);
50
51        let mut frames = Vec::with_capacity(num_frames);
52
53        for frame_idx in 0..num_frames {
54            let start = frame_idx * samples_per_frame * self.channels as usize;
55            let end =
56                ((frame_idx + 1) * samples_per_frame * self.channels as usize).min(samples.len());
57
58            let frame_samples = &samples[start..end];
59            let frame = self.encode_frame(frame_samples);
60            frames.push(frame);
61        }
62
63        frames
64    }
65
66    fn encode_frame(&self, samples: &[f32]) -> Frame {
67        let num_samples = samples.len() / self.channels as usize;
68
69        // Check for silence
70        if samples.iter().all(|&s| s.abs() < 1e-7) {
71            let mut frame = Frame::new(FrameType::Silence as u8, num_samples as u32);
72            for _ in 0..self.channels {
73                frame.channels.push(ChannelData::new_silence());
74            }
75            return frame;
76        }
77
78        // Convert to integer domain
79        let samples_i32: Vec<i32> = samples.iter().map(|&s| f32_to_i32(s)).collect();
80
81        // Deinterleave channels
82        let mut channel_data: Vec<Vec<i32>> = (0..self.channels as usize)
83            .map(|ch| {
84                samples_i32
85                    .iter()
86                    .skip(ch)
87                    .step_by(self.channels as usize)
88                    .copied()
89                    .collect()
90            })
91            .collect();
92
93        // Apply mid-side coding for stereo (if it helps)
94        let use_mid_side = self.channels == 2 && self.should_use_mid_side(&channel_data);
95        if use_mid_side {
96            let (mid, side) = self.to_mid_side(&channel_data[0], &channel_data[1]);
97            channel_data[0] = mid;
98            channel_data[1] = side;
99        }
100
101        // Encode each channel
102        let lpc_order = self.lpc_order_from_level();
103        let mut encoded_channels = Vec::with_capacity(self.channels as usize);
104        let mut all_raw = true;
105
106        for ch_samples in &channel_data {
107            let (ch_data, order_used) = self.encode_channel_int(ch_samples, lpc_order);
108            if order_used > 0 {
109                all_raw = false;
110            }
111            encoded_channels.push(ch_data);
112        }
113
114        // Determine frame type
115        let frame_type = if all_raw {
116            FrameType::Raw
117        } else {
118            FrameType::from_order(lpc_order)
119        };
120
121        let mut frame = Frame::new(frame_type as u8, num_samples as u32);
122        // Set mid-side flag if used
123        if use_mid_side {
124            frame.flags |= 0x01; // Bit 0 = mid-side coding
125        }
126        frame.channels = encoded_channels;
127        frame
128    }
129
130    /// Check if mid-side coding would help
131    fn should_use_mid_side(&self, channels: &[Vec<i32>]) -> bool {
132        if channels.len() != 2 {
133            return false;
134        }
135
136        let left = &channels[0];
137        let right = &channels[1];
138
139        // Calculate variance of L-R vs L and R separately
140        let mut var_l: i64 = 0;
141        let mut var_r: i64 = 0;
142        let mut var_side: i64 = 0;
143
144        for (&l, &r) in left.iter().zip(right.iter()) {
145            var_l += (l as i64) * (l as i64);
146            var_r += (r as i64) * (r as i64);
147            let side = l - r;
148            var_side += (side as i64) * (side as i64);
149        }
150
151        // If side channel has less energy, mid-side helps
152        var_side < (var_l + var_r) / 2
153    }
154
155    /// Convert stereo to mid-side
156    fn to_mid_side(&self, left: &[i32], right: &[i32]) -> (Vec<i32>, Vec<i32>) {
157        // FLAC-style: mid = L + R, side = L - R
158        // This preserves all bits - no rounding
159        let mid: Vec<i32> = left
160            .iter()
161            .zip(right.iter())
162            .map(|(&l, &r)| l + r)
163            .collect();
164        let side: Vec<i32> = left
165            .iter()
166            .zip(right.iter())
167            .map(|(&l, &r)| l - r)
168            .collect();
169        (mid, side)
170    }
171
172    /// Encode a single channel using integer LPC
173    fn encode_channel_int(&self, samples: &[i32], max_order: usize) -> (ChannelData, usize) {
174        if samples.is_empty() {
175            return (ChannelData::new_silence(), 0);
176        }
177
178        // Try different encoding strategies and pick the smallest
179        let mut best_data: Option<ChannelData> = None;
180        let mut best_size = usize::MAX;
181        let mut best_order = 0;
182
183        // Strategy 1: Raw PCM (baseline)
184        let raw = self.encode_raw(samples);
185        let raw_size = raw.residuals.len();
186        if raw_size < best_size {
187            best_size = raw_size;
188            best_data = Some(raw);
189            best_order = 0;
190        }
191
192        // Strategy 2: Fixed predictors (order 0-4, very fast)
193        for order in 0..=4.min(max_order) {
194            if let Some((data, size)) = self.try_fixed_predictor(samples, order) {
195                if size < best_size {
196                    best_size = size;
197                    best_data = Some(data);
198                    best_order = order;
199                }
200            }
201        }
202
203        // Strategy 3: LPC predictors (if compression level allows)
204        if self.compression_level >= 3 && max_order > 4 {
205            for order in 5..=max_order {
206                if let Some((data, size)) = self.try_lpc_predictor(samples, order) {
207                    if size < best_size {
208                        best_size = size;
209                        best_data = Some(data);
210                        best_order = order;
211                    }
212                }
213            }
214        }
215
216        (best_data.unwrap(), best_order)
217    }
218
219    /// Encode as raw PCM
220    fn encode_raw(&self, samples: &[i32]) -> ChannelData {
221        let raw_bytes: Vec<u8> = samples
222            .iter()
223            .flat_map(|&s| (s as i16).to_le_bytes().to_vec())
224            .collect();
225        ChannelData::new_raw(raw_bytes)
226    }
227
228    /// Try fixed predictor
229    fn try_fixed_predictor(&self, samples: &[i32], order: usize) -> Option<(ChannelData, usize)> {
230        if order > 4 {
231            return None;
232        }
233
234        let residuals = fixed_predictor_residuals(samples, order);
235
236        // Find optimal Rice parameter
237        let k = rice::estimate_rice_parameter_i32(&residuals);
238        let encoded = rice::encode_i32(&residuals, k);
239
240        // For fixed predictors: store negative order to distinguish from LPC
241        // predictor_coeffs is empty, shift_bits stores (128 + order) as marker
242        let ch_data = ChannelData {
243            predictor_coeffs: vec![],        // Empty = fixed predictor
244            shift_bits: (128 + order) as u8, // Marker: 128-132 = fixed order 0-4
245            residual_encoding: ResidualEncoding::Rice,
246            rice_parameter: k,
247            residuals: encoded.clone(),
248        };
249
250        Some((ch_data, encoded.len()))
251    }
252
253    /// Try LPC predictor with given order
254    fn try_lpc_predictor(&self, samples: &[i32], order: usize) -> Option<(ChannelData, usize)> {
255        if samples.len() <= order {
256            return None;
257        }
258
259        // Calculate autocorrelation in integer domain
260        let autocorr = autocorr_int(samples, order);
261
262        // Levinson-Durbin for LPC coefficients (in fixed-point)
263        let (coeffs_fp, shift) = levinson_durbin_int(&autocorr, order)?;
264
265        // Calculate residuals using integer arithmetic
266        let residuals = calc_residuals_int(samples, &coeffs_fp, shift, order);
267
268        // Check if residuals are reasonable (not exploding)
269        let max_res = residuals.iter().map(|&r| r.abs()).max().unwrap_or(0);
270        if max_res > 1_000_000 {
271            return None; // Unstable, skip this order
272        }
273
274        // Encode residuals
275        let k = rice::estimate_rice_parameter_i32(&residuals);
276        let encoded = rice::encode_i32(&residuals, k);
277
278        let ch_data = ChannelData {
279            predictor_coeffs: coeffs_fp,
280            shift_bits: shift,
281            residual_encoding: ResidualEncoding::Rice,
282            rice_parameter: k,
283            residuals: encoded.clone(),
284        };
285
286        Some((ch_data, encoded.len()))
287    }
288
289    fn lpc_order_from_level(&self) -> usize {
290        match self.compression_level {
291            0 => 0, // Only fixed predictors
292            1 => 2,
293            2 => 4,
294            3 => 4,
295            4 => 6,
296            5 => 8,
297            6 => 8,
298            7 => 10,
299            8 => 12,
300            _ => 12,
301        }
302    }
303}
304
305impl Default for Encoder {
306    fn default() -> Self {
307        Encoder::new(44100, 1, 16)
308    }
309}