rust_aec/
decoder.rs

1use crate::bitreader::BitReader;
2use crate::error::AecError;
3use crate::params::{AecFlags, AecParams};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum Flush {
7    /// Like `AEC_NO_FLUSH`: decoding may continue once more input is provided.
8    NoFlush,
9    /// Like `AEC_FLUSH`: the caller asserts no more input will be provided.
10    Flush,
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum DecodeStatus {
15    /// More input is required to make progress.
16    NeedInput,
17    /// The output buffer was filled; provide more output space to continue.
18    NeedOutput,
19    /// Finished decoding `output_samples`.
20    Finished,
21}
22
23/// Streaming AEC decoder (Rust-idiomatic, modeled after libaec's `aec_stream`).
24///
25/// This type allows chunked input and chunked output:
26///
27/// - call [`Decoder::push_input`] to append more bytes
28/// - call [`Decoder::decode`] to write decoded bytes into a caller buffer
29///
30/// Notes:
31/// - Output is **packed sample bytes** (same as [`decode_into`]).
32/// - You must know `output_samples` up front (same as one-shot API).
33pub struct Decoder {
34    params: AecParams,
35    bytes_per_sample: usize,
36    id_len: usize,
37    preprocess: bool,
38
39    output_samples: usize,
40    samples_written: usize,
41
42    // Predictor state (only used with preprocessing enabled).
43    predictor_x: Option<i64>,
44    sample_index_within_rsi: u64,
45    block_index_within_rsi: u32,
46
47    // Input bitstream.
48    reader: StreamBitReader,
49
50    // Pending output from a partially-flushed decoded block.
51    pending: Vec<u8>,
52    pending_pos: usize,
53
54    // Pending repeated coded values (used for zero-run etc.).
55    pending_repeat: Option<PendingRepeat>,
56
57    total_in: usize,
58    total_out: usize,
59}
60
61#[derive(Debug, Clone)]
62struct PendingRepeat {
63    coded_value: u32,
64    remaining: usize,
65}
66
67impl Decoder {
68    pub fn new(params: AecParams, output_samples: usize) -> Result<Self, AecError> {
69        validate_params(params)?;
70        let bytes_per_sample = bytes_per_sample(params)?;
71        let id_len = id_len(params)?;
72
73        Ok(Self {
74            params,
75            bytes_per_sample,
76            id_len,
77            preprocess: params.flags.contains(AecFlags::DATA_PREPROCESS),
78            output_samples,
79            samples_written: 0,
80            predictor_x: None,
81            sample_index_within_rsi: 0,
82            block_index_within_rsi: 0,
83            reader: StreamBitReader::new(),
84            pending: Vec::new(),
85            pending_pos: 0,
86            pending_repeat: None,
87            total_in: 0,
88            total_out: 0,
89        })
90    }
91
92    /// Append more bytes to the input buffer.
93    pub fn push_input(&mut self, input: &[u8]) {
94        self.reader.push(input);
95    }
96
97    /// Total number of input bytes consumed so far.
98    pub fn total_in(&self) -> usize {
99        self.total_in
100    }
101
102    /// Total number of output bytes produced so far.
103    pub fn total_out(&self) -> usize {
104        self.total_out
105    }
106
107    /// Bytes currently buffered and available for reading.
108    pub fn avail_in(&self) -> usize {
109        self.reader.avail_bytes()
110    }
111
112    /// Decode into `out` and return (written_bytes, status).
113    pub fn decode(&mut self, out: &mut [u8], flush: Flush) -> Result<(usize, DecodeStatus), AecError> {
114        if self.samples_written >= self.output_samples {
115            return Ok((0, DecodeStatus::Finished));
116        }
117
118        let mut written: usize = 0;
119
120        // Fast-path: flush any pending bytes first.
121        written += self.flush_pending(out, written);
122        if written >= out.len() {
123            self.total_out += written;
124            return Ok((written, DecodeStatus::NeedOutput));
125        }
126
127        // Then flush any pending repeat-run.
128        if let Some(status) = self.flush_repeat(out, &mut written)? {
129            self.total_out += written;
130            return Ok((written, status));
131        }
132
133        // Decode blocks/runs until output is full or decoding completes.
134        while written < out.len() {
135            if self.samples_written >= self.output_samples {
136                self.total_out += written;
137                return Ok((written, DecodeStatus::Finished));
138            }
139
140            // Ensure predictor state is reset at RSI boundary when preprocessing is enabled.
141            if self.preprocess && self.block_index_within_rsi == 0 {
142                self.predictor_x = None;
143            }
144
145            // If we don't have enough input to decode the next unit, request more.
146            let snapshot = self.snapshot();
147            match self.decode_next_unit() {
148                Ok(()) => {
149                    // Compaction: count consumed whole bytes.
150                    let consumed = self.reader.compact_consumed_bytes();
151                    self.total_in += consumed;
152
153                    // Flush any newly produced pending output/repeat.
154                    written += self.flush_pending(out, written);
155                    if written >= out.len() {
156                        self.total_out += written;
157                        return Ok((written, DecodeStatus::NeedOutput));
158                    }
159
160                    if let Some(status) = self.flush_repeat(out, &mut written)? {
161                        self.total_out += written;
162                        return Ok((written, status));
163                    }
164
165                    // Otherwise, loop and decode more.
166                }
167                Err(AecError::UnexpectedEof { .. }) | Err(AecError::UnexpectedEofDuringDecode { .. }) => {
168                    // Restore state and request more input unless flushing.
169                    self.restore(snapshot);
170                    self.total_out += written;
171                    return match flush {
172                        Flush::NoFlush => Ok((written, DecodeStatus::NeedInput)),
173                        Flush::Flush => Err(AecError::UnexpectedEofDuringDecode {
174                            bit_pos: self.reader.bits_read_total(),
175                            samples_written: self.samples_written,
176                        }),
177                    };
178                }
179                Err(e) => {
180                    self.restore(snapshot);
181                    return Err(e);
182                }
183            }
184        }
185
186        self.total_out += written;
187        Ok((written, DecodeStatus::NeedOutput))
188    }
189
190    fn flush_pending(&mut self, out: &mut [u8], written: usize) -> usize {
191        if self.pending_pos >= self.pending.len() {
192            self.pending.clear();
193            self.pending_pos = 0;
194            return 0;
195        }
196
197        let available = out.len().saturating_sub(written);
198        let remaining = self.pending.len().saturating_sub(self.pending_pos);
199        let to_copy = available.min(remaining);
200
201        out[written..written + to_copy]
202            .copy_from_slice(&self.pending[self.pending_pos..self.pending_pos + to_copy]);
203        self.pending_pos += to_copy;
204        to_copy
205    }
206
207    fn flush_repeat(&mut self, out: &mut [u8], written: &mut usize) -> Result<Option<DecodeStatus>, AecError> {
208        let Some(rep) = self.pending_repeat.as_mut() else {
209            return Ok(None);
210        };
211
212        while *written < out.len() && rep.remaining > 0 {
213            if self.samples_written >= self.output_samples {
214                self.pending_repeat = None;
215                return Ok(Some(DecodeStatus::Finished));
216            }
217
218            // Write exactly one sample (packed bytes).
219            let out_start = *written;
220            let out_end = out_start + self.bytes_per_sample;
221            if out_end > out.len() {
222                return Ok(Some(DecodeStatus::NeedOutput));
223            }
224
225            // Use the same semantics as emit_coded_value(): preprocessing applies here.
226            let mut tmp = OutBuf::new(&mut out[out_start..out_end], self.bytes_per_sample);
227            tmp.pos = 0;
228            emit_coded_value(
229                &mut tmp,
230                &mut self.predictor_x,
231                self.params,
232                self.bytes_per_sample,
233                rep.coded_value,
234                &mut self.sample_index_within_rsi,
235                usize::MAX,
236            )?;
237            *written += self.bytes_per_sample;
238            self.samples_written += 1;
239            rep.remaining -= 1;
240        }
241
242        if rep.remaining == 0 {
243            self.pending_repeat = None;
244        }
245
246        if *written >= out.len() {
247            return Ok(Some(DecodeStatus::NeedOutput));
248        }
249        Ok(None)
250    }
251
252    fn snapshot(&self) -> Snapshot {
253        Snapshot {
254            predictor_x: self.predictor_x,
255            sample_index_within_rsi: self.sample_index_within_rsi,
256            block_index_within_rsi: self.block_index_within_rsi,
257            samples_written: self.samples_written,
258            reader: self.reader.clone(),
259            pending: self.pending.clone(),
260            pending_pos: self.pending_pos,
261            pending_repeat: self.pending_repeat.clone(),
262        }
263    }
264
265    fn restore(&mut self, s: Snapshot) {
266        self.predictor_x = s.predictor_x;
267        self.sample_index_within_rsi = s.sample_index_within_rsi;
268        self.block_index_within_rsi = s.block_index_within_rsi;
269        self.samples_written = s.samples_written;
270        self.reader = s.reader;
271        self.pending = s.pending;
272        self.pending_pos = s.pending_pos;
273        self.pending_repeat = s.pending_repeat;
274    }
275
276    fn decode_next_unit(&mut self) -> Result<(), AecError> {
277        // Ensure no pending bytes: we decode one unit into pending.
278        if self.pending_pos < self.pending.len() {
279            return Ok(());
280        }
281
282        // Build a small output buffer for a single block.
283        let mut block_out: Vec<u8> = vec![0u8; self.bytes_per_sample * (self.params.block_size as usize)];
284        let mut out = OutBuf::new(&mut block_out, self.bytes_per_sample);
285
286        // Start-of-RSI predictor reset.
287        if self.preprocess && self.block_index_within_rsi == 0 {
288            self.predictor_x = None;
289        }
290
291        let at_rsi_start = self.preprocess && self.block_index_within_rsi == 0;
292        let ref_pending = at_rsi_start;
293        let mut reference_sample_consumed = false;
294
295        // Read block option id.
296        let id = self.reader.read_bits_u32(self.id_len)?;
297        let max_id = (1u32 << self.id_len) - 1;
298
299        // Helper to consume the RSI reference sample.
300        let mut consume_reference = |this: &mut Self, out: &mut OutBuf<'_>| -> Result<(), AecError> {
301            let ref_raw = this.reader.read_bits_u32(this.params.bits_per_sample as usize)?;
302            let ref_val = if this.params.flags.contains(AecFlags::DATA_SIGNED) {
303                sign_extend(ref_raw, this.params.bits_per_sample)
304            } else {
305                ref_raw as i64
306            };
307            write_sample(out, ref_val, this.params)?;
308            this.predictor_x = Some(ref_val);
309            reference_sample_consumed = true;
310            this.sample_index_within_rsi += 1;
311            Ok(())
312        };
313
314        let remaining_total_samples = self.output_samples.saturating_sub(self.samples_written);
315        let max_samples_this_block = (self.params.block_size as usize).min(remaining_total_samples);
316
317        if id == 0 {
318            // Low-entropy family.
319            let selector = self.reader.read_bit()?;
320
321            // For low-entropy blocks, selector comes before optional RSI reference.
322            if ref_pending {
323                consume_reference(self, &mut out)?;
324                self.samples_written += 1;
325            }
326
327            // Remaining capacity after the optional reference sample.
328            let remaining_total_samples = self.output_samples.saturating_sub(self.samples_written);
329
330            let mut remaining_in_block = self.params.block_size as usize;
331            if reference_sample_consumed {
332                remaining_in_block = remaining_in_block.saturating_sub(1);
333            }
334
335            if !selector {
336                // Zero-block run: do not materialize huge output; schedule repeats.
337                let fs = read_unary_stream(&mut self.reader)?;
338                let mut z_blocks = fs + 1;
339                const ROS: u32 = 5;
340                if z_blocks == ROS {
341                    let b = self.block_index_within_rsi;
342                    let fill1 = self.params.rsi.saturating_sub(b);
343                    let fill2 = 64u32.saturating_sub(b % 64);
344                    z_blocks = fill1.min(fill2);
345                } else if z_blocks > ROS {
346                    z_blocks = z_blocks.saturating_sub(1);
347                }
348
349                let mut zeros_samples = (z_blocks as usize)
350                    .checked_mul(self.params.block_size as usize)
351                    .ok_or(AecError::InvalidInput("zero-run overflow"))?;
352                if reference_sample_consumed {
353                    zeros_samples = zeros_samples.saturating_sub(1);
354                }
355
356                // Limit to remaining total samples (reference already counted in `samples_written`).
357                zeros_samples = zeros_samples.min(remaining_total_samples);
358
359                // Emit any already-written reference sample into pending bytes.
360                let produced_len = out.len();
361                drop(out);
362                self.pending = block_out[..produced_len].to_vec();
363                self.pending_pos = 0;
364
365                // Schedule coded-value repeats (coded_value = 0).
366                if zeros_samples > 0 {
367                    self.pending_repeat = Some(PendingRepeat { coded_value: 0, remaining: zeros_samples });
368                }
369
370                // Advance block counter by z_blocks.
371                self.block_index_within_rsi = self.block_index_within_rsi.saturating_add(z_blocks);
372                if self.block_index_within_rsi >= self.params.rsi {
373                    self.block_index_within_rsi %= self.params.rsi;
374                    if self.params.flags.contains(AecFlags::PAD_RSI) {
375                        self.reader.align_to_byte();
376                    }
377                    self.sample_index_within_rsi = 0;
378                }
379
380                // We do not increment samples_written here; repeats are accounted for in flush.
381                return Ok(());
382            }
383
384            // Second Extension option.
385            let mut produced_samples = 0usize;
386            while remaining_in_block > 0 && produced_samples < max_samples_this_block.saturating_sub(reference_sample_consumed as usize) {
387                let m = read_unary_stream(&mut self.reader)?;
388                if m > 90 {
389                    return Err(AecError::InvalidInput("Second Extension unary symbol too large"));
390                }
391                let (a, b) = second_extension_pair(m);
392
393                // Emit up to two values.
394                if produced_samples < max_samples_this_block.saturating_sub(reference_sample_consumed as usize) {
395                    emit_coded_value(
396                        &mut out,
397                        &mut self.predictor_x,
398                        self.params,
399                        self.bytes_per_sample,
400                        a,
401                        &mut self.sample_index_within_rsi,
402                        usize::MAX,
403                    )?;
404                    produced_samples += 1;
405                    self.samples_written += 1;
406                }
407
408                if remaining_in_block > 0 {
409                    remaining_in_block = remaining_in_block.saturating_sub(1);
410                }
411                if produced_samples < max_samples_this_block.saturating_sub(reference_sample_consumed as usize) {
412                    emit_coded_value(
413                        &mut out,
414                        &mut self.predictor_x,
415                        self.params,
416                        self.bytes_per_sample,
417                        b,
418                        &mut self.sample_index_within_rsi,
419                        usize::MAX,
420                    )?;
421                    produced_samples += 1;
422                    self.samples_written += 1;
423                }
424                if remaining_in_block > 0 {
425                    remaining_in_block = remaining_in_block.saturating_sub(1);
426                }
427            }
428        } else if id == max_id {
429            // Uncompressed block.
430            if ref_pending {
431                consume_reference(self, &mut out)?;
432                self.samples_written += 1;
433            }
434
435            let mut remaining_in_block = self.params.block_size as usize;
436            if reference_sample_consumed {
437                remaining_in_block = remaining_in_block.saturating_sub(1);
438            }
439
440            for _ in 0..remaining_in_block {
441                if self.samples_written >= self.output_samples {
442                    break;
443                }
444                let v = self.reader.read_bits_u32(self.params.bits_per_sample as usize)?;
445                emit_coded_value(
446                    &mut out,
447                    &mut self.predictor_x,
448                    self.params,
449                    self.bytes_per_sample,
450                    v,
451                    &mut self.sample_index_within_rsi,
452                    usize::MAX,
453                )?;
454                self.samples_written += 1;
455            }
456        } else {
457            // Rice split.
458            let k = (id - 1) as usize;
459            if ref_pending {
460                consume_reference(self, &mut out)?;
461                self.samples_written += 1;
462            }
463
464            let mut remaining_in_block = self.params.block_size as usize;
465            if reference_sample_consumed {
466                remaining_in_block = remaining_in_block.saturating_sub(1);
467            }
468            let n = remaining_in_block.min(self.output_samples.saturating_sub(self.samples_written));
469            let mut tmp: Vec<u32> = vec![0u32; n];
470
471            for i in 0..n {
472                let q = read_unary_stream(&mut self.reader)?;
473                tmp[i] = (q as u32)
474                    .checked_shl(k as u32)
475                    .ok_or(AecError::InvalidInput("rice shift overflow"))?;
476            }
477            if k > 0 {
478                for i in 0..n {
479                    let rem = self.reader.read_bits_u32(k)?;
480                    tmp[i] |= rem;
481                }
482            }
483            for v in tmp {
484                if self.samples_written >= self.output_samples {
485                    break;
486                }
487                emit_coded_value(
488                    &mut out,
489                    &mut self.predictor_x,
490                    self.params,
491                    self.bytes_per_sample,
492                    v,
493                    &mut self.sample_index_within_rsi,
494                    usize::MAX,
495                )?;
496                self.samples_written += 1;
497            }
498        }
499
500        // Commit block output.
501        let produced_len = out.len();
502        drop(out);
503        self.pending = block_out[..produced_len].to_vec();
504        self.pending_pos = 0;
505
506        // Advance block counter.
507        self.block_index_within_rsi = self.block_index_within_rsi.saturating_add(1);
508        if self.preprocess && self.block_index_within_rsi >= self.params.rsi {
509            self.block_index_within_rsi = 0;
510            self.sample_index_within_rsi = 0;
511            if self.params.flags.contains(AecFlags::PAD_RSI) {
512                self.reader.align_to_byte();
513            }
514        }
515
516        Ok(())
517    }
518}
519
520#[derive(Clone)]
521struct Snapshot {
522    predictor_x: Option<i64>,
523    sample_index_within_rsi: u64,
524    block_index_within_rsi: u32,
525    samples_written: usize,
526    reader: StreamBitReader,
527    pending: Vec<u8>,
528    pending_pos: usize,
529    pending_repeat: Option<PendingRepeat>,
530}
531
532/// Streaming-capable bit reader backed by an internal buffer.
533///
534/// It allows appending input incrementally and compacting consumed bytes.
535#[derive(Debug, Clone)]
536struct StreamBitReader {
537    buf: Vec<u8>,
538    bit_pos: usize,
539    total_bytes_dropped: usize,
540}
541
542impl StreamBitReader {
543    fn new() -> Self {
544        Self { buf: Vec::new(), bit_pos: 0, total_bytes_dropped: 0 }
545    }
546
547    fn push(&mut self, data: &[u8]) {
548        self.buf.extend_from_slice(data);
549    }
550
551    fn avail_bytes(&self) -> usize {
552        self.buf.len().saturating_sub(self.bit_pos / 8)
553    }
554
555    fn bits_read_total(&self) -> usize {
556        self.total_bytes_dropped * 8 + self.bit_pos
557    }
558
559    fn align_to_byte(&mut self) {
560        let rem = self.bit_pos % 8;
561        if rem != 0 {
562            self.bit_pos += 8 - rem;
563        }
564    }
565
566    fn read_bit(&mut self) -> Result<bool, AecError> {
567        Ok(self.read_bits_u32(1)? != 0)
568    }
569
570    fn read_bits_u32(&mut self, nbits: usize) -> Result<u32, AecError> {
571        if nbits == 0 {
572            return Ok(0);
573        }
574        if nbits > 32 {
575            return Err(AecError::InvalidInput("read_bits_u32 supports up to 32 bits"));
576        }
577
578        let mut out: u32 = 0;
579        for _ in 0..nbits {
580            let byte_idx = self.bit_pos / 8;
581            let bit_in_byte = self.bit_pos % 8;
582            let byte = *self
583                .buf
584                .get(byte_idx)
585                .ok_or(AecError::UnexpectedEof { bit_pos: self.bits_read_total() })?;
586            let bit = (byte >> (7 - bit_in_byte)) & 1;
587            out = (out << 1) | (bit as u32);
588            self.bit_pos += 1;
589        }
590        Ok(out)
591    }
592
593    fn compact_consumed_bytes(&mut self) -> usize {
594        let bytes = self.bit_pos / 8;
595        if bytes == 0 {
596            return 0;
597        }
598        self.buf.drain(0..bytes);
599        self.bit_pos -= bytes * 8;
600        self.total_bytes_dropped += bytes;
601        bytes
602    }
603}
604
605fn read_unary_stream(r: &mut StreamBitReader) -> Result<u32, AecError> {
606    let mut count: u32 = 0;
607    loop {
608        let bit = r.read_bit()?;
609        if bit {
610            return Ok(count);
611        }
612        count = count.saturating_add(1);
613        if count > 1_000_000 {
614            return Err(AecError::InvalidInput("unary run too long"));
615        }
616    }
617}
618
619struct OutBuf<'a> {
620    buf: &'a mut [u8],
621    pos: usize,
622    bytes_per_sample: usize,
623}
624
625impl<'a> OutBuf<'a> {
626    fn new(buf: &'a mut [u8], bytes_per_sample: usize) -> Self {
627        Self { buf, pos: 0, bytes_per_sample }
628    }
629
630    fn len(&self) -> usize {
631        self.pos
632    }
633
634    fn capacity(&self) -> usize {
635        self.buf.len()
636    }
637
638    fn samples_written(&self) -> usize {
639        self.pos / self.bytes_per_sample
640    }
641}
642
643pub fn decode(input: &[u8], params: AecParams, output_samples: usize) -> Result<Vec<u8>, AecError> {
644    validate_params(params)?;
645
646    let bytes_per_sample = bytes_per_sample(params)?;
647    let output_bytes = output_samples
648        .checked_mul(bytes_per_sample)
649        .ok_or(AecError::InvalidInput("output too large"))?;
650
651    let mut out = vec![0u8; output_bytes];
652    decode_into(input, params, output_samples, &mut out)?;
653    Ok(out)
654}
655
656pub fn decode_into(
657    input: &[u8],
658    params: AecParams,
659    output_samples: usize,
660    output: &mut [u8],
661) -> Result<(), AecError> {
662    validate_params(params)?;
663
664    let trace_sample: Option<usize> = std::env::var("RUST_AEC_TRACE_SAMPLE")
665        .ok()
666        .and_then(|v| v.parse::<usize>().ok());
667
668    let bytes_per_sample = bytes_per_sample(params)?;
669    let output_bytes = output_samples
670        .checked_mul(bytes_per_sample)
671        .ok_or(AecError::InvalidInput("output too large"))?;
672
673    if output.len() != output_bytes {
674        return Err(AecError::InvalidInput("output buffer has wrong length"));
675    }
676
677    let mut out = OutBuf::new(output, bytes_per_sample);
678    let mut r = BitReader::new(input);
679
680    let id_len = id_len(params)?;
681
682    let preprocess = params.flags.contains(AecFlags::DATA_PREPROCESS);
683
684    let mut sample_index_within_rsi: u64 = 0;
685    let mut block_index_within_rsi: u32 = 0;
686
687    // Predictor state (only used with preprocessing enabled).
688    let mut predictor_x: Option<i64> = None;
689
690    while out.len() < output_bytes {
691        // Start of RSI interval.
692        if preprocess && block_index_within_rsi == 0 {
693            predictor_x = None;
694        }
695
696        let at_rsi_start = preprocess && block_index_within_rsi == 0;
697        let ref_pending = at_rsi_start;
698        let mut reference_sample_consumed = false;
699
700        let block_start_sample = out.samples_written();
701
702        // Read block option id.
703        let id = match r.read_bits_u32(id_len) {
704            Ok(v) => v,
705            Err(AecError::UnexpectedEof { bit_pos }) => {
706                return Err(AecError::UnexpectedEofDuringDecode {
707                    bit_pos,
708                    samples_written: out.samples_written(),
709                });
710            }
711            Err(e) => return Err(e),
712        };
713
714        let max_id = (1u32 << id_len) - 1;
715
716        // How many *coded values* does this block contribute? (set per mode; for split/SE/zero
717        // it's typically block_size - ref, but uncompressed reads full block_size raw samples).
718        let mut remaining_in_block: usize;
719
720        // Helper: consume the RSI reference sample (when preprocessing is enabled).
721        let mut consume_reference = |r: &mut BitReader, out: &mut OutBuf<'_>| -> Result<(), AecError> {
722            let ref_raw = match r.read_bits_u32(params.bits_per_sample as usize) {
723                Ok(v) => v,
724                Err(AecError::UnexpectedEof { bit_pos }) => {
725                    return Err(AecError::UnexpectedEofDuringDecode {
726                        bit_pos,
727                        samples_written: out.samples_written(),
728                    });
729                }
730                Err(e) => return Err(e),
731            };
732            let ref_val = if params.flags.contains(AecFlags::DATA_SIGNED) {
733                sign_extend(ref_raw, params.bits_per_sample)
734            } else {
735                ref_raw as i64
736            };
737
738            write_sample(out, ref_val, params)?;
739            predictor_x = Some(ref_val);
740            reference_sample_consumed = true;
741            sample_index_within_rsi += 1;
742            Ok(())
743        };
744
745        if id == 0 {
746            // Low-entropy family.
747            let selector = match r.read_bit() {
748                Ok(v) => v,
749                Err(AecError::UnexpectedEof { bit_pos }) => {
750                    return Err(AecError::UnexpectedEofDuringDecode {
751                        bit_pos,
752                        samples_written: out.samples_written(),
753                    });
754                }
755                Err(e) => return Err(e),
756            };
757
758            if let Some(ts) = trace_sample {
759                let block_end = block_start_sample + params.block_size as usize;
760                if (block_start_sample..block_end).contains(&ts) {
761                    eprintln!(
762                        "TRACE sample={ts} rsi_block={block_index_within_rsi} bits={} id=0 mode=LE selector={} block_samples=[{}, {})",
763                        r.bits_read(),
764                        selector,
765                        block_start_sample,
766                        block_end
767                    );
768                }
769            }
770
771            // For low-entropy blocks, the selector bit comes BEFORE the optional RSI reference.
772            if ref_pending {
773                consume_reference(&mut r, &mut out)?;
774                if out.len() >= output_bytes {
775                    break;
776                }
777            }
778
779            remaining_in_block = params.block_size as usize;
780            if reference_sample_consumed {
781                remaining_in_block = remaining_in_block.saturating_sub(1);
782            }
783
784            if !selector {
785                // Zero-block run.
786                let fs = match read_unary(&mut r) {
787                    Ok(v) => v,
788                    Err(AecError::UnexpectedEof { bit_pos }) => {
789                        return Err(AecError::UnexpectedEofDuringDecode {
790                            bit_pos,
791                            samples_written: out.samples_written(),
792                        });
793                    }
794                    Err(e) => return Err(e),
795                };
796                let mut z_blocks = fs + 1;
797
798                const ROS: u32 = 5;
799
800                if z_blocks == ROS {
801                    // Fill-to-boundary; bounded by RSI.
802                    let b = block_index_within_rsi;
803                    let fill1 = params.rsi.saturating_sub(b);
804                    let fill2 = 64u32.saturating_sub(b % 64);
805                    z_blocks = fill1.min(fill2);
806                } else if z_blocks > ROS {
807                    z_blocks = z_blocks.saturating_sub(1);
808                }
809
810                let mut zeros_samples = z_blocks
811                    .checked_mul(params.block_size)
812                    .ok_or(AecError::InvalidInput("zero-run overflow"))? as usize;
813
814                // If we already emitted the reference sample for the first block, the zero-run
815                // covers the whole blocks, but the first sample is already accounted for.
816                if reference_sample_consumed {
817                    zeros_samples = zeros_samples.saturating_sub(1);
818                }
819
820                if let Some(ts) = trace_sample {
821                    let total_samples = (z_blocks as usize)
822                        .checked_mul(params.block_size as usize)
823                        .unwrap_or(usize::MAX);
824                    let run_end = block_start_sample.saturating_add(total_samples);
825                    if (block_start_sample..run_end).contains(&ts) {
826                        eprintln!(
827                            "TRACE sample={ts} rsi_block={block_index_within_rsi} bits={} id=0 mode=ZRUN fs={} z_blocks={} run_samples=[{}, {})",
828                            r.bits_read(),
829                            fs,
830                            z_blocks,
831                            block_start_sample,
832                            run_end
833                        );
834                    }
835                }
836
837                emit_repeated_value(
838                    &mut out,
839                    &mut predictor_x,
840                    params,
841                    bytes_per_sample,
842                    0,
843                    zeros_samples,
844                    &mut sample_index_within_rsi,
845                    output_bytes,
846                )?;
847
848                // Advance block counter by z_blocks.
849                // We have already consumed the current block header as part of the run.
850                block_index_within_rsi = block_index_within_rsi.saturating_add(z_blocks);
851                if block_index_within_rsi >= params.rsi {
852                    block_index_within_rsi %= params.rsi;
853                    if params.flags.contains(AecFlags::PAD_RSI) {
854                        r.align_to_byte();
855                    }
856                    sample_index_within_rsi = 0;
857                }
858
859                continue;
860            }
861
862            // Second Extension option.
863            emit_second_extension(
864                &mut r,
865                &mut out,
866                &mut predictor_x,
867                params,
868                bytes_per_sample,
869                remaining_in_block,
870                reference_sample_consumed,
871                &mut sample_index_within_rsi,
872                output_bytes,
873            )?;
874        } else if id == max_id {
875            // Uncompressed block.
876            if let Some(ts) = trace_sample {
877                let block_end = block_start_sample + params.block_size as usize;
878                if (block_start_sample..block_end).contains(&ts) {
879                    eprintln!(
880                        "TRACE sample={ts} rsi_block={block_index_within_rsi} bits={} id={} mode=UNCOMP block_samples=[{}, {})",
881                        r.bits_read(),
882                        id,
883                        block_start_sample,
884                        block_end
885                    );
886                }
887            }
888            if ref_pending {
889                // For uncompressed blocks, the reference sample is the first raw sample.
890                consume_reference(&mut r, &mut out)?;
891                if out.len() >= output_bytes {
892                    break;
893                }
894                remaining_in_block = params.block_size as usize - 1;
895            } else {
896                remaining_in_block = params.block_size as usize;
897            }
898
899            for _ in 0..remaining_in_block {
900                let v = match r.read_bits_u32(params.bits_per_sample as usize) {
901                    Ok(v) => v,
902                    Err(AecError::UnexpectedEof { bit_pos }) => {
903                        return Err(AecError::UnexpectedEofDuringDecode {
904                            bit_pos,
905                            samples_written: out.samples_written(),
906                        });
907                    }
908                    Err(e) => return Err(e),
909                };
910                emit_coded_value(
911                    &mut out,
912                    &mut predictor_x,
913                    params,
914                    bytes_per_sample,
915                    v,
916                    &mut sample_index_within_rsi,
917                    output_bytes,
918                )?;
919                if out.len() >= output_bytes {
920                    break;
921                }
922            }
923        } else {
924            // Rice "split" option: decode all fundamental sequences first, then all k-bit
925            // binary parts (this matches libaec's bitstream layout).
926            let k = (id - 1) as usize;
927
928            if let Some(ts) = trace_sample {
929                let block_end = block_start_sample + params.block_size as usize;
930                if (block_start_sample..block_end).contains(&ts) {
931                    eprintln!(
932                        "TRACE sample={ts} rsi_block={block_index_within_rsi} bits={} id={} mode=SPLIT k={} block_samples=[{}, {})",
933                        r.bits_read(),
934                        id,
935                        k,
936                        block_start_sample,
937                        block_end
938                    );
939                }
940            }
941
942            if ref_pending {
943                consume_reference(&mut r, &mut out)?;
944                if out.len() >= output_bytes {
945                    break;
946                }
947            }
948
949            remaining_in_block = params.block_size as usize;
950            if reference_sample_consumed {
951                remaining_in_block = remaining_in_block.saturating_sub(1);
952            }
953
954            let n = remaining_in_block;
955            let mut tmp: Vec<u32> = vec![0u32; n];
956
957            // If tracing is enabled and the trace sample falls within the coded portion of this
958            // block, record the quotient/remainder at that offset.
959            let trace_offset_in_block: Option<usize> = trace_sample.and_then(|ts| {
960                let coded_start = out.samples_written();
961                if ts >= coded_start && ts < coded_start + n {
962                    Some(ts - coded_start)
963                } else {
964                    None
965                }
966            });
967            let mut trace_q: Option<u32> = None;
968            let mut trace_rem: Option<u32> = None;
969
970            for i in 0..n {
971                let q = match read_unary(&mut r) {
972                    Ok(v) => v,
973                    Err(AecError::UnexpectedEof { bit_pos }) => {
974                        return Err(AecError::UnexpectedEofDuringDecode {
975                            bit_pos,
976                            samples_written: out.samples_written(),
977                        });
978                    }
979                    Err(e) => return Err(e),
980                };
981                if trace_offset_in_block == Some(i) {
982                    trace_q = Some(q);
983                }
984                tmp[i] = (q as u32)
985                    .checked_shl(k as u32)
986                    .ok_or(AecError::InvalidInput("rice shift overflow"))?;
987            }
988
989            if k > 0 {
990                for i in 0..n {
991                    let rem_bitpos_before = if trace_offset_in_block
992                        .map(|off| i + 2 >= off && i <= off + 2)
993                        .unwrap_or(false)
994                    {
995                        Some(r.bits_read())
996                    } else {
997                        None
998                    };
999
1000                    let rem = match r.read_bits_u32(k) {
1001                        Ok(v) => v,
1002                        Err(AecError::UnexpectedEof { bit_pos }) => {
1003                            return Err(AecError::UnexpectedEofDuringDecode {
1004                                bit_pos,
1005                                samples_written: out.samples_written(),
1006                            });
1007                        }
1008                        Err(e) => return Err(e),
1009                    };
1010
1011                    if let (Some(off), Some(bitpos)) = (trace_offset_in_block, rem_bitpos_before) {
1012                        if i + 2 >= off && i <= off + 2 {
1013                            eprintln!(
1014                                "TRACE rem i={} (off={}) bitpos={} bits={:0width$b} rem={}",
1015                                i,
1016                                off,
1017                                bitpos,
1018                                rem,
1019                                rem,
1020                                width = k
1021                            );
1022                        }
1023                    }
1024
1025                    if trace_offset_in_block == Some(i) {
1026                        trace_rem = Some(rem);
1027                    }
1028                    tmp[i] |= rem;
1029                }
1030            }
1031
1032            if let Some(off) = trace_offset_in_block {
1033                let d = tmp[off];
1034                let w_start = off.saturating_sub(2);
1035                let w_end = (off + 3).min(n);
1036                let window = tmp[w_start..w_end].to_vec();
1037                eprintln!(
1038                    "TRACE split-detail sample={} rsi_block={} id={} k={} off={} q={:?} rem={:?} d={} window[{}..{}]={:?}",
1039                    trace_sample.unwrap_or(0),
1040                    block_index_within_rsi,
1041                    id,
1042                    k,
1043                    off,
1044                    trace_q,
1045                    trace_rem,
1046                    d
1047                    ,
1048                    w_start,
1049                    w_end,
1050                    window
1051                );
1052            }
1053
1054            for v in tmp {
1055                emit_coded_value(
1056                    &mut out,
1057                    &mut predictor_x,
1058                    params,
1059                    bytes_per_sample,
1060                    v,
1061                    &mut sample_index_within_rsi,
1062                    output_bytes,
1063                )?;
1064                if out.len() >= output_bytes {
1065                    break;
1066                }
1067            }
1068        }
1069
1070        // Next block.
1071        block_index_within_rsi = block_index_within_rsi.saturating_add(1);
1072        if preprocess && block_index_within_rsi >= params.rsi {
1073            block_index_within_rsi = 0;
1074            sample_index_within_rsi = 0;
1075            if params.flags.contains(AecFlags::PAD_RSI) {
1076                r.align_to_byte();
1077            }
1078        }
1079    }
1080
1081    Ok(())
1082}
1083
1084fn validate_params(params: AecParams) -> Result<(), AecError> {
1085    if !(1..=32).contains(&params.bits_per_sample) {
1086        return Err(AecError::InvalidInput("bits_per_sample must be 1..=32"));
1087    }
1088    if params.block_size == 0 {
1089        return Err(AecError::InvalidInput("block_size must be > 0"));
1090    }
1091    if params.rsi == 0 {
1092        return Err(AecError::InvalidInput("rsi must be > 0"));
1093    }
1094
1095    // Common AEC block sizes; keep permissive but avoid pathological values.
1096    if ![8u32, 16, 32, 64].contains(&params.block_size) {
1097        return Err(AecError::Unsupported("block_size must be one of 8,16,32,64"));
1098    }
1099
1100    Ok(())
1101}
1102
1103fn bytes_per_sample(params: AecParams) -> Result<usize, AecError> {
1104    let bps = params.bits_per_sample;
1105
1106    let b = match bps {
1107        1..=8 => 1,
1108        9..=16 => 2,
1109        17..=24 => {
1110            if params.flags.contains(AecFlags::DATA_3BYTE) {
1111                3
1112            } else {
1113                4
1114            }
1115        }
1116        25..=32 => 4,
1117        _ => return Err(AecError::InvalidInput("invalid bits_per_sample")),
1118    };
1119
1120    Ok(b)
1121}
1122
1123fn id_len(params: AecParams) -> Result<usize, AecError> {
1124    let bps = params.bits_per_sample;
1125
1126    let mut id_len = if bps > 16 { 5 } else if bps > 8 { 4 } else { 3 };
1127
1128    if params.flags.contains(AecFlags::RESTRICTED) && bps <= 4 {
1129        id_len = if bps <= 2 { 1 } else { 2 };
1130    }
1131
1132    Ok(id_len)
1133}
1134
1135fn read_unary(r: &mut BitReader<'_>) -> Result<u32, AecError> {
1136    let mut count: u32 = 0;
1137    loop {
1138        let bit = r.read_bit()?;
1139        if bit {
1140            return Ok(count);
1141        }
1142        count = count.saturating_add(1);
1143        // Safety guard against pathological/corrupt inputs.
1144        // Valid streams can have unary lengths larger than 90 (Second Extension is the main
1145        // mode that constrains it to <= 90), so we only cap at a very large value.
1146        if count > 1_000_000 {
1147            return Err(AecError::InvalidInput("unary run too long"));
1148        }
1149    }
1150}
1151
1152fn emit_coded_value(
1153    out: &mut OutBuf<'_>,
1154    predictor_x: &mut Option<i64>,
1155    params: AecParams,
1156    _bytes_per_sample: usize,
1157    v: u32,
1158    sample_index_within_rsi: &mut u64,
1159    output_bytes: usize,
1160) -> Result<(), AecError> {
1161    if out.len() >= output_bytes {
1162        return Ok(());
1163    }
1164
1165    if params.flags.contains(AecFlags::DATA_PREPROCESS) {
1166        let x_prev = predictor_x.ok_or(AecError::InvalidInput("missing reference sample"))?;
1167        let x_next = inverse_preprocess_step(x_prev, v, params);
1168        write_sample(out, x_next, params)?;
1169        *predictor_x = Some(x_next);
1170        *sample_index_within_rsi += 1;
1171        return Ok(());
1172    }
1173
1174    // No preprocessing: v is the sample value (raw n-bit field).
1175    write_sample(out, v as i64, params)?;
1176    *sample_index_within_rsi += 1;
1177    Ok(())
1178}
1179
1180fn emit_repeated_value(
1181    out: &mut OutBuf<'_>,
1182    predictor_x: &mut Option<i64>,
1183    params: AecParams,
1184    bytes_per_sample: usize,
1185    v: u32,
1186    count: usize,
1187    sample_index_within_rsi: &mut u64,
1188    output_bytes: usize,
1189) -> Result<(), AecError> {
1190    for _ in 0..count {
1191        if out.len() >= output_bytes {
1192            break;
1193        }
1194        emit_coded_value(
1195            out,
1196            predictor_x,
1197            params,
1198            bytes_per_sample,
1199            v,
1200            sample_index_within_rsi,
1201            output_bytes,
1202        )?;
1203    }
1204    Ok(())
1205}
1206
1207fn emit_second_extension(
1208    r: &mut BitReader<'_>,
1209    out: &mut OutBuf<'_>,
1210    predictor_x: &mut Option<i64>,
1211    params: AecParams,
1212    bytes_per_sample: usize,
1213    mut remaining_in_block: usize,
1214    reference_sample_consumed: bool,
1215    sample_index_within_rsi: &mut u64,
1216    output_bytes: usize,
1217) -> Result<(), AecError> {
1218    // Second Extension yields pairs (a,b) aligned to even sample indices.
1219    // If we started at an odd sample index because sample 0 was the reference,
1220    // emit only the second element from the first symbol.
1221    let mut need_odd_first = reference_sample_consumed;
1222
1223    while remaining_in_block > 0 && out.len() < output_bytes {
1224        let m = read_unary(r)?;
1225        if m > 90 {
1226            return Err(AecError::InvalidInput("Second Extension unary symbol too large"));
1227        }
1228
1229        let (a, b) = second_extension_pair(m);
1230
1231        if need_odd_first {
1232            // Only emit the odd-index element.
1233            emit_coded_value(
1234                out,
1235                predictor_x,
1236                params,
1237                bytes_per_sample,
1238                b,
1239                sample_index_within_rsi,
1240                output_bytes,
1241            )?;
1242            remaining_in_block = remaining_in_block.saturating_sub(1);
1243            need_odd_first = false;
1244            continue;
1245        }
1246
1247        // Emit a (even index)
1248        emit_coded_value(
1249            out,
1250            predictor_x,
1251            params,
1252            bytes_per_sample,
1253            a,
1254            sample_index_within_rsi,
1255            output_bytes,
1256        )?;
1257        remaining_in_block = remaining_in_block.saturating_sub(1);
1258        if remaining_in_block == 0 || out.len() >= output_bytes {
1259            break;
1260        }
1261
1262        // Emit b (odd index)
1263        emit_coded_value(
1264            out,
1265            predictor_x,
1266            params,
1267            bytes_per_sample,
1268            b,
1269            sample_index_within_rsi,
1270            output_bytes,
1271        )?;
1272        remaining_in_block = remaining_in_block.saturating_sub(1);
1273    }
1274
1275    Ok(())
1276}
1277
1278fn second_extension_pair(m: u32) -> (u32, u32) {
1279    // Enumerate sums s = 0..=12, then k = 0..=s, mapping m -> (s-k, k).
1280    let mut idx: u32 = 0;
1281    for s in 0u32..=12 {
1282        for k in 0u32..=s {
1283            if idx == m {
1284                return (s - k, k);
1285            }
1286            idx += 1;
1287        }
1288    }
1289
1290    // m is validated by caller; fallback is harmless.
1291    (0, 0)
1292}
1293
1294fn inverse_preprocess_step(x_prev: i64, d: u32, params: AecParams) -> i64 {
1295    let n = params.bits_per_sample;
1296
1297    // Match libaec inverse preprocessing exactly (see vendor/libaec.../src/decode.c).
1298    // The coded value `d` is mapped to a signed delta using the LSB as sign, but the
1299    // application of that delta is bounded; if it would cross the selected boundary,
1300    // a reflection mapping is used instead.
1301    let delta: i64 = ((d >> 1) as i64) ^ (!(((d & 1) as i64) - 1));
1302    let half_d: i64 = ((d >> 1) + (d & 1)) as i64;
1303
1304    if params.flags.contains(AecFlags::DATA_SIGNED) {
1305        // signed_max matches libaec state->xmax for signed data.
1306        let signed_max: i64 = (1i64 << (n - 1)) - 1;
1307        let data = x_prev;
1308
1309        if data < 0 {
1310            if half_d <= signed_max + data + 1 {
1311                data + delta
1312            } else {
1313                (d as i64) - signed_max - 1
1314            }
1315        } else {
1316            if half_d <= signed_max - data {
1317                data + delta
1318            } else {
1319                signed_max - (d as i64)
1320            }
1321        }
1322    } else {
1323        let unsigned_max: u64 = (1u64 << n) - 1;
1324        let data_u: u64 = x_prev as u64;
1325
1326        // med is a single bit (the MSB) for unsigned samples.
1327        let med: u64 = unsigned_max / 2 + 1;
1328        let mask: u64 = if (data_u & med) != 0 { unsigned_max } else { 0 };
1329
1330        if (half_d as u64) <= (mask ^ data_u) {
1331            (x_prev + delta) as i64
1332        } else {
1333            (mask ^ (d as u64)) as i64
1334        }
1335    }
1336}
1337
1338fn write_sample(out: &mut OutBuf<'_>, value: i64, params: AecParams) -> Result<(), AecError> {
1339    let n = params.bits_per_sample as u32;
1340    let mask: u64 = if n == 32 { u64::MAX } else { (1u64 << n) - 1 };
1341
1342    let raw_u = if params.flags.contains(AecFlags::DATA_SIGNED) {
1343        (value as i64 as u64) & mask
1344    } else {
1345        (value.max(0) as u64) & mask
1346    };
1347
1348    let bytes_per_sample = out.bytes_per_sample;
1349    if out.pos.checked_add(bytes_per_sample).ok_or(AecError::InvalidInput("output too large"))? > out.capacity() {
1350        return Err(AecError::InvalidInput("output buffer too small"));
1351    }
1352
1353    let msb = params.flags.contains(AecFlags::MSB);
1354    if msb {
1355        for i in (0..bytes_per_sample).rev() {
1356            out.buf[out.pos] = ((raw_u >> (i * 8)) & 0xff) as u8;
1357            out.pos += 1;
1358        }
1359    } else {
1360        for i in 0..bytes_per_sample {
1361            out.buf[out.pos] = ((raw_u >> (i * 8)) & 0xff) as u8;
1362            out.pos += 1;
1363        }
1364    }
1365
1366    Ok(())
1367}
1368
1369fn sign_extend(raw: u32, bits: u8) -> i64 {
1370    if bits == 32 {
1371        return (raw as i32) as i64;
1372    }
1373    let shift = 32 - bits as u32;
1374    (((raw << shift) as i32) >> shift) as i64
1375}