1use crate::core::audio_constants::i32_to_f32;
2use crate::core::{rice, ChannelData, FloResult, Frame, FrameType, Header, TocEntry};
3use crate::lossless::Decoder as LosslessDecoder;
4use crate::lossy::{deserialize_frame, TransformDecoder};
5use crate::{Reader, ResidualEncoding, MAGIC};
6
7use super::types::{DecoderState, StreamingAudioInfo};
8
9pub struct StreamingDecoder {
10 buffer: Vec<u8>,
12 state: DecoderState,
14 header: Option<Header>,
16 toc: Vec<TocEntry>,
18 current_frame: usize,
20 data_offset: usize,
22 lossy_decoder: Option<TransformDecoder>,
24 is_lossy: bool,
26 skipped_preroll: bool,
28}
29
30impl StreamingDecoder {
31 pub fn new() -> Self {
33 Self {
34 buffer: Vec::with_capacity(64 * 1024),
35 state: DecoderState::WaitingForHeader,
36 header: None,
37 toc: Vec::new(),
38 current_frame: 0,
39 data_offset: 0,
40 lossy_decoder: None,
41 is_lossy: false,
42 skipped_preroll: false,
43 }
44 }
45
46 pub fn state(&self) -> DecoderState {
48 self.state
49 }
50
51 pub fn info(&self) -> Option<StreamingAudioInfo> {
53 self.header.as_ref().map(|h| StreamingAudioInfo {
54 sample_rate: h.sample_rate,
55 channels: h.channels,
56 bit_depth: h.bit_depth,
57 total_frames: h.total_frames,
58 is_lossy: self.is_lossy,
59 })
60 }
61
62 pub fn frames_available(&self) -> usize {
64 if self.state != DecoderState::Ready {
65 return 0;
66 }
67 self.count_complete_frames()
68 }
69
70 pub fn feed(&mut self, data: &[u8]) -> FloResult<bool> {
72 if self.state == DecoderState::Error || self.state == DecoderState::Finished {
73 return Ok(false);
74 }
75
76 self.buffer.extend_from_slice(data);
77 self.try_advance_state()
78 }
79
80 pub fn next_frame(&mut self) -> FloResult<Option<Vec<f32>>> {
82 if self.state != DecoderState::Ready {
83 return Ok(None);
84 }
85
86 let header = match self.header.as_ref() {
87 Some(h) => h.clone(),
88 None => return Err("No header".to_string()),
89 };
90
91 if self.current_frame >= self.toc.len() {
92 self.state = DecoderState::Finished;
93 return Ok(None);
94 }
95
96 let toc_entry = &self.toc[self.current_frame];
97 let frame_start = self.data_offset + toc_entry.byte_offset as usize;
98 let frame_end = frame_start + toc_entry.frame_size as usize;
99
100 if frame_end > self.buffer.len() {
101 return Ok(None);
102 }
103
104 let frame_data = &self.buffer[frame_start..frame_end];
105 let frame = self.parse_frame(frame_data, header.channels)?;
106
107 self.current_frame += 1;
108 let samples = self.decode_frame(&frame, &header)?;
109
110 Ok(Some(samples))
111 }
112
113 pub fn decode_available(&mut self) -> FloResult<Vec<f32>> {
115 if self.state != DecoderState::Ready {
116 return Ok(Vec::new());
117 }
118
119 let samples = self.decode_with_standard_decoder()?;
120 self.state = DecoderState::Finished;
121 Ok(samples)
122 }
123
124 pub fn reset(&mut self) {
126 self.buffer.clear();
127 self.state = DecoderState::WaitingForHeader;
128 self.header = None;
129 self.toc.clear();
130 self.current_frame = 0;
131 self.data_offset = 0;
132 self.lossy_decoder = None;
133 self.is_lossy = false;
134 self.skipped_preroll = false;
135 }
136
137 pub fn buffered_bytes(&self) -> usize {
139 self.buffer.len()
140 }
141
142 pub fn available_frames(&self) -> usize {
144 if self.state != DecoderState::Ready {
145 return 0;
146 }
147 self.count_complete_frames()
148 .saturating_sub(self.current_frame)
149 }
150
151 pub fn current_frame_index(&self) -> usize {
153 self.current_frame
154 }
155
156 fn try_advance_state(&mut self) -> FloResult<bool> {
159 match self.state {
160 DecoderState::WaitingForHeader => {
161 if self.try_parse_header()? {
162 self.state = DecoderState::WaitingForToc;
163 return self.try_advance_state();
164 }
165 }
166 DecoderState::WaitingForToc => {
167 if self.try_parse_toc()? {
168 self.state = DecoderState::Ready;
169 return Ok(true);
170 }
171 }
172 DecoderState::Ready => {
173 return Ok(self.count_complete_frames() > self.current_frame);
174 }
175 _ => {}
176 }
177 Ok(false)
178 }
179
180 fn try_parse_header(&mut self) -> FloResult<bool> {
181 if self.buffer.len() < 70 {
183 return Ok(false);
184 }
185
186 if self.buffer[0..4] != MAGIC {
187 self.state = DecoderState::Error;
188 return Err("Invalid flo file: bad magic".to_string());
189 }
190
191 let header = Header {
192 version_major: self.buffer[4],
193 version_minor: self.buffer[5],
194 flags: u16::from_le_bytes([self.buffer[6], self.buffer[7]]),
195 sample_rate: u32::from_le_bytes([
196 self.buffer[8],
197 self.buffer[9],
198 self.buffer[10],
199 self.buffer[11],
200 ]),
201 channels: self.buffer[12],
202 bit_depth: self.buffer[13],
203 total_frames: u64::from_le_bytes([
204 self.buffer[14],
205 self.buffer[15],
206 self.buffer[16],
207 self.buffer[17],
208 self.buffer[18],
209 self.buffer[19],
210 self.buffer[20],
211 self.buffer[21],
212 ]),
213 compression_level: self.buffer[22],
214 data_crc32: u32::from_le_bytes([
215 self.buffer[26],
216 self.buffer[27],
217 self.buffer[28],
218 self.buffer[29],
219 ]),
220 header_size: u64::from_le_bytes([
221 self.buffer[30],
222 self.buffer[31],
223 self.buffer[32],
224 self.buffer[33],
225 self.buffer[34],
226 self.buffer[35],
227 self.buffer[36],
228 self.buffer[37],
229 ]),
230 toc_size: u64::from_le_bytes([
231 self.buffer[38],
232 self.buffer[39],
233 self.buffer[40],
234 self.buffer[41],
235 self.buffer[42],
236 self.buffer[43],
237 self.buffer[44],
238 self.buffer[45],
239 ]),
240 data_size: u64::from_le_bytes([
241 self.buffer[46],
242 self.buffer[47],
243 self.buffer[48],
244 self.buffer[49],
245 self.buffer[50],
246 self.buffer[51],
247 self.buffer[52],
248 self.buffer[53],
249 ]),
250 extra_size: u64::from_le_bytes([
251 self.buffer[54],
252 self.buffer[55],
253 self.buffer[56],
254 self.buffer[57],
255 self.buffer[58],
256 self.buffer[59],
257 self.buffer[60],
258 self.buffer[61],
259 ]),
260 meta_size: u64::from_le_bytes([
261 self.buffer[62],
262 self.buffer[63],
263 self.buffer[64],
264 self.buffer[65],
265 self.buffer[66],
266 self.buffer[67],
267 self.buffer[68],
268 self.buffer[69],
269 ]),
270 };
271
272 self.is_lossy = (header.flags & 0x01) != 0;
273 if self.is_lossy {
274 self.lossy_decoder = Some(TransformDecoder::new(header.sample_rate, header.channels));
275 }
276
277 self.header = Some(header);
278 Ok(true)
279 }
280
281 fn try_parse_toc(&mut self) -> FloResult<bool> {
282 let header = self.header.as_ref().ok_or("No header")?;
283 let toc_start = 70;
284 let toc_end = toc_start + header.toc_size as usize;
285
286 if self.buffer.len() < toc_end {
287 return Ok(false);
288 }
289
290 if header.toc_size >= 4 {
291 let num_entries = u32::from_le_bytes([
292 self.buffer[toc_start],
293 self.buffer[toc_start + 1],
294 self.buffer[toc_start + 2],
295 self.buffer[toc_start + 3],
296 ]) as usize;
297
298 let entries_start = toc_start + 4;
299 for i in 0..num_entries {
300 let offset = entries_start + i * 20;
301 if offset + 20 > self.buffer.len() {
302 return Ok(false);
303 }
304
305 self.toc.push(TocEntry {
306 frame_index: u32::from_le_bytes([
307 self.buffer[offset],
308 self.buffer[offset + 1],
309 self.buffer[offset + 2],
310 self.buffer[offset + 3],
311 ]),
312 byte_offset: u64::from_le_bytes([
313 self.buffer[offset + 4],
314 self.buffer[offset + 5],
315 self.buffer[offset + 6],
316 self.buffer[offset + 7],
317 self.buffer[offset + 8],
318 self.buffer[offset + 9],
319 self.buffer[offset + 10],
320 self.buffer[offset + 11],
321 ]),
322 frame_size: u32::from_le_bytes([
323 self.buffer[offset + 12],
324 self.buffer[offset + 13],
325 self.buffer[offset + 14],
326 self.buffer[offset + 15],
327 ]),
328 timestamp_ms: u32::from_le_bytes([
329 self.buffer[offset + 16],
330 self.buffer[offset + 17],
331 self.buffer[offset + 18],
332 self.buffer[offset + 19],
333 ]),
334 });
335 }
336 }
337
338 self.data_offset = toc_end;
339 Ok(true)
340 }
341
342 fn count_complete_frames(&self) -> usize {
343 let mut count = 0;
344 for entry in &self.toc {
345 let frame_end =
346 self.data_offset + entry.byte_offset as usize + entry.frame_size as usize;
347 if frame_end <= self.buffer.len() {
348 count += 1;
349 } else {
350 break;
351 }
352 }
353 count
354 }
355
356 fn parse_frame(&self, data: &[u8], channels: u8) -> FloResult<Frame> {
357 if data.len() < 6 {
358 return Err("Frame too small".to_string());
359 }
360
361 let frame_type_byte = data[0];
362 let frame_samples = u32::from_le_bytes([data[1], data[2], data[3], data[4]]);
363 let flags = data[5];
364
365 let frame_type = FrameType::from(frame_type_byte);
366 let mut frame = Frame::new(frame_type_byte, frame_samples);
367 frame.flags = flags;
368
369 let num_channels = if frame_type == FrameType::Transform {
370 1
371 } else {
372 channels as usize
373 };
374
375 let mut pos = 6;
376 for _ in 0..num_channels {
377 if pos + 4 > data.len() {
378 return Err("Frame truncated".to_string());
379 }
380
381 let ch_size =
382 u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]])
383 as usize;
384 pos += 4;
385
386 if pos + ch_size > data.len() {
387 return Err("Channel data truncated".to_string());
388 }
389
390 let ch_data = &data[pos..pos + ch_size];
391 pos += ch_size;
392
393 let channel = match frame_type {
394 FrameType::Silence => ChannelData::new_silence(),
395 FrameType::Raw | FrameType::Transform => ChannelData {
396 predictor_coeffs: vec![],
397 shift_bits: 0,
398 residual_encoding: ResidualEncoding::Raw,
399 rice_parameter: 0,
400 residuals: ch_data.to_vec(),
401 },
402 _ => self.parse_alpc_channel(ch_data, frame_type)?,
403 };
404
405 frame.channels.push(channel);
406 }
407
408 Ok(frame)
409 }
410
411 fn parse_alpc_channel(&self, data: &[u8], _frame_type: FrameType) -> FloResult<ChannelData> {
412 if data.is_empty() {
413 return Ok(ChannelData::new_silence());
414 }
415
416 let order = data[0] as usize;
417 if order > 12 {
418 return Err("Invalid LPC order".to_string());
419 }
420
421 let coeff_bytes = order * 4;
422 let min_size = 1 + coeff_bytes + 2; if data.len() < min_size {
424 return Err("ALPC channel too small".to_string());
425 }
426
427 let mut coefficients = Vec::with_capacity(order);
429 for i in 0..order {
430 let offset = 1 + i * 4;
431 let coeff = i32::from_le_bytes([
432 data[offset],
433 data[offset + 1],
434 data[offset + 2],
435 data[offset + 3],
436 ]);
437 coefficients.push(coeff);
438 }
439
440 let mut pos = 1 + coeff_bytes;
441
442 let shift_bits = data[pos];
444 pos += 1;
445
446 let residual_encoding_byte = data[pos];
448 let residual_encoding = ResidualEncoding::from(residual_encoding_byte);
449 pos += 1;
450
451 let rice_parameter = if residual_encoding == ResidualEncoding::Rice {
453 if pos >= data.len() {
454 return Err("Missing rice parameter".to_string());
455 }
456 let rp = data[pos];
457 pos += 1;
458 rp
459 } else {
460 0
461 };
462
463 let residuals = data[pos..].to_vec();
465
466 Ok(ChannelData {
467 predictor_coeffs: coefficients,
468 shift_bits,
469 residual_encoding,
470 rice_parameter,
471 residuals,
472 })
473 }
474
475 fn decode_frame(&mut self, frame: &Frame, header: &Header) -> FloResult<Vec<f32>> {
476 let frame_type = FrameType::from(frame.frame_type);
477
478 if frame_type == FrameType::Transform {
480 if frame.channels.is_empty() {
481 return Ok(Vec::new());
482 }
483
484 let frame_data = &frame.channels[0].residuals;
485 if let Some(transform_frame) = deserialize_frame(frame_data) {
486 let decoder = self.lossy_decoder.get_or_insert_with(|| {
487 TransformDecoder::new(header.sample_rate, header.channels)
488 });
489 let samples = decoder.decode_frame(&transform_frame);
490
491 if !self.skipped_preroll {
493 self.skipped_preroll = true;
494 return Ok(Vec::new());
495 }
496
497 return Ok(samples);
498 }
499 return Ok(Vec::new());
500 }
501
502 let channels = header.channels as usize;
504 let frame_samples = frame.frame_samples as usize;
505 let use_mid_side = channels == 2 && (frame.flags & 0x01) != 0;
506
507 let mut frame_channels: Vec<Vec<i32>> = Vec::with_capacity(channels);
508
509 for ch_data in &frame.channels {
510 let samples = self.decode_channel_int(ch_data, frame_samples)?;
511 frame_channels.push(samples);
512 }
513
514 let mut all_samples: Vec<Vec<i32>> = vec![vec![]; channels];
516 if use_mid_side && frame_channels.len() == 2 {
517 let (left, right) = self.decode_mid_side(&frame_channels[0], &frame_channels[1]);
518 all_samples[0] = left;
519 all_samples[1] = right;
520 } else {
521 for (ch_idx, samples) in frame_channels.into_iter().enumerate() {
522 if ch_idx < channels {
523 all_samples[ch_idx] = samples;
524 }
525 }
526 }
527
528 let max_len = all_samples.iter().map(|v| v.len()).max().unwrap_or(0);
530 let mut interleaved = Vec::with_capacity(max_len * channels);
531
532 for i in 0..max_len {
533 for ch in 0..channels {
534 let sample = all_samples[ch].get(i).copied().unwrap_or(0);
535 interleaved.push(i32_to_f32(sample));
536 }
537 }
538
539 Ok(interleaved)
540 }
541
542 fn decode_channel_int(
544 &self,
545 ch_data: &ChannelData,
546 frame_samples: usize,
547 ) -> FloResult<Vec<i32>> {
548 let has_coeffs = !ch_data.predictor_coeffs.is_empty();
549 let has_residuals = !ch_data.residuals.is_empty();
550 let shift_bits = ch_data.shift_bits;
551
552 let is_fixed_predictor = !has_coeffs && has_residuals && shift_bits >= 128;
554
555 if is_fixed_predictor {
556 let fixed_order = (shift_bits - 128) as usize;
557 let residuals =
558 rice::decode_i32(&ch_data.residuals, ch_data.rice_parameter, frame_samples);
559 return Ok(self.reconstruct_fixed(fixed_order, &residuals, frame_samples));
560 }
561
562 if has_coeffs {
563 let residuals = match ch_data.residual_encoding {
566 ResidualEncoding::Rice => {
567 rice::decode_i32(&ch_data.residuals, ch_data.rice_parameter, frame_samples)
568 }
569 ResidualEncoding::Raw | ResidualEncoding::Golomb => {
570 let mut res = Vec::with_capacity(frame_samples);
572 for chunk in ch_data.residuals.chunks(2) {
573 if chunk.len() == 2 {
574 res.push(i16::from_le_bytes([chunk[0], chunk[1]]) as i32);
575 }
576 }
577 while res.len() < frame_samples {
578 res.push(0);
579 }
580 res
581 }
582 };
583
584 let order = ch_data.predictor_coeffs.len();
585 let samples = self.reconstruct_lpc_int(
586 &ch_data.predictor_coeffs,
587 &residuals,
588 shift_bits,
589 order,
590 frame_samples,
591 );
592 return Ok(samples);
593 }
594
595 if has_residuals {
596 let mut samples = Vec::with_capacity(frame_samples);
598 for chunk in ch_data.residuals.chunks(2) {
599 if chunk.len() == 2 {
600 samples.push(i16::from_le_bytes([chunk[0], chunk[1]]) as i32);
601 }
602 }
603 while samples.len() < frame_samples {
604 samples.push(0);
605 }
606 return Ok(samples);
607 }
608
609 Ok(vec![0; frame_samples])
611 }
612
613 fn decode_mid_side(&self, mid: &[i32], side: &[i32]) -> (Vec<i32>, Vec<i32>) {
615 let left: Vec<i32> = mid
616 .iter()
617 .zip(side.iter())
618 .map(|(&m, &s)| (m + s) / 2)
619 .collect();
620 let right: Vec<i32> = mid
621 .iter()
622 .zip(side.iter())
623 .map(|(&m, &s)| (m - s) / 2)
624 .collect();
625 (left, right)
626 }
627
628 fn reconstruct_lpc_int(
630 &self,
631 coeffs: &[i32],
632 residuals: &[i32],
633 shift: u8,
634 order: usize,
635 target_len: usize,
636 ) -> Vec<i32> {
637 let mut samples = Vec::with_capacity(target_len);
638
639 for i in 0..order.min(residuals.len()) {
641 samples.push(residuals[i]);
642 }
643
644 for i in order..target_len.min(residuals.len()) {
646 let mut prediction: i64 = 0;
647 for (j, &coeff) in coeffs.iter().enumerate() {
648 if i > j {
649 prediction += (coeff as i64) * (samples[i - j - 1] as i64);
650 }
651 }
652 prediction >>= shift;
653 samples.push(prediction as i32 + residuals[i]);
654 }
655
656 while samples.len() < target_len {
657 samples.push(0);
658 }
659
660 samples
661 }
662
663 fn reconstruct_fixed(&self, order: usize, residuals: &[i32], target_len: usize) -> Vec<i32> {
665 let mut samples = Vec::with_capacity(target_len);
666
667 if residuals.is_empty() {
668 return vec![0; target_len];
669 }
670
671 match order {
672 0 => samples.extend_from_slice(residuals),
673 1 => {
674 samples.push(residuals[0]);
675 for i in 1..residuals.len().min(target_len) {
676 samples.push(residuals[i].wrapping_add(samples[i - 1]));
677 }
678 }
679 2 => {
680 if !residuals.is_empty() {
681 samples.push(residuals[0]);
682 }
683 if residuals.len() > 1 {
684 samples.push(residuals[1].wrapping_add(samples[0]));
685 }
686 for i in 2..residuals.len().min(target_len) {
687 let pred = (2i64 * samples[i - 1] as i64 - samples[i - 2] as i64) as i32;
688 samples.push(residuals[i].wrapping_add(pred));
689 }
690 }
691 3 => {
692 if !residuals.is_empty() {
693 samples.push(residuals[0]);
694 }
695 if residuals.len() > 1 {
696 samples.push(residuals[1].wrapping_add(samples[0]));
697 }
698 if residuals.len() > 2 {
699 let pred = (2i64 * samples[1] as i64 - samples[0] as i64) as i32;
700 samples.push(residuals[2].wrapping_add(pred));
701 }
702 for i in 3..residuals.len().min(target_len) {
703 let pred = (3i64 * samples[i - 1] as i64 - 3i64 * samples[i - 2] as i64
704 + samples[i - 3] as i64) as i32;
705 samples.push(residuals[i].wrapping_add(pred));
706 }
707 }
708 4 => {
709 if !residuals.is_empty() {
710 samples.push(residuals[0]);
711 }
712 if residuals.len() > 1 {
713 samples.push(residuals[1].wrapping_add(samples[0]));
714 }
715 if residuals.len() > 2 {
716 let pred = (2i64 * samples[1] as i64 - samples[0] as i64) as i32;
717 samples.push(residuals[2].wrapping_add(pred));
718 }
719 if residuals.len() > 3 {
720 let pred = (3i64 * samples[2] as i64 - 3i64 * samples[1] as i64
721 + samples[0] as i64) as i32;
722 samples.push(residuals[3].wrapping_add(pred));
723 }
724 for i in 4..residuals.len().min(target_len) {
725 let pred = (4i64 * samples[i - 1] as i64 - 6i64 * samples[i - 2] as i64
726 + 4i64 * samples[i - 3] as i64
727 - samples[i - 4] as i64) as i32;
728 samples.push(residuals[i].wrapping_add(pred));
729 }
730 }
731 _ => samples.extend_from_slice(residuals),
732 }
733
734 while samples.len() < target_len {
735 samples.push(0);
736 }
737
738 samples
739 }
740
741 fn decode_with_standard_decoder(&self) -> FloResult<Vec<f32>> {
742 let reader = Reader::new();
743 let file = reader.read(&self.buffer)?;
744
745 let is_transform = file
746 .frames
747 .iter()
748 .any(|f| f.frame_type == (FrameType::Transform as u8));
749
750 if is_transform {
751 let mut decoder = TransformDecoder::new(file.header.sample_rate, file.header.channels);
752 let mut all_samples = Vec::new();
753 let mut frame_count = 0;
754
755 for frame in &file.frames {
756 if frame.channels.is_empty() {
757 continue;
758 }
759 let frame_data = &frame.channels[0].residuals;
760 if let Some(transform_frame) = deserialize_frame(frame_data) {
761 let samples = decoder.decode_frame(&transform_frame);
762 if frame_count > 0 {
763 all_samples.extend(samples);
764 }
765 frame_count += 1;
766 }
767 }
768 Ok(all_samples)
769 } else {
770 let decoder = LosslessDecoder::new();
771 decoder.decode_file(&file)
772 }
773 }
774}
775
776impl Default for StreamingDecoder {
777 fn default() -> Self {
778 Self::new()
779 }
780}