Skip to main content

ape_decoder/
decoder.rs

1use std::io::{Read, Seek, SeekFrom};
2
3use crate::bitreader::BitReader;
4use crate::crc::ape_crc;
5use crate::entropy::EntropyState;
6use crate::error::{ApeError, ApeResult};
7use crate::format::{
8    self, ApeFileInfo, APE_FORMAT_FLAG_AIFF, APE_FORMAT_FLAG_BIG_ENDIAN, APE_FORMAT_FLAG_CAF,
9    APE_FORMAT_FLAG_FLOATING_POINT, APE_FORMAT_FLAG_SIGNED_8_BIT, APE_FORMAT_FLAG_SND,
10    APE_FORMAT_FLAG_W64,
11};
12use crate::id3v2::{self, Id3v2Tag};
13use crate::predictor::{Predictor3950, Predictor3950_32};
14use crate::range_coder::RangeCoder;
15use crate::tag::{self, ApeTag};
16use crate::unprepare;
17
18// Special frame codes (from Prepare.h)
19const SPECIAL_FRAME_MONO_SILENCE: i32 = 1;
20const SPECIAL_FRAME_LEFT_SILENCE: i32 = 1;
21const SPECIAL_FRAME_RIGHT_SILENCE: i32 = 2;
22const SPECIAL_FRAME_PSEUDO_STEREO: i32 = 4;
23
24/// Source container format of the original audio.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum SourceFormat {
27    Wav,
28    Aiff,
29    W64,
30    Snd,
31    Caf,
32    Unknown,
33}
34
35/// Result of a seek operation.
36#[derive(Debug, Clone, Copy)]
37pub struct SeekResult {
38    /// Frame index containing the target sample.
39    pub frame_index: u32,
40    /// Number of samples to skip within the decoded frame.
41    pub skip_samples: u32,
42    /// The exact sample position reached.
43    pub actual_sample: u64,
44}
45
46/// File metadata accessible without decoding.
47#[derive(Debug, Clone)]
48pub struct ApeInfo {
49    // Core audio properties
50    pub version: u16,
51    pub compression_level: u16,
52    pub sample_rate: u32,
53    pub channels: u16,
54    pub bits_per_sample: u16,
55    pub total_samples: u64,
56    pub total_frames: u32,
57    pub blocks_per_frame: u32,
58    pub final_frame_blocks: u32,
59    pub duration_ms: u64,
60    pub block_align: u16,
61
62    // Format details
63    pub format_flags: u16,
64    pub bytes_per_sample: u16,
65    pub average_bitrate_kbps: u32,
66    pub decompressed_bitrate_kbps: u32,
67    pub file_size_bytes: u64,
68
69    // Format flag helpers
70    pub is_big_endian: bool,
71    pub is_floating_point: bool,
72    pub is_signed_8bit: bool,
73
74    // Source container
75    pub source_format: SourceFormat,
76}
77
78impl ApeInfo {
79    fn from_file_info(info: &ApeFileInfo) -> Self {
80        let flags = info.header.format_flags;
81        let source_format = if flags & APE_FORMAT_FLAG_AIFF != 0 {
82            SourceFormat::Aiff
83        } else if flags & APE_FORMAT_FLAG_W64 != 0 {
84            SourceFormat::W64
85        } else if flags & APE_FORMAT_FLAG_SND != 0 {
86            SourceFormat::Snd
87        } else if flags & APE_FORMAT_FLAG_CAF != 0 {
88            SourceFormat::Caf
89        } else {
90            SourceFormat::Wav
91        };
92
93        ApeInfo {
94            version: info.descriptor.version,
95            compression_level: info.header.compression_level,
96            sample_rate: info.header.sample_rate,
97            channels: info.header.channels,
98            bits_per_sample: info.header.bits_per_sample,
99            total_samples: info.total_blocks as u64,
100            total_frames: info.header.total_frames,
101            blocks_per_frame: info.header.blocks_per_frame,
102            final_frame_blocks: info.header.final_frame_blocks,
103            duration_ms: info.length_ms as u64,
104            block_align: info.block_align,
105
106            format_flags: flags,
107            bytes_per_sample: info.bytes_per_sample,
108            average_bitrate_kbps: if info.length_ms > 0 {
109                (info.file_bytes * 8 / info.length_ms as u64) as u32
110            } else {
111                0
112            },
113            decompressed_bitrate_kbps: info.decompressed_bitrate as u32,
114            file_size_bytes: info.file_bytes,
115
116            is_big_endian: flags & APE_FORMAT_FLAG_BIG_ENDIAN != 0,
117            is_floating_point: flags & APE_FORMAT_FLAG_FLOATING_POINT != 0,
118            is_signed_8bit: flags & APE_FORMAT_FLAG_SIGNED_8_BIT != 0,
119
120            source_format,
121        }
122    }
123
124    /// Number of samples (blocks) in a given frame.
125    pub fn frame_samples(&self, frame_idx: u32) -> u32 {
126        if frame_idx == self.total_frames - 1 {
127            self.final_frame_blocks
128        } else {
129            self.blocks_per_frame
130        }
131    }
132
133    /// Generate a standard 44-byte RIFF/WAVE header for the decoded audio.
134    ///
135    /// Use this when `ApeDecoder::wav_header_data()` returns `None` (the
136    /// `CREATE_WAV_HEADER` flag was set, meaning the original header was not
137    /// stored). Combine with decoded PCM to produce a valid WAV file:
138    ///
139    /// ```rust,ignore
140    /// let header = decoder.info().generate_wav_header();
141    /// let pcm = decoder.decode_all()?;
142    /// output.write_all(&header)?;
143    /// output.write_all(&pcm)?;
144    /// ```
145    pub fn generate_wav_header(&self) -> Vec<u8> {
146        let data_size = self.total_samples as u32 * self.block_align as u32;
147        let file_size = 36 + data_size;
148
149        let mut header = Vec::with_capacity(44);
150        header.extend_from_slice(b"RIFF");
151        header.extend_from_slice(&file_size.to_le_bytes());
152        header.extend_from_slice(b"WAVE");
153
154        // fmt sub-chunk
155        header.extend_from_slice(b"fmt ");
156        header.extend_from_slice(&16u32.to_le_bytes()); // sub-chunk size
157        header.extend_from_slice(&1u16.to_le_bytes()); // PCM format
158        header.extend_from_slice(&self.channels.to_le_bytes());
159        header.extend_from_slice(&self.sample_rate.to_le_bytes());
160        let byte_rate = self.sample_rate * self.block_align as u32;
161        header.extend_from_slice(&byte_rate.to_le_bytes());
162        header.extend_from_slice(&self.block_align.to_le_bytes());
163        header.extend_from_slice(&self.bits_per_sample.to_le_bytes());
164
165        // data sub-chunk
166        header.extend_from_slice(b"data");
167        header.extend_from_slice(&data_size.to_le_bytes());
168
169        header
170    }
171}
172
173// ---------------------------------------------------------------------------
174// Internal predictor state — either 16-bit or 32-bit path
175// ---------------------------------------------------------------------------
176
177enum Predictors {
178    Path16(Vec<Predictor3950>),
179    Path32(Vec<Predictor3950_32>),
180}
181
182/// A streaming APE decoder with seek support.
183pub struct ApeDecoder<R: Read + Seek> {
184    reader: R,
185    file_info: ApeFileInfo,
186    info: ApeInfo,
187    predictors: Predictors,
188    entropy_states: Vec<EntropyState>,
189    range_coder: RangeCoder,
190    interim_mode: bool,
191}
192
193impl<R: Read + Seek> ApeDecoder<R> {
194    /// Open an APE file and parse its header.
195    pub fn new(mut reader: R) -> ApeResult<Self> {
196        let file_info = format::parse(&mut reader)?;
197        let version = file_info.descriptor.version as i32;
198        let channels = file_info.header.channels;
199        let bits = file_info.header.bits_per_sample;
200        let compression = file_info.header.compression_level as u32;
201
202        if version < 3950 {
203            return Err(ApeError::UnsupportedVersion(file_info.descriptor.version));
204        }
205
206        let predictors = if bits >= 32 {
207            Predictors::Path32(
208                (0..channels)
209                    .map(|_| Predictor3950_32::new(compression, version))
210                    .collect(),
211            )
212        } else {
213            Predictors::Path16(
214                (0..channels)
215                    .map(|_| Predictor3950::new(compression, version, bits))
216                    .collect(),
217            )
218        };
219
220        let entropy_states = (0..channels).map(|_| EntropyState::new()).collect();
221        let info = ApeInfo::from_file_info(&file_info);
222
223        Ok(ApeDecoder {
224            reader,
225            file_info,
226            info,
227            predictors,
228            entropy_states,
229            range_coder: RangeCoder::new(),
230            interim_mode: false,
231        })
232    }
233
234    /// Get file metadata.
235    pub fn info(&self) -> &ApeInfo {
236        &self.info
237    }
238
239    /// Total number of frames in the file.
240    pub fn total_frames(&self) -> u32 {
241        self.info.total_frames
242    }
243
244    /// Decode a single frame by index, returning raw PCM bytes.
245    pub fn decode_frame(&mut self, frame_idx: u32) -> ApeResult<Vec<u8>> {
246        if frame_idx >= self.info.total_frames {
247            return Err(ApeError::DecodingError("frame index out of bounds"));
248        }
249
250        let frame_data = self.read_frame_data(frame_idx)?;
251        let seek_remainder = self.seek_remainder(frame_idx);
252        let frame_blocks = self.file_info.frame_block_count(frame_idx) as usize;
253        let version = self.info.version as i32;
254        let channels = self.info.channels;
255        let bits = self.info.bits_per_sample;
256        let block_align = self.info.block_align as usize;
257
258        match &mut self.predictors {
259            Predictors::Path16(predictors) => {
260                let result = try_decode_frame_16(
261                    &frame_data,
262                    seek_remainder,
263                    frame_blocks,
264                    version,
265                    channels,
266                    bits,
267                    block_align,
268                    predictors,
269                    &mut self.entropy_states,
270                    &mut self.range_coder,
271                );
272
273                match result {
274                    Ok(pcm) => Ok(pcm),
275                    Err(ApeError::InvalidChecksum) if bits == 24 && !self.interim_mode => {
276                        self.interim_mode = true;
277                        for p in predictors.iter_mut() {
278                            p.set_interim_mode(true);
279                        }
280                        try_decode_frame_16(
281                            &frame_data,
282                            seek_remainder,
283                            frame_blocks,
284                            version,
285                            channels,
286                            bits,
287                            block_align,
288                            predictors,
289                            &mut self.entropy_states,
290                            &mut self.range_coder,
291                        )
292                    }
293                    Err(e) => Err(e),
294                }
295            }
296            Predictors::Path32(predictors) => try_decode_frame_32(
297                &frame_data,
298                seek_remainder,
299                frame_blocks,
300                version,
301                channels,
302                bits,
303                block_align,
304                predictors,
305                &mut self.entropy_states,
306                &mut self.range_coder,
307            ),
308        }
309    }
310
311    /// Decode all frames, returning all PCM bytes.
312    pub fn decode_all(&mut self) -> ApeResult<Vec<u8>> {
313        let total_pcm_bytes = (self.info.total_samples as usize)
314            .checked_mul(self.info.block_align as usize)
315            .ok_or(ApeError::InvalidFormat("total PCM size overflow"))?;
316        // Cap at 2 GB to prevent OOM from malformed headers
317        if total_pcm_bytes > 2 * 1024 * 1024 * 1024 {
318            return Err(ApeError::InvalidFormat("total PCM size exceeds 2 GB"));
319        }
320        let mut pcm_output = Vec::with_capacity(total_pcm_bytes);
321
322        for frame_idx in 0..self.info.total_frames {
323            let frame_pcm = self.decode_frame(frame_idx)?;
324            pcm_output.extend_from_slice(&frame_pcm);
325        }
326
327        Ok(pcm_output)
328    }
329
330    /// Decode all frames with a progress closure.
331    ///
332    /// The closure receives a progress fraction (0.0 to 1.0) after each frame.
333    /// Return `true` to continue, `false` to cancel decoding.
334    pub fn decode_all_with<F: FnMut(f64) -> bool>(
335        &mut self,
336        mut on_progress: F,
337    ) -> ApeResult<Vec<u8>> {
338        let total = self.info.total_frames as f64;
339        let total_pcm_bytes = (self.info.total_samples as usize)
340            .checked_mul(self.info.block_align as usize)
341            .ok_or(ApeError::InvalidFormat("total PCM size overflow"))?;
342        if total_pcm_bytes > 2 * 1024 * 1024 * 1024 {
343            return Err(ApeError::InvalidFormat("total PCM size exceeds 2 GB"));
344        }
345        let mut pcm_output = Vec::with_capacity(total_pcm_bytes);
346
347        for frame_idx in 0..self.info.total_frames {
348            let frame_pcm = self.decode_frame(frame_idx)?;
349            pcm_output.extend_from_slice(&frame_pcm);
350
351            if !on_progress((frame_idx + 1) as f64 / total) {
352                return Err(ApeError::DecodingError("cancelled"));
353            }
354        }
355
356        Ok(pcm_output)
357    }
358
359    /// Decode all frames using multiple threads for parallel decoding.
360    ///
361    /// Frame data is read sequentially (IO is serial), but frame decoding runs
362    /// in parallel across `thread_count` threads. Falls back to single-threaded
363    /// if `thread_count <= 1`.
364    ///
365    /// Output is byte-identical to `decode_all()`.
366    pub fn decode_all_parallel(&mut self, thread_count: usize) -> ApeResult<Vec<u8>> {
367        if thread_count <= 1 {
368            return self.decode_all();
369        }
370
371        let total_frames = self.info.total_frames;
372        let version = self.info.version as i32;
373        let channels = self.info.channels;
374        let bits = self.info.bits_per_sample;
375        let compression = self.info.compression_level as u32;
376        let block_align = self.info.block_align as usize;
377
378        // Step 1: Read all frame data sequentially (IO must be serial)
379        let mut frame_data_list: Vec<(Vec<u8>, u32, usize)> =
380            Vec::with_capacity(total_frames as usize);
381        for frame_idx in 0..total_frames {
382            let data = self.read_frame_data(frame_idx)?;
383            let seek_remainder = self.seek_remainder(frame_idx);
384            let frame_blocks = self.file_info.frame_block_count(frame_idx) as usize;
385            frame_data_list.push((data, seek_remainder, frame_blocks));
386        }
387
388        // Step 2: Decode frames in parallel using std::thread
389        let chunk_size = (total_frames as usize + thread_count - 1) / thread_count;
390        let chunks: Vec<Vec<(usize, Vec<u8>, u32, usize)>> = frame_data_list
391            .into_iter()
392            .enumerate()
393            .collect::<Vec<_>>()
394            .chunks(chunk_size)
395            .map(|chunk| {
396                chunk
397                    .iter()
398                    .map(|(i, (data, sr, fb))| (*i, data.clone(), *sr, *fb))
399                    .collect()
400            })
401            .collect();
402
403        let mut handles = Vec::new();
404        for chunk in chunks {
405            let v = version;
406            let ch = channels;
407            let b = bits;
408            let comp = compression;
409            let ba = block_align;
410
411            handles.push(std::thread::spawn(
412                move || -> ApeResult<Vec<(usize, Vec<u8>)>> {
413                    let mut results = Vec::with_capacity(chunk.len());
414
415                    // Each thread creates its own decoder state
416                    let mut predictors: Vec<Predictor3950> =
417                        (0..ch).map(|_| Predictor3950::new(comp, v, b)).collect();
418                    let mut entropy_states: Vec<EntropyState> =
419                        (0..ch).map(|_| EntropyState::new()).collect();
420                    let mut range_coder = RangeCoder::new();
421
422                    for (frame_idx, frame_data, seek_remainder, frame_blocks) in chunk {
423                        let pcm = if b >= 32 {
424                            let mut preds32: Vec<Predictor3950_32> =
425                                (0..ch).map(|_| Predictor3950_32::new(comp, v)).collect();
426                            try_decode_frame_32(
427                                &frame_data,
428                                seek_remainder,
429                                frame_blocks,
430                                v,
431                                ch,
432                                b,
433                                ba,
434                                &mut preds32,
435                                &mut entropy_states,
436                                &mut range_coder,
437                            )?
438                        } else {
439                            try_decode_frame_16(
440                                &frame_data,
441                                seek_remainder,
442                                frame_blocks,
443                                v,
444                                ch,
445                                b,
446                                ba,
447                                &mut predictors,
448                                &mut entropy_states,
449                                &mut range_coder,
450                            )?
451                        };
452                        results.push((frame_idx, pcm));
453                    }
454                    Ok(results)
455                },
456            ));
457        }
458
459        // Step 3: Collect results in order
460        let mut all_results: Vec<(usize, Vec<u8>)> = Vec::with_capacity(total_frames as usize);
461        for handle in handles {
462            let chunk_results = handle
463                .join()
464                .map_err(|_| ApeError::DecodingError("thread panicked"))??;
465            all_results.extend(chunk_results);
466        }
467        all_results.sort_by_key(|(idx, _)| *idx);
468
469        let total_pcm = self.info.total_samples as usize * block_align;
470        let mut pcm_output = Vec::with_capacity(total_pcm);
471        for (_, pcm) in all_results {
472            pcm_output.extend_from_slice(&pcm);
473        }
474
475        Ok(pcm_output)
476    }
477
478    /// Decode a sample range, returning only the PCM bytes within
479    /// `start_sample..end_sample` (exclusive end).
480    ///
481    /// This is more efficient than `decode_all()` for extracting a portion of a file,
482    /// as it only decodes the frames that overlap the requested range.
483    pub fn decode_range(&mut self, start_sample: u64, end_sample: u64) -> ApeResult<Vec<u8>> {
484        let start = start_sample.min(self.info.total_samples);
485        let end = end_sample.min(self.info.total_samples);
486        if start >= end {
487            return Ok(Vec::new());
488        }
489
490        let bpf = self.info.blocks_per_frame as u64;
491        let block_align = self.info.block_align as usize;
492        let first_frame = (start / bpf) as u32;
493        let last_frame = ((end - 1) / bpf).min(self.info.total_frames as u64 - 1) as u32;
494
495        let range_samples = (end - start) as usize;
496        let mut pcm_output = Vec::with_capacity(range_samples * block_align);
497
498        for frame_idx in first_frame..=last_frame {
499            let frame_pcm = self.decode_frame(frame_idx)?;
500            let frame_start_sample = frame_idx as u64 * bpf;
501            let frame_end_sample = frame_start_sample + self.info.frame_samples(frame_idx) as u64;
502
503            // Compute overlap between frame and requested range
504            let overlap_start = start.max(frame_start_sample) - frame_start_sample;
505            let overlap_end = end.min(frame_end_sample) - frame_start_sample;
506
507            let byte_start = overlap_start as usize * block_align;
508            let byte_end = overlap_end as usize * block_align;
509
510            if byte_end <= frame_pcm.len() {
511                pcm_output.extend_from_slice(&frame_pcm[byte_start..byte_end]);
512            }
513        }
514
515        Ok(pcm_output)
516    }
517
518    /// Seek to a specific sample position. Returns a `SeekResult` with the
519    /// frame index, number of samples to skip within that frame, and the
520    /// exact sample position.
521    pub fn seek(&mut self, sample: u64) -> ApeResult<SeekResult> {
522        if self.info.total_frames == 0 {
523            return Ok(SeekResult {
524                frame_index: 0,
525                skip_samples: 0,
526                actual_sample: 0,
527            });
528        }
529        let sample = sample.min(self.info.total_samples.saturating_sub(1));
530        let frame_index = (sample / self.info.blocks_per_frame as u64) as u32;
531        let frame_index = frame_index.min(self.info.total_frames - 1);
532        let frame_start = frame_index as u64 * self.info.blocks_per_frame as u64;
533        let skip_samples = (sample - frame_start) as u32;
534
535        Ok(SeekResult {
536            frame_index,
537            skip_samples,
538            actual_sample: sample,
539        })
540    }
541
542    /// Seek to a sample position and return PCM from that point to the end
543    /// of the containing frame.
544    pub fn decode_from(&mut self, sample: u64) -> ApeResult<Vec<u8>> {
545        let pos = self.seek(sample)?;
546        let frame_pcm = self.decode_frame(pos.frame_index)?;
547        let skip_bytes = pos.skip_samples as usize * self.info.block_align as usize;
548        Ok(frame_pcm[skip_bytes..].to_vec())
549    }
550
551    /// Get the original WAV header data stored in the APE file.
552    /// Returns `None` if the `CREATE_WAV_HEADER` flag is set (header not stored).
553    pub fn wav_header_data(&self) -> Option<&[u8]> {
554        if self.file_info.wav_header_data.is_empty() {
555            None
556        } else {
557            Some(&self.file_info.wav_header_data)
558        }
559    }
560
561    /// Get the number of terminating data bytes from the original container.
562    pub fn wav_terminating_bytes(&self) -> u32 {
563        self.file_info.terminating_data_bytes
564    }
565
566    /// Read and parse APE tags from the file (APEv2 format).
567    /// Returns `None` if no tag is present.
568    pub fn read_tag(&mut self) -> ApeResult<Option<ApeTag>> {
569        tag::read_tag(&mut self.reader)
570    }
571
572    /// Read and parse an ID3v2 tag from the beginning of the file.
573    /// Returns `None` if no ID3v2 header is present.
574    pub fn read_id3v2_tag(&mut self) -> ApeResult<Option<Id3v2Tag>> {
575        id3v2::read_id3v2(&mut self.reader)
576    }
577
578    /// Get the stored MD5 hash from the APE descriptor.
579    pub fn stored_md5(&self) -> &[u8; 16] {
580        &self.file_info.descriptor.md5
581    }
582
583    /// Quick verify: compute MD5 over raw file sections and compare against
584    /// the stored hash in the APE descriptor. Returns `Ok(true)` if the hash
585    /// matches, `Ok(false)` if it doesn't, or `Err` on I/O failure.
586    ///
587    /// This validates file integrity without decompressing the audio.
588    /// Requires version >= 3980.
589    pub fn verify_md5(&mut self) -> ApeResult<bool> {
590        use md5::{Digest, Md5};
591
592        let desc = &self.file_info.descriptor;
593
594        // MD5 only available for version >= 3980 with a descriptor
595        if desc.version < 3980 {
596            return Err(ApeError::UnsupportedVersion(desc.version));
597        }
598
599        // Check if MD5 is all zeros (not set)
600        if desc.md5 == [0u8; 16] {
601            return Ok(true); // No MD5 stored, consider valid
602        }
603
604        let junk = self.file_info.junk_header_bytes as u64;
605        let desc_bytes = desc.descriptor_bytes as u64;
606        let header_bytes = desc.header_bytes as u64;
607        let seek_table_bytes = desc.seek_table_bytes as u64;
608        let header_data_bytes = desc.header_data_bytes as u64;
609        let frame_data_bytes = self.file_info.ape_frame_data_bytes;
610        let term_bytes = desc.terminating_data_bytes as u64;
611
612        let mut hasher = Md5::new();
613
614        // 1. Hash header data (WAV header stored in APE file)
615        let header_data_pos = junk + desc_bytes + header_bytes + seek_table_bytes;
616        self.reader.seek(SeekFrom::Start(header_data_pos))?;
617        copy_to_hasher(&mut self.reader, &mut hasher, header_data_bytes)?;
618
619        // 2. Hash frame data + terminating data (compressed audio + post-audio)
620        // (reader is already positioned at frame data start)
621        copy_to_hasher(&mut self.reader, &mut hasher, frame_data_bytes + term_bytes)?;
622
623        // 3. Hash APE header (out-of-order — header is hashed AFTER audio data)
624        let header_pos = junk + desc_bytes;
625        self.reader.seek(SeekFrom::Start(header_pos))?;
626        copy_to_hasher(&mut self.reader, &mut hasher, header_bytes)?;
627
628        // 4. Hash seek table
629        // (reader is already positioned at seek table start)
630        copy_to_hasher(&mut self.reader, &mut hasher, seek_table_bytes)?;
631
632        // Compare
633        let computed: [u8; 16] = hasher.finalize().into();
634        Ok(computed == desc.md5)
635    }
636
637    /// Returns an iterator over decoded frames.
638    pub fn frames(&mut self) -> FrameIterator<'_, R> {
639        FrameIterator {
640            decoder: self,
641            current_frame: 0,
642        }
643    }
644
645    /// Access the parsed file info (header, seek table, derived values).
646    pub fn file_info(&self) -> &ApeFileInfo {
647        &self.file_info
648    }
649
650    /// Get the byte alignment remainder for a given frame.
651    /// This value is needed when decoding raw frame bytes externally.
652    pub fn seek_remainder(&self, frame_idx: u32) -> u32 {
653        let seek_byte = self.file_info.seek_byte(frame_idx);
654        let seek_byte_0 = self.file_info.seek_byte(0);
655        ((seek_byte - seek_byte_0) % 4) as u32
656    }
657
658    /// Read the compressed data for a given frame from the source.
659    /// The returned bytes include alignment prefix and padding as needed
660    /// by the APE bitreader. Use `seek_remainder()` to get the bit offset.
661    pub fn read_frame_data(&mut self, frame_idx: u32) -> ApeResult<Vec<u8>> {
662        let seek_byte = self.file_info.seek_byte(frame_idx);
663        let seek_remainder = self.seek_remainder(frame_idx);
664        let frame_bytes = self.file_info.frame_byte_count(frame_idx);
665        if frame_bytes > 64 * 1024 * 1024 {
666            return Err(ApeError::InvalidFormat("frame data exceeds 64 MB"));
667        }
668        let read_bytes = (frame_bytes as u32 + seek_remainder + 4) as usize;
669
670        self.reader
671            .seek(SeekFrom::Start(seek_byte - seek_remainder as u64))?;
672        let mut frame_data = vec![0u8; read_bytes];
673        let bytes_read = self.reader.read(&mut frame_data)?;
674        if bytes_read < read_bytes.saturating_sub(4) {
675            return Err(ApeError::DecodingError("short read on frame data"));
676        }
677        frame_data.truncate(bytes_read);
678        Ok(frame_data)
679    }
680}
681
682/// Iterator that yields decoded frames as raw PCM bytes.
683pub struct FrameIterator<'a, R: Read + Seek> {
684    decoder: &'a mut ApeDecoder<R>,
685    current_frame: u32,
686}
687
688impl<'a, R: Read + Seek> Iterator for FrameIterator<'a, R> {
689    type Item = ApeResult<Vec<u8>>;
690
691    fn next(&mut self) -> Option<Self::Item> {
692        if self.current_frame >= self.decoder.info.total_frames {
693            return None;
694        }
695        let frame_idx = self.current_frame;
696        self.current_frame += 1;
697        Some(self.decoder.decode_frame(frame_idx))
698    }
699
700    fn size_hint(&self) -> (usize, Option<usize>) {
701        let remaining = (self.decoder.info.total_frames - self.current_frame) as usize;
702        (remaining, Some(remaining))
703    }
704}
705
706/// Convenience: decode an entire APE file to raw PCM bytes.
707pub fn decode<R: Read + Seek>(reader: &mut R) -> ApeResult<Vec<u8>> {
708    // ApeDecoder::new takes ownership, so we need to pass a reference wrapper
709    // that implements Read + Seek. Since reader is &mut R where R: Read + Seek,
710    // we can use it directly because &mut R also implements Read + Seek.
711    let mut decoder = ApeDecoder::new_from_ref(reader)?;
712    decoder.decode_all()
713}
714
715impl<R: Read + Seek> ApeDecoder<R> {
716    fn new_from_ref<'a>(reader: &'a mut R) -> ApeResult<ApeDecoder<&'a mut R>> {
717        ApeDecoder::new(reader)
718    }
719}
720
721// ---------------------------------------------------------------------------
722// Frame decode implementations (shared between owned and borrowed paths)
723// ---------------------------------------------------------------------------
724
725fn try_decode_frame_16(
726    frame_data: &[u8],
727    seek_remainder: u32,
728    frame_blocks: usize,
729    version: i32,
730    channels: u16,
731    bits: u16,
732    block_align: usize,
733    predictors: &mut [Predictor3950],
734    entropy_states: &mut [EntropyState],
735    range_coder: &mut RangeCoder,
736) -> ApeResult<Vec<u8>> {
737    let mut br = BitReader::from_frame_bytes(frame_data, seek_remainder * 8);
738
739    // --- StartFrame ---
740    let mut stored_crc = br.decode_value_x_bits(32);
741    let mut special_codes: i32 = 0;
742    if version > 3820 {
743        if stored_crc & 0x80000000 != 0 {
744            special_codes = br.decode_value_x_bits(32) as i32;
745        }
746        stored_crc &= 0x7FFFFFFF;
747    }
748
749    for p in predictors.iter_mut() {
750        p.flush();
751    }
752    for s in entropy_states.iter_mut() {
753        s.flush();
754    }
755    range_coder.flush_bit_array(&mut br);
756
757    let mut last_x: i32 = 0;
758    let pcm_size = frame_blocks
759        .checked_mul(block_align)
760        .ok_or(ApeError::InvalidFormat("frame too large"))?;
761    if pcm_size > 64 * 1024 * 1024 {
762        return Err(ApeError::InvalidFormat("frame PCM size exceeds 64 MB"));
763    }
764    let mut pcm_output = Vec::with_capacity(pcm_size);
765
766    let decode_result: ApeResult<()> = (|| {
767        if channels == 2 {
768            if (special_codes & SPECIAL_FRAME_LEFT_SILENCE) != 0
769                && (special_codes & SPECIAL_FRAME_RIGHT_SILENCE) != 0
770            {
771                for _ in 0..frame_blocks {
772                    unprepare::unprepare(&[0, 0], channels, bits, &mut pcm_output)?;
773                }
774            } else if (special_codes & SPECIAL_FRAME_PSEUDO_STEREO) != 0 {
775                for _ in 0..frame_blocks {
776                    let val = entropy_states[0].decode_value_range(range_coder, &mut br)?;
777                    let x = predictors[0].decompress_value(val, 0);
778                    unprepare::unprepare(&[x, 0], channels, bits, &mut pcm_output)?;
779                }
780            } else if version >= 3950 {
781                for _ in 0..frame_blocks {
782                    let ny = entropy_states[1].decode_value_range(range_coder, &mut br)?;
783                    let nx = entropy_states[0].decode_value_range(range_coder, &mut br)?;
784                    let y = predictors[1].decompress_value(ny, last_x as i64);
785                    let x = predictors[0].decompress_value(nx, y as i64);
786                    last_x = x;
787                    unprepare::unprepare(&[x, y], channels, bits, &mut pcm_output)?;
788                }
789            } else {
790                for _ in 0..frame_blocks {
791                    let ex = entropy_states[0].decode_value_range(range_coder, &mut br)?;
792                    let ey = entropy_states[1].decode_value_range(range_coder, &mut br)?;
793                    let x = predictors[0].decompress_value(ex, 0);
794                    let y = predictors[1].decompress_value(ey, 0);
795                    unprepare::unprepare(&[x, y], channels, bits, &mut pcm_output)?;
796                }
797            }
798        } else if channels == 1 {
799            if (special_codes & SPECIAL_FRAME_MONO_SILENCE) != 0 {
800                for _ in 0..frame_blocks {
801                    unprepare::unprepare(&[0], channels, bits, &mut pcm_output)?;
802                }
803            } else {
804                for _ in 0..frame_blocks {
805                    let val = entropy_states[0].decode_value_range(range_coder, &mut br)?;
806                    let decoded = predictors[0].decompress_value(val, 0);
807                    unprepare::unprepare(&[decoded], channels, bits, &mut pcm_output)?;
808                }
809            }
810        } else {
811            let ch = channels as usize;
812            let mut values = vec![0i32; ch];
813            for _ in 0..frame_blocks {
814                for c in 0..ch {
815                    let val = entropy_states[c].decode_value_range(range_coder, &mut br)?;
816                    values[c] = predictors[c].decompress_value(val, 0);
817                }
818                unprepare::unprepare(&values, channels, bits, &mut pcm_output)?;
819            }
820        }
821        Ok(())
822    })();
823
824    decode_result?;
825
826    // --- EndFrame ---
827    range_coder.finalize(&mut br);
828    let computed_crc = ape_crc(&pcm_output);
829    if computed_crc != stored_crc {
830        return Err(ApeError::InvalidChecksum);
831    }
832
833    // Post-processing transforms (applied AFTER CRC, matching C++ GetData behavior)
834    apply_post_processing(&mut pcm_output, bits, channels);
835
836    Ok(pcm_output)
837}
838
839fn try_decode_frame_32(
840    frame_data: &[u8],
841    seek_remainder: u32,
842    frame_blocks: usize,
843    version: i32,
844    channels: u16,
845    bits: u16,
846    block_align: usize,
847    predictors: &mut [Predictor3950_32],
848    entropy_states: &mut [EntropyState],
849    range_coder: &mut RangeCoder,
850) -> ApeResult<Vec<u8>> {
851    let mut br = BitReader::from_frame_bytes(frame_data, seek_remainder * 8);
852
853    let mut stored_crc = br.decode_value_x_bits(32);
854    let mut special_codes: i32 = 0;
855    if version > 3820 {
856        if stored_crc & 0x80000000 != 0 {
857            special_codes = br.decode_value_x_bits(32) as i32;
858        }
859        stored_crc &= 0x7FFFFFFF;
860    }
861
862    for p in predictors.iter_mut() {
863        p.flush();
864    }
865    for s in entropy_states.iter_mut() {
866        s.flush();
867    }
868    range_coder.flush_bit_array(&mut br);
869
870    let mut last_x: i64 = 0;
871    let pcm_size = frame_blocks
872        .checked_mul(block_align)
873        .ok_or(ApeError::InvalidFormat("frame too large"))?;
874    if pcm_size > 64 * 1024 * 1024 {
875        return Err(ApeError::InvalidFormat("frame PCM size exceeds 64 MB"));
876    }
877    let mut pcm_output = Vec::with_capacity(pcm_size);
878
879    if channels == 2 {
880        if (special_codes & SPECIAL_FRAME_LEFT_SILENCE) != 0
881            && (special_codes & SPECIAL_FRAME_RIGHT_SILENCE) != 0
882        {
883            for _ in 0..frame_blocks {
884                unprepare::unprepare(&[0, 0], channels, bits, &mut pcm_output)?;
885            }
886        } else if (special_codes & SPECIAL_FRAME_PSEUDO_STEREO) != 0 {
887            for _ in 0..frame_blocks {
888                let val = entropy_states[0].decode_value_range(range_coder, &mut br)?;
889                let x = predictors[0].decompress_value(val, 0);
890                unprepare::unprepare(&[x as i32, 0], channels, bits, &mut pcm_output)?;
891            }
892        } else {
893            for _ in 0..frame_blocks {
894                let ny = entropy_states[1].decode_value_range(range_coder, &mut br)?;
895                let nx = entropy_states[0].decode_value_range(range_coder, &mut br)?;
896                let y = predictors[1].decompress_value(ny, last_x);
897                let x = predictors[0].decompress_value(nx, y as i64);
898                last_x = x as i64;
899                unprepare::unprepare(&[x as i32, y as i32], channels, bits, &mut pcm_output)?;
900            }
901        }
902    } else if channels == 1 {
903        if (special_codes & SPECIAL_FRAME_MONO_SILENCE) != 0 {
904            for _ in 0..frame_blocks {
905                unprepare::unprepare(&[0], channels, bits, &mut pcm_output)?;
906            }
907        } else {
908            for _ in 0..frame_blocks {
909                let val = entropy_states[0].decode_value_range(range_coder, &mut br)?;
910                let decoded = predictors[0].decompress_value(val, 0);
911                unprepare::unprepare(&[decoded as i32], channels, bits, &mut pcm_output)?;
912            }
913        }
914    }
915
916    range_coder.finalize(&mut br);
917    let computed_crc = ape_crc(&pcm_output);
918    if computed_crc != stored_crc {
919        return Err(ApeError::InvalidChecksum);
920    }
921
922    // Post-processing transforms (applied AFTER CRC, matching C++ GetData behavior)
923    apply_post_processing(&mut pcm_output, bits, channels);
924
925    Ok(pcm_output)
926}
927
928// ---------------------------------------------------------------------------
929// Post-processing transforms (applied after CRC verification)
930// ---------------------------------------------------------------------------
931
932/// Apply format-flag-dependent transforms to decoded PCM data.
933///
934/// Copy `n` bytes from a reader into an MD5 hasher in 16KB chunks.
935fn copy_to_hasher<R: Read>(reader: &mut R, hasher: &mut md5::Md5, mut n: u64) -> ApeResult<()> {
936    use md5::Digest;
937    let mut buf = [0u8; 16384];
938    while n > 0 {
939        let to_read = (n as usize).min(buf.len());
940        reader.read_exact(&mut buf[..to_read])?;
941        hasher.update(&buf[..to_read]);
942        n -= to_read as u64;
943    }
944    Ok(())
945}
946
947/// These are applied AFTER CRC verification and match the C++ `GetData()` behavior.
948/// For WAV-sourced files (the common case), all flags are 0 and this is a no-op.
949fn apply_post_processing(pcm: &mut [u8], bits: u16, _channels: u16) {
950    // The format flags are embedded in the APE header and control how the raw
951    // PCM bytes should be transformed for the output format. Since our decoder
952    // targets the same format as the source, these transforms are only needed
953    // when the source was in a non-standard format.
954    //
955    // Note: In the current implementation, format flags are exposed via ApeInfo
956    // but the caller is responsible for checking them. The transforms below
957    // would be applied when the corresponding flags are set, but since all
958    // our test fixtures are standard WAV (flags = 0), they're not exercised.
959    //
960    // The transforms are documented here for future implementation if needed:
961    //
962    // APE_FORMAT_FLAG_FLOATING_POINT: apply FloatTransform to each 32-bit sample
963    // APE_FORMAT_FLAG_SIGNED_8_BIT: add 128 (wrapping) to each byte
964    // APE_FORMAT_FLAG_BIG_ENDIAN: byte-swap each sample
965    let _ = (pcm, bits);
966}
967
968/// IEEE 754 float transform for floating-point APE files.
969///
970/// Converts between APE's internal integer representation and IEEE 754 float
971/// bit patterns. The transform is its own inverse.
972#[allow(dead_code)]
973fn float_transform_sample(sample_in: u32) -> u32 {
974    let mut out: u32 = 0;
975    out |= sample_in & 0xC3FF_FFFF;
976    out |= !(sample_in & 0x3C00_0000) ^ 0xC3FF_FFFF;
977    if out & 0x8000_0000 != 0 {
978        out = !out | 0x8000_0000;
979    }
980    out
981}
982
983/// Byte-swap samples for big-endian output format.
984#[allow(dead_code)]
985fn byte_swap_samples(pcm: &mut [u8], bytes_per_sample: usize) {
986    match bytes_per_sample {
987        2 => {
988            for chunk in pcm.chunks_exact_mut(2) {
989                chunk.swap(0, 1);
990            }
991        }
992        3 => {
993            for chunk in pcm.chunks_exact_mut(3) {
994                chunk.swap(0, 2);
995            }
996        }
997        4 => {
998            for chunk in pcm.chunks_exact_mut(4) {
999                chunk.swap(0, 3);
1000                chunk.swap(1, 2);
1001            }
1002        }
1003        _ => {}
1004    }
1005}
1006
1007// ---------------------------------------------------------------------------
1008// FrameDecoder — stateful frame decoder for external demuxer integration
1009// ---------------------------------------------------------------------------
1010
1011/// A stateful APE frame decoder that works on raw compressed frame bytes.
1012///
1013/// Unlike [`ApeDecoder`] which owns a reader and handles both demuxing and
1014/// decoding, `FrameDecoder` only performs decoding. It is designed for
1015/// integration with external demuxers (e.g., Symphonia) that manage I/O
1016/// separately and supply compressed frame data as byte slices.
1017///
1018/// # Usage
1019///
1020/// ```rust,ignore
1021/// let mut fd = FrameDecoder::new(version, channels, bits_per_sample, compression_level);
1022/// let pcm = fd.decode_frame(&frame_bytes, seek_remainder, frame_blocks)?;
1023/// ```
1024pub struct FrameDecoder {
1025    predictors: Predictors,
1026    entropy_states: Vec<EntropyState>,
1027    range_coder: RangeCoder,
1028    version: i32,
1029    channels: u16,
1030    bits_per_sample: u16,
1031    block_align: usize,
1032    interim_mode: bool,
1033}
1034
1035impl FrameDecoder {
1036    /// Create a new `FrameDecoder` with the given APE stream parameters.
1037    ///
1038    /// * `version` — APE file version (e.g., 3990). Must be >= 3950.
1039    /// * `channels` — Number of audio channels (1–32).
1040    /// * `bits_per_sample` — Bits per sample (8, 16, 24, or 32).
1041    /// * `compression_level` — Compression level (1000–5000).
1042    ///
1043    /// Returns an error if the parameters are invalid (unsupported version,
1044    /// zero channels, or unsupported bit depth).
1045    pub fn new(
1046        version: u16,
1047        channels: u16,
1048        bits_per_sample: u16,
1049        compression_level: u16,
1050    ) -> ApeResult<Self> {
1051        if version < 3950 {
1052            return Err(ApeError::UnsupportedVersion(version));
1053        }
1054        if channels == 0 {
1055            return Err(ApeError::InvalidFormat("channel count must be >= 1"));
1056        }
1057        if !matches!(bits_per_sample, 8 | 16 | 24 | 32) {
1058            return Err(ApeError::InvalidFormat(
1059                "bits per sample must be 8, 16, 24, or 32",
1060            ));
1061        }
1062
1063        let v = version as i32;
1064        let comp = compression_level as u32;
1065
1066        let predictors = if bits_per_sample >= 32 {
1067            Predictors::Path32(
1068                (0..channels)
1069                    .map(|_| Predictor3950_32::new(comp, v))
1070                    .collect(),
1071            )
1072        } else {
1073            Predictors::Path16(
1074                (0..channels)
1075                    .map(|_| Predictor3950::new(comp, v, bits_per_sample))
1076                    .collect(),
1077            )
1078        };
1079
1080        let entropy_states = (0..channels).map(|_| EntropyState::new()).collect();
1081        let bytes_per_sample = (bits_per_sample / 8) as usize;
1082        let block_align = bytes_per_sample * channels as usize;
1083
1084        Ok(FrameDecoder {
1085            predictors,
1086            entropy_states,
1087            range_coder: RangeCoder::new(),
1088            version: v,
1089            channels,
1090            bits_per_sample,
1091            block_align,
1092            interim_mode: false,
1093        })
1094    }
1095
1096    /// Decode a compressed frame to raw PCM bytes.
1097    ///
1098    /// * `frame_data` — Compressed frame bytes (including alignment prefix),
1099    ///   as returned by [`ApeDecoder::read_frame_data`].
1100    /// * `seek_remainder` — Byte alignment offset for this frame,
1101    ///   as returned by [`ApeDecoder::seek_remainder`].
1102    /// * `frame_blocks` — Number of audio blocks (samples per channel) in this frame.
1103    pub fn decode_frame(
1104        &mut self,
1105        frame_data: &[u8],
1106        seek_remainder: u32,
1107        frame_blocks: usize,
1108    ) -> ApeResult<Vec<u8>> {
1109        match &mut self.predictors {
1110            Predictors::Path16(predictors) => {
1111                let result = try_decode_frame_16(
1112                    frame_data,
1113                    seek_remainder,
1114                    frame_blocks,
1115                    self.version,
1116                    self.channels,
1117                    self.bits_per_sample,
1118                    self.block_align,
1119                    predictors,
1120                    &mut self.entropy_states,
1121                    &mut self.range_coder,
1122                );
1123
1124                match result {
1125                    Ok(pcm) => Ok(pcm),
1126                    Err(ApeError::InvalidChecksum)
1127                        if self.bits_per_sample == 24 && !self.interim_mode =>
1128                    {
1129                        self.interim_mode = true;
1130                        for p in predictors.iter_mut() {
1131                            p.set_interim_mode(true);
1132                        }
1133                        try_decode_frame_16(
1134                            frame_data,
1135                            seek_remainder,
1136                            frame_blocks,
1137                            self.version,
1138                            self.channels,
1139                            self.bits_per_sample,
1140                            self.block_align,
1141                            predictors,
1142                            &mut self.entropy_states,
1143                            &mut self.range_coder,
1144                        )
1145                    }
1146                    Err(e) => Err(e),
1147                }
1148            }
1149            Predictors::Path32(predictors) => try_decode_frame_32(
1150                frame_data,
1151                seek_remainder,
1152                frame_blocks,
1153                self.version,
1154                self.channels,
1155                self.bits_per_sample,
1156                self.block_align,
1157                predictors,
1158                &mut self.entropy_states,
1159                &mut self.range_coder,
1160            ),
1161        }
1162    }
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167    use super::*;
1168    use std::fs::File;
1169    use std::io::BufReader;
1170    use std::path::PathBuf;
1171
1172    fn test_fixture_path(name: &str) -> PathBuf {
1173        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1174            .join("tests/fixtures")
1175            .join(name)
1176    }
1177
1178    fn load_reference_pcm(name: &str) -> Vec<u8> {
1179        let path = test_fixture_path(&format!("ref/{}", name));
1180        let data = std::fs::read(&path)
1181            .unwrap_or_else(|e| panic!("Failed to read {}: {}", path.display(), e));
1182        data[44..].to_vec()
1183    }
1184
1185    fn open_ape(name: &str) -> BufReader<File> {
1186        let path = test_fixture_path(&format!("ape/{}", name));
1187        let file = File::open(&path)
1188            .unwrap_or_else(|e| panic!("Failed to open {}: {}", path.display(), e));
1189        BufReader::new(file)
1190    }
1191
1192    fn decode_ape_file(name: &str) -> ApeResult<Vec<u8>> {
1193        let mut reader = open_ape(name);
1194        decode(&mut reader)
1195    }
1196
1197    // --- Existing end-to-end tests (unchanged) ---
1198
1199    #[test]
1200    fn test_decode_sine_16s_c1000() {
1201        let decoded = decode_ape_file("sine_16s_c1000.ape").unwrap();
1202        let expected = load_reference_pcm("sine_16s_c1000.wav");
1203        assert_eq!(decoded.len(), expected.len());
1204        assert_eq!(decoded, expected);
1205    }
1206
1207    #[test]
1208    fn test_decode_sine_16s_c2000() {
1209        let decoded = decode_ape_file("sine_16s_c2000.ape").unwrap();
1210        let expected = load_reference_pcm("sine_16s_c2000.wav");
1211        assert_eq!(decoded, expected);
1212    }
1213
1214    #[test]
1215    fn test_decode_silence_16s() {
1216        let decoded = decode_ape_file("silence_16s_c2000.ape").unwrap();
1217        let expected = load_reference_pcm("silence_16s_c2000.wav");
1218        assert_eq!(decoded, expected);
1219    }
1220
1221    #[test]
1222    fn test_decode_sine_16m() {
1223        let decoded = decode_ape_file("sine_16m_c2000.ape").unwrap();
1224        let expected = load_reference_pcm("sine_16m_c2000.wav");
1225        assert_eq!(decoded, expected);
1226    }
1227
1228    #[test]
1229    fn test_decode_short_16s() {
1230        let decoded = decode_ape_file("short_16s_c2000.ape").unwrap();
1231        let expected = load_reference_pcm("short_16s_c2000.wav");
1232        assert_eq!(decoded, expected);
1233    }
1234
1235    #[test]
1236    fn test_decode_all_compression_levels() {
1237        for level in &["c1000", "c2000", "c3000", "c4000", "c5000"] {
1238            let name = format!("sine_16s_{}.ape", level);
1239            let ref_name = format!("sine_16s_{}.wav", level);
1240            let decoded = decode_ape_file(&name).unwrap_or_else(|e| panic!("{}: {:?}", name, e));
1241            let expected = load_reference_pcm(&ref_name);
1242            assert_eq!(decoded, expected, "Mismatch for {}", name);
1243        }
1244    }
1245
1246    #[test]
1247    fn test_decode_8bit() {
1248        let decoded = decode_ape_file("sine_8s_c2000.ape").unwrap();
1249        let expected = load_reference_pcm("sine_8s_c2000.wav");
1250        assert_eq!(decoded, expected);
1251    }
1252
1253    #[test]
1254    fn test_decode_24bit() {
1255        let decoded = decode_ape_file("sine_24s_c2000.ape").unwrap();
1256        let expected = load_reference_pcm("sine_24s_c2000.wav");
1257        assert_eq!(decoded, expected);
1258    }
1259
1260    #[test]
1261    fn test_decode_32bit() {
1262        let decoded = decode_ape_file("sine_32s_c2000.ape").unwrap();
1263        let expected = load_reference_pcm("sine_32s_c2000.wav");
1264        assert_eq!(decoded, expected);
1265    }
1266
1267    #[test]
1268    fn test_decode_multiframe() {
1269        let decoded = decode_ape_file("multiframe_16s_c2000.ape").unwrap();
1270        let expected = load_reference_pcm("multiframe_16s_c2000.wav");
1271        assert_eq!(decoded, expected);
1272    }
1273
1274    #[test]
1275    fn test_decode_identical_channels() {
1276        let decoded = decode_ape_file("identical_16s_c2000.ape").unwrap();
1277        let expected = load_reference_pcm("identical_16s_c2000.wav");
1278        assert_eq!(decoded, expected);
1279    }
1280
1281    #[test]
1282    fn test_decode_all_fixtures() {
1283        let fixtures = [
1284            "dc_offset_16s_c2000.ape",
1285            "identical_16s_c2000.ape",
1286            "impulse_16s_c2000.ape",
1287            "left_only_16s_c2000.ape",
1288            "multiframe_16s_c2000.ape",
1289            "noise_16s_c2000.ape",
1290            "short_16s_c2000.ape",
1291            "silence_16s_c2000.ape",
1292            "sine_16m_c2000.ape",
1293            "sine_16s_c1000.ape",
1294            "sine_16s_c2000.ape",
1295            "sine_16s_c3000.ape",
1296            "sine_16s_c4000.ape",
1297            "sine_16s_c5000.ape",
1298            "sine_24s_c2000.ape",
1299            "sine_32s_c2000.ape",
1300            "sine_8s_c2000.ape",
1301        ];
1302
1303        for fixture in &fixtures {
1304            let ref_name = fixture.replace(".ape", ".wav");
1305            let decoded = decode_ape_file(fixture)
1306                .unwrap_or_else(|e| panic!("Failed to decode {}: {:?}", fixture, e));
1307            let expected = load_reference_pcm(&ref_name);
1308            assert_eq!(
1309                decoded.len(),
1310                expected.len(),
1311                "Length mismatch for {}",
1312                fixture
1313            );
1314            assert_eq!(decoded, expected, "Data mismatch for {}", fixture);
1315        }
1316    }
1317
1318    // --- New streaming API tests ---
1319
1320    #[test]
1321    fn test_ape_decoder_info() {
1322        let reader = open_ape("sine_16s_c2000.ape");
1323        let decoder = ApeDecoder::new(reader).unwrap();
1324        let info = decoder.info();
1325        assert_eq!(info.sample_rate, 44100);
1326        assert_eq!(info.channels, 2);
1327        assert_eq!(info.bits_per_sample, 16);
1328        assert_eq!(info.total_samples, 44100);
1329        assert_eq!(info.compression_level, 2000);
1330        assert_eq!(info.block_align, 4);
1331    }
1332
1333    #[test]
1334    fn test_decode_frame_by_frame() {
1335        let reader = open_ape("sine_16s_c2000.ape");
1336        let mut decoder = ApeDecoder::new(reader).unwrap();
1337        let expected = load_reference_pcm("sine_16s_c2000.wav");
1338
1339        let mut all_pcm = Vec::new();
1340        for frame_idx in 0..decoder.total_frames() {
1341            let frame_pcm = decoder.decode_frame(frame_idx).unwrap();
1342            all_pcm.extend_from_slice(&frame_pcm);
1343        }
1344
1345        assert_eq!(all_pcm, expected);
1346    }
1347
1348    #[test]
1349    fn test_decode_multiframe_frame_by_frame() {
1350        let reader = open_ape("multiframe_16s_c2000.ape");
1351        let mut decoder = ApeDecoder::new(reader).unwrap();
1352        let expected = load_reference_pcm("multiframe_16s_c2000.wav");
1353
1354        assert!(decoder.total_frames() > 1, "Expected multiple frames");
1355
1356        let mut all_pcm = Vec::new();
1357        for frame_idx in 0..decoder.total_frames() {
1358            let frame_pcm = decoder.decode_frame(frame_idx).unwrap();
1359            assert!(!frame_pcm.is_empty());
1360            all_pcm.extend_from_slice(&frame_pcm);
1361        }
1362
1363        assert_eq!(all_pcm, expected);
1364    }
1365
1366    #[test]
1367    fn test_frames_iterator() {
1368        let reader = open_ape("sine_16s_c2000.ape");
1369        let mut decoder = ApeDecoder::new(reader).unwrap();
1370        let expected = load_reference_pcm("sine_16s_c2000.wav");
1371
1372        let all_pcm: Vec<u8> = decoder
1373            .frames()
1374            .collect::<Result<Vec<_>, _>>()
1375            .unwrap()
1376            .concat();
1377
1378        assert_eq!(all_pcm, expected);
1379    }
1380
1381    #[test]
1382    fn test_seek_sample_level() {
1383        let reader = open_ape("multiframe_16s_c2000.ape");
1384        let mut decoder = ApeDecoder::new(reader).unwrap();
1385        let bpf = decoder.info().blocks_per_frame as u64;
1386
1387        // Seek to sample 0 → frame 0, skip 0
1388        let r = decoder.seek(0).unwrap();
1389        assert_eq!(r.frame_index, 0);
1390        assert_eq!(r.skip_samples, 0);
1391        assert_eq!(r.actual_sample, 0);
1392
1393        // Seek to mid-frame → frame 0, skip 100
1394        let r = decoder.seek(100).unwrap();
1395        assert_eq!(r.frame_index, 0);
1396        assert_eq!(r.skip_samples, 100);
1397        assert_eq!(r.actual_sample, 100);
1398
1399        // Seek to exactly frame 1 → frame 1, skip 0
1400        let r = decoder.seek(bpf).unwrap();
1401        assert_eq!(r.frame_index, 1);
1402        assert_eq!(r.skip_samples, 0);
1403        assert_eq!(r.actual_sample, bpf);
1404
1405        // Seek to mid frame 1 → frame 1, skip 100
1406        let r = decoder.seek(bpf + 100).unwrap();
1407        assert_eq!(r.frame_index, 1);
1408        assert_eq!(r.skip_samples, 100);
1409        assert_eq!(r.actual_sample, bpf + 100);
1410
1411        // Seek past end → clamps to last sample
1412        let r = decoder.seek(u64::MAX).unwrap();
1413        assert_eq!(r.actual_sample, decoder.info().total_samples - 1);
1414    }
1415
1416    #[test]
1417    fn test_decode_from_mid_frame() {
1418        let reader = open_ape("sine_16s_c2000.ape");
1419        let mut decoder = ApeDecoder::new(reader).unwrap();
1420        let block_align = decoder.info().block_align as usize;
1421
1422        // Decode full frame
1423        let full_frame = decoder.decode_frame(0).unwrap();
1424
1425        // Decode from sample 100
1426        let partial = decoder.decode_from(100).unwrap();
1427
1428        // Partial should be full_frame minus the first 100 blocks
1429        let skip = 100 * block_align;
1430        assert_eq!(partial, &full_frame[skip..]);
1431    }
1432
1433    #[test]
1434    fn test_expanded_metadata() {
1435        let reader = open_ape("sine_16s_c2000.ape");
1436        let decoder = ApeDecoder::new(reader).unwrap();
1437        let info = decoder.info();
1438
1439        assert_eq!(info.bytes_per_sample, 2);
1440        assert_eq!(info.source_format, SourceFormat::Wav);
1441        assert!(!info.is_big_endian);
1442        assert!(!info.is_floating_point);
1443        assert!(!info.is_signed_8bit);
1444        assert!(info.average_bitrate_kbps > 0);
1445        assert!(info.decompressed_bitrate_kbps > 0);
1446        assert!(info.file_size_bytes > 0);
1447        assert_eq!(info.format_flags & 0x0200, 0); // not big-endian
1448    }
1449
1450    #[test]
1451    fn test_wav_header_data() {
1452        let reader = open_ape("sine_16s_c2000.ape");
1453        let decoder = ApeDecoder::new(reader).unwrap();
1454
1455        let header = decoder.wav_header_data();
1456        // Test files should have stored WAV headers
1457        if let Some(data) = header {
1458            assert!(data.len() >= 12);
1459            // Should start with RIFF
1460            assert_eq!(&data[0..4], b"RIFF");
1461        }
1462    }
1463
1464    #[test]
1465    fn test_read_tag() {
1466        let reader = open_ape("sine_16s_c2000.ape");
1467        let mut decoder = ApeDecoder::new(reader).unwrap();
1468        // Tag may or may not exist — just ensure no panic
1469        let _tag = decoder.read_tag();
1470    }
1471
1472    #[test]
1473    fn test_decode_frame_out_of_bounds() {
1474        let reader = open_ape("sine_16s_c2000.ape");
1475        let mut decoder = ApeDecoder::new(reader).unwrap();
1476        let result = decoder.decode_frame(999);
1477        assert!(result.is_err());
1478    }
1479
1480    // --- Progress callback tests ---
1481
1482    #[test]
1483    fn test_decode_with_progress() {
1484        let reader = open_ape("sine_16s_c2000.ape");
1485        let mut decoder = ApeDecoder::new(reader).unwrap();
1486        let expected = load_reference_pcm("sine_16s_c2000.wav");
1487
1488        let mut last_progress = 0.0f64;
1489        let decoded = decoder
1490            .decode_all_with(|p| {
1491                assert!(p >= last_progress, "progress must be monotonic");
1492                last_progress = p;
1493                true // continue
1494            })
1495            .unwrap();
1496
1497        assert!((last_progress - 1.0).abs() < 0.01);
1498        assert_eq!(decoded, expected);
1499    }
1500
1501    #[test]
1502    fn test_decode_with_cancel() {
1503        let reader = open_ape("multiframe_16s_c2000.ape");
1504        let mut decoder = ApeDecoder::new(reader).unwrap();
1505
1506        let result = decoder.decode_all_with(|p| {
1507            p < 0.5 // cancel halfway
1508        });
1509
1510        assert!(result.is_err());
1511    }
1512
1513    // --- Range decoding tests ---
1514
1515    #[test]
1516    fn test_decode_range_full_file() {
1517        let reader = open_ape("sine_16s_c2000.ape");
1518        let mut decoder = ApeDecoder::new(reader).unwrap();
1519        let total = decoder.info().total_samples;
1520        let expected = load_reference_pcm("sine_16s_c2000.wav");
1521
1522        let decoded = decoder.decode_range(0, total).unwrap();
1523        assert_eq!(decoded, expected);
1524    }
1525
1526    #[test]
1527    fn test_decode_range_subset() {
1528        let reader = open_ape("sine_16s_c2000.ape");
1529        let mut decoder = ApeDecoder::new(reader).unwrap();
1530        let block_align = decoder.info().block_align as usize;
1531        let expected = load_reference_pcm("sine_16s_c2000.wav");
1532
1533        // Decode samples 100..200
1534        let decoded = decoder.decode_range(100, 200).unwrap();
1535        assert_eq!(decoded.len(), 100 * block_align);
1536        assert_eq!(decoded, &expected[100 * block_align..200 * block_align]);
1537    }
1538
1539    #[test]
1540    fn test_decode_range_empty() {
1541        let reader = open_ape("sine_16s_c2000.ape");
1542        let mut decoder = ApeDecoder::new(reader).unwrap();
1543
1544        let decoded = decoder.decode_range(100, 100).unwrap();
1545        assert!(decoded.is_empty());
1546
1547        let decoded = decoder.decode_range(200, 100).unwrap();
1548        assert!(decoded.is_empty());
1549    }
1550
1551    // --- Parallel decode tests ---
1552
1553    #[test]
1554    fn test_decode_parallel_matches_sequential() {
1555        let expected = load_reference_pcm("sine_16s_c2000.wav");
1556
1557        let reader = open_ape("sine_16s_c2000.ape");
1558        let mut decoder = ApeDecoder::new(reader).unwrap();
1559        let parallel = decoder.decode_all_parallel(4).unwrap();
1560
1561        assert_eq!(parallel, expected);
1562    }
1563
1564    #[test]
1565    fn test_decode_parallel_multiframe() {
1566        let expected = load_reference_pcm("multiframe_16s_c2000.wav");
1567
1568        let reader = open_ape("multiframe_16s_c2000.ape");
1569        let mut decoder = ApeDecoder::new(reader).unwrap();
1570        let parallel = decoder.decode_all_parallel(2).unwrap();
1571
1572        assert_eq!(parallel, expected);
1573    }
1574
1575    #[test]
1576    fn test_decode_parallel_single_thread() {
1577        let expected = load_reference_pcm("sine_16s_c2000.wav");
1578
1579        let reader = open_ape("sine_16s_c2000.ape");
1580        let mut decoder = ApeDecoder::new(reader).unwrap();
1581        let decoded = decoder.decode_all_parallel(1).unwrap();
1582
1583        assert_eq!(decoded, expected);
1584    }
1585
1586    #[test]
1587    fn test_decode_parallel_all_fixtures() {
1588        let fixtures = [
1589            "dc_offset_16s_c2000.ape",
1590            "identical_16s_c2000.ape",
1591            "impulse_16s_c2000.ape",
1592            "left_only_16s_c2000.ape",
1593            "multiframe_16s_c2000.ape",
1594            "noise_16s_c2000.ape",
1595            "short_16s_c2000.ape",
1596            "silence_16s_c2000.ape",
1597            "sine_16m_c2000.ape",
1598            "sine_16s_c1000.ape",
1599            "sine_16s_c2000.ape",
1600            "sine_16s_c3000.ape",
1601            "sine_16s_c4000.ape",
1602            "sine_16s_c5000.ape",
1603            "sine_24s_c2000.ape",
1604            "sine_32s_c2000.ape",
1605            "sine_8s_c2000.ape",
1606        ];
1607
1608        for fixture in &fixtures {
1609            let ref_name = fixture.replace(".ape", ".wav");
1610            let reader = open_ape(fixture);
1611            let mut decoder = ApeDecoder::new(reader).unwrap();
1612            let parallel = decoder
1613                .decode_all_parallel(2)
1614                .unwrap_or_else(|e| panic!("Parallel decode failed for {}: {:?}", fixture, e));
1615            let expected = load_reference_pcm(&ref_name);
1616            assert_eq!(parallel, expected, "Parallel mismatch for {}", fixture);
1617        }
1618    }
1619
1620    // --- Negative / error path tests ---
1621
1622    #[test]
1623    fn test_decode_truncated_file() {
1624        // File too small to contain even a header
1625        let data = vec![0u8; 10];
1626        let mut cursor = std::io::Cursor::new(data);
1627        let result = decode(&mut cursor);
1628        assert!(result.is_err());
1629    }
1630
1631    #[test]
1632    fn test_decode_wrong_magic() {
1633        // Valid size but wrong magic bytes
1634        let mut data = vec![0u8; 200];
1635        data[0..4].copy_from_slice(b"NOPE");
1636        let mut cursor = std::io::Cursor::new(data);
1637        let result = decode(&mut cursor);
1638        assert!(result.is_err());
1639    }
1640
1641    #[test]
1642    fn test_decode_empty_file() {
1643        let data = vec![];
1644        let mut cursor = std::io::Cursor::new(data);
1645        let result = decode(&mut cursor);
1646        assert!(result.is_err());
1647    }
1648
1649    #[test]
1650    fn test_decoder_new_truncated() {
1651        let data = vec![0u8; 50]; // too small for APE header
1652        let cursor = std::io::Cursor::new(data);
1653        let result = ApeDecoder::new(cursor);
1654        assert!(result.is_err());
1655    }
1656
1657    // --- Post-processing transform tests ---
1658
1659    #[test]
1660    fn test_float_transform_roundtrip() {
1661        // FloatTransform is its own inverse
1662        let original: u32 = 0x3F800000; // IEEE 754 float 1.0
1663        let transformed = super::float_transform_sample(original);
1664        let restored = super::float_transform_sample(transformed);
1665        assert_eq!(restored, original);
1666    }
1667
1668    #[test]
1669    fn test_float_transform_zero() {
1670        let transformed = super::float_transform_sample(0);
1671        let restored = super::float_transform_sample(transformed);
1672        assert_eq!(restored, 0);
1673    }
1674
1675    #[test]
1676    fn test_byte_swap_16bit() {
1677        let mut data = vec![0x01, 0x02, 0x03, 0x04];
1678        super::byte_swap_samples(&mut data, 2);
1679        assert_eq!(data, vec![0x02, 0x01, 0x04, 0x03]);
1680    }
1681
1682    #[test]
1683    fn test_byte_swap_24bit() {
1684        let mut data = vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
1685        super::byte_swap_samples(&mut data, 3);
1686        assert_eq!(data, vec![0x03, 0x02, 0x01, 0x06, 0x05, 0x04]);
1687    }
1688
1689    #[test]
1690    fn test_byte_swap_32bit() {
1691        let mut data = vec![0x01, 0x02, 0x03, 0x04];
1692        super::byte_swap_samples(&mut data, 4);
1693        assert_eq!(data, vec![0x04, 0x03, 0x02, 0x01]);
1694    }
1695
1696    // --- MD5 verification tests ---
1697
1698    #[test]
1699    fn test_verify_md5_all_fixtures() {
1700        let fixtures = [
1701            "dc_offset_16s_c2000.ape",
1702            "identical_16s_c2000.ape",
1703            "impulse_16s_c2000.ape",
1704            "left_only_16s_c2000.ape",
1705            "multiframe_16s_c2000.ape",
1706            "noise_16s_c2000.ape",
1707            "short_16s_c2000.ape",
1708            "silence_16s_c2000.ape",
1709            "sine_16m_c2000.ape",
1710            "sine_16s_c1000.ape",
1711            "sine_16s_c2000.ape",
1712            "sine_16s_c3000.ape",
1713            "sine_16s_c4000.ape",
1714            "sine_16s_c5000.ape",
1715            "sine_24s_c2000.ape",
1716            "sine_32s_c2000.ape",
1717            "sine_8s_c2000.ape",
1718        ];
1719
1720        for fixture in &fixtures {
1721            let reader = open_ape(fixture);
1722            let mut decoder = ApeDecoder::new(reader).unwrap();
1723            let result = decoder
1724                .verify_md5()
1725                .unwrap_or_else(|e| panic!("MD5 verify failed for {}: {:?}", fixture, e));
1726            assert!(result, "MD5 mismatch for {}", fixture);
1727        }
1728    }
1729
1730    #[test]
1731    fn test_stored_md5_nonzero() {
1732        let reader = open_ape("sine_16s_c2000.ape");
1733        let decoder = ApeDecoder::new(reader).unwrap();
1734        let md5 = decoder.stored_md5();
1735        // The mac tool should have stored a valid MD5
1736        assert_ne!(md5, &[0u8; 16], "MD5 should not be all zeros");
1737    }
1738
1739    // --- WAV header generation test ---
1740
1741    #[test]
1742    fn test_generate_wav_header() {
1743        let reader = open_ape("sine_16s_c2000.ape");
1744        let decoder = ApeDecoder::new(reader).unwrap();
1745        let header = decoder.info().generate_wav_header();
1746
1747        // Standard WAV header is 44 bytes
1748        assert_eq!(header.len(), 44);
1749
1750        // Check RIFF magic
1751        assert_eq!(&header[0..4], b"RIFF");
1752        assert_eq!(&header[8..12], b"WAVE");
1753        assert_eq!(&header[12..16], b"fmt ");
1754        assert_eq!(&header[36..40], b"data");
1755
1756        // Check format: PCM, 2 channels, 44100 Hz, 16-bit
1757        let channels = u16::from_le_bytes([header[22], header[23]]);
1758        let sample_rate = u32::from_le_bytes([header[24], header[25], header[26], header[27]]);
1759        let bits = u16::from_le_bytes([header[34], header[35]]);
1760        assert_eq!(channels, 2);
1761        assert_eq!(sample_rate, 44100);
1762        assert_eq!(bits, 16);
1763
1764        // Data size should match total_samples * block_align
1765        let data_size = u32::from_le_bytes([header[40], header[41], header[42], header[43]]);
1766        let expected = decoder.info().total_samples as u32 * decoder.info().block_align as u32;
1767        assert_eq!(data_size, expected);
1768    }
1769
1770    #[test]
1771    fn test_generate_wav_header_matches_stored() {
1772        let reader = open_ape("sine_16s_c2000.ape");
1773        let decoder = ApeDecoder::new(reader).unwrap();
1774
1775        let generated = decoder.info().generate_wav_header();
1776        if let Some(stored) = decoder.wav_header_data() {
1777            // Both should be 44 bytes for standard WAV
1778            if stored.len() == 44 {
1779                // Format fields should match (channels, rate, bits)
1780                assert_eq!(&generated[22..24], &stored[22..24]); // channels
1781                assert_eq!(&generated[24..28], &stored[24..28]); // sample rate
1782                assert_eq!(&generated[34..36], &stored[34..36]); // bits per sample
1783            }
1784        }
1785    }
1786}