Skip to main content

dct_io/
lib.rs

1#![forbid(unsafe_code)]
2//! Read and write quantized DCT coefficients in baseline JPEG files.
3//!
4//! This crate provides direct access to the quantized DCT coefficients stored
5//! in the entropy-coded data of a baseline JPEG. It is useful for
6//! steganography, watermarking, forensic analysis, and JPEG-domain signal
7//! processing where you need to read or modify coefficients without fully
8//! decoding the image to pixel values.
9//!
10//! # What this crate does
11//!
12//! JPEG compresses images by dividing them into 8×8 pixel blocks, applying a
13//! Discrete Cosine Transform (DCT) to each block, quantizing the resulting
14//! coefficients, and then entropy-coding them with Huffman coding. This crate
15//! parses the entropy-coded stream, decodes the Huffman symbols, reconstructs
16//! the quantized coefficient values, and lets you read or modify them before
17//! re-encoding everything back into a valid JPEG byte stream.
18//!
19//! # What this crate does NOT do
20//!
21//! - It does not decode pixel values (no IDCT, no dequantisation).
22//! - It does not support progressive JPEG (SOF2), lossless JPEG (SOF3), or
23//!   arithmetic coding (SOF9). Passing such files returns an error.
24//! - It does not support JPEG 2000.
25//!
26//! # Supported JPEG variants
27//!
28//! - Baseline DCT (SOF0) — the most common variant
29//! - Extended sequential DCT (SOF1) — treated identically to SOF0
30//! - Grayscale (1 component) and colour (3 components, typically YCbCr)
31//! - All standard chroma subsampling ratios (4:4:4, 4:2:2, 4:2:0, etc.)
32//! - EXIF and JFIF headers
33//! - Restart markers (DRI / RST0–RST7)
34//!
35//! # Example
36//!
37//! ```no_run
38//! use dct_io::{read_coefficients, write_coefficients};
39//!
40//! let jpeg = std::fs::read("photo.jpg").unwrap();
41//!
42//! let mut coeffs = read_coefficients(&jpeg).unwrap();
43//!
44//! // Flip the LSB of every eligible AC coefficient in the first component.
45//! for block in &mut coeffs.components[0].blocks {
46//!     for coeff in block[1..].iter_mut() {
47//!         if *coeff != 0 {
48//!             *coeff ^= 1;
49//!         }
50//!     }
51//! }
52//!
53//! let modified = write_coefficients(&jpeg, &coeffs).unwrap();
54//! std::fs::write("photo_modified.jpg", modified).unwrap();
55//! ```
56
57// ── Public error type ─────────────────────────────────────────────────────────
58
59/// Errors returned by this crate.
60#[derive(Debug)]
61pub enum DctError {
62    /// The input does not start with a JPEG SOI marker (`0xFF 0xD8`).
63    NotJpeg,
64
65    /// The input was truncated mid-marker or mid-entropy-stream.
66    Truncated,
67
68    /// The entropy-coded data contains an invalid Huffman symbol or an
69    /// unexpected structure.
70    CorruptEntropy,
71
72    /// The JPEG uses a feature this crate does not support (e.g. progressive
73    /// scan, lossless, or arithmetic coding).
74    Unsupported(String),
75
76    /// A required marker or table is missing from the JPEG (e.g. no SOF, no
77    /// SOS, or a scan references a Huffman table that was not defined).
78    Missing(String),
79
80    /// The `JpegCoefficients` passed to [`write_coefficients`] is not
81    /// compatible with the JPEG (wrong number of components, wrong block
82    /// count, wrong component index).
83    Incompatible(String),
84}
85
86impl core::fmt::Display for DctError {
87    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
88        match self {
89            DctError::NotJpeg => f.write_str("not a JPEG file"),
90            DctError::Truncated => f.write_str("truncated JPEG data"),
91            DctError::CorruptEntropy => f.write_str("corrupt or malformed JPEG entropy stream"),
92            DctError::Unsupported(s) => write!(f, "unsupported JPEG variant: {}", s),
93            DctError::Missing(s) => write!(f, "missing required JPEG structure: {}", s),
94            DctError::Incompatible(s) => {
95                write!(f, "coefficient data is incompatible with this JPEG: {}", s)
96            }
97        }
98    }
99}
100
101impl std::error::Error for DctError {}
102
103// ── Public types ──────────────────────────────────────────────────────────────
104
105/// Metadata for a single image component, as read from the SOF marker.
106#[derive(Debug, Clone)]
107pub struct ComponentInfo {
108    /// Component identifier (1=Y, 2=Cb, 3=Cr in YCbCr; 1=Y in grayscale).
109    pub id: u8,
110    /// Horizontal sampling factor.
111    pub h_samp: u8,
112    /// Vertical sampling factor.
113    pub v_samp: u8,
114    /// Number of 8×8 DCT blocks this component contributes to the image.
115    pub block_count: usize,
116}
117
118/// Image metadata extracted from a JPEG without decoding the entropy stream.
119///
120/// Obtained from [`inspect`]. Cheaper than [`read_coefficients`] when you
121/// only need dimensions, component count, or block counts.
122#[derive(Debug, Clone)]
123pub struct JpegInfo {
124    /// Image width in pixels.
125    pub width: u16,
126    /// Image height in pixels.
127    pub height: u16,
128    /// Per-component metadata, in SOF order (typically Y, Cb, Cr).
129    pub components: Vec<ComponentInfo>,
130}
131
132/// Quantized DCT coefficients for a single component (Y, Cb, or Cr).
133///
134/// Each element of `blocks` is one 8×8 DCT block, stored in the JPEG zigzag
135/// scan order:
136/// - Index 0: DC coefficient (top-left of the frequency matrix).
137/// - Indices 1–63: AC coefficients in zigzag order.
138///
139/// The values are the quantized coefficients exactly as they appear in the
140/// JPEG bitstream. They have **not** been dequantized; multiply by the
141/// quantization table to recover the pre-quantized DCT values.
142#[derive(Debug, Clone)]
143pub struct ComponentCoefficients {
144    /// Component identifier as written in the JPEG SOF marker.
145    pub id: u8,
146    /// All 8×8 blocks for this component, in raster scan order (left-to-right,
147    /// top-to-bottom). Each block contains exactly 64 `i16` values.
148    pub blocks: Vec<[i16; 64]>,
149}
150
151/// Quantized DCT coefficients for all components in a JPEG image.
152///
153/// Returned by [`read_coefficients`] and accepted by [`write_coefficients`].
154#[derive(Debug, Clone)]
155pub struct JpegCoefficients {
156    /// One entry per component, in the order they appear in the JPEG SOF
157    /// marker (typically Y, Cb, Cr for colour images).
158    pub components: Vec<ComponentCoefficients>,
159}
160
161// ── Public API ────────────────────────────────────────────────────────────────
162
163/// Decode the quantized DCT coefficients from a baseline JPEG.
164///
165/// Returns [`JpegCoefficients`] containing all blocks for all components.
166/// Does not dequantize or apply IDCT; values are the raw quantized integers.
167///
168/// # Errors
169///
170/// Returns [`DctError`] if the input is not a supported baseline JPEG, if
171/// required markers are missing, or if the entropy stream is corrupt.
172#[must_use = "returns the decoded coefficients or an error; ignoring it discards the result"]
173pub fn read_coefficients(jpeg: &[u8]) -> Result<JpegCoefficients, DctError> {
174    let mut parser = JpegParser::new(jpeg)?;
175    parser.parse()?;
176    parser.decode_coefficients()
177}
178
179/// Re-encode a JPEG with modified DCT coefficients.
180///
181/// Takes the original JPEG bytes and a [`JpegCoefficients`] (typically
182/// obtained from [`read_coefficients`] and then modified), and produces a new
183/// JPEG byte stream with the updated coefficients re-encoded using the same
184/// Huffman tables as the original.
185///
186/// The output is a valid JPEG. All non-entropy-coded segments (EXIF, ICC
187/// profile, quantization tables, etc.) are preserved verbatim.
188///
189/// # Safety note
190///
191/// The output is only as valid as the input JPEG's Huffman tables permit.
192/// If you set a coefficient to a value whose (run, category) symbol does not
193/// exist in the original Huffman table, encoding will return
194/// [`DctError::CorruptEntropy`]. Stick to modifying the LSB of coefficients
195/// with `|v| >= 2` (JSteg-style) to stay safely within the table.
196///
197/// # Errors
198///
199/// Returns [`DctError::Incompatible`] if `coeffs` has a different number of
200/// components, a different block count, or mismatched component IDs compared
201/// to the original JPEG.
202/// Returns [`DctError`] for any parse or encoding failure.
203#[must_use = "returns the re-encoded JPEG bytes or an error; ignoring it discards the result"]
204pub fn write_coefficients(jpeg: &[u8], coeffs: &JpegCoefficients) -> Result<Vec<u8>, DctError> {
205    let mut parser = JpegParser::new(jpeg)?;
206    parser.parse()?;
207    parser.encode_coefficients(jpeg, coeffs)
208}
209
210/// Return the number of 8×8 DCT blocks per component in a JPEG.
211///
212/// The returned `Vec` has one entry per component (in SOF order). Useful
213/// for determining how many blocks are available before calling
214/// [`read_coefficients`].
215///
216/// # Errors
217///
218/// Returns [`DctError`] if the input is not a supported baseline JPEG.
219#[must_use = "returns block counts or an error; ignoring it discards the result"]
220pub fn block_count(jpeg: &[u8]) -> Result<Vec<usize>, DctError> {
221    let mut parser = JpegParser::new(jpeg)?;
222    parser.parse()?;
223    parser.block_counts()
224}
225
226/// Inspect a JPEG and return image metadata without decoding the entropy stream.
227///
228/// Much cheaper than [`read_coefficients`] when you only need the image
229/// dimensions, component layout, or block counts.
230///
231/// # Errors
232///
233/// Returns [`DctError`] if the input is not a supported baseline JPEG.
234#[must_use = "returns image metadata or an error; ignoring it discards the result"]
235pub fn inspect(jpeg: &[u8]) -> Result<JpegInfo, DctError> {
236    let mut parser = JpegParser::new(jpeg)?;
237    parser.parse()?;
238    let counts = parser.block_counts()?;
239    Ok(JpegInfo {
240        width: parser.image_width,
241        height: parser.image_height,
242        components: parser
243            .frame_components
244            .iter()
245            .enumerate()
246            .map(|(i, fc)| ComponentInfo {
247                id: fc.id,
248                h_samp: fc.h_samp,
249                v_samp: fc.v_samp,
250                block_count: counts[i],
251            })
252            .collect(),
253    })
254}
255
256/// Count the number of AC coefficients with `|v| >= 2` across all components.
257///
258/// These are the coefficients that can be modified without altering zero-run
259/// lengths or EOB positions — the eligible positions for JSteg-style LSB
260/// embedding. Decodes all coefficients internally; use
261/// [`JpegCoefficients::eligible_ac_count`] to avoid decoding twice.
262///
263/// # Errors
264///
265/// Returns [`DctError`] if the input is not a supported baseline JPEG.
266#[must_use = "returns the eligible AC coefficient count or an error; ignoring it discards the result"]
267pub fn eligible_ac_count(jpeg: &[u8]) -> Result<usize, DctError> {
268    Ok(read_coefficients(jpeg)?.eligible_ac_count())
269}
270
271impl JpegCoefficients {
272    /// Count the number of AC coefficients with `|v| >= 2` across all
273    /// components.
274    ///
275    /// Modifying only these coefficients preserves the zero-run structure of
276    /// the entropy stream, keeping the output a valid JPEG that is
277    /// perceptually indistinguishable from the original.
278    ///
279    /// # Example
280    ///
281    /// ```no_run
282    /// use dct_io::read_coefficients;
283    ///
284    /// let jpeg = std::fs::read("photo.jpg").unwrap();
285    /// let coeffs = read_coefficients(&jpeg).unwrap();
286    /// println!("Eligible AC positions: {}", coeffs.eligible_ac_count());
287    /// ```
288    #[must_use]
289    pub fn eligible_ac_count(&self) -> usize {
290        self.components
291            .iter()
292            .flat_map(|c| c.blocks.iter())
293            .flat_map(|b| b[1..].iter())
294            .filter(|&&v| v.abs() >= 2)
295            .count()
296    }
297}
298
299// ── Internal constants ────────────────────────────────────────────────────────
300
301/// Zigzag scan order: maps coefficient index (0..64) to (row, col) in an 8×8
302/// block, expressed as a flat index `row*8 + col`.
303#[rustfmt::skip]
304const ZIGZAG: [u8; 64] = [
305     0,  1,  8, 16,  9,  2,  3, 10,
306    17, 24, 32, 25, 18, 11,  4,  5,
307    12, 19, 26, 33, 40, 48, 41, 34,
308    27, 20, 13,  6,  7, 14, 21, 28,
309    35, 42, 49, 56, 57, 50, 43, 36,
310    29, 22, 15, 23, 30, 37, 44, 51,
311    58, 59, 52, 45, 38, 31, 39, 46,
312    53, 60, 61, 54, 47, 55, 62, 63,
313];
314
315/// Maximum number of MCUs we are willing to decode (safety cap).
316const MAX_MCU_COUNT: usize = 1_048_576; // 1M MCUs ~ 67 megapixels at 4:2:0
317
318// ── Value-category helper ─────────────────────────────────────────────────────
319
320/// JPEG value category: the number of bits needed to represent `abs(v)`.
321/// Category 0 is special (used for the zero DC difference and the EOB symbol).
322/// Capped at 15 to guard against malformed input.
323#[inline]
324fn category(value: i16) -> u8 {
325    if value == 0 {
326        return 0;
327    }
328    let abs = value.unsigned_abs();
329    let cat = (16u32 - abs.leading_zeros()) as u8;
330    cat.min(15)
331}
332
333/// Encode `value` into its (category, magnitude bits) JPEG representation.
334/// Returns `(cat, bits, bit_count)`.
335#[inline]
336fn encode_value(value: i16) -> (u8, u16, u8) {
337    let cat = category(value);
338    if cat == 0 {
339        return (0, 0, 0);
340    }
341    let bits = if value > 0 {
342        value as u16
343    } else {
344        // Negative: encode as (2^cat - 1 + value)
345        let v = (1i16 << cat) - 1 + value;
346        v as u16
347    };
348    (cat, bits, cat)
349}
350
351// ── Huffman table ─────────────────────────────────────────────────────────────
352
353/// A single Huffman table (DC or AC, for one component class).
354///
355/// Decoding uses a flat 65 536-entry lookup table indexed by the top 16 bits
356/// of the bit-stream. Each entry packs `(symbol << 8) | code_len` as a `u16`,
357/// with 0 meaning "no code with this prefix". This gives O(1) decode with no
358/// branch on the hot path.
359///
360/// Encoding uses a flat 256-entry array keyed by symbol (u8). Each entry is
361/// `(code, code_length)`; a length of 0 means the symbol is not in this table.
362#[derive(Clone)]
363struct HuffTable {
364    /// 16-bit LUT: index = top 16 stream bits → `(symbol << 8) | len`, 0 = invalid.
365    lut: Vec<u16>,
366    /// Encode table: `encode[symbol] = (code, code_length)`, len 0 = absent.
367    encode: [(u16, u8); 256],
368}
369
370impl std::fmt::Debug for HuffTable {
371    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372        let entries = self.encode.iter().filter(|e| e.1 > 0).count();
373        f.debug_struct("HuffTable")
374            .field("encode_entries", &entries)
375            .finish()
376    }
377}
378
379impl HuffTable {
380    /// Build a Huffman table from the DHT segment payload.
381    ///
382    /// `counts` is the 16-byte array of code counts per length (1..=16).
383    /// `symbols` is the flat list of symbols in canonical order.
384    fn from_jpeg(counts: &[u8; 16], symbols: &[u8]) -> Result<Self, DctError> {
385        let mut encode = [(0u16, 0u8); 256];
386        let mut lut = vec![0u16; 65536];
387        let mut code: u16 = 0;
388        let mut sym_idx = 0usize;
389
390        for len in 1u8..=16u8 {
391            let count = counts[(len - 1) as usize] as usize;
392            for _ in 0..count {
393                if sym_idx >= symbols.len() {
394                    return Err(DctError::CorruptEntropy);
395                }
396                // Guard against a malformed DHT where the canonical code would
397                // overflow 16 bits or index outside our LUT. Use u32 for the
398                // shift so `len == 16` does not itself overflow.
399                if (code as u32) >= (1u32 << len) {
400                    return Err(DctError::CorruptEntropy);
401                }
402                let sym = symbols[sym_idx];
403                sym_idx += 1;
404                encode[sym as usize] = (code, len);
405
406                // Fill all 16-bit keys whose top `len` bits equal `code`.
407                // Each such key represents a stream where the Huffman prefix
408                // is followed by arbitrary suffix bits.
409                let spread = 1usize << (16 - len);
410                let base = (code as usize) << (16 - len);
411                let entry = ((sym as u16) << 8) | (len as u16);
412                lut[base..base + spread].fill(entry);
413
414                code += 1;
415            }
416            code <<= 1;
417        }
418
419        Ok(HuffTable { lut, encode })
420    }
421}
422
423// ── Bit reader ────────────────────────────────────────────────────────────────
424
425struct BitReader<'a> {
426    data: &'a [u8],
427    pos: usize,
428    buf: u64,
429    bits: u8,
430}
431
432impl<'a> BitReader<'a> {
433    fn new(data: &'a [u8]) -> Self {
434        BitReader {
435            data,
436            pos: 0,
437            buf: 0,
438            bits: 0,
439        }
440    }
441
442    /// Fill `buf` from the entropy stream, skipping byte stuffing (0xFF 0x00)
443    /// and stopping at any marker (0xFF 0xD0–0xD9 or any non-0x00 after 0xFF).
444    fn refill(&mut self) {
445        while self.bits <= 56 {
446            if self.pos >= self.data.len() {
447                break;
448            }
449            let byte = self.data[self.pos];
450            if byte == 0xFF {
451                if self.pos + 1 >= self.data.len() {
452                    break;
453                }
454                let next = self.data[self.pos + 1];
455                if next == 0x00 {
456                    // Byte stuffing — consume both, emit 0xFF.
457                    self.pos += 2;
458                    self.buf = (self.buf << 8) | 0xFF;
459                    self.bits += 8;
460                } else {
461                    // Marker — stop refilling.
462                    break;
463                }
464            } else {
465                self.pos += 1;
466                self.buf = (self.buf << 8) | (byte as u64);
467                self.bits += 8;
468            }
469        }
470    }
471
472    /// Peek at the top `n` bits without consuming them.
473    fn peek(&mut self, n: u8) -> Result<u16, DctError> {
474        if self.bits < n {
475            self.refill();
476        }
477        if self.bits < n {
478            return Err(DctError::Truncated);
479        }
480        Ok(((self.buf >> (self.bits - n)) & ((1u64 << n) - 1)) as u16)
481    }
482
483    /// Consume `n` bits.
484    fn consume(&mut self, n: u8) {
485        debug_assert!(self.bits >= n);
486        self.bits -= n;
487        self.buf &= (1u64 << self.bits) - 1;
488    }
489
490    /// Read `n` bits and return them as a `u16`.
491    fn read_bits(&mut self, n: u8) -> Result<u16, DctError> {
492        if n == 0 {
493            return Ok(0);
494        }
495        let v = self.peek(n)?;
496        self.consume(n);
497        Ok(v)
498    }
499
500    /// Decode the next Huffman symbol using the 16-bit LUT.
501    ///
502    /// Forms a 16-bit key from the top bits of the buffer (right-padded with
503    /// zeros if fewer than 16 bits are available). The LUT maps this key
504    /// directly to `(symbol, code_length)` in a single indexed read.
505    fn decode_huffman(&mut self, table: &HuffTable) -> Result<u8, DctError> {
506        if self.bits < 16 {
507            self.refill();
508        }
509        // Build the 16-bit key: top `min(bits, 16)` stream bits left-aligned.
510        let key = if self.bits >= 16 {
511            ((self.buf >> (self.bits - 16)) & 0xFFFF) as u16
512        } else {
513            // Fewer than 16 bits available — pad the right with zeros.
514            // The LUT entry for any short code covers all possible suffixes,
515            // so zero-padding is safe as long as len <= self.bits.
516            ((self.buf << (16 - self.bits)) & 0xFFFF) as u16
517        };
518
519        let entry = table.lut[key as usize];
520        let len = (entry & 0xFF) as u8;
521        let sym = (entry >> 8) as u8;
522
523        if len == 0 {
524            return Err(DctError::CorruptEntropy);
525        }
526        if self.bits < len {
527            return Err(DctError::Truncated);
528        }
529        self.consume(len);
530        Ok(sym)
531    }
532
533    /// Skip any restart marker at the current position and reset DC predictor.
534    /// Returns `true` if a restart marker was consumed.
535    fn sync_restart(&mut self) -> bool {
536        // Discard any remaining bits in the current byte.
537        self.bits = 0;
538        self.buf = 0;
539        // Check for a single 0xFF followed by RST0–RST7 (0xD0–0xD7).
540        if self.pos + 1 < self.data.len()
541            && self.data[self.pos] == 0xFF
542            && (0xD0..=0xD7).contains(&self.data[self.pos + 1])
543        {
544            self.pos += 2;
545            return true;
546        }
547        false
548    }
549}
550
551// ── Bit writer ────────────────────────────────────────────────────────────────
552
553struct BitWriter {
554    out: Vec<u8>,
555    buf: u64,
556    bits: u8,
557}
558
559impl BitWriter {
560    fn with_capacity(cap: usize) -> Self {
561        BitWriter {
562            out: Vec::with_capacity(cap),
563            buf: 0,
564            bits: 0,
565        }
566    }
567
568    /// Write `n` bits of `value` (MSB first).
569    fn write_bits(&mut self, value: u16, n: u8) {
570        if n == 0 {
571            return;
572        }
573        self.buf = (self.buf << n) | (value as u64);
574        self.bits += n;
575        while self.bits >= 8 {
576            self.bits -= 8;
577            let byte = ((self.buf >> self.bits) & 0xFF) as u8;
578            self.out.push(byte);
579            if byte == 0xFF {
580                self.out.push(0x00); // Byte stuffing.
581            }
582            self.buf &= (1u64 << self.bits) - 1;
583        }
584    }
585
586    /// Flush any remaining bits (padded with 1-bits per the JPEG spec).
587    fn flush(&mut self) {
588        if self.bits > 0 {
589            let pad = 8 - self.bits;
590            let byte = (((self.buf << pad) | ((1u64 << pad) - 1)) & 0xFF) as u8;
591            self.out.push(byte);
592            if byte == 0xFF {
593                self.out.push(0x00);
594            }
595            self.bits = 0;
596            self.buf = 0;
597        }
598    }
599
600    /// Emit a restart marker (0xFF 0xDn) directly into the output without
601    /// byte-stuffing (markers are not entropy data).
602    fn write_restart_marker(&mut self, n: u8) {
603        self.flush();
604        self.out.push(0xFF);
605        self.out.push(0xD0 | (n & 0x07));
606    }
607}
608
609// ── Internal JPEG parser ──────────────────────────────────────────────────────
610
611/// Metadata for one component as read from the SOF marker.
612#[derive(Debug, Clone)]
613struct FrameComponent {
614    id: u8,
615    h_samp: u8,
616    v_samp: u8,
617    #[allow(dead_code)]
618    qt_id: u8,
619}
620
621/// Per-component data from the SOS marker.
622#[derive(Debug, Clone)]
623struct ScanComponent {
624    comp_idx: usize, // index into frame_components
625    dc_table: usize,
626    ac_table: usize,
627}
628
629/// Parsed state accumulated while scanning JPEG markers.
630struct JpegParser<'a> {
631    data: &'a [u8],
632    pos: usize,
633
634    /// Byte offset of the first entropy-coded data byte.
635    entropy_start: usize,
636    /// Byte length of the entropy-coded segment (up to next non-RST marker).
637    entropy_len: usize,
638
639    frame_components: Vec<FrameComponent>,
640    scan_components: Vec<ScanComponent>,
641    dc_tables: [Option<HuffTable>; 4],
642    ac_tables: [Option<HuffTable>; 4],
643    restart_interval: u16,
644    image_width: u16,
645    image_height: u16,
646}
647
648impl<'a> JpegParser<'a> {
649    fn new(data: &'a [u8]) -> Result<Self, DctError> {
650        if data.len() < 2 || data[0] != 0xFF || data[1] != 0xD8 {
651            return Err(DctError::NotJpeg);
652        }
653        Ok(JpegParser {
654            data,
655            pos: 2,
656            entropy_start: 0,
657            entropy_len: 0,
658            frame_components: Vec::new(),
659            scan_components: Vec::new(),
660            dc_tables: [None, None, None, None],
661            ac_tables: [None, None, None, None],
662            restart_interval: 0,
663            image_width: 0,
664            image_height: 0,
665        })
666    }
667
668    /// Read a 2-byte big-endian u16 from `data[pos..]`, advancing `pos`.
669    fn read_u16(&mut self) -> Result<u16, DctError> {
670        if self.pos + 1 >= self.data.len() {
671            return Err(DctError::Truncated);
672        }
673        let v = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
674        self.pos += 2;
675        Ok(v)
676    }
677
678    /// Parse all JPEG markers up to and including SOS. Sets `entropy_start`
679    /// and `entropy_len`.
680    fn parse(&mut self) -> Result<(), DctError> {
681        loop {
682            // Find next marker.
683            if self.pos >= self.data.len() {
684                return Err(DctError::Missing("SOS marker".into()));
685            }
686            if self.data[self.pos] != 0xFF {
687                return Err(DctError::CorruptEntropy);
688            }
689            // Skip 0xFF padding.
690            while self.pos < self.data.len() && self.data[self.pos] == 0xFF {
691                self.pos += 1;
692            }
693            if self.pos >= self.data.len() {
694                return Err(DctError::Truncated);
695            }
696            let marker = self.data[self.pos];
697            self.pos += 1;
698
699            match marker {
700                0xD8 => {} // SOI — already consumed.
701                0xD9 => return Err(DctError::Missing("SOS before EOI".into())),
702
703                // SOF0 (baseline) and SOF1 (extended sequential) — supported.
704                0xC0 | 0xC1 => self.parse_sof()?,
705
706                // SOF markers we reject with a clear message.
707                0xC2 => return Err(DctError::Unsupported("progressive JPEG (SOF2)".into())),
708                0xC3 => return Err(DctError::Unsupported("lossless JPEG (SOF3)".into())),
709                0xC9 => return Err(DctError::Unsupported("arithmetic coding (SOF9)".into())),
710                0xCA => {
711                    return Err(DctError::Unsupported(
712                        "progressive arithmetic (SOF10)".into(),
713                    ))
714                }
715                0xCB => return Err(DctError::Unsupported("lossless arithmetic (SOF11)".into())),
716
717                0xC4 => self.parse_dht()?,
718                0xDD => self.parse_dri()?,
719
720                0xDA => {
721                    // SOS — parse header, then record entropy start.
722                    self.parse_sos_header()?;
723                    self.entropy_start = self.pos;
724                    self.entropy_len = self.find_entropy_end();
725                    return Ok(());
726                }
727
728                // Any other marker with a length field — skip.
729                _ => {
730                    let len = self.read_u16()? as usize;
731                    if len < 2 {
732                        return Err(DctError::CorruptEntropy);
733                    }
734                    let skip = len - 2;
735                    if self.pos + skip > self.data.len() {
736                        return Err(DctError::Truncated);
737                    }
738                    self.pos += skip;
739                }
740            }
741        }
742    }
743
744    fn parse_sof(&mut self) -> Result<(), DctError> {
745        let len = self.read_u16()? as usize;
746        if len < 8 {
747            return Err(DctError::CorruptEntropy);
748        }
749        let end = self.pos + len - 2;
750        if end > self.data.len() {
751            return Err(DctError::Truncated);
752        }
753        let _precision = self.data[self.pos];
754        self.pos += 1;
755        self.image_height = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
756        self.pos += 2;
757        self.image_width = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
758        self.pos += 2;
759
760        if self.image_width == 0 || self.image_height == 0 {
761            return Err(DctError::Unsupported("zero image dimension".into()));
762        }
763
764        let ncomp = self.data[self.pos] as usize;
765        self.pos += 1;
766
767        if ncomp == 0 || ncomp > 4 {
768            return Err(DctError::Unsupported(format!("{} components", ncomp)));
769        }
770        if self.pos + ncomp * 3 > end {
771            return Err(DctError::Truncated);
772        }
773
774        self.frame_components.clear();
775        for _ in 0..ncomp {
776            let id = self.data[self.pos];
777            let samp = self.data[self.pos + 1];
778            let qt_id = self.data[self.pos + 2];
779            self.pos += 3;
780            let h_samp = samp >> 4;
781            let v_samp = samp & 0x0F;
782            if h_samp == 0 || v_samp == 0 {
783                return Err(DctError::CorruptEntropy);
784            }
785            self.frame_components.push(FrameComponent {
786                id,
787                h_samp,
788                v_samp,
789                qt_id,
790            });
791        }
792        self.pos = end;
793        Ok(())
794    }
795
796    fn parse_dht(&mut self) -> Result<(), DctError> {
797        let len = self.read_u16()? as usize;
798        if len < 2 {
799            return Err(DctError::CorruptEntropy);
800        }
801        let end = self.pos + len - 2;
802        if end > self.data.len() {
803            return Err(DctError::Truncated);
804        }
805
806        while self.pos < end {
807            if self.pos >= self.data.len() {
808                return Err(DctError::Truncated);
809            }
810            let tc_th = self.data[self.pos];
811            self.pos += 1;
812            let tc = (tc_th >> 4) & 0x0F; // 0=DC, 1=AC
813            let th = (tc_th & 0x0F) as usize; // table index 0–3
814
815            if tc > 1 {
816                return Err(DctError::CorruptEntropy);
817            }
818            if th > 3 {
819                return Err(DctError::CorruptEntropy);
820            }
821
822            if self.pos + 16 > end {
823                return Err(DctError::Truncated);
824            }
825            let mut counts = [0u8; 16];
826            counts.copy_from_slice(&self.data[self.pos..self.pos + 16]);
827            self.pos += 16;
828
829            let total: usize = counts.iter().map(|&c| c as usize).sum();
830            // JPEG Huffman symbols are u8, so at most 256 unique symbols per table.
831            if total > 256 {
832                return Err(DctError::CorruptEntropy);
833            }
834            if self.pos + total > end {
835                return Err(DctError::Truncated);
836            }
837            let symbols = &self.data[self.pos..self.pos + total];
838            self.pos += total;
839
840            let table = HuffTable::from_jpeg(&counts, symbols)?;
841            if tc == 0 {
842                self.dc_tables[th] = Some(table);
843            } else {
844                self.ac_tables[th] = Some(table);
845            }
846        }
847
848        self.pos = end;
849        Ok(())
850    }
851
852    fn parse_dri(&mut self) -> Result<(), DctError> {
853        let len = self.read_u16()?;
854        if len != 4 {
855            return Err(DctError::CorruptEntropy);
856        }
857        self.restart_interval = self.read_u16()?;
858        Ok(())
859    }
860
861    fn parse_sos_header(&mut self) -> Result<(), DctError> {
862        let len = self.read_u16()? as usize;
863        if len < 3 {
864            return Err(DctError::CorruptEntropy);
865        }
866        let end = self.pos + len - 2;
867        if end > self.data.len() {
868            return Err(DctError::Truncated);
869        }
870
871        let ns = self.data[self.pos] as usize;
872        self.pos += 1;
873
874        if ns == 0 || ns > self.frame_components.len() {
875            return Err(DctError::CorruptEntropy);
876        }
877        if self.pos + ns * 2 > end {
878            return Err(DctError::Truncated);
879        }
880
881        self.scan_components.clear();
882        for _ in 0..ns {
883            let comp_id = self.data[self.pos];
884            let td_ta = self.data[self.pos + 1];
885            self.pos += 2;
886
887            let dc_table = (td_ta >> 4) as usize;
888            let ac_table = (td_ta & 0x0F) as usize;
889
890            if dc_table > 3 || ac_table > 3 {
891                return Err(DctError::CorruptEntropy);
892            }
893
894            let comp_idx = self
895                .frame_components
896                .iter()
897                .position(|fc| fc.id == comp_id)
898                .ok_or_else(|| DctError::Missing(format!("component id {} in frame", comp_id)))?;
899
900            self.scan_components.push(ScanComponent {
901                comp_idx,
902                dc_table,
903                ac_table,
904            });
905        }
906
907        // Skip Ss, Se, Ah/Al (3 bytes).
908        self.pos = end;
909        Ok(())
910    }
911
912    /// Find the length of the entropy-coded segment by scanning for a marker
913    /// that is not RST0–RST7 (0xD0–0xD7).
914    fn find_entropy_end(&self) -> usize {
915        let mut i = self.entropy_start;
916        while i < self.data.len() {
917            if self.data[i] == 0xFF && i + 1 < self.data.len() {
918                let next = self.data[i + 1];
919                if next == 0x00 {
920                    // Byte stuffing.
921                    i += 2;
922                    continue;
923                }
924                if (0xD0..=0xD7).contains(&next) {
925                    // RST marker inside entropy data — skip it.
926                    i += 2;
927                    continue;
928                }
929                // Real marker — entropy stream ends here.
930                return i - self.entropy_start;
931            }
932            i += 1;
933        }
934        self.data.len() - self.entropy_start
935    }
936
937    // ── MCU geometry helpers ──────────────────────────────────────────────────
938
939    fn max_h_samp(&self) -> u8 {
940        self.frame_components
941            .iter()
942            .map(|c| c.h_samp)
943            .max()
944            .unwrap_or(1)
945    }
946
947    fn max_v_samp(&self) -> u8 {
948        self.frame_components
949            .iter()
950            .map(|c| c.v_samp)
951            .max()
952            .unwrap_or(1)
953    }
954
955    fn mcu_cols(&self) -> usize {
956        let max_h = self.max_h_samp() as usize;
957        (self.image_width as usize + max_h * 8 - 1) / (max_h * 8)
958    }
959
960    fn mcu_rows(&self) -> usize {
961        let max_v = self.max_v_samp() as usize;
962        (self.image_height as usize + max_v * 8 - 1) / (max_v * 8)
963    }
964
965    fn mcu_count(&self) -> Result<usize, DctError> {
966        self.mcu_cols()
967            .checked_mul(self.mcu_rows())
968            .ok_or_else(|| DctError::Unsupported("image dimensions overflow usize".into()))
969    }
970
971    /// Number of 8×8 data units per MCU for each scan component.
972    fn du_per_mcu(&self) -> Vec<usize> {
973        self.scan_components
974            .iter()
975            .map(|sc| {
976                let fc = &self.frame_components[sc.comp_idx];
977                (fc.h_samp as usize) * (fc.v_samp as usize)
978            })
979            .collect()
980    }
981
982    /// Total block count per frame component (after all scan components resolved).
983    fn block_counts(&self) -> Result<Vec<usize>, DctError> {
984        let n_mcu = self.mcu_count()?;
985        let du = self.du_per_mcu();
986        let mut counts = vec![0usize; self.frame_components.len()];
987        for (sc_idx, sc) in self.scan_components.iter().enumerate() {
988            counts[sc.comp_idx] = n_mcu * du[sc_idx];
989        }
990        Ok(counts)
991    }
992
993    // ── Decode ────────────────────────────────────────────────────────────────
994
995    fn decode_coefficients(&self) -> Result<JpegCoefficients, DctError> {
996        let entropy = &self.data[self.entropy_start..self.entropy_start + self.entropy_len];
997        let n_mcu = self.mcu_count()?;
998
999        if n_mcu > MAX_MCU_COUNT {
1000            return Err(DctError::Unsupported(format!(
1001                "image too large ({} MCUs; max {})",
1002                n_mcu, MAX_MCU_COUNT
1003            )));
1004        }
1005
1006        let du = self.du_per_mcu();
1007
1008        // Pre-allocate output vectors.
1009        let counts = self.block_counts()?;
1010        let mut comp_blocks: Vec<Vec<[i16; 64]>> =
1011            counts.iter().map(|&c| vec![[0i16; 64]; c]).collect();
1012        let mut comp_block_idx: Vec<usize> = vec![0; self.frame_components.len()];
1013
1014        let mut dc_pred: Vec<i16> = vec![0; self.scan_components.len()];
1015        let mut reader = BitReader::new(entropy);
1016
1017        let restart_interval = self.restart_interval as usize;
1018
1019        for mcu_idx in 0..n_mcu {
1020            // Handle restart markers.
1021            if restart_interval > 0 && mcu_idx > 0 && mcu_idx % restart_interval == 0 {
1022                reader.sync_restart();
1023                for p in dc_pred.iter_mut() {
1024                    *p = 0;
1025                }
1026            }
1027
1028            for (sc_idx, sc) in self.scan_components.iter().enumerate() {
1029                let dc_table = self.dc_tables[sc.dc_table]
1030                    .as_ref()
1031                    .ok_or_else(|| DctError::Missing(format!("DC table {}", sc.dc_table)))?;
1032                let ac_table = self.ac_tables[sc.ac_table]
1033                    .as_ref()
1034                    .ok_or_else(|| DctError::Missing(format!("AC table {}", sc.ac_table)))?;
1035
1036                for _du_i in 0..du[sc_idx] {
1037                    let mut block = [0i16; 64];
1038
1039                    // DC coefficient.
1040                    let dc_cat = reader.decode_huffman(dc_table)?;
1041                    let dc_cat = dc_cat.min(15);
1042                    let dc_bits = reader.read_bits(dc_cat)?;
1043                    let dc_diff = decode_magnitude(dc_cat, dc_bits);
1044                    dc_pred[sc_idx] = dc_pred[sc_idx].saturating_add(dc_diff);
1045                    block[ZIGZAG[0] as usize] = dc_pred[sc_idx];
1046
1047                    // AC coefficients.
1048                    let mut k = 1usize;
1049                    while k < 64 {
1050                        let rs = reader.decode_huffman(ac_table)?;
1051                        if rs == 0x00 {
1052                            // EOB — rest of block is zero.
1053                            break;
1054                        }
1055                        if rs == 0xF0 {
1056                            // ZRL — 16 zeros.
1057                            k += 16;
1058                            continue;
1059                        }
1060                        let run = (rs >> 4) as usize;
1061                        let cat = (rs & 0x0F).min(15);
1062                        k += run;
1063                        if k >= 64 {
1064                            break;
1065                        }
1066                        let bits = reader.read_bits(cat)?;
1067                        let val = decode_magnitude(cat, bits);
1068                        block[ZIGZAG[k] as usize] = val;
1069                        k += 1;
1070                    }
1071
1072                    let block_idx = comp_block_idx[sc.comp_idx];
1073                    if block_idx >= comp_blocks[sc.comp_idx].len() {
1074                        return Err(DctError::CorruptEntropy);
1075                    }
1076                    comp_blocks[sc.comp_idx][block_idx] = block;
1077                    comp_block_idx[sc.comp_idx] += 1;
1078                }
1079            }
1080        }
1081
1082        let components = self
1083            .frame_components
1084            .iter()
1085            .zip(comp_blocks)
1086            .map(|(fc, blocks)| ComponentCoefficients { id: fc.id, blocks })
1087            .collect();
1088
1089        Ok(JpegCoefficients { components })
1090    }
1091
1092    // ── Encode ────────────────────────────────────────────────────────────────
1093
1094    fn encode_coefficients(
1095        &self,
1096        original: &[u8],
1097        coeffs: &JpegCoefficients,
1098    ) -> Result<Vec<u8>, DctError> {
1099        // Validate compatibility.
1100        if coeffs.components.len() != self.frame_components.len() {
1101            return Err(DctError::Incompatible(format!(
1102                "expected {} components, got {}",
1103                self.frame_components.len(),
1104                coeffs.components.len()
1105            )));
1106        }
1107        let counts = self.block_counts()?;
1108        for (i, (cc, &expected)) in coeffs.components.iter().zip(counts.iter()).enumerate() {
1109            if cc.id != self.frame_components[i].id {
1110                return Err(DctError::Incompatible(format!(
1111                    "component {}: expected id {}, got {}",
1112                    i, self.frame_components[i].id, cc.id
1113                )));
1114            }
1115            if cc.blocks.len() != expected {
1116                return Err(DctError::Incompatible(format!(
1117                    "component {}: expected {} blocks, got {}",
1118                    i,
1119                    expected,
1120                    cc.blocks.len()
1121                )));
1122            }
1123        }
1124
1125        let n_mcu = self.mcu_count()?;
1126        let du = self.du_per_mcu();
1127
1128        let mut writer = BitWriter::with_capacity(self.entropy_len);
1129        let mut dc_pred: Vec<i16> = vec![0; self.scan_components.len()];
1130        let mut comp_block_idx: Vec<usize> = vec![0; self.frame_components.len()];
1131        let restart_interval = self.restart_interval as usize;
1132        let mut rst_count: u8 = 0;
1133
1134        for mcu_idx in 0..n_mcu {
1135            if restart_interval > 0 && mcu_idx > 0 && mcu_idx % restart_interval == 0 {
1136                writer.write_restart_marker(rst_count);
1137                rst_count = rst_count.wrapping_add(1) & 0x07;
1138                for p in dc_pred.iter_mut() {
1139                    *p = 0;
1140                }
1141            }
1142
1143            for (sc_idx, sc) in self.scan_components.iter().enumerate() {
1144                let dc_table = self.dc_tables[sc.dc_table]
1145                    .as_ref()
1146                    .ok_or_else(|| DctError::Missing(format!("DC table {}", sc.dc_table)))?;
1147                let ac_table = self.ac_tables[sc.ac_table]
1148                    .as_ref()
1149                    .ok_or_else(|| DctError::Missing(format!("AC table {}", sc.ac_table)))?;
1150
1151                for _du_i in 0..du[sc_idx] {
1152                    let block = &coeffs.components[sc.comp_idx].blocks[comp_block_idx[sc.comp_idx]];
1153                    comp_block_idx[sc.comp_idx] += 1;
1154
1155                    // DC coefficient.
1156                    let dc_val = block[ZIGZAG[0] as usize];
1157                    let dc_diff = dc_val.saturating_sub(dc_pred[sc_idx]);
1158                    dc_pred[sc_idx] = dc_val;
1159                    let (dc_cat, dc_bits, dc_n) = encode_value(dc_diff);
1160                    let (dc_code, dc_code_len) = {
1161                        let e = dc_table.encode[dc_cat as usize];
1162                        if e.1 == 0 {
1163                            return Err(DctError::CorruptEntropy);
1164                        }
1165                        e
1166                    };
1167                    writer.write_bits(dc_code, dc_code_len);
1168                    writer.write_bits(dc_bits, dc_n);
1169
1170                    // AC coefficients.
1171                    // Find last non-zero AC position in zigzag order.
1172                    let last_nonzero_zz = (1..64).rev().find(|&i| block[ZIGZAG[i] as usize] != 0);
1173
1174                    let mut k = 1usize;
1175                    let mut zero_run = 0usize;
1176
1177                    if let Some(last_pos) = last_nonzero_zz {
1178                        while k <= last_pos {
1179                            let val = block[ZIGZAG[k] as usize];
1180                            if val == 0 {
1181                                zero_run += 1;
1182                                if zero_run == 16 {
1183                                    // Emit ZRL.
1184                                    let (zrl_code, zrl_len) = {
1185                                        let e = ac_table.encode[0xF0];
1186                                        if e.1 == 0 {
1187                                            return Err(DctError::CorruptEntropy);
1188                                        }
1189                                        e
1190                                    };
1191                                    writer.write_bits(zrl_code, zrl_len);
1192                                    zero_run = 0;
1193                                }
1194                            } else {
1195                                let (cat, bits, n) = encode_value(val);
1196                                let rs = ((zero_run as u8) << 4) | cat;
1197                                let (ac_code, ac_len) = {
1198                                    let e = ac_table.encode[rs as usize];
1199                                    if e.1 == 0 {
1200                                        return Err(DctError::CorruptEntropy);
1201                                    }
1202                                    e
1203                                };
1204                                writer.write_bits(ac_code, ac_len);
1205                                writer.write_bits(bits, n);
1206                                zero_run = 0;
1207                            }
1208                            k += 1;
1209                        }
1210                    }
1211                    // Emit EOB only when there are trailing zeros after the last
1212                    // non-zero coefficient. If the last non-zero is at position 63,
1213                    // EOB is unnecessary (libjpeg/libjpeg-turbo behaviour).
1214                    let needs_eob = last_nonzero_zz.map_or(true, |p| p < 63);
1215                    if needs_eob {
1216                        let (eob_code, eob_len) = {
1217                            let e = ac_table.encode[0x00];
1218                            if e.1 == 0 {
1219                                return Err(DctError::CorruptEntropy);
1220                            }
1221                            e
1222                        };
1223                        writer.write_bits(eob_code, eob_len);
1224                    }
1225                }
1226            }
1227        }
1228
1229        writer.flush();
1230
1231        // Reconstruct the full JPEG: everything before entropy data + new
1232        // entropy data + everything after (from the first post-entropy marker).
1233        let after_entropy = self.entropy_start + self.entropy_len;
1234        let mut out = Vec::with_capacity(original.len());
1235        out.extend_from_slice(&original[..self.entropy_start]);
1236        out.extend_from_slice(&writer.out);
1237        out.extend_from_slice(&original[after_entropy..]);
1238        Ok(out)
1239    }
1240}
1241
1242// ── Magnitude decode helper ───────────────────────────────────────────────────
1243
1244/// Decode a JPEG magnitude value from its category and raw bits.
1245fn decode_magnitude(cat: u8, bits: u16) -> i16 {
1246    if cat == 0 {
1247        return 0;
1248    }
1249    // If the MSB of `bits` is 1, the value is positive; otherwise negative.
1250    if bits >= (1u16 << (cat - 1)) {
1251        bits as i16
1252    } else {
1253        bits as i16 - (1i16 << cat) + 1
1254    }
1255}
1256
1257// ── Tests ─────────────────────────────────────────────────────────────────────
1258
1259#[cfg(test)]
1260mod tests {
1261    use super::*;
1262
1263    // Build a minimal valid baseline JPEG from raw pixel data using the
1264    // `image` crate, so our tests do not depend on external fixture files.
1265    fn make_jpeg_gray(width: u32, height: u32) -> Vec<u8> {
1266        use image::{codecs::jpeg::JpegEncoder, GrayImage, ImageEncoder};
1267        let img = GrayImage::from_fn(width, height, |x, y| {
1268            image::Luma([(((x * 7 + y * 13) % 200) + 28) as u8])
1269        });
1270        let mut buf = Vec::new();
1271        let enc = JpegEncoder::new_with_quality(&mut buf, 90);
1272        enc.write_image(img.as_raw(), width, height, image::ExtendedColorType::L8)
1273            .unwrap();
1274        buf
1275    }
1276
1277    fn make_jpeg_rgb(width: u32, height: u32) -> Vec<u8> {
1278        use image::{codecs::jpeg::JpegEncoder, ImageEncoder, RgbImage};
1279        let img = RgbImage::from_fn(width, height, |x, y| {
1280            image::Rgb([
1281                ((x * 11 + y * 3) % 200 + 28) as u8,
1282                ((x * 5 + y * 17) % 200 + 28) as u8,
1283                ((x * 3 + y * 7) % 200 + 28) as u8,
1284            ])
1285        });
1286        let mut buf = Vec::new();
1287        let enc = JpegEncoder::new_with_quality(&mut buf, 85);
1288        enc.write_image(img.as_raw(), width, height, image::ExtendedColorType::Rgb8)
1289            .unwrap();
1290        buf
1291    }
1292
1293    // ── Error path tests ──────────────────────────────────────────────────────
1294
1295    #[test]
1296    fn not_jpeg_returns_error() {
1297        let result = read_coefficients(b"PNG\x00garbage");
1298        assert!(matches!(result, Err(DctError::NotJpeg)));
1299    }
1300
1301    #[test]
1302    fn empty_input_returns_error() {
1303        assert!(matches!(read_coefficients(b""), Err(DctError::NotJpeg)));
1304    }
1305
1306    #[test]
1307    fn truncated_returns_error() {
1308        // A valid SOI but nothing else.
1309        assert!(matches!(
1310            read_coefficients(b"\xFF\xD8\xFF"),
1311            Err(DctError::Truncated | DctError::Missing(_))
1312        ));
1313    }
1314
1315    #[test]
1316    fn progressive_jpeg_returns_unsupported() {
1317        // Craft a minimal JPEG with SOF2 marker.
1318        let mut data = vec![0xFF, 0xD8]; // SOI
1319                                         // APP0 JFIF (minimal)
1320        data.extend_from_slice(&[0xFF, 0xE0, 0x00, 0x10]);
1321        data.extend_from_slice(&[
1322            0x4A, 0x46, 0x49, 0x46, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00,
1323        ]);
1324        // SOF2 marker (progressive)
1325        data.extend_from_slice(&[0xFF, 0xC2, 0x00, 0x0B]);
1326        data.extend_from_slice(&[0x08, 0x00, 0x10, 0x00, 0x10, 0x01, 0x01, 0x11, 0x00]);
1327        let result = read_coefficients(&data);
1328        assert!(matches!(result, Err(DctError::Unsupported(_))));
1329    }
1330
1331    #[test]
1332    fn incompatible_block_count_returns_error() {
1333        let jpeg = make_jpeg_gray(16, 16);
1334        let mut coeffs = read_coefficients(&jpeg).unwrap();
1335        // Remove one block to make it incompatible.
1336        coeffs.components[0].blocks.pop();
1337        let result = write_coefficients(&jpeg, &coeffs);
1338        assert!(matches!(result, Err(DctError::Incompatible(_))));
1339    }
1340
1341    // ── Roundtrip identity tests ──────────────────────────────────────────────
1342
1343    #[test]
1344    fn roundtrip_identity_gray() {
1345        let jpeg = make_jpeg_gray(32, 32);
1346        let coeffs = read_coefficients(&jpeg).unwrap();
1347        let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1348        // Re-encoding with unmodified coefficients must produce bit-identical output.
1349        assert_eq!(jpeg, reencoded, "roundtrip changed the JPEG bytes");
1350    }
1351
1352    #[test]
1353    fn roundtrip_identity_rgb() {
1354        let jpeg = make_jpeg_rgb(32, 32);
1355        let coeffs = read_coefficients(&jpeg).unwrap();
1356        let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1357        assert_eq!(jpeg, reencoded, "roundtrip changed the JPEG bytes");
1358    }
1359
1360    #[test]
1361    fn roundtrip_identity_non_square() {
1362        let jpeg = make_jpeg_rgb(48, 16);
1363        let coeffs = read_coefficients(&jpeg).unwrap();
1364        let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1365        assert_eq!(jpeg, reencoded);
1366    }
1367
1368    // ── Modification survival test ────────────────────────────────────────────
1369
1370    #[test]
1371    fn lsb_modification_survives_roundtrip() {
1372        let jpeg = make_jpeg_gray(32, 32);
1373        let mut coeffs = read_coefficients(&jpeg).unwrap();
1374
1375        let mut modified_count = 0usize;
1376        for block in &mut coeffs.components[0].blocks {
1377            for coeff in block[1..].iter_mut() {
1378                if coeff.abs() >= 2 {
1379                    *coeff ^= 1;
1380                    modified_count += 1;
1381                }
1382            }
1383        }
1384        assert!(
1385            modified_count > 0,
1386            "test image had no eligible coefficients"
1387        );
1388
1389        let modified_jpeg = write_coefficients(&jpeg, &coeffs).unwrap();
1390
1391        // Read back and verify the modifications are preserved.
1392        let coeffs2 = read_coefficients(&modified_jpeg).unwrap();
1393        assert_eq!(coeffs.components[0].blocks, coeffs2.components[0].blocks);
1394    }
1395
1396    // ── block_count tests ─────────────────────────────────────────────────────
1397
1398    #[test]
1399    fn block_count_gray_16x16() {
1400        let jpeg = make_jpeg_gray(16, 16);
1401        let counts = block_count(&jpeg).unwrap();
1402        // 16×16 / 8×8 = 4 blocks for the single Y component.
1403        assert_eq!(counts, vec![4]);
1404    }
1405
1406    #[test]
1407    fn block_count_rgb_32x32() {
1408        let jpeg = make_jpeg_rgb(32, 32);
1409        let counts = block_count(&jpeg).unwrap();
1410        // For 4:2:0 subsampling: Y has 4×(2×2)=16 blocks, Cb/Cr have 4 each.
1411        // For 4:4:4: all three have 16 blocks.
1412        // Accept either — exact layout depends on the encoder.
1413        assert_eq!(counts.len(), 3);
1414        let total: usize = counts.iter().sum();
1415        assert!(total > 0);
1416    }
1417
1418    // ── Category function tests ───────────────────────────────────────────────
1419
1420    #[test]
1421    fn category_values() {
1422        assert_eq!(category(0), 0);
1423        assert_eq!(category(1), 1);
1424        assert_eq!(category(-1), 1);
1425        assert_eq!(category(2), 2);
1426        assert_eq!(category(3), 2);
1427        assert_eq!(category(4), 3);
1428        assert_eq!(category(127), 7);
1429        assert_eq!(category(-128), 8);
1430        assert_eq!(category(1023), 10);
1431        assert_eq!(category(i16::MAX), 15); // capped at 15
1432    }
1433
1434    // ── Valid output test ─────────────────────────────────────────────────────
1435
1436    #[test]
1437    fn output_is_valid_jpeg() {
1438        let jpeg = make_jpeg_rgb(24, 24);
1439        let mut coeffs = read_coefficients(&jpeg).unwrap();
1440        // Flip one LSB.
1441        if let Some(block) = coeffs.components[0].blocks.first_mut() {
1442            block[1] |= 1;
1443        }
1444        let out = write_coefficients(&jpeg, &coeffs).unwrap();
1445        // Check SOI and EOI markers.
1446        assert_eq!(&out[..2], &[0xFF, 0xD8], "missing SOI");
1447        assert_eq!(&out[out.len() - 2..], &[0xFF, 0xD9], "missing EOI");
1448    }
1449
1450    // ── inspect() tests ───────────────────────────────────────────────────────
1451
1452    #[test]
1453    fn inspect_gray_returns_correct_dimensions() {
1454        let jpeg = make_jpeg_gray(32, 16);
1455        let info = inspect(&jpeg).unwrap();
1456        assert_eq!(info.width, 32);
1457        assert_eq!(info.height, 16);
1458        assert_eq!(info.components.len(), 1);
1459        assert_eq!(info.components[0].block_count, 8); // 4×2 blocks
1460    }
1461
1462    #[test]
1463    fn inspect_rgb_returns_three_components() {
1464        let jpeg = make_jpeg_rgb(32, 32);
1465        let info = inspect(&jpeg).unwrap();
1466        assert_eq!(info.width, 32);
1467        assert_eq!(info.height, 32);
1468        assert_eq!(info.components.len(), 3);
1469        // Total blocks across components must be positive.
1470        let total: usize = info.components.iter().map(|c| c.block_count).sum();
1471        assert!(total > 0);
1472    }
1473
1474    #[test]
1475    fn inspect_matches_block_count() {
1476        let jpeg = make_jpeg_rgb(48, 32);
1477        let info = inspect(&jpeg).unwrap();
1478        let counts = block_count(&jpeg).unwrap();
1479        let info_counts: Vec<usize> = info.components.iter().map(|c| c.block_count).collect();
1480        assert_eq!(info_counts, counts);
1481    }
1482
1483    // ── eligible_ac_count tests ───────────────────────────────────────────────
1484
1485    #[test]
1486    fn eligible_ac_count_is_positive() {
1487        let jpeg = make_jpeg_rgb(32, 32);
1488        let n = eligible_ac_count(&jpeg).unwrap();
1489        assert!(n > 0, "natural image should have eligible AC coefficients");
1490    }
1491
1492    #[test]
1493    fn eligible_ac_count_method_matches_free_fn() {
1494        let jpeg = make_jpeg_gray(32, 32);
1495        let coeffs = read_coefficients(&jpeg).unwrap();
1496        let via_method = coeffs.eligible_ac_count();
1497        let via_fn = eligible_ac_count(&jpeg).unwrap();
1498        assert_eq!(via_method, via_fn);
1499    }
1500
1501    #[test]
1502    fn eligible_ac_count_leq_total_ac_count() {
1503        let jpeg = make_jpeg_rgb(32, 32);
1504        let coeffs = read_coefficients(&jpeg).unwrap();
1505        let eligible = coeffs.eligible_ac_count();
1506        let total_ac: usize = coeffs
1507            .components
1508            .iter()
1509            .flat_map(|c| c.blocks.iter())
1510            .map(|_| 63) // 63 AC coefficients per block
1511            .sum();
1512        assert!(eligible <= total_ac);
1513    }
1514
1515    // ── LUT Huffman decode correctness (regression for the old HashMap version) ─
1516
1517    #[test]
1518    fn lut_decode_matches_modification_roundtrip() {
1519        // A natural image exercises many different Huffman code lengths.
1520        // If the LUT decode is wrong, the modification roundtrip will fail.
1521        let jpeg = make_jpeg_rgb(64, 64);
1522        let mut coeffs = read_coefficients(&jpeg).unwrap();
1523        let mut flipped = 0usize;
1524        for comp in &mut coeffs.components {
1525            for block in &mut comp.blocks {
1526                for coeff in block[1..].iter_mut() {
1527                    if coeff.abs() >= 2 {
1528                        *coeff ^= 1;
1529                        flipped += 1;
1530                    }
1531                }
1532            }
1533        }
1534        assert!(flipped > 0);
1535        let modified = write_coefficients(&jpeg, &coeffs).unwrap();
1536        let coeffs2 = read_coefficients(&modified).unwrap();
1537        assert_eq!(coeffs.components.len(), coeffs2.components.len());
1538        for (c1, c2) in coeffs.components.iter().zip(coeffs2.components.iter()) {
1539            assert_eq!(c1.blocks, c2.blocks);
1540        }
1541    }
1542}