ben/decode/
mod.rs

1//! This module contains the main functions for decoding XBEN and BEN files.
2//!
3//! XBEN files are generally transformed back into BEN files, and BEN files
4//! are transformed into a JSONL file with the formatting
5//!
6//! ```json
7//! {"assignment": [...], "sample": #}
8//! ```
9//!
10//! The BEN file format is a bit-packed binary format that is used to store
11//! run-length encoded assignment vectors, and is streamable. Therefore, the
12//! BEN file format works well with the `read` submodule of this module
13//! which is designed to extract a single assignment vector from a BEN file.
14pub mod read;
15
16use byteorder::{BigEndian, ReadBytesExt};
17use serde_json::json;
18use std::fs::File;
19use std::io::{self, BufRead, Read, Write}; // trait imports
20use std::io::{BufReader, Cursor, Error}; // type import
21use std::iter::Peekable;
22use std::path::Path;
23use std::path::PathBuf;
24use xz2::read::XzDecoder;
25
26use crate::utils::rle_to_vec;
27
28use super::encode::translate::*;
29use super::{log, logln, BenVariant};
30
31pub type MkvRecord = (Vec<u16>, u16);
32
33#[derive(Debug)]
34pub enum DecoderInitError {
35    InvalidFileFormat(Vec<u8>),
36    Io(io::Error),
37}
38
39/// Check if the given header matches the XZ magic number.
40/// This is used to provide a more informative error message when
41/// a user tries to decode a compressed .xben file with the
42/// `BenDecoder` instead of the `decode_xben_to_ben` function.
43fn is_xz_header(h: &[u8]) -> bool {
44    h.len() >= 6 && &h[..6] == b"\xFD\x37\x7A\x58\x5A\x00"
45}
46
47/// Convert a byte slice to a hex string for display purposes.
48/// Each byte is represented as two uppercase hex digits, separated by spaces.
49fn to_hex(bytes: &[u8]) -> String {
50    bytes
51        .iter()
52        .map(|b| format!("{:02X}", b))
53        .collect::<Vec<_>>()
54        .join(" ")
55}
56
57impl std::fmt::Display for DecoderInitError {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Self::Io(e) => write!(f, "IO error: {e}"),
61            Self::InvalidFileFormat(header) => {
62                if is_xz_header(header) {
63                    write!(
64                        f,
65                        "Invalid file format: Compressed header detected (hex: {}). \
66                     This reader expects an uncompressed .ben file. \
67                     Decompress this file using the BEN cli `ben -m decode <file_name>.xben` tool \
68                     or the `decode_xben_to_ben` function in this library.",
69                        to_hex(header)
70                    )
71                } else {
72                    let lossy = String::from_utf8_lossy(header);
73                    write!(
74                        f,
75                        "Invalid file format. Found header (utf8-lossy: {lossy:?}, hex: {})",
76                        to_hex(header)
77                    )
78                }
79            }
80        }
81    }
82}
83
84impl std::error::Error for DecoderInitError {
85    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
86        match self {
87            DecoderInitError::Io(e) => Some(e),
88            DecoderInitError::InvalidFileFormat(_) => None,
89        }
90    }
91}
92
93impl From<io::Error> for DecoderInitError {
94    fn from(error: io::Error) -> Self {
95        DecoderInitError::Io(error)
96    }
97}
98
99impl From<DecoderInitError> for io::Error {
100    fn from(error: DecoderInitError) -> Self {
101        match error {
102            DecoderInitError::Io(e) => e,
103            DecoderInitError::InvalidFileFormat(msg) => {
104                io::Error::new(io::ErrorKind::InvalidData, format!("{msg:?}"))
105            }
106        }
107    }
108}
109
110pub struct BenDecoder<R: Read> {
111    reader: R,
112    sample_count: usize,
113    variant: BenVariant,
114}
115
116/// A single frame from a BEN file.
117#[derive(Clone)]
118pub struct BenFrame {
119    pub max_val_bits: u8,  // number of bits used for each value
120    pub max_len_bits: u8,  // number of bits used for each run-length
121    pub count: u16,        // repetition count (1 for Standard)
122    pub n_bytes: u32,      // number of bytes used for the raw assignment data
123    pub raw_data: Vec<u8>, // raw bit-compressed BEN data
124}
125
126impl<R: Read> BenDecoder<R> {
127    /// Create a new BenDecoder from a reader.
128    /// The reader must contain a valid BEN file.
129    /// The first 17 bytes of the file are checked to determine
130    /// the variant of the BEN file.
131    pub fn new(mut reader: R) -> Result<Self, DecoderInitError> {
132        let mut check_buffer = [0u8; 17];
133
134        if let Err(e) = reader.read_exact(&mut check_buffer) {
135            return Err(DecoderInitError::Io(e));
136        }
137
138        match &check_buffer {
139            b"STANDARD BEN FILE" => Ok(BenDecoder {
140                reader,
141                sample_count: 0,
142                variant: BenVariant::Standard,
143            }),
144            b"MKVCHAIN BEN FILE" => Ok(BenDecoder {
145                reader,
146                sample_count: 0,
147                variant: BenVariant::MkvChain,
148            }),
149            _ => Err(DecoderInitError::InvalidFileFormat(check_buffer.to_vec())),
150        }
151    }
152
153    /// Write all decoded assignments to a writer in JSONL format.
154    ///
155    /// Arguments:
156    ///
157    /// * `writer`: A mutable reference to a writer where the JSONL output will be written.
158    fn write_all_jsonl(&mut self, mut writer: impl Write) -> io::Result<()> {
159        while let Some(result_tuple) = self.next() {
160            match result_tuple {
161                Ok((assignment, count)) => {
162                    for _ in 0..count {
163                        self.sample_count += 1;
164                        let line = json!({
165                            "assignment": assignment,
166                            "sample": self.sample_count,
167                        })
168                        .to_string()
169                            + "\n";
170                        writer.write_all(line.as_bytes()).unwrap();
171                    }
172                }
173                Err(e) => {
174                    return Err(e);
175                }
176            }
177        }
178        Ok(())
179    }
180
181    /// Internal helper function that pops a single ben frame from the reader.
182    /// This frame may then either be decoded into an assignment vector
183    /// or returned as-is for further processing.
184    fn pop_frame_from_reader(&mut self) -> Option<io::Result<BenFrame>> {
185        let mut b1 = [0u8; 1];
186        let max_val_bits = match self.reader.read_exact(&mut b1) {
187            Ok(()) => b1[0],
188            Err(e) => {
189                if e.kind() == io::ErrorKind::UnexpectedEof {
190                    // clean EOF before starting a new frame
191                    logln!();
192                    logln!("Done!");
193                    return None;
194                }
195                return Some(Err(e));
196            }
197        };
198
199        let mut b2 = [0u8; 1];
200        if let Err(e) = self.reader.read_exact(&mut b2) {
201            return Some(Err(e));
202        }
203        let max_len_bits = b2[0];
204
205        let n_bytes = match self.reader.read_u32::<BigEndian>() {
206            Ok(n) => n,
207            Err(e) => return Some(Err(e)),
208        };
209
210        let mut raw_assignment = vec![0u8; n_bytes as usize];
211        if let Err(e) = self.reader.read_exact(&mut raw_assignment) {
212            return Some(Err(e));
213        }
214
215        let count = if self.variant == BenVariant::MkvChain {
216            match self.reader.read_u16::<BigEndian>() {
217                Ok(c) => c,
218                Err(e) => return Some(Err(e)),
219            }
220        } else {
221            1
222        };
223
224        Some(Ok(BenFrame {
225            max_val_bits,
226            max_len_bits,
227            n_bytes,
228            raw_data: raw_assignment,
229            count,
230        }))
231    }
232}
233
234/// Helper function to decode a ben frame into an assignment vector.
235fn decode_ben_frame_to_assignment(frame: &BenFrame) -> io::Result<Vec<u16>> {
236    decode_ben_line(
237        Cursor::new(&frame.raw_data),
238        frame.max_val_bits,
239        frame.max_len_bits,
240        frame.n_bytes,
241    )
242    .map(rle_to_vec)
243}
244
245impl<R: Read> Iterator for BenDecoder<R> {
246    type Item = io::Result<MkvRecord>;
247
248    fn next(&mut self) -> Option<io::Result<MkvRecord>> {
249        let ben_frame = match self.pop_frame_from_reader() {
250            Some(Ok(frame)) => frame,
251            Some(Err(e)) => return Some(Err(e)),
252            None => return None,
253        };
254        let assignment = match decode_ben_frame_to_assignment(&ben_frame) {
255            Ok(assgn) => assgn,
256            Err(e) => return Some(Err(e)),
257        };
258        log!(
259            "Decoding sample: {}\r",
260            self.sample_count + ben_frame.count as usize
261        );
262        Some(Ok((assignment, ben_frame.count)))
263    }
264}
265
266pub struct BenFrameDecoeder<R: Read> {
267    inner: BenDecoder<R>,
268}
269
270impl<R: Read> BenFrameDecoeder<R> {
271    pub fn new(reader: R) -> io::Result<Self> {
272        Ok(Self {
273            inner: BenDecoder::new(reader)?,
274        })
275    }
276}
277
278impl<R: Read> Iterator for BenFrameDecoeder<R> {
279    type Item = io::Result<BenFrame>;
280
281    fn next(&mut self) -> Option<Self::Item> {
282        self.inner.pop_frame_from_reader()
283    }
284}
285
286impl<R: Read> BenDecoder<R> {
287    /// Consume this decoder and iterate raw ben frames instead of decoded assignments.
288    pub fn into_frames(self) -> BenFrameDecoeder<R> {
289        BenFrameDecoeder { inner: self }
290    }
291}
292
293impl<R: Read> BenDecoder<R> {
294    /// Count how many samples remain in this BEN stream.
295    /// Consumes the decoder (fast: walks frames only).
296    pub fn count_samples(self) -> io::Result<usize> {
297        let mut total = 0usize;
298        for frame_res in self.into_frames() {
299            let f = frame_res?; // BenFrame
300            total += f.count as usize; // 1 for Standard; >1 for MKVCHAIN
301        }
302        Ok(total)
303    }
304}
305
306/// This function takes a reader containing a single ben32 encoded assignment
307/// vector and decodes it into a full assignment vector of u16s.
308///
309/// # Errors
310///
311/// This function will return an error if the input reader is not a multiple of 4
312/// bytes long since each assignment vector is an run-length encoded as a 32 bit
313/// integer (2 bytes for the value and 2 bytes for the count).
314///
315fn decode_ben32_line<R: BufRead>(mut reader: R, variant: BenVariant) -> io::Result<MkvRecord> {
316    let mut buffer = [0u8; 4];
317    let mut output_vec: Vec<u16> = Vec::new();
318
319    loop {
320        match reader.read_exact(&mut buffer) {
321            Ok(()) => {
322                let encoded = u32::from_be_bytes(buffer);
323                if encoded == 0 {
324                    // Check for separator (all 0s)
325                    break; // Exit loop to process next sample
326                }
327
328                let value = (encoded >> 16) as u16; // High 16 bits
329                let count = (encoded & 0xFFFF) as u16; // Low 16 bits
330
331                // Reconstruct the original data
332                for _ in 0..count {
333                    output_vec.push(value);
334                }
335            }
336            Err(e) => {
337                return Err(e); // Propagate other errors
338            }
339        }
340    }
341
342    let count = if variant == BenVariant::MkvChain {
343        reader
344            .read_u16::<BigEndian>()
345            .expect("Error when reading sample.")
346    } else {
347        1
348    };
349
350    Ok((output_vec, count))
351}
352
353/// This function takes a reader containing a file encoded with the
354/// "ben32" format and decodes it into a JSONL file.
355///
356/// The output JSONL file will have the formatting
357///
358/// ```json
359/// {"assignment": [...], "sample": #}
360/// ```
361///
362/// # Errors
363///
364/// This function will return an error if the input reader contains invalid ben32
365/// data or if the the decode method encounters while trying to extract a single
366/// assignment vector, that error is propagated.
367fn jsonl_decode_ben32<R: BufRead, W: Write>(
368    mut reader: R,
369    mut writer: W,
370    starting_sample: usize,
371    variant: BenVariant,
372) -> io::Result<()> {
373    let mut sample_number = 1;
374    loop {
375        let result = decode_ben32_line(&mut reader, variant);
376        if let Err(e) = result {
377            if e.kind() == io::ErrorKind::UnexpectedEof {
378                return Ok(());
379            }
380            return Err(e);
381        }
382
383        let (output_vec, count) = result.unwrap();
384
385        for _ in 0..count {
386            // Write the reconstructed vector as JSON to the output file
387            let line = json!({
388                "assignment": output_vec,
389                "sample": sample_number + starting_sample,
390            })
391            .to_string()
392                + "\n";
393
394            writer.write_all(line.as_bytes())?;
395            sample_number += 1;
396        }
397    }
398}
399
400/// This function takes a reader containing a file encoded in the XBEN format
401/// and decodes it into a BEN file.
402///
403/// # Errors
404///
405/// This function will return an error if the input reader contains invalid xben
406/// data or if the the decode method encounters while trying to convert the
407/// xben data to ben data.
408pub fn decode_xben_to_ben<R: BufRead, W: Write>(reader: R, mut writer: W) -> io::Result<()> {
409    let mut decoder = XzDecoder::new(reader);
410
411    let mut first_buffer = [0u8; 17];
412
413    if let Err(e) = decoder.read_exact(&mut first_buffer) {
414        return Err(e);
415    }
416
417    let variant = match &first_buffer {
418        b"STANDARD BEN FILE" => {
419            writer.write_all(b"STANDARD BEN FILE")?;
420            BenVariant::Standard
421        }
422        b"MKVCHAIN BEN FILE" => {
423            writer.write_all(b"MKVCHAIN BEN FILE")?;
424            BenVariant::MkvChain
425        }
426        _ => {
427            return Err(Error::new(
428                io::ErrorKind::InvalidData,
429                "Invalid file format",
430            ));
431        }
432    };
433
434    let mut buffer = [0u8; 1048576]; // 1MB buffer
435    let mut overflow: Vec<u8> = Vec::new();
436
437    let mut line_count: usize = 0;
438    while let Ok(count) = decoder.read(&mut buffer) {
439        if count == 0 {
440            break;
441        }
442
443        overflow.extend(&buffer[..count]);
444
445        let mut last_valid_assignment = 0;
446
447        // It is technically faster to read backwards from the last
448        // multiple of 4 smaller than the length of the overflow buffer
449        // but this provides only a minute speedup in almost all cases (maybe a
450        // few seconds). Reading from the front is both safer from a
451        // maintenance perspective and allows for a better progress indicator
452        match variant {
453            BenVariant::Standard => {
454                for i in (3..overflow.len()).step_by(4) {
455                    if overflow[i - 3..=i] == [0, 0, 0, 0] {
456                        last_valid_assignment = i + 1;
457                        line_count += 1;
458                        log!("Decoding sample: {}\r", line_count);
459                    }
460                }
461            }
462            BenVariant::MkvChain => {
463                for i in (3..overflow.len() - 2).step_by(2) {
464                    if overflow[i - 3..=i] == [0, 0, 0, 0] {
465                        last_valid_assignment = i + 3;
466                        let lines = &overflow[i + 1..i + 3];
467                        let n_lines = u16::from_be_bytes([lines[0], lines[1]]);
468                        line_count += n_lines as usize;
469                        log!("Decoding sample: {}\r", line_count);
470                    }
471                }
472            }
473        }
474
475        if last_valid_assignment == 0 {
476            continue;
477        }
478
479        ben32_to_ben_lines(&overflow[0..last_valid_assignment], &mut writer, variant)?;
480        overflow = overflow[last_valid_assignment..].to_vec();
481    }
482    logln!();
483    logln!("Done!");
484    Ok(())
485}
486
487/// This is a convenience function that decodes a general level 9 LZMA2 compressed file.
488///
489/// ```
490/// use ben::encode::xz_compress;
491/// use ben::decode::xz_decompress;
492/// use lipsum::lipsum;
493/// use std::io::{BufReader, BufWriter};
494///
495/// let input = lipsum(100);
496/// let reader = BufReader::new(input.as_bytes());
497/// let mut output_buffer = Vec::new();
498/// let writer = BufWriter::new(&mut output_buffer);
499///
500/// xz_compress(reader, writer, Some(1), Some(1)).unwrap();
501///
502/// let mut recovery_buff = Vec::new();
503/// let recovery_reader = BufWriter::new(&mut recovery_buff);
504/// xz_decompress(output_buffer.as_slice(), recovery_reader).unwrap();
505/// println!("{:?}", output_buffer);
506/// ```
507pub fn xz_decompress<R: BufRead, W: Write>(reader: R, mut writer: W) -> io::Result<()> {
508    let mut decoder = XzDecoder::new(reader);
509    let mut buffer = [0u8; 4096];
510
511    while let Ok(count) = decoder.read(&mut buffer) {
512        if count == 0 {
513            break;
514        }
515        writer.write_all(&buffer[..count])?;
516    }
517
518    Ok(())
519}
520
521/// This is a helper function that is designed to read in a single
522/// ben encoded line and convert it to a regular run-length encoded
523/// assignment vector.
524pub fn decode_ben_line<R: Read>(
525    mut reader: R,
526    max_val_bits: u8,
527    max_len_bits: u8,
528    n_bytes: u32,
529) -> io::Result<Vec<(u16, u16)>> {
530    let mut assign_bits: Vec<u8> = vec![0; n_bytes as usize];
531    reader.read_exact(&mut assign_bits)?;
532
533    // This should be right, but it doesn't need to be exact
534    let n_assignments: usize =
535        (n_bytes as f64 / ((max_val_bits + max_len_bits) as f64 / 8.0)) as usize;
536    let mut output_rle: Vec<(u16, u16)> = Vec::with_capacity(n_assignments);
537
538    let mut buffer: u32 = 0;
539    let mut n_bits_in_buff: u16 = 0;
540
541    let mut val = 0;
542    let mut val_set = false;
543    let mut len = 0;
544    let mut len_set = false;
545
546    for (_, &byte) in assign_bits.iter().enumerate() {
547        buffer = buffer | ((byte as u32).to_be() >> (n_bits_in_buff));
548        n_bits_in_buff += 8;
549
550        if n_bits_in_buff >= max_val_bits as u16 && !val_set {
551            val = (buffer >> (32 - max_val_bits)) as u16;
552
553            buffer = (buffer << max_val_bits) as u32;
554            n_bits_in_buff -= max_val_bits as u16;
555            val_set = true;
556        }
557
558        if n_bits_in_buff >= max_len_bits as u16 && val_set && !len_set {
559            len = (buffer >> (32 - max_len_bits)) as u16;
560            buffer = buffer << max_len_bits;
561            n_bits_in_buff -= max_len_bits as u16;
562            len_set = true;
563        }
564
565        if val_set && len_set {
566            // If max_val_bits and max_len_bits are <= 4
567            // then the rle can bet (0,0) pairs pushed to it
568            if len > 0 {
569                output_rle.push((val, len));
570            }
571            val_set = false;
572            len_set = false;
573        }
574
575        while n_bits_in_buff >= max_val_bits as u16 + max_len_bits as u16 {
576            if n_bits_in_buff >= max_val_bits as u16 && !val_set {
577                val = (buffer >> (32 - max_val_bits)) as u16;
578                buffer = (buffer << max_val_bits) as u32;
579                n_bits_in_buff -= max_val_bits as u16;
580                val_set = true;
581            }
582
583            if n_bits_in_buff >= max_len_bits as u16 && val_set && !len_set {
584                len = (buffer >> (32 - max_len_bits)) as u16;
585                buffer = buffer << max_len_bits;
586                n_bits_in_buff -= max_len_bits as u16;
587                len_set = true;
588            }
589
590            if val_set && len_set {
591                // If the max_val_bits and max_len_bits are <= 4
592                // then the rle can bet (0,0) pairs pushed to it
593                if len > 0 {
594                    output_rle.push((val, len));
595                }
596                val_set = false;
597                len_set = false;
598            }
599        }
600    }
601
602    Ok(output_rle)
603}
604
605/// This function takes a reader containing a file encoded in the BEN format
606/// and decodes it into a JSONL file.
607///
608/// The output JSONL file will have the formatting
609///
610/// ```json
611/// {"assignment": [...], "sample": #}
612/// ```
613///
614/// # Errors
615///
616/// This function will return an error if the input reader contains invalid ben
617/// data or if the the decode method encounters while trying to extract a single
618/// assignment vector, that error is then propagated.
619pub fn decode_ben_to_jsonl<R: Read, W: Write>(reader: R, writer: W) -> io::Result<()> {
620    let mut ben_decoder = BenDecoder::new(reader)?;
621    ben_decoder.write_all_jsonl(writer)
622}
623
624/// This function takes a reader containing a file encoded in the XBEN format
625/// and decodes it into a JSONL file.
626///
627/// The output JSONL file will have the formatting
628///
629/// ```json
630/// {"assignment": [...], "sample": #}
631/// ```
632///
633/// # Errors
634///
635/// This function will return an error if the input reader contains invalid xben
636/// data or if the the decode method encounters while trying to extract a single
637/// assignment vector, that error is then propagated.
638pub fn decode_xben_to_jsonl<R: BufRead, W: Write>(reader: R, mut writer: W) -> io::Result<()> {
639    let mut decoder = XzDecoder::new(reader);
640
641    let mut first_buffer = [0u8; 17];
642
643    if let Err(e) = decoder.read_exact(&mut first_buffer) {
644        return Err(e);
645    }
646
647    let variant = match &first_buffer {
648        b"STANDARD BEN FILE" => BenVariant::Standard,
649        b"MKVCHAIN BEN FILE" => BenVariant::MkvChain,
650        _ => {
651            return Err(Error::new(
652                io::ErrorKind::InvalidData,
653                "Invalid file format",
654            ));
655        }
656    };
657
658    let mut buffer = [0u8; 1 << 20]; // 1MB buffer
659    let mut overflow: Vec<u8> = Vec::new();
660
661    let mut line_count: usize = 0;
662    let mut starting_sample: usize = 0;
663    while let Ok(count) = decoder.read(&mut buffer) {
664        if count == 0 {
665            break;
666        }
667
668        overflow.extend(&buffer[..count]);
669
670        let mut last_valid_assignment = 0;
671
672        // It is technically faster to read backwards from the last
673        // multiple of 4 smaller than the length of the overflow buffer
674        // but this provides only a minute speedup in almost all cases (maybe a
675        // few seconds). Reading from the front is both safer from a
676        // maintenance perspective and allows for a better progress indicator
677        match variant {
678            BenVariant::Standard => {
679                for i in (3..overflow.len()).step_by(4) {
680                    if overflow[i - 3..=i] == [0, 0, 0, 0] {
681                        last_valid_assignment = i + 1;
682                        line_count += 1;
683                        log!("Decoding sample: {}\r", line_count);
684                    }
685                }
686            }
687            BenVariant::MkvChain => {
688                // Need a different step size here because each assignment
689                // vector is no longer guaranteed to be a multiple of 4 bytes
690                // due to the 2-byte repetition count appended at the end
691                for i in (last_valid_assignment + 3..overflow.len().saturating_sub(2)).step_by(2) {
692                    if overflow[i - 3..=i] == [0, 0, 0, 0] {
693                        last_valid_assignment = i + 3;
694                        let lines = &overflow[i + 1..i + 3];
695                        let n_lines = u16::from_be_bytes([lines[0], lines[1]]);
696                        line_count += n_lines as usize;
697                        log!("Decoding sample: {}\r", line_count);
698                    }
699                }
700            }
701        }
702
703        if last_valid_assignment == 0 {
704            continue;
705        }
706
707        jsonl_decode_ben32(
708            &overflow[0..last_valid_assignment],
709            &mut writer,
710            starting_sample,
711            variant,
712        )?;
713        overflow.drain(..last_valid_assignment);
714        starting_sample = line_count;
715    }
716    logln!();
717    logln!("Done!");
718    Ok(())
719}
720
721/// Iterator over decoded assignments inside an XBEN stream.
722/// Yields `(assignment, count)` where `count` is the repetition count
723pub struct XBenDecoder<R: Read> {
724    xz: BufReader<XzDecoder<R>>,
725    pub variant: BenVariant,
726    overflow: Vec<u8>,
727    buf: Box<[u8]>, // reusable read buffer
728}
729
730impl<R: Read> XBenDecoder<R> {
731    pub fn new(reader: R) -> io::Result<Self> {
732        let xz = XzDecoder::new(reader);
733        let mut xz = BufReader::with_capacity(1 << 20, xz);
734
735        // Read the 17-byte banner to determine variant
736        let mut first = [0u8; 17];
737        xz.read_exact(&mut first)?;
738        let variant = match &first {
739            b"STANDARD BEN FILE" => BenVariant::Standard,
740            b"MKVCHAIN BEN FILE" => BenVariant::MkvChain,
741            _ => {
742                return Err(io::Error::new(
743                    io::ErrorKind::InvalidData,
744                    "Invalid .xben header (expecting STANDARD/MKVCHAIN BEN FILE)",
745                ));
746            }
747        };
748
749        Ok(Self {
750            xz,
751            variant,
752            overflow: Vec::with_capacity(1 << 20),
753            buf: vec![0u8; 1 << 20].into_boxed_slice(),
754        })
755    }
756
757    /// Try to pop one *complete* ben32 frame from `overflow`.
758    ///
759    /// # Arguments
760    ///
761    /// * `overflow` - A byte slice that may contain one or more complete ben32 frames.
762    ///
763    /// # Returns
764    ///
765    /// An Option containing a tuple of:
766    ///
767    /// * the complete frame as a byte slice,
768    /// * the number of bytes consumed from the start of `overflow` to get this frame,
769    fn pop_frame_from_overflow<'a>(&self, overflow: &'a [u8]) -> Option<(&'a [u8], usize, u16)> {
770        match self.variant {
771            BenVariant::Standard => {
772                // Frame ends right after 4 zero bytes
773                // ... [payload] ... 00 00 00 00
774                if overflow.len() < 4 {
775                    return None;
776                }
777                for i in (3..overflow.len()).step_by(4) {
778                    if overflow[i - 3..=i] == [0, 0, 0, 0] {
779                        let end = i + 1;
780                        let frame = &overflow[..end];
781                        // In STANDARD, count is always 1
782                        return Some((frame, end, 1));
783                    }
784                }
785                None
786            }
787            BenVariant::MkvChain => {
788                // ... [payload] ... 00 00 00 00 <n_lines_hi_byte> <n_lines_lo_byte>
789                if overflow.len() < 6 {
790                    return None;
791                }
792                for i in (3..overflow.len().saturating_sub(2)).step_by(2) {
793                    if overflow[i - 3..=i] == [0, 0, 0, 0] {
794                        let count_hi = overflow[i + 1];
795                        let count_lo = overflow[i + 2];
796                        let count = u16::from_be_bytes([count_hi, count_lo]);
797                        let end = i + 3; // inclusive of count bytes
798                        let frame = &overflow[..end];
799                        return Some((frame, end, count));
800                    }
801                }
802                None
803            }
804        }
805    }
806}
807
808/// Helper function to decode a ben32 frame (raw bytes) into an assignment vector.
809fn decode_xben_frame_to_assignment(
810    frame_bytes: &[u8],
811    variant: BenVariant,
812) -> io::Result<Vec<u16>> {
813    let cursor = Cursor::new(frame_bytes);
814    let (assignment, _) = decode_ben32_line(cursor, variant)?;
815    Ok(assignment)
816}
817
818impl<R: Read> Iterator for XBenDecoder<R> {
819    type Item = io::Result<MkvRecord>;
820
821    fn next(&mut self) -> Option<Self::Item> {
822        loop {
823            // If we already have a complete frame in overflow, decode and return it
824            if let Some((frame_bytes, consumed, count)) =
825                self.pop_frame_from_overflow(&self.overflow)
826            {
827                let res = match decode_xben_frame_to_assignment(frame_bytes, self.variant) {
828                    Ok(assignment) => Ok((assignment, count)),
829                    Err(e) => Err(e),
830                };
831                // drop the used bytes
832                self.overflow.drain(..consumed);
833                return Some(res);
834            }
835
836            // Otherwise, read more from the XZ stream
837            let read = match self.xz.read(&mut self.buf) {
838                Ok(0) => {
839                    // EOF: no more data; if there's leftover but not a full frame, report error or stop
840                    if self.overflow.is_empty() {
841                        return None;
842                    } else {
843                        return Some(Err(io::Error::new(
844                            io::ErrorKind::UnexpectedEof,
845                            "truncated .xben stream (partial frame at EOF)",
846                        )));
847                    }
848                }
849                Ok(n) => n,
850                Err(e) => return Some(Err(e)),
851            };
852            self.overflow.extend_from_slice(&self.buf[..read]);
853        }
854    }
855}
856
857/// A frame is the raw ben32 bytes plus its repetition count (1 for Standard).
858pub type Ben32Frame = (Vec<u8>, u16);
859
860/// Iterator over raw ben32 frames inside an XBEN stream.
861///
862/// Yields `(frame_bytes, count)` where `frame_bytes` includes the 4-byte
863/// 0x00_00_00_00 terminator; for `MkvChain` frames it also includes the
864/// 2-byte big-endian repetition count at the end. `count` is the decoded
865/// repetition count (1 for Standard).
866///
867/// Mainly useful for finding an assignment quickly
868pub struct XBenFrameDecoder<R: Read> {
869    inner: XBenDecoder<R>,
870}
871
872impl<R: Read> XBenFrameDecoder<R> {
873    pub fn new(reader: R) -> io::Result<Self> {
874        Ok(Self {
875            inner: XBenDecoder::new(reader)?,
876        })
877    }
878}
879
880impl<R: Read> Iterator for XBenFrameDecoder<R> {
881    type Item = io::Result<Ben32Frame>;
882
883    fn next(&mut self) -> Option<Self::Item> {
884        loop {
885            if let Some((frame, consumed, count)) =
886                self.inner.pop_frame_from_overflow(&self.inner.overflow)
887            {
888                // copy out the frame; caller owns the bytes
889                let out = frame.to_vec();
890                self.inner.overflow.drain(..consumed);
891                return Some(Ok((out, count)));
892            }
893
894            // refill from xz
895            let read = match self.inner.xz.read(&mut self.inner.buf) {
896                Ok(0) => {
897                    if self.inner.overflow.is_empty() {
898                        return None;
899                    } else {
900                        return Some(Err(io::Error::new(
901                            io::ErrorKind::UnexpectedEof,
902                            "truncated .xben stream (partial frame at EOF)",
903                        )));
904                    }
905                }
906                Ok(n) => n,
907                Err(e) => return Some(Err(e)),
908            };
909            self.inner
910                .overflow
911                .extend_from_slice(&self.inner.buf[..read]);
912        }
913    }
914}
915
916impl<R: Read> XBenDecoder<R> {
917    /// Consumes the decoder and iterate raw ben32 frames instead of decoded assignments.
918    pub fn into_frames(self) -> XBenFrameDecoder<R> {
919        XBenFrameDecoder { inner: self }
920    }
921}
922
923impl<R: Read> XBenDecoder<R> {
924    /// Count how many samples remain in this XBEN stream.
925    /// Consumes the decoder (fast: walks frames only).
926    pub fn count_samples(self) -> io::Result<usize> {
927        let mut total = 0usize;
928        for frame_res in self.into_frames() {
929            let (_bytes, cnt) = frame_res?; // raw ben32 bytes + repetition count
930            total += cnt as usize;
931        }
932        Ok(total)
933    }
934}
935
936/// A generalized frame object that can be either a BenFrame
937/// or a XBEN frame (raw bytes + variant).
938#[derive(Clone)]
939pub enum Frame {
940    Ben(BenFrame),             // from BenFrameDecoeder
941    XBen(Vec<u8>, BenVariant), // raw ben32 bytes + variant (count is carried beside)
942}
943
944pub enum Selection {
945    Indices(Peekable<std::vec::IntoIter<usize>>), // 1-based, sorted
946    Every { step: usize, offset: usize },         // 1-based
947    Range { start: usize, end: usize },           // inclusive, 1-based
948}
949
950/// Decode a Frame (Ben or XBen) into an assignment vector.
951fn decode_frame_to_assignment(frame: &Frame) -> io::Result<Vec<u16>> {
952    match frame {
953        Frame::Ben(f) => decode_ben_frame_to_assignment(f),
954        Frame::XBen(bytes, variant) => decode_xben_frame_to_assignment(bytes, *variant),
955    }
956}
957
958pub struct SubsampleFrameDecoder<I>
959where
960    I: Iterator<Item = io::Result<(Frame, u16)>>,
961{
962    inner: I,
963    selection: Selection,
964    sample: usize, // processed so far (1-based)
965}
966
967impl<I> SubsampleFrameDecoder<I>
968where
969    I: Iterator<Item = io::Result<(Frame, u16)>>,
970{
971    pub fn new(inner: I, selection: Selection) -> Self {
972        Self {
973            inner,
974            selection,
975            sample: 0,
976        }
977    }
978
979    /// 1-based indices, in any order (duplicates removed internally).
980    pub fn by_indices<T>(inner: I, indices: T) -> Self
981    where
982        T: IntoIterator<Item = usize>,
983    {
984        let mut v: Vec<usize> = indices.into_iter().collect();
985        v.sort_unstable();
986        v.dedup();
987        Self::new(inner, Selection::Indices(v.into_iter().peekable()))
988    }
989
990    /// Inclusive 1-based range [start, end].
991    pub fn by_range(inner: I, start: usize, end: usize) -> Self {
992        assert!(
993            start >= 1 && end >= start,
994            "range must be 1-based and end >= start"
995        );
996        Self::new(inner, Selection::Range { start, end })
997    }
998
999    /// Every `step` samples starting from 1-based `offset`.
1000    pub fn every(inner: I, step: usize, offset: usize) -> Self {
1001        assert!(step >= 1 && offset >= 1, "step and offset must be >= 1");
1002        Self::new(inner, Selection::Every { step, offset })
1003    }
1004
1005    // Helper function to count how many selected samples are in the interval [lo, hi].
1006    // Both lo and hi are 1-based, inclusive.
1007    fn count_selected_in(&mut self, lo: usize, hi: usize) -> u16 {
1008        match &mut self.selection {
1009            Selection::Indices(iter) => {
1010                let mut taken = 0u16;
1011                while let Some(&next) = iter.peek() {
1012                    if next < lo {
1013                        iter.next();
1014                        continue;
1015                    }
1016                    if next > hi {
1017                        break;
1018                    }
1019                    iter.next();
1020                    taken = taken.saturating_add(1);
1021                }
1022                taken
1023            }
1024            Selection::Every { step, offset } => {
1025                let start = lo.max(*offset);
1026                if start > hi {
1027                    return 0;
1028                }
1029                let r = (start as isize - *offset as isize).rem_euclid(*step as isize) as usize;
1030                let first = start + ((*step - r) % *step);
1031                if first > hi {
1032                    0
1033                } else {
1034                    (1 + (hi - first) / *step) as u16
1035                }
1036            }
1037            Selection::Range { start, end } => {
1038                if hi < *start || lo > *end {
1039                    0
1040                } else {
1041                    let a = lo.max(*start);
1042                    let b = hi.min(*end);
1043                    (b - a + 1) as u16
1044                }
1045            }
1046        }
1047    }
1048}
1049
1050impl<I> Iterator for SubsampleFrameDecoder<I>
1051where
1052    I: Iterator<Item = io::Result<(Frame, u16)>>,
1053{
1054    type Item = io::Result<MkvRecord>; // (Vec<u16>, u16)
1055
1056    fn next(&mut self) -> Option<Self::Item> {
1057        loop {
1058            // early-exit for Range
1059            if let Selection::Range { end, .. } = self.selection {
1060                if self.sample >= end {
1061                    return None;
1062                }
1063            }
1064            // early-exit for Indices
1065            if let Selection::Indices(ref mut it) = self.selection {
1066                if it.peek().is_none() {
1067                    return None;
1068                }
1069            }
1070
1071            let (frame, count) = match self.inner.next()? {
1072                Ok(x) => x,
1073                Err(e) => return Some(Err(e)),
1074            };
1075
1076            let lo = self.sample + 1;
1077            let hi = self.sample + count as usize;
1078            let selected = self.count_selected_in(lo, hi);
1079
1080            // advance regardless
1081            self.sample = hi;
1082
1083            if selected > 0 {
1084                match decode_frame_to_assignment(&frame) {
1085                    Ok(assignment) => return Some(Ok((assignment, selected))),
1086                    Err(e) => return Some(Err(e)),
1087                }
1088            }
1089        }
1090    }
1091}
1092
1093pub type FrameIter = Box<dyn Iterator<Item = io::Result<(Frame, u16)>> + Send>;
1094
1095/// Build a frame iterator from a file path and mode ("ben" or "xben")
1096///
1097/// Frame iteration is useful for subsampling since you do not need to decode every frame
1098/// into an assignment vector. Since the BEN standard includes information about the number
1099/// of bytes used to encode each frame, reading through the file and extracting particular
1100/// frames is incredibly fast.
1101///
1102/// # Arguments
1103///
1104/// * `file_path` - A PathBuf pointing to the BEN or XBEN file.
1105/// * `mode` - A string slice indicating the file type: "ben" or "xben".
1106///
1107/// # Returns
1108///
1109/// An io::Result containing a boxed iterator over frames and their repetition counts.
1110pub fn build_frame_iter(file_path: &PathBuf, mode: &str) -> io::Result<FrameIter> {
1111    let file = File::options().read(true).open(file_path)?;
1112    let reader = BufReader::new(file);
1113
1114    match mode {
1115        "ben" => {
1116            // Ben -> BenFrameDecoeder
1117            let frames = BenFrameDecoeder::new(reader)?; // Iterator<Item=io::Result<BenFrame>>
1118            let mapped = frames.map(|res| {
1119                res.map(|f| {
1120                    let cnt = f.count;
1121                    (Frame::Ben(f), cnt)
1122                })
1123            });
1124            Ok(Box::new(mapped))
1125        }
1126        "xben" => {
1127            // XBen -> XBenFrameDecoder (need variant)
1128            let x = XBenDecoder::new(reader)?;
1129            let variant = x.variant;
1130            let frames = x.into_frames(); // Iterator<Item=io::Result<(Vec<u8>, u16)>>
1131            let mapped =
1132                frames.map(move |res| res.map(|(bytes, cnt)| (Frame::XBen(bytes, variant), cnt)));
1133            Ok(Box::new(mapped))
1134        }
1135        _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Unknown mode")),
1136    }
1137}
1138
1139impl<R: Read + Send> BenDecoder<R> {
1140    /// Create a subsample iterator from this decoder that iterates over specific indices.
1141    /// These indices are 1-based.
1142    ///
1143    /// # Arguments
1144    ///
1145    /// * `indices` - A collection of 1-based indices to select.
1146    ///
1147    /// # Returns
1148    ///
1149    /// An io::Result containing a SubsampleFrameDecoder that yields
1150    /// decoded assignments and their repetition counts.
1151    pub fn into_subsample_by_indices<T>(
1152        self,
1153        indices: T,
1154    ) -> SubsampleFrameDecoder<impl Iterator<Item = io::Result<(Frame, u16)>> + Send>
1155    where
1156        T: IntoIterator<Item = usize>,
1157    {
1158        let frames = self.into_frames().map(|res| {
1159            res.map(|f| {
1160                let count = f.count;
1161                (Frame::Ben(f), count)
1162            })
1163        });
1164        SubsampleFrameDecoder::by_indices(frames, indices)
1165    }
1166
1167    /// Create a subsample iterator from this decoder that iterates over a range of samples.
1168    ///
1169    /// # Arguments
1170    ///
1171    /// * `start` - The 1-based start index (inclusive).
1172    /// * `end` - The 1-based end index (inclusive).
1173    ///
1174    /// # Returns
1175    ///
1176    /// An io::Result containing a SubsampleFrameDecoder that yields
1177    /// decoded assignments and their repetition counts.
1178    pub fn into_subsample_by_range(
1179        self,
1180        start: usize,
1181        end: usize,
1182    ) -> SubsampleFrameDecoder<impl Iterator<Item = io::Result<(Frame, u16)>> + Send> {
1183        let frames = self.into_frames().map(|res| {
1184            res.map(|f| {
1185                let cnt = f.count;
1186                (Frame::Ben(f), cnt)
1187            })
1188        });
1189        SubsampleFrameDecoder::by_range(frames, start, end)
1190    }
1191
1192    /// Create a subsample iterator from this decoder that iterates every `step` samples
1193    /// starting from 1-based `offset`.
1194    ///
1195    /// # Arguments
1196    ///
1197    /// * `step` - The step size (must be >= 1).
1198    /// * `offset` - The 1-based offset to start from (must be >= 1).
1199    ///
1200    /// # Returns
1201    ///
1202    /// An io::Result containing a SubsampleFrameDecoder that yields
1203    /// decoded assignments and their repetition counts.
1204    pub fn into_subsample_every(
1205        self,
1206        step: usize,
1207        offset: usize,
1208    ) -> SubsampleFrameDecoder<impl Iterator<Item = io::Result<(Frame, u16)>> + Send> {
1209        let frames = self.into_frames().map(|res| {
1210            res.map(|f| {
1211                let cnt = f.count;
1212                (Frame::Ben(f), cnt)
1213            })
1214        });
1215        SubsampleFrameDecoder::every(frames, step, offset)
1216    }
1217}
1218
1219impl<R: Read + Send> XBenDecoder<R> {
1220    /// Create a subsample iterator from this decoder that iterates over specific indices.
1221    /// These indices are 1-based.
1222    ///
1223    /// # Arguments
1224    ///
1225    /// * `indices` - A collection of 1-based indices to select.
1226    ///
1227    /// # Returns
1228    ///
1229    /// An io::Result containing a SubsampleFrameDecoder that yields
1230    /// decoded assignments and their repetition counts.
1231    pub fn into_subsample_by_indices<T>(
1232        self,
1233        indices: T,
1234    ) -> SubsampleFrameDecoder<impl Iterator<Item = io::Result<(Frame, u16)>> + Send>
1235    where
1236        T: IntoIterator<Item = usize>,
1237    {
1238        let variant = self.variant; // ensure BenVariant: Copy
1239        let frames = self
1240            .into_frames()
1241            .map(move |res| res.map(|(bytes, cnt)| (Frame::XBen(bytes, variant), cnt)));
1242        SubsampleFrameDecoder::by_indices(Box::new(frames), indices)
1243    }
1244
1245    /// Create a subsample iterator from this decoder that iterates over a range of samples.
1246    ///
1247    /// # Arguments
1248    ///
1249    /// * `start` - The 1-based start index (inclusive).
1250    /// * `end` - The 1-based end index (inclusive).
1251    ///
1252    /// # Returns
1253    ///
1254    /// An io::Result containing a SubsampleFrameDecoder that yields
1255    /// decoded assignments and their repetition counts.
1256    pub fn into_subsample_by_range(
1257        self,
1258        start: usize,
1259        end: usize,
1260    ) -> SubsampleFrameDecoder<impl Iterator<Item = io::Result<(Frame, u16)>> + Send> {
1261        let variant = self.variant;
1262        let frames = self
1263            .into_frames()
1264            .map(move |res| res.map(|(bytes, cnt)| (Frame::XBen(bytes, variant), cnt)));
1265        SubsampleFrameDecoder::by_range(Box::new(frames), start, end)
1266    }
1267
1268    /// Create a subsample iterator from this decoder that iterates every `step` samples
1269    /// starting from 1-based `offset`.
1270    ///
1271    /// # Arguments
1272    ///
1273    /// * `step` - The step size (must be >= 1).
1274    /// * `offset` - The 1-based offset to start from (must be >= 1).
1275    ///
1276    /// # Returns
1277    ///
1278    /// An io::Result containing a SubsampleFrameDecoder that yields
1279    /// decoded assignments and their repetition counts.
1280    pub fn into_subsample_every(
1281        self,
1282        step: usize,
1283        offset: usize,
1284    ) -> SubsampleFrameDecoder<impl Iterator<Item = io::Result<(Frame, u16)>> + Send> {
1285        let variant = self.variant;
1286        let frames = self
1287            .into_frames()
1288            .map(move |res| res.map(|(bytes, cnt)| (Frame::XBen(bytes, variant), cnt)));
1289        SubsampleFrameDecoder::every(Box::new(frames), step, offset)
1290    }
1291}
1292
1293pub fn count_samples_from_file(path: &Path, mode: &str) -> io::Result<usize> {
1294    let iter = build_frame_iter(&path.to_path_buf(), mode)?;
1295    let mut total = 0usize;
1296    for item in iter {
1297        let (_frame, cnt) = item?;
1298        total += cnt as usize;
1299    }
1300    Ok(total)
1301}
1302
1303#[cfg(test)]
1304#[path = "tests/decode_tests.rs"]
1305mod tests;