Skip to main content

dct_io/
lib.rs

1#![forbid(unsafe_code)]
2//! Read and write quantized DCT coefficients in baseline JPEG files.
3//!
4//! This crate provides direct access to the quantized DCT coefficients stored
5//! in the entropy-coded data of a baseline JPEG. It is useful for
6//! steganography, watermarking, forensic analysis, and JPEG-domain signal
7//! processing where you need to read or modify coefficients without fully
8//! decoding the image to pixel values.
9//!
10//! # What this crate does
11//!
12//! JPEG compresses images by dividing them into 8×8 pixel blocks, applying a
13//! Discrete Cosine Transform (DCT) to each block, quantizing the resulting
14//! coefficients, and then entropy-coding them with Huffman coding. This crate
15//! parses the entropy-coded stream, decodes the Huffman symbols, reconstructs
16//! the quantized coefficient values, and lets you read or modify them before
17//! re-encoding everything back into a valid JPEG byte stream.
18//!
19//! # What this crate does NOT do
20//!
21//! - It does not decode pixel values (no IDCT, no dequantisation).
22//! - It does not support progressive JPEG (SOF2), lossless JPEG (SOF3), or
23//!   arithmetic coding (SOF9). Passing such files returns an error.
24//! - It does not support JPEG 2000.
25//!
26//! # Supported JPEG variants
27//!
28//! - Baseline DCT (SOF0) — the most common variant
29//! - Extended sequential DCT (SOF1) — treated identically to SOF0
30//! - Grayscale (1 component) and colour (3 components, typically YCbCr)
31//! - All standard chroma subsampling ratios (4:4:4, 4:2:2, 4:2:0, etc.)
32//! - EXIF and JFIF headers
33//! - Restart markers (DRI / RST0–RST7)
34//!
35//! # Example
36//!
37//! ```no_run
38//! use dct_io::{read_coefficients, write_coefficients};
39//!
40//! let jpeg = std::fs::read("photo.jpg").unwrap();
41//!
42//! let mut coeffs = read_coefficients(&jpeg).unwrap();
43//!
44//! // Flip the LSB of every eligible AC coefficient in the first component.
45//! for block in &mut coeffs.components[0].blocks {
46//!     for coeff in block[1..].iter_mut() {
47//!         if *coeff != 0 {
48//!             *coeff ^= 1;
49//!         }
50//!     }
51//! }
52//!
53//! let modified = write_coefficients(&jpeg, &coeffs).unwrap();
54//! std::fs::write("photo_modified.jpg", modified).unwrap();
55//! ```
56
57use thiserror::Error;
58
59// ── Public error type ─────────────────────────────────────────────────────────
60
61/// Errors returned by this crate.
62#[derive(Debug, Error)]
63pub enum DctError {
64    /// The input does not start with a JPEG SOI marker (`0xFF 0xD8`).
65    #[error("not a JPEG file")]
66    NotJpeg,
67
68    /// The input was truncated mid-marker or mid-entropy-stream.
69    #[error("truncated JPEG data")]
70    Truncated,
71
72    /// The entropy-coded data contains an invalid Huffman symbol or an
73    /// unexpected structure.
74    #[error("corrupt or malformed JPEG entropy stream")]
75    CorruptEntropy,
76
77    /// The JPEG uses a feature this crate does not support (e.g. progressive
78    /// scan, lossless, or arithmetic coding).
79    #[error("unsupported JPEG variant: {0}")]
80    Unsupported(String),
81
82    /// A required marker or table is missing from the JPEG (e.g. no SOF, no
83    /// SOS, or a scan references a Huffman table that was not defined).
84    #[error("missing required JPEG structure: {0}")]
85    Missing(String),
86
87    /// The `JpegCoefficients` passed to [`write_coefficients`] is not
88    /// compatible with the JPEG (wrong number of components, wrong block
89    /// count, wrong component index).
90    #[error("coefficient data is incompatible with this JPEG: {0}")]
91    Incompatible(String),
92}
93
94// ── Public types ──────────────────────────────────────────────────────────────
95
96/// Metadata for a single image component, as read from the SOF marker.
97#[derive(Debug, Clone)]
98pub struct ComponentInfo {
99    /// Component identifier (1=Y, 2=Cb, 3=Cr in YCbCr; 1=Y in grayscale).
100    pub id: u8,
101    /// Horizontal sampling factor.
102    pub h_samp: u8,
103    /// Vertical sampling factor.
104    pub v_samp: u8,
105    /// Number of 8×8 DCT blocks this component contributes to the image.
106    pub block_count: usize,
107}
108
109/// Image metadata extracted from a JPEG without decoding the entropy stream.
110///
111/// Obtained from [`inspect`]. Cheaper than [`read_coefficients`] when you
112/// only need dimensions, component count, or block counts.
113#[derive(Debug, Clone)]
114pub struct JpegInfo {
115    /// Image width in pixels.
116    pub width: u16,
117    /// Image height in pixels.
118    pub height: u16,
119    /// Per-component metadata, in SOF order (typically Y, Cb, Cr).
120    pub components: Vec<ComponentInfo>,
121}
122
123/// Quantized DCT coefficients for a single component (Y, Cb, or Cr).
124///
125/// Each element of `blocks` is one 8×8 DCT block, stored in the JPEG zigzag
126/// scan order:
127/// - Index 0: DC coefficient (top-left of the frequency matrix).
128/// - Indices 1–63: AC coefficients in zigzag order.
129///
130/// The values are the quantized coefficients exactly as they appear in the
131/// JPEG bitstream. They have **not** been dequantized; multiply by the
132/// quantization table to recover the pre-quantized DCT values.
133#[derive(Debug, Clone)]
134pub struct ComponentCoefficients {
135    /// Component identifier as written in the JPEG SOF marker.
136    pub id: u8,
137    /// All 8×8 blocks for this component, in raster scan order (left-to-right,
138    /// top-to-bottom). Each block contains exactly 64 `i16` values.
139    pub blocks: Vec<[i16; 64]>,
140}
141
142/// Quantized DCT coefficients for all components in a JPEG image.
143///
144/// Returned by [`read_coefficients`] and accepted by [`write_coefficients`].
145#[derive(Debug, Clone)]
146pub struct JpegCoefficients {
147    /// One entry per component, in the order they appear in the JPEG SOF
148    /// marker (typically Y, Cb, Cr for colour images).
149    pub components: Vec<ComponentCoefficients>,
150}
151
152// ── Public API ────────────────────────────────────────────────────────────────
153
154/// Decode the quantized DCT coefficients from a baseline JPEG.
155///
156/// Returns [`JpegCoefficients`] containing all blocks for all components.
157/// Does not dequantize or apply IDCT; values are the raw quantized integers.
158///
159/// # Errors
160///
161/// Returns [`DctError`] if the input is not a supported baseline JPEG, if
162/// required markers are missing, or if the entropy stream is corrupt.
163#[must_use = "returns the decoded coefficients or an error; ignoring it discards the result"]
164pub fn read_coefficients(jpeg: &[u8]) -> Result<JpegCoefficients, DctError> {
165    let mut parser = JpegParser::new(jpeg)?;
166    parser.parse()?;
167    parser.decode_coefficients()
168}
169
170/// Re-encode a JPEG with modified DCT coefficients.
171///
172/// Takes the original JPEG bytes and a [`JpegCoefficients`] (typically
173/// obtained from [`read_coefficients`] and then modified), and produces a new
174/// JPEG byte stream with the updated coefficients re-encoded using the same
175/// Huffman tables as the original.
176///
177/// The output is a valid JPEG. All non-entropy-coded segments (EXIF, ICC
178/// profile, quantization tables, etc.) are preserved verbatim.
179///
180/// # Safety note
181///
182/// The output is only as valid as the input JPEG's Huffman tables permit.
183/// If you set a coefficient to a value whose (run, category) symbol does not
184/// exist in the original Huffman table, encoding will return
185/// [`DctError::CorruptEntropy`]. Stick to modifying the LSB of coefficients
186/// with `|v| >= 2` (JSteg-style) to stay safely within the table.
187///
188/// # Errors
189///
190/// Returns [`DctError::Incompatible`] if `coeffs` has a different number of
191/// components, a different block count, or mismatched component IDs compared
192/// to the original JPEG.
193/// Returns [`DctError`] for any parse or encoding failure.
194#[must_use = "returns the re-encoded JPEG bytes or an error; ignoring it discards the result"]
195pub fn write_coefficients(jpeg: &[u8], coeffs: &JpegCoefficients) -> Result<Vec<u8>, DctError> {
196    let mut parser = JpegParser::new(jpeg)?;
197    parser.parse()?;
198    parser.encode_coefficients(jpeg, coeffs)
199}
200
201/// Return the number of 8×8 DCT blocks per component in a JPEG.
202///
203/// The returned `Vec` has one entry per component (in SOF order). Useful
204/// for determining how many blocks are available before calling
205/// [`read_coefficients`].
206///
207/// # Errors
208///
209/// Returns [`DctError`] if the input is not a supported baseline JPEG.
210#[must_use = "returns block counts or an error; ignoring it discards the result"]
211pub fn block_count(jpeg: &[u8]) -> Result<Vec<usize>, DctError> {
212    let mut parser = JpegParser::new(jpeg)?;
213    parser.parse()?;
214    parser.block_counts()
215}
216
217/// Inspect a JPEG and return image metadata without decoding the entropy stream.
218///
219/// Much cheaper than [`read_coefficients`] when you only need the image
220/// dimensions, component layout, or block counts.
221///
222/// # Errors
223///
224/// Returns [`DctError`] if the input is not a supported baseline JPEG.
225#[must_use = "returns image metadata or an error; ignoring it discards the result"]
226pub fn inspect(jpeg: &[u8]) -> Result<JpegInfo, DctError> {
227    let mut parser = JpegParser::new(jpeg)?;
228    parser.parse()?;
229    let counts = parser.block_counts()?;
230    Ok(JpegInfo {
231        width: parser.image_width,
232        height: parser.image_height,
233        components: parser
234            .frame_components
235            .iter()
236            .enumerate()
237            .map(|(i, fc)| ComponentInfo {
238                id: fc.id,
239                h_samp: fc.h_samp,
240                v_samp: fc.v_samp,
241                block_count: counts[i],
242            })
243            .collect(),
244    })
245}
246
247/// Count the number of AC coefficients with `|v| >= 2` across all components.
248///
249/// These are the coefficients that can be modified without altering zero-run
250/// lengths or EOB positions — the eligible positions for JSteg-style LSB
251/// embedding. Decodes all coefficients internally; use
252/// [`JpegCoefficients::eligible_ac_count`] to avoid decoding twice.
253///
254/// # Errors
255///
256/// Returns [`DctError`] if the input is not a supported baseline JPEG.
257#[must_use = "returns the eligible AC coefficient count or an error; ignoring it discards the result"]
258pub fn eligible_ac_count(jpeg: &[u8]) -> Result<usize, DctError> {
259    Ok(read_coefficients(jpeg)?.eligible_ac_count())
260}
261
262impl JpegCoefficients {
263    /// Count the number of AC coefficients with `|v| >= 2` across all
264    /// components.
265    ///
266    /// Modifying only these coefficients preserves the zero-run structure of
267    /// the entropy stream, keeping the output a valid JPEG that is
268    /// perceptually indistinguishable from the original.
269    ///
270    /// # Example
271    ///
272    /// ```no_run
273    /// use dct_io::read_coefficients;
274    ///
275    /// let jpeg = std::fs::read("photo.jpg").unwrap();
276    /// let coeffs = read_coefficients(&jpeg).unwrap();
277    /// println!("Eligible AC positions: {}", coeffs.eligible_ac_count());
278    /// ```
279    #[must_use]
280    pub fn eligible_ac_count(&self) -> usize {
281        self.components
282            .iter()
283            .flat_map(|c| c.blocks.iter())
284            .flat_map(|b| b[1..].iter())
285            .filter(|&&v| v.abs() >= 2)
286            .count()
287    }
288}
289
290// ── Internal constants ────────────────────────────────────────────────────────
291
292/// Zigzag scan order: maps coefficient index (0..64) to (row, col) in an 8×8
293/// block, expressed as a flat index `row*8 + col`.
294#[rustfmt::skip]
295const ZIGZAG: [u8; 64] = [
296     0,  1,  8, 16,  9,  2,  3, 10,
297    17, 24, 32, 25, 18, 11,  4,  5,
298    12, 19, 26, 33, 40, 48, 41, 34,
299    27, 20, 13,  6,  7, 14, 21, 28,
300    35, 42, 49, 56, 57, 50, 43, 36,
301    29, 22, 15, 23, 30, 37, 44, 51,
302    58, 59, 52, 45, 38, 31, 39, 46,
303    53, 60, 61, 54, 47, 55, 62, 63,
304];
305
306/// Maximum number of MCUs we are willing to decode (safety cap).
307const MAX_MCU_COUNT: usize = 1_048_576; // 1M MCUs ~ 67 megapixels at 4:2:0
308
309// ── Value-category helper ─────────────────────────────────────────────────────
310
311/// JPEG value category: the number of bits needed to represent `abs(v)`.
312/// Category 0 is special (used for the zero DC difference and the EOB symbol).
313/// Capped at 15 to guard against malformed input.
314#[inline]
315fn category(value: i16) -> u8 {
316    if value == 0 {
317        return 0;
318    }
319    let abs = value.unsigned_abs();
320    let cat = (16u32 - abs.leading_zeros()) as u8;
321    cat.min(15)
322}
323
324/// Encode `value` into its (category, magnitude bits) JPEG representation.
325/// Returns `(cat, bits, bit_count)`.
326#[inline]
327fn encode_value(value: i16) -> (u8, u16, u8) {
328    let cat = category(value);
329    if cat == 0 {
330        return (0, 0, 0);
331    }
332    let bits = if value > 0 {
333        value as u16
334    } else {
335        // Negative: encode as (2^cat - 1 + value)
336        let v = (1i16 << cat) - 1 + value;
337        v as u16
338    };
339    (cat, bits, cat)
340}
341
342// ── Huffman table ─────────────────────────────────────────────────────────────
343
344/// A single Huffman table (DC or AC, for one component class).
345///
346/// Decoding uses a flat 65 536-entry lookup table indexed by the top 16 bits
347/// of the bit-stream. Each entry packs `(symbol << 8) | code_len` as a `u16`,
348/// with 0 meaning "no code with this prefix". This gives O(1) decode with no
349/// branch on the hot path.
350///
351/// Encoding uses a flat 256-entry array keyed by symbol (u8). Each entry is
352/// `(code, code_length)`; a length of 0 means the symbol is not in this table.
353#[derive(Clone)]
354struct HuffTable {
355    /// 16-bit LUT: index = top 16 stream bits → `(symbol << 8) | len`, 0 = invalid.
356    lut: Vec<u16>,
357    /// Encode table: `encode[symbol] = (code, code_length)`, len 0 = absent.
358    encode: [(u16, u8); 256],
359}
360
361impl std::fmt::Debug for HuffTable {
362    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363        let entries = self.encode.iter().filter(|e| e.1 > 0).count();
364        f.debug_struct("HuffTable")
365            .field("encode_entries", &entries)
366            .finish()
367    }
368}
369
370impl HuffTable {
371    /// Build a Huffman table from the DHT segment payload.
372    ///
373    /// `counts` is the 16-byte array of code counts per length (1..=16).
374    /// `symbols` is the flat list of symbols in canonical order.
375    fn from_jpeg(counts: &[u8; 16], symbols: &[u8]) -> Result<Self, DctError> {
376        let mut encode = [(0u16, 0u8); 256];
377        let mut lut = vec![0u16; 65536];
378        let mut code: u16 = 0;
379        let mut sym_idx = 0usize;
380
381        for len in 1u8..=16u8 {
382            let count = counts[(len - 1) as usize] as usize;
383            for _ in 0..count {
384                if sym_idx >= symbols.len() {
385                    return Err(DctError::CorruptEntropy);
386                }
387                // Guard against a malformed DHT where the canonical code would
388                // overflow 16 bits or index outside our LUT. Use u32 for the
389                // shift so `len == 16` does not itself overflow.
390                if (code as u32) >= (1u32 << len) {
391                    return Err(DctError::CorruptEntropy);
392                }
393                let sym = symbols[sym_idx];
394                sym_idx += 1;
395                encode[sym as usize] = (code, len);
396
397                // Fill all 16-bit keys whose top `len` bits equal `code`.
398                // Each such key represents a stream where the Huffman prefix
399                // is followed by arbitrary suffix bits.
400                let spread = 1usize << (16 - len);
401                let base = (code as usize) << (16 - len);
402                let entry = ((sym as u16) << 8) | (len as u16);
403                lut[base..base + spread].fill(entry);
404
405                code += 1;
406            }
407            code <<= 1;
408        }
409
410        Ok(HuffTable { lut, encode })
411    }
412}
413
414// ── Bit reader ────────────────────────────────────────────────────────────────
415
416struct BitReader<'a> {
417    data: &'a [u8],
418    pos: usize,
419    buf: u64,
420    bits: u8,
421}
422
423impl<'a> BitReader<'a> {
424    fn new(data: &'a [u8]) -> Self {
425        BitReader { data, pos: 0, buf: 0, bits: 0 }
426    }
427
428    /// Fill `buf` from the entropy stream, skipping byte stuffing (0xFF 0x00)
429    /// and stopping at any marker (0xFF 0xD0–0xD9 or any non-0x00 after 0xFF).
430    fn refill(&mut self) {
431        while self.bits <= 56 {
432            if self.pos >= self.data.len() {
433                break;
434            }
435            let byte = self.data[self.pos];
436            if byte == 0xFF {
437                if self.pos + 1 >= self.data.len() {
438                    break;
439                }
440                let next = self.data[self.pos + 1];
441                if next == 0x00 {
442                    // Byte stuffing — consume both, emit 0xFF.
443                    self.pos += 2;
444                    self.buf = (self.buf << 8) | 0xFF;
445                    self.bits += 8;
446                } else {
447                    // Marker — stop refilling.
448                    break;
449                }
450            } else {
451                self.pos += 1;
452                self.buf = (self.buf << 8) | (byte as u64);
453                self.bits += 8;
454            }
455        }
456    }
457
458    /// Peek at the top `n` bits without consuming them.
459    fn peek(&mut self, n: u8) -> Result<u16, DctError> {
460        if self.bits < n {
461            self.refill();
462        }
463        if self.bits < n {
464            return Err(DctError::Truncated);
465        }
466        Ok(((self.buf >> (self.bits - n)) & ((1u64 << n) - 1)) as u16)
467    }
468
469    /// Consume `n` bits.
470    fn consume(&mut self, n: u8) {
471        debug_assert!(self.bits >= n);
472        self.bits -= n;
473        self.buf &= (1u64 << self.bits) - 1;
474    }
475
476    /// Read `n` bits and return them as a `u16`.
477    fn read_bits(&mut self, n: u8) -> Result<u16, DctError> {
478        if n == 0 {
479            return Ok(0);
480        }
481        let v = self.peek(n)?;
482        self.consume(n);
483        Ok(v)
484    }
485
486    /// Decode the next Huffman symbol using the 16-bit LUT.
487    ///
488    /// Forms a 16-bit key from the top bits of the buffer (right-padded with
489    /// zeros if fewer than 16 bits are available). The LUT maps this key
490    /// directly to `(symbol, code_length)` in a single indexed read.
491    fn decode_huffman(&mut self, table: &HuffTable) -> Result<u8, DctError> {
492        if self.bits < 16 {
493            self.refill();
494        }
495        // Build the 16-bit key: top `min(bits, 16)` stream bits left-aligned.
496        let key = if self.bits >= 16 {
497            ((self.buf >> (self.bits - 16)) & 0xFFFF) as u16
498        } else {
499            // Fewer than 16 bits available — pad the right with zeros.
500            // The LUT entry for any short code covers all possible suffixes,
501            // so zero-padding is safe as long as len <= self.bits.
502            ((self.buf << (16 - self.bits)) & 0xFFFF) as u16
503        };
504
505        let entry = table.lut[key as usize];
506        let len = (entry & 0xFF) as u8;
507        let sym = (entry >> 8) as u8;
508
509        if len == 0 {
510            return Err(DctError::CorruptEntropy);
511        }
512        if self.bits < len {
513            return Err(DctError::Truncated);
514        }
515        self.consume(len);
516        Ok(sym)
517    }
518
519    /// Skip any restart marker at the current position and reset DC predictor.
520    /// Returns `true` if a restart marker was consumed.
521    fn sync_restart(&mut self) -> bool {
522        // Discard any remaining bits in the current byte.
523        self.bits = 0;
524        self.buf = 0;
525        // Check for a single 0xFF followed by RST0–RST7 (0xD0–0xD7).
526        if self.pos + 1 < self.data.len()
527            && self.data[self.pos] == 0xFF
528            && (0xD0..=0xD7).contains(&self.data[self.pos + 1])
529        {
530            self.pos += 2;
531            return true;
532        }
533        false
534    }
535}
536
537// ── Bit writer ────────────────────────────────────────────────────────────────
538
539struct BitWriter {
540    out: Vec<u8>,
541    buf: u64,
542    bits: u8,
543}
544
545impl BitWriter {
546    fn with_capacity(cap: usize) -> Self {
547        BitWriter { out: Vec::with_capacity(cap), buf: 0, bits: 0 }
548    }
549
550    /// Write `n` bits of `value` (MSB first).
551    fn write_bits(&mut self, value: u16, n: u8) {
552        if n == 0 {
553            return;
554        }
555        self.buf = (self.buf << n) | (value as u64);
556        self.bits += n;
557        while self.bits >= 8 {
558            self.bits -= 8;
559            let byte = ((self.buf >> self.bits) & 0xFF) as u8;
560            self.out.push(byte);
561            if byte == 0xFF {
562                self.out.push(0x00); // Byte stuffing.
563            }
564            self.buf &= (1u64 << self.bits) - 1;
565        }
566    }
567
568    /// Flush any remaining bits (padded with 1-bits per the JPEG spec).
569    fn flush(&mut self) {
570        if self.bits > 0 {
571            let pad = 8 - self.bits;
572            let byte = (((self.buf << pad) | ((1u64 << pad) - 1)) & 0xFF) as u8;
573            self.out.push(byte);
574            if byte == 0xFF {
575                self.out.push(0x00);
576            }
577            self.bits = 0;
578            self.buf = 0;
579        }
580    }
581
582    /// Emit a restart marker (0xFF 0xDn) directly into the output without
583    /// byte-stuffing (markers are not entropy data).
584    fn write_restart_marker(&mut self, n: u8) {
585        self.flush();
586        self.out.push(0xFF);
587        self.out.push(0xD0 | (n & 0x07));
588    }
589}
590
591// ── Internal JPEG parser ──────────────────────────────────────────────────────
592
593/// Metadata for one component as read from the SOF marker.
594#[derive(Debug, Clone)]
595struct FrameComponent {
596    id: u8,
597    h_samp: u8,
598    v_samp: u8,
599    #[allow(dead_code)]
600    qt_id: u8,
601}
602
603/// Per-component data from the SOS marker.
604#[derive(Debug, Clone)]
605struct ScanComponent {
606    comp_idx: usize, // index into frame_components
607    dc_table: usize,
608    ac_table: usize,
609}
610
611/// Parsed state accumulated while scanning JPEG markers.
612struct JpegParser<'a> {
613    data: &'a [u8],
614    pos: usize,
615
616    /// Byte offset of the first entropy-coded data byte.
617    entropy_start: usize,
618    /// Byte length of the entropy-coded segment (up to next non-RST marker).
619    entropy_len: usize,
620
621    frame_components: Vec<FrameComponent>,
622    scan_components: Vec<ScanComponent>,
623    dc_tables: [Option<HuffTable>; 4],
624    ac_tables: [Option<HuffTable>; 4],
625    restart_interval: u16,
626    image_width: u16,
627    image_height: u16,
628}
629
630impl<'a> JpegParser<'a> {
631    fn new(data: &'a [u8]) -> Result<Self, DctError> {
632        if data.len() < 2 || data[0] != 0xFF || data[1] != 0xD8 {
633            return Err(DctError::NotJpeg);
634        }
635        Ok(JpegParser {
636            data,
637            pos: 2,
638            entropy_start: 0,
639            entropy_len: 0,
640            frame_components: Vec::new(),
641            scan_components: Vec::new(),
642            dc_tables: [None, None, None, None],
643            ac_tables: [None, None, None, None],
644            restart_interval: 0,
645            image_width: 0,
646            image_height: 0,
647        })
648    }
649
650    /// Read a 2-byte big-endian u16 from `data[pos..]`, advancing `pos`.
651    fn read_u16(&mut self) -> Result<u16, DctError> {
652        if self.pos + 1 >= self.data.len() {
653            return Err(DctError::Truncated);
654        }
655        let v = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
656        self.pos += 2;
657        Ok(v)
658    }
659
660    /// Parse all JPEG markers up to and including SOS. Sets `entropy_start`
661    /// and `entropy_len`.
662    fn parse(&mut self) -> Result<(), DctError> {
663        loop {
664            // Find next marker.
665            if self.pos >= self.data.len() {
666                return Err(DctError::Missing("SOS marker".into()));
667            }
668            if self.data[self.pos] != 0xFF {
669                return Err(DctError::CorruptEntropy);
670            }
671            // Skip 0xFF padding.
672            while self.pos < self.data.len() && self.data[self.pos] == 0xFF {
673                self.pos += 1;
674            }
675            if self.pos >= self.data.len() {
676                return Err(DctError::Truncated);
677            }
678            let marker = self.data[self.pos];
679            self.pos += 1;
680
681            match marker {
682                0xD8 => {} // SOI — already consumed.
683                0xD9 => return Err(DctError::Missing("SOS before EOI".into())),
684
685                // SOF0 (baseline) and SOF1 (extended sequential) — supported.
686                0xC0 | 0xC1 => self.parse_sof()?,
687
688                // SOF markers we reject with a clear message.
689                0xC2 => return Err(DctError::Unsupported("progressive JPEG (SOF2)".into())),
690                0xC3 => return Err(DctError::Unsupported("lossless JPEG (SOF3)".into())),
691                0xC9 => return Err(DctError::Unsupported("arithmetic coding (SOF9)".into())),
692                0xCA => return Err(DctError::Unsupported("progressive arithmetic (SOF10)".into())),
693                0xCB => return Err(DctError::Unsupported("lossless arithmetic (SOF11)".into())),
694
695                0xC4 => self.parse_dht()?,
696                0xDD => self.parse_dri()?,
697
698                0xDA => {
699                    // SOS — parse header, then record entropy start.
700                    self.parse_sos_header()?;
701                    self.entropy_start = self.pos;
702                    self.entropy_len = self.find_entropy_end();
703                    return Ok(());
704                }
705
706                // Any other marker with a length field — skip.
707                _ => {
708                    let len = self.read_u16()? as usize;
709                    if len < 2 {
710                        return Err(DctError::CorruptEntropy);
711                    }
712                    let skip = len - 2;
713                    if self.pos + skip > self.data.len() {
714                        return Err(DctError::Truncated);
715                    }
716                    self.pos += skip;
717                }
718            }
719        }
720    }
721
722    fn parse_sof(&mut self) -> Result<(), DctError> {
723        let len = self.read_u16()? as usize;
724        if len < 8 {
725            return Err(DctError::CorruptEntropy);
726        }
727        let end = self.pos + len - 2;
728        if end > self.data.len() {
729            return Err(DctError::Truncated);
730        }
731        let _precision = self.data[self.pos];
732        self.pos += 1;
733        self.image_height = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
734        self.pos += 2;
735        self.image_width = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
736        self.pos += 2;
737
738        if self.image_width == 0 || self.image_height == 0 {
739            return Err(DctError::Unsupported("zero image dimension".into()));
740        }
741
742        let ncomp = self.data[self.pos] as usize;
743        self.pos += 1;
744
745        if ncomp == 0 || ncomp > 4 {
746            return Err(DctError::Unsupported(format!("{ncomp} components")));
747        }
748        if self.pos + ncomp * 3 > end {
749            return Err(DctError::Truncated);
750        }
751
752        self.frame_components.clear();
753        for _ in 0..ncomp {
754            let id = self.data[self.pos];
755            let samp = self.data[self.pos + 1];
756            let qt_id = self.data[self.pos + 2];
757            self.pos += 3;
758            let h_samp = samp >> 4;
759            let v_samp = samp & 0x0F;
760            if h_samp == 0 || v_samp == 0 {
761                return Err(DctError::CorruptEntropy);
762            }
763            self.frame_components.push(FrameComponent {
764                id,
765                h_samp,
766                v_samp,
767                qt_id,
768            });
769        }
770        self.pos = end;
771        Ok(())
772    }
773
774    fn parse_dht(&mut self) -> Result<(), DctError> {
775        let len = self.read_u16()? as usize;
776        if len < 2 {
777            return Err(DctError::CorruptEntropy);
778        }
779        let end = self.pos + len - 2;
780        if end > self.data.len() {
781            return Err(DctError::Truncated);
782        }
783
784        while self.pos < end {
785            if self.pos >= self.data.len() {
786                return Err(DctError::Truncated);
787            }
788            let tc_th = self.data[self.pos];
789            self.pos += 1;
790            let tc = (tc_th >> 4) & 0x0F; // 0=DC, 1=AC
791            let th = (tc_th & 0x0F) as usize; // table index 0–3
792
793            if tc > 1 {
794                return Err(DctError::CorruptEntropy);
795            }
796            if th > 3 {
797                return Err(DctError::CorruptEntropy);
798            }
799
800            if self.pos + 16 > end {
801                return Err(DctError::Truncated);
802            }
803            let mut counts = [0u8; 16];
804            counts.copy_from_slice(&self.data[self.pos..self.pos + 16]);
805            self.pos += 16;
806
807            let total: usize = counts.iter().map(|&c| c as usize).sum();
808            // JPEG Huffman symbols are u8, so at most 256 unique symbols per table.
809            if total > 256 {
810                return Err(DctError::CorruptEntropy);
811            }
812            if self.pos + total > end {
813                return Err(DctError::Truncated);
814            }
815            let symbols = &self.data[self.pos..self.pos + total];
816            self.pos += total;
817
818            let table = HuffTable::from_jpeg(&counts, symbols)?;
819            if tc == 0 {
820                self.dc_tables[th] = Some(table);
821            } else {
822                self.ac_tables[th] = Some(table);
823            }
824        }
825
826        self.pos = end;
827        Ok(())
828    }
829
830    fn parse_dri(&mut self) -> Result<(), DctError> {
831        let len = self.read_u16()?;
832        if len != 4 {
833            return Err(DctError::CorruptEntropy);
834        }
835        self.restart_interval = self.read_u16()?;
836        Ok(())
837    }
838
839    fn parse_sos_header(&mut self) -> Result<(), DctError> {
840        let len = self.read_u16()? as usize;
841        if len < 3 {
842            return Err(DctError::CorruptEntropy);
843        }
844        let end = self.pos + len - 2;
845        if end > self.data.len() {
846            return Err(DctError::Truncated);
847        }
848
849        let ns = self.data[self.pos] as usize;
850        self.pos += 1;
851
852        if ns == 0 || ns > self.frame_components.len() {
853            return Err(DctError::CorruptEntropy);
854        }
855        if self.pos + ns * 2 > end {
856            return Err(DctError::Truncated);
857        }
858
859        self.scan_components.clear();
860        for _ in 0..ns {
861            let comp_id = self.data[self.pos];
862            let td_ta = self.data[self.pos + 1];
863            self.pos += 2;
864
865            let dc_table = (td_ta >> 4) as usize;
866            let ac_table = (td_ta & 0x0F) as usize;
867
868            if dc_table > 3 || ac_table > 3 {
869                return Err(DctError::CorruptEntropy);
870            }
871
872            let comp_idx = self
873                .frame_components
874                .iter()
875                .position(|fc| fc.id == comp_id)
876                .ok_or_else(|| {
877                    DctError::Missing(format!("component id {comp_id} in frame"))
878                })?;
879
880            self.scan_components.push(ScanComponent { comp_idx, dc_table, ac_table });
881        }
882
883        // Skip Ss, Se, Ah/Al (3 bytes).
884        self.pos = end;
885        Ok(())
886    }
887
888    /// Find the length of the entropy-coded segment by scanning for a marker
889    /// that is not RST0–RST7 (0xD0–0xD7).
890    fn find_entropy_end(&self) -> usize {
891        let mut i = self.entropy_start;
892        while i < self.data.len() {
893            if self.data[i] == 0xFF && i + 1 < self.data.len() {
894                let next = self.data[i + 1];
895                if next == 0x00 {
896                    // Byte stuffing.
897                    i += 2;
898                    continue;
899                }
900                if (0xD0..=0xD7).contains(&next) {
901                    // RST marker inside entropy data — skip it.
902                    i += 2;
903                    continue;
904                }
905                // Real marker — entropy stream ends here.
906                return i - self.entropy_start;
907            }
908            i += 1;
909        }
910        self.data.len() - self.entropy_start
911    }
912
913    // ── MCU geometry helpers ──────────────────────────────────────────────────
914
915    fn max_h_samp(&self) -> u8 {
916        self.frame_components.iter().map(|c| c.h_samp).max().unwrap_or(1)
917    }
918
919    fn max_v_samp(&self) -> u8 {
920        self.frame_components.iter().map(|c| c.v_samp).max().unwrap_or(1)
921    }
922
923    fn mcu_cols(&self) -> usize {
924        let max_h = self.max_h_samp() as usize;
925        (self.image_width as usize + max_h * 8 - 1) / (max_h * 8)
926    }
927
928    fn mcu_rows(&self) -> usize {
929        let max_v = self.max_v_samp() as usize;
930        (self.image_height as usize + max_v * 8 - 1) / (max_v * 8)
931    }
932
933    fn mcu_count(&self) -> Result<usize, DctError> {
934        self.mcu_cols()
935            .checked_mul(self.mcu_rows())
936            .ok_or_else(|| DctError::Unsupported("image dimensions overflow usize".into()))
937    }
938
939    /// Number of 8×8 data units per MCU for each scan component.
940    fn du_per_mcu(&self) -> Vec<usize> {
941        self.scan_components
942            .iter()
943            .map(|sc| {
944                let fc = &self.frame_components[sc.comp_idx];
945                (fc.h_samp as usize) * (fc.v_samp as usize)
946            })
947            .collect()
948    }
949
950    /// Total block count per frame component (after all scan components resolved).
951    fn block_counts(&self) -> Result<Vec<usize>, DctError> {
952        let n_mcu = self.mcu_count()?;
953        let du = self.du_per_mcu();
954        let mut counts = vec![0usize; self.frame_components.len()];
955        for (sc_idx, sc) in self.scan_components.iter().enumerate() {
956            counts[sc.comp_idx] = n_mcu * du[sc_idx];
957        }
958        Ok(counts)
959    }
960
961    // ── Decode ────────────────────────────────────────────────────────────────
962
963    fn decode_coefficients(&self) -> Result<JpegCoefficients, DctError> {
964        let entropy = &self.data[self.entropy_start..self.entropy_start + self.entropy_len];
965        let n_mcu = self.mcu_count()?;
966
967        if n_mcu > MAX_MCU_COUNT {
968            return Err(DctError::Unsupported(format!(
969                "image too large ({n_mcu} MCUs; max {MAX_MCU_COUNT})"
970            )));
971        }
972
973        let du = self.du_per_mcu();
974
975        // Pre-allocate output vectors.
976        let counts = self.block_counts()?;
977        let mut comp_blocks: Vec<Vec<[i16; 64]>> =
978            counts.iter().map(|&c| vec![[0i16; 64]; c]).collect();
979        let mut comp_block_idx: Vec<usize> = vec![0; self.frame_components.len()];
980
981        let mut dc_pred: Vec<i16> = vec![0; self.scan_components.len()];
982        let mut reader = BitReader::new(entropy);
983
984        let restart_interval = self.restart_interval as usize;
985
986        for mcu_idx in 0..n_mcu {
987            // Handle restart markers.
988            if restart_interval > 0 && mcu_idx > 0 && mcu_idx % restart_interval == 0 {
989                reader.sync_restart();
990                for p in dc_pred.iter_mut() {
991                    *p = 0;
992                }
993            }
994
995            for (sc_idx, sc) in self.scan_components.iter().enumerate() {
996                let dc_table = self.dc_tables[sc.dc_table]
997                    .as_ref()
998                    .ok_or_else(|| DctError::Missing(format!("DC table {}", sc.dc_table)))?;
999                let ac_table = self.ac_tables[sc.ac_table]
1000                    .as_ref()
1001                    .ok_or_else(|| DctError::Missing(format!("AC table {}", sc.ac_table)))?;
1002
1003                for _du_i in 0..du[sc_idx] {
1004                    let mut block = [0i16; 64];
1005
1006                    // DC coefficient.
1007                    let dc_cat = reader.decode_huffman(dc_table)?;
1008                    let dc_cat = dc_cat.min(15);
1009                    let dc_bits = reader.read_bits(dc_cat)?;
1010                    let dc_diff = decode_magnitude(dc_cat, dc_bits);
1011                    dc_pred[sc_idx] = dc_pred[sc_idx].saturating_add(dc_diff);
1012                    block[ZIGZAG[0] as usize] = dc_pred[sc_idx];
1013
1014                    // AC coefficients.
1015                    let mut k = 1usize;
1016                    while k < 64 {
1017                        let rs = reader.decode_huffman(ac_table)?;
1018                        if rs == 0x00 {
1019                            // EOB — rest of block is zero.
1020                            break;
1021                        }
1022                        if rs == 0xF0 {
1023                            // ZRL — 16 zeros.
1024                            k += 16;
1025                            continue;
1026                        }
1027                        let run = (rs >> 4) as usize;
1028                        let cat = (rs & 0x0F).min(15);
1029                        k += run;
1030                        if k >= 64 {
1031                            break;
1032                        }
1033                        let bits = reader.read_bits(cat)?;
1034                        let val = decode_magnitude(cat, bits);
1035                        block[ZIGZAG[k] as usize] = val;
1036                        k += 1;
1037                    }
1038
1039                    let block_idx = comp_block_idx[sc.comp_idx];
1040                    if block_idx >= comp_blocks[sc.comp_idx].len() {
1041                        return Err(DctError::CorruptEntropy);
1042                    }
1043                    comp_blocks[sc.comp_idx][block_idx] = block;
1044                    comp_block_idx[sc.comp_idx] += 1;
1045                }
1046            }
1047        }
1048
1049        let components = self
1050            .frame_components
1051            .iter()
1052            .zip(comp_blocks)
1053            .map(|(fc, blocks)| ComponentCoefficients { id: fc.id, blocks })
1054            .collect();
1055
1056        Ok(JpegCoefficients { components })
1057    }
1058
1059    // ── Encode ────────────────────────────────────────────────────────────────
1060
1061    fn encode_coefficients(
1062        &self,
1063        original: &[u8],
1064        coeffs: &JpegCoefficients,
1065    ) -> Result<Vec<u8>, DctError> {
1066        // Validate compatibility.
1067        if coeffs.components.len() != self.frame_components.len() {
1068            return Err(DctError::Incompatible(format!(
1069                "expected {} components, got {}",
1070                self.frame_components.len(),
1071                coeffs.components.len()
1072            )));
1073        }
1074        let counts = self.block_counts()?;
1075        for (i, (cc, &expected)) in coeffs.components.iter().zip(counts.iter()).enumerate() {
1076            if cc.id != self.frame_components[i].id {
1077                return Err(DctError::Incompatible(format!(
1078                    "component {i}: expected id {}, got {}",
1079                    self.frame_components[i].id, cc.id
1080                )));
1081            }
1082            if cc.blocks.len() != expected {
1083                return Err(DctError::Incompatible(format!(
1084                    "component {i}: expected {expected} blocks, got {}",
1085                    cc.blocks.len()
1086                )));
1087            }
1088        }
1089
1090        let n_mcu = self.mcu_count()?;
1091        let du = self.du_per_mcu();
1092
1093        let mut writer = BitWriter::with_capacity(self.entropy_len);
1094        let mut dc_pred: Vec<i16> = vec![0; self.scan_components.len()];
1095        let mut comp_block_idx: Vec<usize> = vec![0; self.frame_components.len()];
1096        let restart_interval = self.restart_interval as usize;
1097        let mut rst_count: u8 = 0;
1098
1099        for mcu_idx in 0..n_mcu {
1100            if restart_interval > 0 && mcu_idx > 0 && mcu_idx % restart_interval == 0 {
1101                writer.write_restart_marker(rst_count);
1102                rst_count = rst_count.wrapping_add(1) & 0x07;
1103                for p in dc_pred.iter_mut() {
1104                    *p = 0;
1105                }
1106            }
1107
1108            for (sc_idx, sc) in self.scan_components.iter().enumerate() {
1109                let dc_table = self.dc_tables[sc.dc_table]
1110                    .as_ref()
1111                    .ok_or_else(|| DctError::Missing(format!("DC table {}", sc.dc_table)))?;
1112                let ac_table = self.ac_tables[sc.ac_table]
1113                    .as_ref()
1114                    .ok_or_else(|| DctError::Missing(format!("AC table {}", sc.ac_table)))?;
1115
1116                for _du_i in 0..du[sc_idx] {
1117                    let block = &coeffs.components[sc.comp_idx].blocks
1118                        [comp_block_idx[sc.comp_idx]];
1119                    comp_block_idx[sc.comp_idx] += 1;
1120
1121                    // DC coefficient.
1122                    let dc_val = block[ZIGZAG[0] as usize];
1123                    let dc_diff = dc_val.saturating_sub(dc_pred[sc_idx]);
1124                    dc_pred[sc_idx] = dc_val;
1125                    let (dc_cat, dc_bits, dc_n) = encode_value(dc_diff);
1126                    let (dc_code, dc_code_len) = {
1127                        let e = dc_table.encode[dc_cat as usize];
1128                        if e.1 == 0 { return Err(DctError::CorruptEntropy); }
1129                        e
1130                    };
1131                    writer.write_bits(dc_code, dc_code_len);
1132                    writer.write_bits(dc_bits, dc_n);
1133
1134                    // AC coefficients.
1135                    // Find last non-zero AC position in zigzag order.
1136                    let last_nonzero_zz = (1..64)
1137                        .rev()
1138                        .find(|&i| block[ZIGZAG[i] as usize] != 0);
1139
1140                    let mut k = 1usize;
1141                    let mut zero_run = 0usize;
1142
1143                    if let Some(last_pos) = last_nonzero_zz {
1144                        while k <= last_pos {
1145                            let val = block[ZIGZAG[k] as usize];
1146                            if val == 0 {
1147                                zero_run += 1;
1148                                if zero_run == 16 {
1149                                    // Emit ZRL.
1150                                    let (zrl_code, zrl_len) = {
1151                                        let e = ac_table.encode[0xF0];
1152                                        if e.1 == 0 { return Err(DctError::CorruptEntropy); }
1153                                        e
1154                                    };
1155                                    writer.write_bits(zrl_code, zrl_len);
1156                                    zero_run = 0;
1157                                }
1158                            } else {
1159                                let (cat, bits, n) = encode_value(val);
1160                                let rs = ((zero_run as u8) << 4) | cat;
1161                                let (ac_code, ac_len) = {
1162                                    let e = ac_table.encode[rs as usize];
1163                                    if e.1 == 0 { return Err(DctError::CorruptEntropy); }
1164                                    e
1165                                };
1166                                writer.write_bits(ac_code, ac_len);
1167                                writer.write_bits(bits, n);
1168                                zero_run = 0;
1169                            }
1170                            k += 1;
1171                        }
1172                    }
1173                    // Emit EOB only when there are trailing zeros after the last
1174                    // non-zero coefficient. If the last non-zero is at position 63,
1175                    // EOB is unnecessary (libjpeg/libjpeg-turbo behaviour).
1176                    let needs_eob = last_nonzero_zz.map_or(true, |p| p < 63);
1177                    if needs_eob {
1178                        let (eob_code, eob_len) = {
1179                            let e = ac_table.encode[0x00];
1180                            if e.1 == 0 { return Err(DctError::CorruptEntropy); }
1181                            e
1182                        };
1183                        writer.write_bits(eob_code, eob_len);
1184                    }
1185                }
1186            }
1187        }
1188
1189        writer.flush();
1190
1191        // Reconstruct the full JPEG: everything before entropy data + new
1192        // entropy data + everything after (from the first post-entropy marker).
1193        let after_entropy = self.entropy_start + self.entropy_len;
1194        let mut out = Vec::with_capacity(original.len());
1195        out.extend_from_slice(&original[..self.entropy_start]);
1196        out.extend_from_slice(&writer.out);
1197        out.extend_from_slice(&original[after_entropy..]);
1198        Ok(out)
1199    }
1200}
1201
1202// ── Magnitude decode helper ───────────────────────────────────────────────────
1203
1204/// Decode a JPEG magnitude value from its category and raw bits.
1205fn decode_magnitude(cat: u8, bits: u16) -> i16 {
1206    if cat == 0 {
1207        return 0;
1208    }
1209    // If the MSB of `bits` is 1, the value is positive; otherwise negative.
1210    if bits >= (1u16 << (cat - 1)) {
1211        bits as i16
1212    } else {
1213        bits as i16 - (1i16 << cat) + 1
1214    }
1215}
1216
1217// ── Tests ─────────────────────────────────────────────────────────────────────
1218
1219#[cfg(test)]
1220mod tests {
1221    use super::*;
1222
1223    // Build a minimal valid baseline JPEG from raw pixel data using the
1224    // `image` crate, so our tests do not depend on external fixture files.
1225    fn make_jpeg_gray(width: u32, height: u32) -> Vec<u8> {
1226        use image::{GrayImage, ImageEncoder, codecs::jpeg::JpegEncoder};
1227        let img = GrayImage::from_fn(width, height, |x, y| {
1228            image::Luma([(((x * 7 + y * 13) % 200) + 28) as u8])
1229        });
1230        let mut buf = Vec::new();
1231        let enc = JpegEncoder::new_with_quality(&mut buf, 90);
1232        enc.write_image(img.as_raw(), width, height, image::ExtendedColorType::L8)
1233            .unwrap();
1234        buf
1235    }
1236
1237    fn make_jpeg_rgb(width: u32, height: u32) -> Vec<u8> {
1238        use image::{ImageEncoder, RgbImage, codecs::jpeg::JpegEncoder};
1239        let img = RgbImage::from_fn(width, height, |x, y| {
1240            image::Rgb([
1241                ((x * 11 + y * 3) % 200 + 28) as u8,
1242                ((x * 5  + y * 17) % 200 + 28) as u8,
1243                ((x * 3  + y * 7)  % 200 + 28) as u8,
1244            ])
1245        });
1246        let mut buf = Vec::new();
1247        let enc = JpegEncoder::new_with_quality(&mut buf, 85);
1248        enc.write_image(img.as_raw(), width, height, image::ExtendedColorType::Rgb8)
1249            .unwrap();
1250        buf
1251    }
1252
1253    // ── Error path tests ──────────────────────────────────────────────────────
1254
1255    #[test]
1256    fn not_jpeg_returns_error() {
1257        let result = read_coefficients(b"PNG\x00garbage");
1258        assert!(matches!(result, Err(DctError::NotJpeg)));
1259    }
1260
1261    #[test]
1262    fn empty_input_returns_error() {
1263        assert!(matches!(read_coefficients(b""), Err(DctError::NotJpeg)));
1264    }
1265
1266    #[test]
1267    fn truncated_returns_error() {
1268        // A valid SOI but nothing else.
1269        assert!(matches!(
1270            read_coefficients(b"\xFF\xD8\xFF"),
1271            Err(DctError::Truncated | DctError::Missing(_))
1272        ));
1273    }
1274
1275    #[test]
1276    fn progressive_jpeg_returns_unsupported() {
1277        // Craft a minimal JPEG with SOF2 marker.
1278        let mut data = vec![0xFF, 0xD8]; // SOI
1279        // APP0 JFIF (minimal)
1280        data.extend_from_slice(&[0xFF, 0xE0, 0x00, 0x10]);
1281        data.extend_from_slice(&[0x4A, 0x46, 0x49, 0x46, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00]);
1282        // SOF2 marker (progressive)
1283        data.extend_from_slice(&[0xFF, 0xC2, 0x00, 0x0B]);
1284        data.extend_from_slice(&[0x08, 0x00, 0x10, 0x00, 0x10, 0x01, 0x01, 0x11, 0x00]);
1285        let result = read_coefficients(&data);
1286        assert!(matches!(result, Err(DctError::Unsupported(_))));
1287    }
1288
1289    #[test]
1290    fn incompatible_block_count_returns_error() {
1291        let jpeg = make_jpeg_gray(16, 16);
1292        let mut coeffs = read_coefficients(&jpeg).unwrap();
1293        // Remove one block to make it incompatible.
1294        coeffs.components[0].blocks.pop();
1295        let result = write_coefficients(&jpeg, &coeffs);
1296        assert!(matches!(result, Err(DctError::Incompatible(_))));
1297    }
1298
1299    // ── Roundtrip identity tests ──────────────────────────────────────────────
1300
1301    #[test]
1302    fn roundtrip_identity_gray() {
1303        let jpeg = make_jpeg_gray(32, 32);
1304        let coeffs = read_coefficients(&jpeg).unwrap();
1305        let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1306        // Re-encoding with unmodified coefficients must produce bit-identical output.
1307        assert_eq!(jpeg, reencoded, "roundtrip changed the JPEG bytes");
1308    }
1309
1310    #[test]
1311    fn roundtrip_identity_rgb() {
1312        let jpeg = make_jpeg_rgb(32, 32);
1313        let coeffs = read_coefficients(&jpeg).unwrap();
1314        let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1315        assert_eq!(jpeg, reencoded, "roundtrip changed the JPEG bytes");
1316    }
1317
1318    #[test]
1319    fn roundtrip_identity_non_square() {
1320        let jpeg = make_jpeg_rgb(48, 16);
1321        let coeffs = read_coefficients(&jpeg).unwrap();
1322        let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
1323        assert_eq!(jpeg, reencoded);
1324    }
1325
1326    // ── Modification survival test ────────────────────────────────────────────
1327
1328    #[test]
1329    fn lsb_modification_survives_roundtrip() {
1330        let jpeg = make_jpeg_gray(32, 32);
1331        let mut coeffs = read_coefficients(&jpeg).unwrap();
1332
1333        let mut modified_count = 0usize;
1334        for block in &mut coeffs.components[0].blocks {
1335            for coeff in block[1..].iter_mut() {
1336                if coeff.abs() >= 2 {
1337                    *coeff ^= 1;
1338                    modified_count += 1;
1339                }
1340            }
1341        }
1342        assert!(modified_count > 0, "test image had no eligible coefficients");
1343
1344        let modified_jpeg = write_coefficients(&jpeg, &coeffs).unwrap();
1345
1346        // Read back and verify the modifications are preserved.
1347        let coeffs2 = read_coefficients(&modified_jpeg).unwrap();
1348        assert_eq!(coeffs.components[0].blocks, coeffs2.components[0].blocks);
1349    }
1350
1351    // ── block_count tests ─────────────────────────────────────────────────────
1352
1353    #[test]
1354    fn block_count_gray_16x16() {
1355        let jpeg = make_jpeg_gray(16, 16);
1356        let counts = block_count(&jpeg).unwrap();
1357        // 16×16 / 8×8 = 4 blocks for the single Y component.
1358        assert_eq!(counts, vec![4]);
1359    }
1360
1361    #[test]
1362    fn block_count_rgb_32x32() {
1363        let jpeg = make_jpeg_rgb(32, 32);
1364        let counts = block_count(&jpeg).unwrap();
1365        // For 4:2:0 subsampling: Y has 4×(2×2)=16 blocks, Cb/Cr have 4 each.
1366        // For 4:4:4: all three have 16 blocks.
1367        // Accept either — exact layout depends on the encoder.
1368        assert_eq!(counts.len(), 3);
1369        let total: usize = counts.iter().sum();
1370        assert!(total > 0);
1371    }
1372
1373    // ── Category function tests ───────────────────────────────────────────────
1374
1375    #[test]
1376    fn category_values() {
1377        assert_eq!(category(0), 0);
1378        assert_eq!(category(1), 1);
1379        assert_eq!(category(-1), 1);
1380        assert_eq!(category(2), 2);
1381        assert_eq!(category(3), 2);
1382        assert_eq!(category(4), 3);
1383        assert_eq!(category(127), 7);
1384        assert_eq!(category(-128), 8);
1385        assert_eq!(category(1023), 10);
1386        assert_eq!(category(i16::MAX), 15); // capped at 15
1387    }
1388
1389    // ── Valid output test ─────────────────────────────────────────────────────
1390
1391    #[test]
1392    fn output_is_valid_jpeg() {
1393        let jpeg = make_jpeg_rgb(24, 24);
1394        let mut coeffs = read_coefficients(&jpeg).unwrap();
1395        // Flip one LSB.
1396        if let Some(block) = coeffs.components[0].blocks.first_mut() {
1397            block[1] |= 1;
1398        }
1399        let out = write_coefficients(&jpeg, &coeffs).unwrap();
1400        // Check SOI and EOI markers.
1401        assert_eq!(&out[..2], &[0xFF, 0xD8], "missing SOI");
1402        assert_eq!(&out[out.len() - 2..], &[0xFF, 0xD9], "missing EOI");
1403    }
1404
1405    // ── inspect() tests ───────────────────────────────────────────────────────
1406
1407    #[test]
1408    fn inspect_gray_returns_correct_dimensions() {
1409        let jpeg = make_jpeg_gray(32, 16);
1410        let info = inspect(&jpeg).unwrap();
1411        assert_eq!(info.width, 32);
1412        assert_eq!(info.height, 16);
1413        assert_eq!(info.components.len(), 1);
1414        assert_eq!(info.components[0].block_count, 8); // 4×2 blocks
1415    }
1416
1417    #[test]
1418    fn inspect_rgb_returns_three_components() {
1419        let jpeg = make_jpeg_rgb(32, 32);
1420        let info = inspect(&jpeg).unwrap();
1421        assert_eq!(info.width, 32);
1422        assert_eq!(info.height, 32);
1423        assert_eq!(info.components.len(), 3);
1424        // Total blocks across components must be positive.
1425        let total: usize = info.components.iter().map(|c| c.block_count).sum();
1426        assert!(total > 0);
1427    }
1428
1429    #[test]
1430    fn inspect_matches_block_count() {
1431        let jpeg = make_jpeg_rgb(48, 32);
1432        let info = inspect(&jpeg).unwrap();
1433        let counts = block_count(&jpeg).unwrap();
1434        let info_counts: Vec<usize> = info.components.iter().map(|c| c.block_count).collect();
1435        assert_eq!(info_counts, counts);
1436    }
1437
1438    // ── eligible_ac_count tests ───────────────────────────────────────────────
1439
1440    #[test]
1441    fn eligible_ac_count_is_positive() {
1442        let jpeg = make_jpeg_rgb(32, 32);
1443        let n = eligible_ac_count(&jpeg).unwrap();
1444        assert!(n > 0, "natural image should have eligible AC coefficients");
1445    }
1446
1447    #[test]
1448    fn eligible_ac_count_method_matches_free_fn() {
1449        let jpeg = make_jpeg_gray(32, 32);
1450        let coeffs = read_coefficients(&jpeg).unwrap();
1451        let via_method = coeffs.eligible_ac_count();
1452        let via_fn = eligible_ac_count(&jpeg).unwrap();
1453        assert_eq!(via_method, via_fn);
1454    }
1455
1456    #[test]
1457    fn eligible_ac_count_leq_total_ac_count() {
1458        let jpeg = make_jpeg_rgb(32, 32);
1459        let coeffs = read_coefficients(&jpeg).unwrap();
1460        let eligible = coeffs.eligible_ac_count();
1461        let total_ac: usize = coeffs.components.iter()
1462            .flat_map(|c| c.blocks.iter())
1463            .map(|_| 63) // 63 AC coefficients per block
1464            .sum();
1465        assert!(eligible <= total_ac);
1466    }
1467
1468    // ── LUT Huffman decode correctness (regression for the old HashMap version) ─
1469
1470    #[test]
1471    fn lut_decode_matches_modification_roundtrip() {
1472        // A natural image exercises many different Huffman code lengths.
1473        // If the LUT decode is wrong, the modification roundtrip will fail.
1474        let jpeg = make_jpeg_rgb(64, 64);
1475        let mut coeffs = read_coefficients(&jpeg).unwrap();
1476        let mut flipped = 0usize;
1477        for comp in &mut coeffs.components {
1478            for block in &mut comp.blocks {
1479                for coeff in block[1..].iter_mut() {
1480                    if coeff.abs() >= 2 {
1481                        *coeff ^= 1;
1482                        flipped += 1;
1483                    }
1484                }
1485            }
1486        }
1487        assert!(flipped > 0);
1488        let modified = write_coefficients(&jpeg, &coeffs).unwrap();
1489        let coeffs2 = read_coefficients(&modified).unwrap();
1490        assert_eq!(coeffs.components.len(), coeffs2.components.len());
1491        for (c1, c2) in coeffs.components.iter().zip(coeffs2.components.iter()) {
1492            assert_eq!(c1.blocks, c2.blocks);
1493        }
1494    }
1495}