libflo_audio/lossy/
encoder.rs

1use super::mdct::{BlockSize, Mdct, WindowType};
2use super::psychoacoustic::{PsychoacousticModel, NUM_BARK_BANDS};
3use crate::core::{ChannelData, Frame, FrameType, ResidualEncoding, I16_MAX_F32, I16_MIN_F32};
4
5/// Transform lossy encoder
6pub struct TransformEncoder {
7    /// Sample rate
8    sample_rate: u32,
9    /// Number of channels
10    channels: u8,
11    /// MDCT processor
12    mdct: Mdct,
13    /// Psychoacoustic model (one per channel)
14    psy_models: Vec<PsychoacousticModel>,
15    /// Quality setting (0.0 = lowest, 1.0 = transparent)
16    quality: f32,
17    /// Block size
18    block_size: BlockSize,
19}
20
21/// Encoded frame data
22#[derive(Debug, Clone)]
23pub struct TransformFrame {
24    /// Quantized MDCT coefficients per channel (as i16)
25    pub coefficients: Vec<Vec<i16>>,
26    /// Scale factors per Bark band per channel
27    pub scale_factors: Vec<Vec<f32>>,
28    /// Block size used
29    pub block_size: BlockSize,
30    /// Number of samples this frame represents (after overlap-add)
31    pub num_samples: usize,
32}
33
34impl TransformEncoder {
35    /// Create a new transform encoder
36    pub fn new(sample_rate: u32, channels: u8, quality: f32) -> Self {
37        let block_size = BlockSize::Long; // 2048 samples
38        let fft_size = block_size.samples();
39
40        let mdct = Mdct::new(channels as usize, WindowType::Vorbis);
41        let psy_models: Vec<_> = (0..channels)
42            .map(|_| PsychoacousticModel::new(sample_rate, fft_size))
43            .collect();
44
45        Self {
46            sample_rate,
47            channels,
48            mdct,
49            psy_models,
50            quality: quality.clamp(0.0, 1.0),
51            block_size,
52        }
53    }
54
55    /// Set quality (0.0-1.0)
56    pub fn set_quality(&mut self, quality: f32) {
57        self.quality = quality.clamp(0.0, 1.0);
58    }
59
60    /// Encode a frame of audio
61    /// Input: interleaved samples for one frame (block_size * channels)
62    /// Returns encoded frame
63    pub fn encode_frame(&mut self, samples: &[f32]) -> TransformFrame {
64        let block_samples = self.block_size.samples();
65        let num_coeffs = self.block_size.coefficients();
66        let hop_size = num_coeffs; // 50% overlap
67
68        // Deinterleave channels
69        let mut channel_data: Vec<Vec<f32>> = (0..self.channels as usize)
70            .map(|_| Vec::with_capacity(samples.len() / self.channels as usize))
71            .collect();
72
73        for (i, &s) in samples.iter().enumerate() {
74            channel_data[i % self.channels as usize].push(s);
75        }
76
77        let mut all_coefficients = Vec::with_capacity(self.channels as usize);
78        let mut all_scale_factors = Vec::with_capacity(self.channels as usize);
79
80        for (ch, data) in channel_data.iter().enumerate() {
81            // Pad to block size if needed
82            let mut frame_data = data.clone();
83            if frame_data.len() < block_samples {
84                frame_data.resize(block_samples, 0.0);
85            }
86
87            // MDCT transform
88            let coeffs = self.mdct.forward(&frame_data, self.block_size);
89
90            // Psychoacoustic analysis
91            let smr = self.psy_models[ch].calculate_smr(&coeffs);
92
93            // Quantize based on perceptual importance
94            let (quantized, scale_factors) = self.quantize_coefficients(&coeffs, &smr);
95
96            all_coefficients.push(quantized);
97            all_scale_factors.push(scale_factors);
98        }
99
100        TransformFrame {
101            coefficients: all_coefficients,
102            scale_factors: all_scale_factors,
103            block_size: self.block_size,
104            num_samples: hop_size,
105        }
106    }
107
108    /// Quantize MDCT coefficients based on SMR
109    pub fn quantize_coefficients(&self, coeffs: &[f32], smr: &[f32]) -> (Vec<i16>, Vec<f32>) {
110        // Calculate scale factors per Bark band
111        let mut band_max = [0.0f32; NUM_BARK_BANDS];
112        let freq_resolution = self.sample_rate as f32 / self.block_size.samples() as f32;
113
114        for (k, &c) in coeffs.iter().enumerate() {
115            let freq = (k as f32 + 0.5) * freq_resolution;
116            let band = PsychoacousticModel::freq_to_bark_band(freq);
117            band_max[band] = band_max[band].max(c.abs());
118        }
119
120        // Calculate scale factors (to fit i16 range without clipping)
121        let mut scale_factors = vec![1.0f32; NUM_BARK_BANDS];
122        for (sf, &max_val) in scale_factors.iter_mut().zip(band_max.iter()) {
123            if max_val > 1e-10 {
124                // Use 30000 as max to leave some headroom
125                *sf = 30000.0 / max_val;
126            }
127        }
128
129        // Quality-dependent masking threshold
130        let smr_threshold = if self.quality >= 0.99 {
131            -100.0 // At max quality, keep essentially everything
132        } else {
133            // Exponential decay from 0 dB at quality=0 to -60 dB at quality=1
134            let t = (1.0 - self.quality).max(0.001);
135            -60.0 * (1.0 - t.powf(0.5))
136        };
137
138        // Quantize
139        let mut quantized = vec![0i16; coeffs.len()];
140
141        for (k, (q, &c)) in quantized.iter_mut().zip(coeffs.iter()).enumerate() {
142            let freq = (k as f32 + 0.5) * freq_resolution;
143            let band = PsychoacousticModel::freq_to_bark_band(freq);
144
145            if smr[k] > smr_threshold {
146                // Above masking threshold, quantize with appropriate precision
147                let scaled = c * scale_factors[band];
148                *q = scaled.round().clamp(I16_MIN_F32, I16_MAX_F32) as i16;
149            }
150            // else: below threshold, leave as 0
151        }
152
153        (quantized, scale_factors)
154    }
155
156    /// Reset encoder state
157    pub fn reset(&mut self) {
158        self.mdct.reset();
159        for model in &mut self.psy_models {
160            model.reset();
161        }
162    }
163
164    /// Encode audio samples to flo™ file format
165    ///
166    /// This produces a complete flo™ file with transform-based frames
167    pub fn encode_to_flo(&mut self, samples: &[f32], metadata: &[u8]) -> crate::FloResult<Vec<u8>> {
168        let block_samples = self.block_size.samples();
169        let hop_size = self.block_size.coefficients(); // 50% overlap (N = block_samples/2)
170
171        // For proper MDCT overlap-add reconstruction, we need:
172        // - A priming frame at the start (silence) to initialize overlap buffer
173        // - Proper number of frames to cover all samples
174        let num_samples_per_channel = samples.len() / self.channels as usize;
175
176        // Add hop_size samples of pre-roll (zeros) at start for proper reconstruction
177        let pre_roll = hop_size;
178        let total_samples = num_samples_per_channel + pre_roll;
179        let num_hops = total_samples.div_ceil(hop_size);
180        let total_samples_needed = (num_hops + 1) * hop_size;
181
182        // Create padded buffer with pre-roll zeros at start
183        let mut padded = vec![0.0f32; total_samples_needed * self.channels as usize];
184
185        // Copy original samples after pre-roll
186        for ch in 0..self.channels as usize {
187            for i in 0..num_samples_per_channel.min(total_samples_needed - pre_roll) {
188                let src_idx = i * self.channels as usize + ch;
189                let dst_idx = (i + pre_roll) * self.channels as usize + ch;
190                if src_idx < samples.len() && dst_idx < padded.len() {
191                    padded[dst_idx] = samples[src_idx];
192                }
193            }
194        }
195
196        // Encode frames
197        let mut encoded_frames: Vec<Frame> = Vec::new();
198
199        // Process overlapping blocks
200        for hop_idx in 0..num_hops {
201            let start = hop_idx * hop_size * self.channels as usize;
202            let end = start + block_samples * self.channels as usize;
203
204            if end > padded.len() {
205                break;
206            }
207
208            let frame_samples = &padded[start..end];
209            let transform_frame = self.encode_frame(frame_samples);
210
211            // Serialize the transform frame
212            let frame_data = serialize_frame(&transform_frame);
213
214            // Create a flo Frame with transform type
215            let mut flo_frame = Frame::new(FrameType::Transform as u8, hop_size as u32);
216            flo_frame.channels.push(ChannelData {
217                predictor_coeffs: vec![],
218                shift_bits: 0,
219                residual_encoding: ResidualEncoding::Raw,
220                rice_parameter: 0,
221                residuals: frame_data,
222            });
223
224            encoded_frames.push(flo_frame);
225        }
226
227        // Write using the standard Writer
228        let writer = crate::Writer::new();
229        writer.write_ex(
230            self.sample_rate,
231            self.channels,
232            16,                                          // bit_depth for lossy
233            5,    // compression level (not used for transform)
234            true, // is_lossy
235            ((self.quality * 4.0).round() as u8).min(4), // quality as 0-4
236            &encoded_frames,
237            metadata,
238        )
239    }
240}
241
242/// Serialize a transform frame to bytes (optimized)
243pub fn serialize_frame(frame: &TransformFrame) -> Vec<u8> {
244    let mut data = Vec::new();
245
246    // Block size (1 byte)
247    data.push(match frame.block_size {
248        BlockSize::Long => 0,
249        BlockSize::Short => 1,
250        BlockSize::Start => 2,
251        BlockSize::Stop => 3,
252    });
253
254    // Number of channels (1 byte)
255    data.push(frame.coefficients.len() as u8);
256
257    // Scale factors per channel (25 bands * 2 bytes * channels)
258    // Encode as log scale u16 instead of f32 to save space
259    for sf in &frame.scale_factors {
260        for &s in sf {
261            // Convert to log scale: log2(sf) * 256 + 32768
262            let log_sf = if s > 1e-10 {
263                ((s.log2() * 256.0) + 32768.0).clamp(0.0, 65535.0) as u16
264            } else {
265                0
266            };
267            data.extend_from_slice(&log_sf.to_le_bytes());
268        }
269    }
270
271    // Coefficients per channel (sparse encoding for mostly-zeros)
272    for quantized in &frame.coefficients {
273        let encoded = serialize_sparse(quantized);
274        let len = encoded.len() as u32;
275        data.extend_from_slice(&len.to_le_bytes());
276        data.extend_from_slice(&encoded);
277    }
278
279    data
280}
281
282/// Encode coefficients using sparse run-length encoding
283/// Format: [zero_count_varint] [non_zero_count] [values...]
284pub fn serialize_sparse(coeffs: &[i16]) -> Vec<u8> {
285    let mut output = Vec::new();
286    let mut i = 0;
287
288    while i < coeffs.len() {
289        // Count leading zeros
290        let zero_start = i;
291        while i < coeffs.len() && coeffs[i] == 0 {
292            i += 1;
293        }
294        let zero_count = i - zero_start;
295
296        // Count non-zeros (up to 255)
297        let non_zero_start = i;
298        while i < coeffs.len() && coeffs[i] != 0 && (i - non_zero_start) < 255 {
299            i += 1;
300        }
301        let non_zero_count = i - non_zero_start;
302
303        // Encode run: [zero_count_varint] [non_zero_count] [values...]
304        encode_varint(&mut output, zero_count as u32);
305        output.push(non_zero_count as u8);
306
307        // Write non-zero values as i16 LE
308        for j in non_zero_start..non_zero_start + non_zero_count {
309            output.extend_from_slice(&coeffs[j].to_le_bytes());
310        }
311    }
312
313    output
314}
315
316/// Encode a u32 as varint (1-5 bytes)
317fn encode_varint(output: &mut Vec<u8>, mut value: u32) {
318    loop {
319        let mut byte = (value & 0x7F) as u8;
320        value >>= 7;
321        if value != 0 {
322            byte |= 0x80;
323        }
324        output.push(byte);
325        if value == 0 {
326            break;
327        }
328    }
329}