libflo_audio/streaming/
decoder.rs

1use crate::core::audio_constants::i32_to_f32;
2use crate::core::{rice, ChannelData, FloResult, Frame, FrameType, Header, TocEntry};
3use crate::lossless::Decoder as LosslessDecoder;
4use crate::lossy::{deserialize_frame, TransformDecoder};
5use crate::{Reader, ResidualEncoding, MAGIC};
6
7use super::types::{DecoderState, StreamingAudioInfo};
8
9pub struct StreamingDecoder {
10    /// incoming data buffer
11    buffer: Vec<u8>,
12    /// current state
13    state: DecoderState,
14    /// parsed header
15    header: Option<Header>,
16    /// toc entries
17    toc: Vec<TocEntry>,
18    /// current frame being decoded
19    current_frame: usize,
20    /// where data chunk starts
21    data_offset: usize,
22    /// lossy decoder when needed
23    lossy_decoder: Option<TransformDecoder>,
24    /// is lossy?
25    is_lossy: bool,
26    /// skipped preroll frame?
27    skipped_preroll: bool,
28}
29
30impl StreamingDecoder {
31    /// new streaming decoder
32    pub fn new() -> Self {
33        Self {
34            buffer: Vec::with_capacity(64 * 1024),
35            state: DecoderState::WaitingForHeader,
36            header: None,
37            toc: Vec::new(),
38            current_frame: 0,
39            data_offset: 0,
40            lossy_decoder: None,
41            is_lossy: false,
42            skipped_preroll: false,
43        }
44    }
45
46    /// current state
47    pub fn state(&self) -> DecoderState {
48        self.state
49    }
50
51    /// audio info if we have the header
52    pub fn info(&self) -> Option<StreamingAudioInfo> {
53        self.header.as_ref().map(|h| StreamingAudioInfo {
54            sample_rate: h.sample_rate,
55            channels: h.channels,
56            bit_depth: h.bit_depth,
57            total_frames: h.total_frames,
58            is_lossy: self.is_lossy,
59        })
60    }
61
62    /// how many frames ready to decode
63    pub fn frames_available(&self) -> usize {
64        if self.state != DecoderState::Ready {
65            return 0;
66        }
67        self.count_complete_frames()
68    }
69
70    /// feed more data, returns true if new frames available
71    pub fn feed(&mut self, data: &[u8]) -> FloResult<bool> {
72        if self.state == DecoderState::Error || self.state == DecoderState::Finished {
73            return Ok(false);
74        }
75
76        self.buffer.extend_from_slice(data);
77        self.try_advance_state()
78    }
79
80    /// decode next frame, or None if nothing ready
81    pub fn next_frame(&mut self) -> FloResult<Option<Vec<f32>>> {
82        if self.state != DecoderState::Ready {
83            return Ok(None);
84        }
85
86        let header = match self.header.as_ref() {
87            Some(h) => h.clone(),
88            None => return Err("No header".to_string()),
89        };
90
91        if self.current_frame >= self.toc.len() {
92            self.state = DecoderState::Finished;
93            return Ok(None);
94        }
95
96        let toc_entry = &self.toc[self.current_frame];
97        let frame_start = self.data_offset + toc_entry.byte_offset as usize;
98        let frame_end = frame_start + toc_entry.frame_size as usize;
99
100        if frame_end > self.buffer.len() {
101            return Ok(None);
102        }
103
104        let frame_data = &self.buffer[frame_start..frame_end];
105        let frame = self.parse_frame(frame_data, header.channels)?;
106
107        self.current_frame += 1;
108        let samples = self.decode_frame(&frame, &header)?;
109
110        Ok(Some(samples))
111    }
112
113    /// decode everything we have
114    pub fn decode_available(&mut self) -> FloResult<Vec<f32>> {
115        if self.state != DecoderState::Ready {
116            return Ok(Vec::new());
117        }
118
119        let samples = self.decode_with_standard_decoder()?;
120        self.state = DecoderState::Finished;
121        Ok(samples)
122    }
123
124    /// reset for reuse
125    pub fn reset(&mut self) {
126        self.buffer.clear();
127        self.state = DecoderState::WaitingForHeader;
128        self.header = None;
129        self.toc.clear();
130        self.current_frame = 0;
131        self.data_offset = 0;
132        self.lossy_decoder = None;
133        self.is_lossy = false;
134        self.skipped_preroll = false;
135    }
136
137    /// bytes buffered
138    pub fn buffered_bytes(&self) -> usize {
139        self.buffer.len()
140    }
141
142    /// frames ready to decode
143    pub fn available_frames(&self) -> usize {
144        if self.state != DecoderState::Ready {
145            return 0;
146        }
147        self.count_complete_frames()
148            .saturating_sub(self.current_frame)
149    }
150
151    /// current frame index
152    pub fn current_frame_index(&self) -> usize {
153        self.current_frame
154    }
155
156    // internal stuff
157
158    fn try_advance_state(&mut self) -> FloResult<bool> {
159        match self.state {
160            DecoderState::WaitingForHeader => {
161                if self.try_parse_header()? {
162                    self.state = DecoderState::WaitingForToc;
163                    return self.try_advance_state();
164                }
165            }
166            DecoderState::WaitingForToc => {
167                if self.try_parse_toc()? {
168                    self.state = DecoderState::Ready;
169                    return Ok(true);
170                }
171            }
172            DecoderState::Ready => {
173                return Ok(self.count_complete_frames() > self.current_frame);
174            }
175            _ => {}
176        }
177        Ok(false)
178    }
179
180    fn try_parse_header(&mut self) -> FloResult<bool> {
181        // need at least 70 bytes
182        if self.buffer.len() < 70 {
183            return Ok(false);
184        }
185
186        if self.buffer[0..4] != MAGIC {
187            self.state = DecoderState::Error;
188            return Err("Invalid flo file: bad magic".to_string());
189        }
190
191        let header = Header {
192            version_major: self.buffer[4],
193            version_minor: self.buffer[5],
194            flags: u16::from_le_bytes([self.buffer[6], self.buffer[7]]),
195            sample_rate: u32::from_le_bytes([
196                self.buffer[8],
197                self.buffer[9],
198                self.buffer[10],
199                self.buffer[11],
200            ]),
201            channels: self.buffer[12],
202            bit_depth: self.buffer[13],
203            total_frames: u64::from_le_bytes([
204                self.buffer[14],
205                self.buffer[15],
206                self.buffer[16],
207                self.buffer[17],
208                self.buffer[18],
209                self.buffer[19],
210                self.buffer[20],
211                self.buffer[21],
212            ]),
213            compression_level: self.buffer[22],
214            data_crc32: u32::from_le_bytes([
215                self.buffer[26],
216                self.buffer[27],
217                self.buffer[28],
218                self.buffer[29],
219            ]),
220            header_size: u64::from_le_bytes([
221                self.buffer[30],
222                self.buffer[31],
223                self.buffer[32],
224                self.buffer[33],
225                self.buffer[34],
226                self.buffer[35],
227                self.buffer[36],
228                self.buffer[37],
229            ]),
230            toc_size: u64::from_le_bytes([
231                self.buffer[38],
232                self.buffer[39],
233                self.buffer[40],
234                self.buffer[41],
235                self.buffer[42],
236                self.buffer[43],
237                self.buffer[44],
238                self.buffer[45],
239            ]),
240            data_size: u64::from_le_bytes([
241                self.buffer[46],
242                self.buffer[47],
243                self.buffer[48],
244                self.buffer[49],
245                self.buffer[50],
246                self.buffer[51],
247                self.buffer[52],
248                self.buffer[53],
249            ]),
250            extra_size: u64::from_le_bytes([
251                self.buffer[54],
252                self.buffer[55],
253                self.buffer[56],
254                self.buffer[57],
255                self.buffer[58],
256                self.buffer[59],
257                self.buffer[60],
258                self.buffer[61],
259            ]),
260            meta_size: u64::from_le_bytes([
261                self.buffer[62],
262                self.buffer[63],
263                self.buffer[64],
264                self.buffer[65],
265                self.buffer[66],
266                self.buffer[67],
267                self.buffer[68],
268                self.buffer[69],
269            ]),
270        };
271
272        self.is_lossy = (header.flags & 0x01) != 0;
273        if self.is_lossy {
274            self.lossy_decoder = Some(TransformDecoder::new(header.sample_rate, header.channels));
275        }
276
277        self.header = Some(header);
278        Ok(true)
279    }
280
281    fn try_parse_toc(&mut self) -> FloResult<bool> {
282        let header = self.header.as_ref().ok_or("No header")?;
283        let toc_start = 70;
284        let toc_end = toc_start + header.toc_size as usize;
285
286        if self.buffer.len() < toc_end {
287            return Ok(false);
288        }
289
290        if header.toc_size >= 4 {
291            let num_entries = u32::from_le_bytes([
292                self.buffer[toc_start],
293                self.buffer[toc_start + 1],
294                self.buffer[toc_start + 2],
295                self.buffer[toc_start + 3],
296            ]) as usize;
297
298            let entries_start = toc_start + 4;
299            for i in 0..num_entries {
300                let offset = entries_start + i * 20;
301                if offset + 20 > self.buffer.len() {
302                    return Ok(false);
303                }
304
305                self.toc.push(TocEntry {
306                    frame_index: u32::from_le_bytes([
307                        self.buffer[offset],
308                        self.buffer[offset + 1],
309                        self.buffer[offset + 2],
310                        self.buffer[offset + 3],
311                    ]),
312                    byte_offset: u64::from_le_bytes([
313                        self.buffer[offset + 4],
314                        self.buffer[offset + 5],
315                        self.buffer[offset + 6],
316                        self.buffer[offset + 7],
317                        self.buffer[offset + 8],
318                        self.buffer[offset + 9],
319                        self.buffer[offset + 10],
320                        self.buffer[offset + 11],
321                    ]),
322                    frame_size: u32::from_le_bytes([
323                        self.buffer[offset + 12],
324                        self.buffer[offset + 13],
325                        self.buffer[offset + 14],
326                        self.buffer[offset + 15],
327                    ]),
328                    timestamp_ms: u32::from_le_bytes([
329                        self.buffer[offset + 16],
330                        self.buffer[offset + 17],
331                        self.buffer[offset + 18],
332                        self.buffer[offset + 19],
333                    ]),
334                });
335            }
336        }
337
338        self.data_offset = toc_end;
339        Ok(true)
340    }
341
342    fn count_complete_frames(&self) -> usize {
343        let mut count = 0;
344        for entry in &self.toc {
345            let frame_end =
346                self.data_offset + entry.byte_offset as usize + entry.frame_size as usize;
347            if frame_end <= self.buffer.len() {
348                count += 1;
349            } else {
350                break;
351            }
352        }
353        count
354    }
355
356    fn parse_frame(&self, data: &[u8], channels: u8) -> FloResult<Frame> {
357        if data.len() < 6 {
358            return Err("Frame too small".to_string());
359        }
360
361        let frame_type_byte = data[0];
362        let frame_samples = u32::from_le_bytes([data[1], data[2], data[3], data[4]]);
363        let flags = data[5];
364
365        let frame_type = FrameType::from(frame_type_byte);
366        let mut frame = Frame::new(frame_type_byte, frame_samples);
367        frame.flags = flags;
368
369        let num_channels = if frame_type == FrameType::Transform {
370            1
371        } else {
372            channels as usize
373        };
374
375        let mut pos = 6;
376        for _ in 0..num_channels {
377            if pos + 4 > data.len() {
378                return Err("Frame truncated".to_string());
379            }
380
381            let ch_size =
382                u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]])
383                    as usize;
384            pos += 4;
385
386            if pos + ch_size > data.len() {
387                return Err("Channel data truncated".to_string());
388            }
389
390            let ch_data = &data[pos..pos + ch_size];
391            pos += ch_size;
392
393            let channel = match frame_type {
394                FrameType::Silence => ChannelData::new_silence(),
395                FrameType::Raw | FrameType::Transform => ChannelData {
396                    predictor_coeffs: vec![],
397                    shift_bits: 0,
398                    residual_encoding: ResidualEncoding::Raw,
399                    rice_parameter: 0,
400                    residuals: ch_data.to_vec(),
401                },
402                _ => self.parse_alpc_channel(ch_data, frame_type)?,
403            };
404
405            frame.channels.push(channel);
406        }
407
408        Ok(frame)
409    }
410
411    fn parse_alpc_channel(&self, data: &[u8], _frame_type: FrameType) -> FloResult<ChannelData> {
412        if data.is_empty() {
413            return Ok(ChannelData::new_silence());
414        }
415
416        let order = data[0] as usize;
417        if order > 12 {
418            return Err("Invalid LPC order".to_string());
419        }
420
421        let coeff_bytes = order * 4;
422        let min_size = 1 + coeff_bytes + 2; // order + coeffs + shift + encoding
423        if data.len() < min_size {
424            return Err("ALPC channel too small".to_string());
425        }
426
427        // Read coefficients
428        let mut coefficients = Vec::with_capacity(order);
429        for i in 0..order {
430            let offset = 1 + i * 4;
431            let coeff = i32::from_le_bytes([
432                data[offset],
433                data[offset + 1],
434                data[offset + 2],
435                data[offset + 3],
436            ]);
437            coefficients.push(coeff);
438        }
439
440        let mut pos = 1 + coeff_bytes;
441
442        // Read shift_bits
443        let shift_bits = data[pos];
444        pos += 1;
445
446        // Read residual encoding
447        let residual_encoding_byte = data[pos];
448        let residual_encoding = ResidualEncoding::from(residual_encoding_byte);
449        pos += 1;
450
451        // Read rice parameter (only for Rice encoding)
452        let rice_parameter = if residual_encoding == ResidualEncoding::Rice {
453            if pos >= data.len() {
454                return Err("Missing rice parameter".to_string());
455            }
456            let rp = data[pos];
457            pos += 1;
458            rp
459        } else {
460            0
461        };
462
463        // Rest is residuals
464        let residuals = data[pos..].to_vec();
465
466        Ok(ChannelData {
467            predictor_coeffs: coefficients,
468            shift_bits,
469            residual_encoding,
470            rice_parameter,
471            residuals,
472        })
473    }
474
475    fn decode_frame(&mut self, frame: &Frame, header: &Header) -> FloResult<Vec<f32>> {
476        let frame_type = FrameType::from(frame.frame_type);
477
478        // Handle Transform (lossy) frames
479        if frame_type == FrameType::Transform {
480            if frame.channels.is_empty() {
481                return Ok(Vec::new());
482            }
483
484            let frame_data = &frame.channels[0].residuals;
485            if let Some(transform_frame) = deserialize_frame(frame_data) {
486                let decoder = self.lossy_decoder.get_or_insert_with(|| {
487                    TransformDecoder::new(header.sample_rate, header.channels)
488                });
489                let samples = decoder.decode_frame(&transform_frame);
490
491                // Skip first frame (preroll) for lossy
492                if !self.skipped_preroll {
493                    self.skipped_preroll = true;
494                    return Ok(Vec::new());
495                }
496
497                return Ok(samples);
498            }
499            return Ok(Vec::new());
500        }
501
502        // Handle lossless frames (Silence, Raw, ALPC variants)
503        let channels = header.channels as usize;
504        let frame_samples = frame.frame_samples as usize;
505        let use_mid_side = channels == 2 && (frame.flags & 0x01) != 0;
506
507        let mut frame_channels: Vec<Vec<i32>> = Vec::with_capacity(channels);
508
509        for ch_data in &frame.channels {
510            let samples = self.decode_channel_int(ch_data, frame_samples)?;
511            frame_channels.push(samples);
512        }
513
514        // Convert mid-side back to left-right if needed
515        let mut all_samples: Vec<Vec<i32>> = vec![vec![]; channels];
516        if use_mid_side && frame_channels.len() == 2 {
517            let (left, right) = self.decode_mid_side(&frame_channels[0], &frame_channels[1]);
518            all_samples[0] = left;
519            all_samples[1] = right;
520        } else {
521            for (ch_idx, samples) in frame_channels.into_iter().enumerate() {
522                if ch_idx < channels {
523                    all_samples[ch_idx] = samples;
524                }
525            }
526        }
527
528        // Interleave and convert to f32
529        let max_len = all_samples.iter().map(|v| v.len()).max().unwrap_or(0);
530        let mut interleaved = Vec::with_capacity(max_len * channels);
531
532        for i in 0..max_len {
533            for ch in 0..channels {
534                let sample = all_samples[ch].get(i).copied().unwrap_or(0);
535                interleaved.push(i32_to_f32(sample));
536            }
537        }
538
539        Ok(interleaved)
540    }
541
542    /// Decode a single channel to integers
543    fn decode_channel_int(
544        &self,
545        ch_data: &ChannelData,
546        frame_samples: usize,
547    ) -> FloResult<Vec<i32>> {
548        let has_coeffs = !ch_data.predictor_coeffs.is_empty();
549        let has_residuals = !ch_data.residuals.is_empty();
550        let shift_bits = ch_data.shift_bits;
551
552        // Check for fixed predictor marker: shift_bits >= 128 means fixed order (128 + order)
553        let is_fixed_predictor = !has_coeffs && has_residuals && shift_bits >= 128;
554
555        if is_fixed_predictor {
556            let fixed_order = (shift_bits - 128) as usize;
557            let residuals =
558                rice::decode_i32(&ch_data.residuals, ch_data.rice_parameter, frame_samples);
559            return Ok(self.reconstruct_fixed(fixed_order, &residuals, frame_samples));
560        }
561
562        if has_coeffs {
563            // LPC decoding with stored coefficients
564            // Decode residuals based on encoding type
565            let residuals = match ch_data.residual_encoding {
566                ResidualEncoding::Rice => {
567                    rice::decode_i32(&ch_data.residuals, ch_data.rice_parameter, frame_samples)
568                }
569                ResidualEncoding::Raw | ResidualEncoding::Golomb => {
570                    // Raw residuals as i16 (Golomb not implemented, fallback to raw)
571                    let mut res = Vec::with_capacity(frame_samples);
572                    for chunk in ch_data.residuals.chunks(2) {
573                        if chunk.len() == 2 {
574                            res.push(i16::from_le_bytes([chunk[0], chunk[1]]) as i32);
575                        }
576                    }
577                    while res.len() < frame_samples {
578                        res.push(0);
579                    }
580                    res
581                }
582            };
583
584            let order = ch_data.predictor_coeffs.len();
585            let samples = self.reconstruct_lpc_int(
586                &ch_data.predictor_coeffs,
587                &residuals,
588                shift_bits,
589                order,
590                frame_samples,
591            );
592            return Ok(samples);
593        }
594
595        if has_residuals {
596            // Raw PCM (no prediction)
597            let mut samples = Vec::with_capacity(frame_samples);
598            for chunk in ch_data.residuals.chunks(2) {
599                if chunk.len() == 2 {
600                    samples.push(i16::from_le_bytes([chunk[0], chunk[1]]) as i32);
601                }
602            }
603            while samples.len() < frame_samples {
604                samples.push(0);
605            }
606            return Ok(samples);
607        }
608
609        // Silence
610        Ok(vec![0; frame_samples])
611    }
612
613    /// Convert mid-side back to left-right
614    fn decode_mid_side(&self, mid: &[i32], side: &[i32]) -> (Vec<i32>, Vec<i32>) {
615        let left: Vec<i32> = mid
616            .iter()
617            .zip(side.iter())
618            .map(|(&m, &s)| (m + s) / 2)
619            .collect();
620        let right: Vec<i32> = mid
621            .iter()
622            .zip(side.iter())
623            .map(|(&m, &s)| (m - s) / 2)
624            .collect();
625        (left, right)
626    }
627
628    /// Reconstruct from LPC prediction
629    fn reconstruct_lpc_int(
630        &self,
631        coeffs: &[i32],
632        residuals: &[i32],
633        shift: u8,
634        order: usize,
635        target_len: usize,
636    ) -> Vec<i32> {
637        let mut samples = Vec::with_capacity(target_len);
638
639        // Warmup samples from residuals
640        for i in 0..order.min(residuals.len()) {
641            samples.push(residuals[i]);
642        }
643
644        // Reconstruct remaining
645        for i in order..target_len.min(residuals.len()) {
646            let mut prediction: i64 = 0;
647            for (j, &coeff) in coeffs.iter().enumerate() {
648                if i > j {
649                    prediction += (coeff as i64) * (samples[i - j - 1] as i64);
650                }
651            }
652            prediction >>= shift;
653            samples.push(prediction as i32 + residuals[i]);
654        }
655
656        while samples.len() < target_len {
657            samples.push(0);
658        }
659
660        samples
661    }
662
663    /// Reconstruct from fixed predictor
664    fn reconstruct_fixed(&self, order: usize, residuals: &[i32], target_len: usize) -> Vec<i32> {
665        let mut samples = Vec::with_capacity(target_len);
666
667        if residuals.is_empty() {
668            return vec![0; target_len];
669        }
670
671        match order {
672            0 => samples.extend_from_slice(residuals),
673            1 => {
674                samples.push(residuals[0]);
675                for i in 1..residuals.len().min(target_len) {
676                    samples.push(residuals[i].wrapping_add(samples[i - 1]));
677                }
678            }
679            2 => {
680                if !residuals.is_empty() {
681                    samples.push(residuals[0]);
682                }
683                if residuals.len() > 1 {
684                    samples.push(residuals[1].wrapping_add(samples[0]));
685                }
686                for i in 2..residuals.len().min(target_len) {
687                    let pred = (2i64 * samples[i - 1] as i64 - samples[i - 2] as i64) as i32;
688                    samples.push(residuals[i].wrapping_add(pred));
689                }
690            }
691            3 => {
692                if !residuals.is_empty() {
693                    samples.push(residuals[0]);
694                }
695                if residuals.len() > 1 {
696                    samples.push(residuals[1].wrapping_add(samples[0]));
697                }
698                if residuals.len() > 2 {
699                    let pred = (2i64 * samples[1] as i64 - samples[0] as i64) as i32;
700                    samples.push(residuals[2].wrapping_add(pred));
701                }
702                for i in 3..residuals.len().min(target_len) {
703                    let pred = (3i64 * samples[i - 1] as i64 - 3i64 * samples[i - 2] as i64
704                        + samples[i - 3] as i64) as i32;
705                    samples.push(residuals[i].wrapping_add(pred));
706                }
707            }
708            4 => {
709                if !residuals.is_empty() {
710                    samples.push(residuals[0]);
711                }
712                if residuals.len() > 1 {
713                    samples.push(residuals[1].wrapping_add(samples[0]));
714                }
715                if residuals.len() > 2 {
716                    let pred = (2i64 * samples[1] as i64 - samples[0] as i64) as i32;
717                    samples.push(residuals[2].wrapping_add(pred));
718                }
719                if residuals.len() > 3 {
720                    let pred = (3i64 * samples[2] as i64 - 3i64 * samples[1] as i64
721                        + samples[0] as i64) as i32;
722                    samples.push(residuals[3].wrapping_add(pred));
723                }
724                for i in 4..residuals.len().min(target_len) {
725                    let pred = (4i64 * samples[i - 1] as i64 - 6i64 * samples[i - 2] as i64
726                        + 4i64 * samples[i - 3] as i64
727                        - samples[i - 4] as i64) as i32;
728                    samples.push(residuals[i].wrapping_add(pred));
729                }
730            }
731            _ => samples.extend_from_slice(residuals),
732        }
733
734        while samples.len() < target_len {
735            samples.push(0);
736        }
737
738        samples
739    }
740
741    fn decode_with_standard_decoder(&self) -> FloResult<Vec<f32>> {
742        let reader = Reader::new();
743        let file = reader.read(&self.buffer)?;
744
745        let is_transform = file
746            .frames
747            .iter()
748            .any(|f| f.frame_type == (FrameType::Transform as u8));
749
750        if is_transform {
751            let mut decoder = TransformDecoder::new(file.header.sample_rate, file.header.channels);
752            let mut all_samples = Vec::new();
753            let mut frame_count = 0;
754
755            for frame in &file.frames {
756                if frame.channels.is_empty() {
757                    continue;
758                }
759                let frame_data = &frame.channels[0].residuals;
760                if let Some(transform_frame) = deserialize_frame(frame_data) {
761                    let samples = decoder.decode_frame(&transform_frame);
762                    if frame_count > 0 {
763                        all_samples.extend(samples);
764                    }
765                    frame_count += 1;
766                }
767            }
768            Ok(all_samples)
769        } else {
770            let decoder = LosslessDecoder::new();
771            decoder.decode_file(&file)
772        }
773    }
774}
775
776impl Default for StreamingDecoder {
777    fn default() -> Self {
778        Self::new()
779    }
780}