Skip to main content

phasm_core/codec/jpeg/
mod.rs

1// Copyright (c) 2026 Christoph Gaffga
2// SPDX-License-Identifier: GPL-3.0-only
3// https://github.com/cgaffga/phasmcore
4
5//! Pure-Rust JPEG coefficient codec (zero external dependencies).
6//!
7//! Reads and writes baseline and progressive JPEG files, providing direct access
8//! to quantized DCT coefficients without any pixel-domain processing. This is
9//! the foundation for steganographic embedding, which operates entirely in
10//! the DCT domain.
11//!
12//! Supports:
13//! - Baseline sequential DCT (SOF0), 8-bit precision
14//! - Progressive DCT (SOF2) — read-only (always writes baseline)
15//! - YCbCr, grayscale, and arbitrary component counts
16//! - Chroma subsampling: 4:2:0, 4:2:2, 4:4:4
17//! - Restart markers (DRI/RST)
18//! - Byte-for-byte round-trip for unmodified baseline images
19//! - Huffman table rebuild for modified coefficients
20//!
21//! Does NOT support:
22//! - Arithmetic coding (SOF9+) -- rejected at parse time
23//! - 12-bit precision -- rejected at parse time
24
25pub mod error;
26pub mod zigzag;
27pub mod dct;
28pub mod bitio;
29pub mod tables;
30pub mod huffman;
31pub mod frame;
32pub mod marker;
33pub mod scan;
34pub mod pixels;
35
36use dct::DctGrid;
37use error::{JpegError, Result};
38use frame::FrameInfo;
39use huffman::encode_value;
40use marker::{MarkerSegment, iterate_markers, iterate_markers_all, parse_sos, parse_sos_params, parse_dri};
41use scan::ScanComponent;
42use tables::{HuffmanSpec, parse_dqt, parse_dht};
43use zigzag::NATURAL_TO_ZIGZAG;
44
45/// A decoded JPEG image providing access to quantized DCT coefficients.
46///
47/// Created by parsing a JPEG byte stream with [`JpegImage::from_bytes`].
48/// After modifying DCT coefficients (e.g., for steganographic embedding),
49/// call [`JpegImage::to_bytes`] to re-encode. If coefficient modifications
50/// introduce symbols not present in the original Huffman tables, call
51/// [`JpegImage::rebuild_huffman_tables`] first.
52#[derive(Clone)]
53pub struct JpegImage {
54    /// Frame information (dimensions, components, sampling factors).
55    frame: FrameInfo,
56    /// DCT coefficient grids, one per component in scan order.
57    grids: Vec<DctGrid>,
58    /// Quantization tables, indexed by table ID (0–3).
59    quant_tables: [Option<dct::QuantTable>; 4],
60    /// DC Huffman table specs, indexed by table ID (0–3).
61    dc_huff_specs: [Option<HuffmanSpec>; 4],
62    /// AC Huffman table specs, indexed by table ID (0–3).
63    ac_huff_specs: [Option<HuffmanSpec>; 4],
64    /// Scan component selectors (component index + table IDs).
65    scan_components: Vec<ScanComponent>,
66    /// Restart interval (0 = no restarts).
67    restart_interval: u16,
68    /// Raw marker segments in original order (for header preservation).
69    /// Includes all markers between SOI and SOS (exclusive) except SOI itself.
70    raw_segments: Vec<MarkerSegment>,
71    /// Raw SOS header data (for exact reconstruction).
72    sos_data: Vec<u8>,
73}
74
75impl JpegImage {
76    /// Parse a JPEG file from bytes.
77    ///
78    /// Supports both baseline (SOF0) and progressive (SOF2) JPEG.
79    /// Progressive images are decoded by accumulating all scans, then the
80    /// coefficients are stored exactly as in baseline — `to_bytes()` always
81    /// writes baseline output.
82    pub fn from_bytes(data: &[u8]) -> Result<Self> {
83        // First pass: quick check if this is progressive by scanning for SOF2
84        // We use iterate_markers_all which handles multiple SOS markers.
85        let is_progressive = Self::check_progressive(data);
86
87        if is_progressive {
88            Self::from_bytes_progressive(data)
89        } else {
90            Self::from_bytes_baseline(data)
91        }
92    }
93
94    /// Quick check: does this JPEG contain a SOF2 marker?
95    fn check_progressive(data: &[u8]) -> bool {
96        // Scan for 0xFF 0xC2 in the header area (before any SOS)
97        let mut pos = 2; // skip SOI
98        while pos + 1 < data.len() {
99            if data[pos] == 0xFF {
100                let m = data[pos + 1];
101                if m == marker::SOF2 {
102                    return true;
103                }
104                if m == marker::SOS {
105                    return false; // Reached scan data without finding SOF2
106                }
107                if m == 0x00 || m == 0xFF || (0xD0..=0xD7).contains(&m) || m == marker::SOI || m == marker::EOI {
108                    pos += 2;
109                    continue;
110                }
111                // Skip segment
112                if pos + 3 < data.len() {
113                    let len = u16::from_be_bytes([data[pos + 2], data[pos + 3]]) as usize;
114                    if len < 2 || pos + 2 + len > data.len() {
115                        break; // Malformed segment length
116                    }
117                    pos += 2 + len;
118                } else {
119                    break;
120                }
121            } else {
122                pos += 1;
123            }
124        }
125        false
126    }
127
128    /// Parse a baseline (SOF0) JPEG file.
129    fn from_bytes_baseline(data: &[u8]) -> Result<Self> {
130        let (entries, scan_start) = iterate_markers(data)?;
131
132        let mut frame_info: Option<FrameInfo> = None;
133        let mut quant_tables: [Option<dct::QuantTable>; 4] = [None, None, None, None];
134        let mut dc_huff_specs: [Option<HuffmanSpec>; 4] = [None, None, None, None];
135        let mut ac_huff_specs: [Option<HuffmanSpec>; 4] = [None, None, None, None];
136        let mut restart_interval: u16 = 0;
137        let mut raw_segments = Vec::new();
138        let mut sos_data = Vec::new();
139        let mut scan_components = Vec::new();
140
141        for entry in &entries {
142            match entry.marker {
143                marker::SOI => {}
144                marker::EOI => {}
145                marker::DQT => {
146                    raw_segments.push(MarkerSegment {
147                        marker: entry.marker,
148                        data: entry.data.clone(),
149                    });
150                    let tables = parse_dqt(&entry.data)?;
151                    for (id, qt) in tables {
152                        quant_tables[id as usize] = Some(qt);
153                    }
154                }
155                marker::DHT => {
156                    raw_segments.push(MarkerSegment {
157                        marker: entry.marker,
158                        data: entry.data.clone(),
159                    });
160                    let specs = parse_dht(&entry.data)?;
161                    for spec in specs {
162                        let id = spec.id as usize;
163                        if spec.class == 0 {
164                            dc_huff_specs[id] = Some(spec);
165                        } else {
166                            ac_huff_specs[id] = Some(spec);
167                        }
168                    }
169                }
170                marker::SOF0 => {
171                    raw_segments.push(MarkerSegment {
172                        marker: entry.marker,
173                        data: entry.data.clone(),
174                    });
175                    frame_info = Some(frame::parse_sof(&entry.data)?);
176                }
177                marker::DRI => {
178                    raw_segments.push(MarkerSegment {
179                        marker: entry.marker,
180                        data: entry.data.clone(),
181                    });
182                    restart_interval = parse_dri(&entry.data)?;
183                }
184                marker::SOS => {
185                    sos_data = entry.data.clone();
186                    let selectors = parse_sos(&entry.data)?;
187                    let fi = frame_info
188                        .as_ref()
189                        .ok_or(JpegError::InvalidMarkerData("SOS before SOF"))?;
190
191                    for (comp_id, dc_id, ac_id) in selectors {
192                        let comp_idx = fi
193                            .components
194                            .iter()
195                            .position(|c| c.id == comp_id)
196                            .ok_or(JpegError::UnknownComponentId(comp_id))?;
197                        scan_components.push(ScanComponent {
198                            comp_idx,
199                            dc_table: dc_id as usize,
200                            ac_table: ac_id as usize,
201                        });
202                    }
203                }
204                _ => {
205                    raw_segments.push(MarkerSegment {
206                        marker: entry.marker,
207                        data: entry.data.clone(),
208                    });
209                }
210            }
211        }
212
213        let fi = frame_info.ok_or(JpegError::InvalidMarkerData("no SOF marker found"))?;
214
215        let (grids, _end_pos) = scan::decode_scan(
216            data,
217            scan_start,
218            &fi,
219            &scan_components,
220            &dc_huff_specs,
221            &ac_huff_specs,
222            restart_interval,
223        )?;
224
225        Ok(Self {
226            frame: fi,
227            grids,
228            quant_tables,
229            dc_huff_specs,
230            ac_huff_specs,
231            scan_components,
232            restart_interval,
233            raw_segments,
234            sos_data,
235        })
236    }
237
238    /// Parse a progressive (SOF2) JPEG file.
239    ///
240    /// Progressive JPEG files have multiple SOS markers, each contributing
241    /// partial coefficient data. We accumulate all scans into DctGrids,
242    /// then store the result as if it were a baseline image.
243    fn from_bytes_progressive(data: &[u8]) -> Result<Self> {
244        let (entries, scan_starts) = iterate_markers_all(data)?;
245
246        let mut frame_info: Option<FrameInfo> = None;
247        let mut quant_tables: [Option<dct::QuantTable>; 4] = [None, None, None, None];
248        let mut dc_huff_specs: [Option<HuffmanSpec>; 4] = [None, None, None, None];
249        let mut ac_huff_specs: [Option<HuffmanSpec>; 4] = [None, None, None, None];
250        let mut restart_interval: u16 = 0;
251        let mut raw_segments = Vec::new();
252
253        // Collect all SOS entries with their scan start positions
254        struct ScanInfo {
255            components: Vec<ScanComponent>,
256            params: marker::SosParams,
257            scan_start: usize,
258            #[allow(dead_code)]
259            sos_data: Vec<u8>,
260        }
261        let mut scans: Vec<ScanInfo> = Vec::new();
262        let mut sos_index = 0usize;
263
264        for entry in &entries {
265            match entry.marker {
266                marker::SOI => {}
267                marker::EOI => {}
268                marker::DQT => {
269                    // Only preserve DQT/DRI in raw_segments (first occurrence)
270                    raw_segments.push(MarkerSegment {
271                        marker: entry.marker,
272                        data: entry.data.clone(),
273                    });
274                    let tables = parse_dqt(&entry.data)?;
275                    for (id, qt) in tables {
276                        quant_tables[id as usize] = Some(qt);
277                    }
278                }
279                marker::DHT => {
280                    // For progressive, DHT markers can appear between scans.
281                    // We accumulate all Huffman tables (later tables override earlier ones
282                    // with the same ID, which is the correct behavior).
283                    // Don't preserve DHT in raw_segments — we'll rebuild them.
284                    let specs = parse_dht(&entry.data)?;
285                    for spec in specs {
286                        let id = spec.id as usize;
287                        if spec.class == 0 {
288                            dc_huff_specs[id] = Some(spec);
289                        } else {
290                            ac_huff_specs[id] = Some(spec);
291                        }
292                    }
293                }
294                marker::SOF2 => {
295                    raw_segments.push(MarkerSegment {
296                        // Store as SOF0 for baseline output
297                        marker: marker::SOF0,
298                        data: entry.data.clone(),
299                    });
300                    frame_info = Some(frame::parse_sof_ext(&entry.data, true)?);
301                }
302                marker::SOF0 => {
303                    raw_segments.push(MarkerSegment {
304                        marker: entry.marker,
305                        data: entry.data.clone(),
306                    });
307                    frame_info = Some(frame::parse_sof(&entry.data)?);
308                }
309                marker::DRI => {
310                    raw_segments.push(MarkerSegment {
311                        marker: entry.marker,
312                        data: entry.data.clone(),
313                    });
314                    restart_interval = parse_dri(&entry.data)?;
315                }
316                marker::SOS => {
317                    let selectors = parse_sos(&entry.data)?;
318                    let params = parse_sos_params(&entry.data)?;
319                    let fi = frame_info
320                        .as_ref()
321                        .ok_or(JpegError::InvalidMarkerData("SOS before SOF"))?;
322
323                    let mut components = Vec::new();
324                    for (comp_id, dc_id, ac_id) in selectors {
325                        let comp_idx = fi
326                            .components
327                            .iter()
328                            .position(|c| c.id == comp_id)
329                            .ok_or(JpegError::UnknownComponentId(comp_id))?;
330                        components.push(ScanComponent {
331                            comp_idx,
332                            dc_table: dc_id as usize,
333                            ac_table: ac_id as usize,
334                        });
335                    }
336
337                    if sos_index < scan_starts.len() {
338                        scans.push(ScanInfo {
339                            components,
340                            params,
341                            scan_start: scan_starts[sos_index],
342                            sos_data: entry.data.clone(),
343                        });
344                        sos_index += 1;
345                    }
346                }
347                _ => {
348                    raw_segments.push(MarkerSegment {
349                        marker: entry.marker,
350                        data: entry.data.clone(),
351                    });
352                }
353            }
354        }
355
356        let fi = frame_info.ok_or(JpegError::InvalidMarkerData("no SOF marker found"))?;
357
358        // Allocate DctGrids for all components (initialized to zero)
359        let mut grids: Vec<DctGrid> = Vec::with_capacity(fi.components.len());
360        for comp_idx in 0..fi.components.len() {
361            let bw = fi.blocks_wide(comp_idx);
362            let bt = fi.blocks_tall(comp_idx);
363            grids.push(DctGrid::new(bw, bt));
364        }
365
366        // Snapshot the Huffman specs before processing scans, since progressive
367        // JPEG can define new DHT tables between scans.
368        // We already accumulated all DHTs above, which works for most files.
369        // However, some encoders define DHT tables incrementally before each scan.
370        // To handle this correctly, we need to re-parse DHTs in scan order.
371        // Let's re-parse by walking entries again, updating specs as we go.
372        let mut scan_dc_specs: [Option<HuffmanSpec>; 4] = [None, None, None, None];
373        let mut scan_ac_specs: [Option<HuffmanSpec>; 4] = [None, None, None, None];
374        let mut scan_idx = 0usize;
375
376        for entry in &entries {
377            match entry.marker {
378                marker::DHT => {
379                    let specs = parse_dht(&entry.data)?;
380                    for spec in specs {
381                        let id = spec.id as usize;
382                        if spec.class == 0 {
383                            scan_dc_specs[id] = Some(spec);
384                        } else {
385                            scan_ac_specs[id] = Some(spec);
386                        }
387                    }
388                }
389                marker::SOS => {
390                    if scan_idx < scans.len() {
391                        let scan = &scans[scan_idx];
392                        scan::decode_progressive_scan(
393                            data,
394                            scan.scan_start,
395                            &fi,
396                            &scan.components,
397                            &scan_dc_specs,
398                            &scan_ac_specs,
399                            restart_interval,
400                            &scan.params,
401                            &mut grids,
402                        )?;
403                        scan_idx += 1;
404                    }
405                }
406                _ => {}
407            }
408        }
409
410        // For baseline output, we need to build a single SOS header.
411        // Use all components in component order, with table IDs from the
412        // frame components (standard convention: luma=table 0, chroma=table 1).
413        let mut final_scan_components = Vec::new();
414        let mut final_sos_data = Vec::new();
415        final_sos_data.push(fi.components.len() as u8);
416
417        for (comp_idx, comp) in fi.components.iter().enumerate() {
418            // Use table ID 0 for luminance (first component), 1 for chrominance
419            let table_id = if comp_idx == 0 { 0usize } else { 1usize };
420            final_scan_components.push(ScanComponent {
421                comp_idx,
422                dc_table: table_id,
423                ac_table: table_id,
424            });
425            final_sos_data.push(comp.id);
426            final_sos_data.push(((table_id as u8) << 4) | (table_id as u8));
427        }
428        // Append baseline SOS parameters: Ss=0, Se=63, Ah=0, Al=0
429        final_sos_data.push(0);  // Ss
430        final_sos_data.push(63); // Se
431        final_sos_data.push(0);  // Ah=0, Al=0
432
433        // Build a minimal but complete set of baseline Huffman tables.
434        // We set the specs to None first, then rebuild from the coefficient data.
435        // This ensures the tables match the actual coefficient values.
436        let final_dc_specs: [Option<HuffmanSpec>; 4] = [None, None, None, None];
437        let final_ac_specs: [Option<HuffmanSpec>; 4] = [None, None, None, None];
438
439        // Create the image with placeholder Huffman tables, then rebuild them.
440        let mut img = Self {
441            frame: FrameInfo { is_progressive: false, ..fi },
442            grids,
443            quant_tables,
444            dc_huff_specs: final_dc_specs,
445            ac_huff_specs: final_ac_specs,
446            scan_components: final_scan_components,
447            restart_interval,
448            raw_segments,
449            sos_data: final_sos_data,
450        };
451
452        // Rebuild Huffman tables from the actual coefficient data so they
453        // encode correctly as baseline. This also inserts DHT segments into
454        // raw_segments.
455        img.rebuild_huffman_tables();
456
457        Ok(img)
458    }
459
460    /// Encode the (possibly modified) image back to JPEG bytes.
461    pub fn to_bytes(&self) -> Result<Vec<u8>> {
462        self.to_bytes_with_progress(None)
463    }
464
465    /// Serialize this JPEG image to bytes, with an optional progress callback
466    /// that fires approximately [`scan::JPEG_WRITE_STEPS`] times during scan
467    /// encoding.
468    pub fn to_bytes_with_progress(&self, on_progress: Option<&dyn Fn()>) -> Result<Vec<u8>> {
469        let mut out = Vec::new();
470
471        // SOI
472        out.push(0xFF);
473        out.push(marker::SOI);
474
475        // Write all preserved header segments in original order
476        for seg in &self.raw_segments {
477            out.push(0xFF);
478            out.push(seg.marker);
479            let length = (seg.data.len() + 2) as u16;
480            out.push((length >> 8) as u8);
481            out.push(length as u8);
482            out.extend_from_slice(&seg.data);
483        }
484
485        // Write SOS header
486        out.push(0xFF);
487        out.push(marker::SOS);
488        let sos_length = (self.sos_data.len() + 2) as u16;
489        out.push((sos_length >> 8) as u8);
490        out.push(sos_length as u8);
491        out.extend_from_slice(&self.sos_data);
492
493        // Re-encode scan data
494        let scan_bytes = scan::encode_scan_with_progress(
495            &self.frame,
496            &self.scan_components,
497            &self.grids,
498            &self.dc_huff_specs,
499            &self.ac_huff_specs,
500            self.restart_interval,
501            on_progress,
502        )?;
503        out.extend_from_slice(&scan_bytes);
504
505        // EOI
506        out.push(0xFF);
507        out.push(marker::EOI);
508
509        Ok(out)
510    }
511
512    /// Get a reference to the DCT coefficient grid for a component.
513    /// Component index is in scan order (typically 0=Y, 1=Cb, 2=Cr).
514    pub fn dct_grid(&self, component: usize) -> &DctGrid {
515        &self.grids[component]
516    }
517
518    /// Get a mutable reference to the DCT coefficient grid for a component.
519    pub fn dct_grid_mut(&mut self, component: usize) -> &mut DctGrid {
520        &mut self.grids[component]
521    }
522
523    /// Get the frame information.
524    pub fn frame_info(&self) -> &FrameInfo {
525        &self.frame
526    }
527
528    /// Get a quantization table by ID.
529    pub fn quant_table(&self, id: usize) -> Option<&dct::QuantTable> {
530        self.quant_tables[id].as_ref()
531    }
532
533    /// Number of components in the scan.
534    pub fn num_components(&self) -> usize {
535        self.grids.len()
536    }
537
538    /// Rebuild Huffman tables from the current coefficient data.
539    ///
540    /// Call this after modifying DCT coefficients to ensure the Huffman tables
541    /// can encode all symbols present in the modified data. This replaces the
542    /// DHT segments in `raw_segments` and updates `dc_huff_specs`/`ac_huff_specs`.
543    pub fn rebuild_huffman_tables(&mut self) {
544        // Collect symbol frequencies per table.
545        let mut dc_freq: [Vec<u32>; 4] = [vec![], vec![], vec![], vec![]];
546        let mut ac_freq: [Vec<u32>; 4] = [vec![], vec![], vec![], vec![]];
547
548        for sc in &self.scan_components {
549            if dc_freq[sc.dc_table].is_empty() {
550                dc_freq[sc.dc_table] = vec![0u32; 256];
551            }
552            if ac_freq[sc.ac_table].is_empty() {
553                ac_freq[sc.ac_table] = vec![0u32; 256];
554            }
555        }
556
557        // Count symbols by simulating the scan encoding.
558        // Must match encode_scan exactly, including restart interval DC pred resets.
559        let mut dc_pred = vec![0i16; self.scan_components.len()];
560        let mut mcu_count = 0usize;
561
562        for mcu_row in 0..self.frame.mcus_tall as usize {
563            for mcu_col in 0..self.frame.mcus_wide as usize {
564                // Reset DC predictors at restart boundaries (must match encode_scan)
565                if self.restart_interval > 0
566                    && mcu_count > 0
567                    && mcu_count.is_multiple_of(self.restart_interval as usize)
568                {
569                    dc_pred.fill(0);
570                }
571
572                for (sci, sc) in self.scan_components.iter().enumerate() {
573                    let comp = &self.frame.components[sc.comp_idx];
574                    for v in 0..comp.v_sampling as usize {
575                        for h in 0..comp.h_sampling as usize {
576                            let br = mcu_row * comp.v_sampling as usize + v;
577                            let bc = mcu_col * comp.h_sampling as usize + h;
578                            let block = self.grids[sci].block(br, bc);
579                            let mut zz = [0i16; 64];
580                            for ni in 0..64 {
581                                zz[NATURAL_TO_ZIGZAG[ni]] = block[ni];
582                            }
583
584                            // DC symbol
585                            let dc_diff = zz[0] - dc_pred[sci];
586                            dc_pred[sci] = zz[0];
587                            let (_, dc_size) = encode_value(dc_diff);
588                            dc_freq[sc.dc_table][dc_size as usize] += 1;
589
590                            // AC symbols
591                            let mut k = 1;
592                            while k < 64 {
593                                let mut run = 0usize;
594                                while k + run < 64 && zz[k + run] == 0 {
595                                    run += 1;
596                                }
597                                if k + run >= 64 {
598                                    // EOB
599                                    ac_freq[sc.ac_table][0x00] += 1;
600                                    break;
601                                }
602                                while run >= 16 {
603                                    ac_freq[sc.ac_table][0xF0] += 1;
604                                    run -= 16;
605                                    k += 16;
606                                }
607                                k += run;
608                                let (_, ac_size) = encode_value(zz[k]);
609                                let rs = ((run as u8) << 4) | ac_size;
610                                ac_freq[sc.ac_table][rs as usize] += 1;
611                                k += 1;
612                            }
613                        }
614                    }
615                }
616
617                mcu_count += 1;
618            }
619        }
620
621        // Build Huffman specs from frequency counts and update state.
622        for (id, freq) in dc_freq.iter().enumerate() {
623            if freq.is_empty() {
624                continue;
625            }
626            let spec = build_huffman_spec(0, id as u8, freq);
627            self.dc_huff_specs[id] = Some(spec);
628        }
629        for (id, freq) in ac_freq.iter().enumerate() {
630            if freq.is_empty() {
631                continue;
632            }
633            let spec = build_huffman_spec(1, id as u8, freq);
634            self.ac_huff_specs[id] = Some(spec);
635        }
636
637        // Replace DHT segments in raw_segments.
638        self.raw_segments.retain(|s| s.marker != marker::DHT);
639
640        // Find the position just before SOF0 to insert DHT segments.
641        let sof_pos = self
642            .raw_segments
643            .iter()
644            .position(|s| s.marker == marker::SOF0)
645            .unwrap_or(self.raw_segments.len());
646
647        // Build new DHT data: combine all tables into one segment.
648        let mut dht_data = Vec::new();
649        for id in 0..4 {
650            if let Some(spec) = &self.dc_huff_specs[id] {
651                dht_data.push((spec.class << 4) | (spec.id & 0x0F));
652                dht_data.extend_from_slice(&spec.bits);
653                dht_data.extend_from_slice(&spec.huffval);
654            }
655        }
656        for id in 0..4 {
657            if let Some(spec) = &self.ac_huff_specs[id] {
658                dht_data.push((spec.class << 4) | (spec.id & 0x0F));
659                dht_data.extend_from_slice(&spec.bits);
660                dht_data.extend_from_slice(&spec.huffval);
661            }
662        }
663
664        self.raw_segments.insert(
665            sof_pos,
666            MarkerSegment {
667                marker: marker::DHT,
668                data: dht_data,
669            },
670        );
671    }
672
673    /// Replace a quantization table by ID and rebuild the DQT marker segments.
674    ///
675    /// Call this after modifying DCT coefficients to reflect new quantization
676    /// (e.g., for recompression simulation). Updates both the internal table
677    /// and the raw DQT segments so that `to_bytes()` produces correct output.
678    pub fn set_quant_table(&mut self, id: usize, qt: dct::QuantTable) {
679        self.quant_tables[id] = Some(qt);
680        self.rebuild_dqt_segments();
681    }
682
683    /// Rebuild DQT marker segments from internal quantization table state.
684    ///
685    /// Removes all existing DQT entries from `raw_segments` and inserts fresh
686    /// ones before the SOF0 marker (matching the standard JPEG header order).
687    fn rebuild_dqt_segments(&mut self) {
688        use zigzag::ZIGZAG_TO_NATURAL;
689
690        // Remove old DQT segments.
691        self.raw_segments.retain(|s| s.marker != marker::DQT);
692
693        // Build new DQT data: one segment containing all defined tables.
694        // DQT stores values in zigzag order. Our internal tables are in
695        // natural (row-major) order. For each zigzag index zi, we need
696        // the natural index: ni = ZIGZAG_TO_NATURAL[zi].
697        let mut dqt_data = Vec::new();
698        for id in 0..4u8 {
699            if let Some(qt) = &self.quant_tables[id as usize] {
700                // precision_and_id: precision=0 (8-bit) for values ≤255
701                let precision: u8 = if qt.values.iter().all(|&v| v <= 255) { 0 } else { 1 };
702                dqt_data.push((precision << 4) | id);
703                for zi in 0..64 {
704                    let ni = ZIGZAG_TO_NATURAL[zi];
705                    if precision == 0 {
706                        dqt_data.push(qt.values[ni] as u8);
707                    } else {
708                        dqt_data.extend_from_slice(&qt.values[ni].to_be_bytes());
709                    }
710                }
711            }
712        }
713
714        // Insert before SOF0 (same position strategy as DHT rebuild).
715        let sof_pos = self
716            .raw_segments
717            .iter()
718            .position(|s| s.marker == marker::SOF0)
719            .unwrap_or(self.raw_segments.len());
720
721        self.raw_segments.insert(
722            sof_pos,
723            MarkerSegment {
724                marker: marker::DQT,
725                data: dqt_data,
726            },
727        );
728    }
729}
730
731/// Build an optimal Huffman spec from symbol frequency counts.
732///
733/// Implements JPEG Annex K (Figures K.1–K.4) with the libjpeg pseudo-symbol
734/// technique: a dummy symbol 256 with frequency 1 is added before tree
735/// construction. This guarantees:
736/// - No real symbol gets the all-ones codeword.
737/// - The Kraft inequality is strictly satisfied after code-length limiting.
738/// - Output tables are fully compatible with libjpeg/libjpeg-turbo.
739fn build_huffman_spec(class: u8, id: u8, freq: &[u32]) -> HuffmanSpec {
740    // Collect symbols with nonzero frequency (u16 to accommodate pseudo-symbol 256).
741    let mut symbols: Vec<(u16, u32)> = freq
742        .iter()
743        .enumerate()
744        .filter(|&(_, &f)| f > 0)
745        .map(|(sym, &f)| (sym as u16, f))
746        .collect();
747
748    if symbols.is_empty() {
749        // Need at least one symbol. Use symbol 0 (EOB for AC, size-0 for DC).
750        symbols.push((0, 1));
751    }
752
753    // If only one real symbol, we still need a valid Huffman code (1-bit code).
754    if symbols.len() == 1 {
755        let sym = symbols[0].0 as u8;
756        return HuffmanSpec {
757            class,
758            id,
759            bits: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
760            huffval: vec![sym],
761        };
762    }
763
764    // Add pseudo-symbol 256 with frequency 1 (libjpeg technique).
765    // This symbol will get the longest code, preventing any real symbol from
766    // receiving the all-ones codeword and providing a Kraft inequality safety
767    // margin after Annex K.3 code-length limiting.
768    symbols.push((256, 1));
769
770    let n = symbols.len(); // includes pseudo-symbol
771
772    // Sort ascending by (frequency, symbol) for the tree-building merge.
773    // Higher symbol number breaks ties → pseudo-symbol 256 sorts last among
774    // freq-1 symbols, ensuring it gets the longest code.
775    symbols.sort_by_key(|&(sym, f)| (f, sym));
776
777    // Build Huffman tree using two-queue merge (standard algorithm).
778    let total_nodes = 2 * n - 1;
779    let mut parent = vec![0usize; total_nodes];
780    let mut next_internal = n;
781
782    let mut q1: std::collections::VecDeque<(u64, usize)> = symbols
783        .iter()
784        .enumerate()
785        .map(|(idx, &(_, f))| (f as u64, idx))
786        .collect();
787    let mut q2: std::collections::VecDeque<(u64, usize)> = std::collections::VecDeque::new();
788
789    let pick_min = |q1: &mut std::collections::VecDeque<(u64, usize)>,
790                    q2: &mut std::collections::VecDeque<(u64, usize)>|
791     -> (u64, usize) {
792        match (q1.front(), q2.front()) {
793            (Some(&a), Some(&b)) => {
794                if a.0 <= b.0 {
795                    q1.pop_front().unwrap()
796                } else {
797                    q2.pop_front().unwrap()
798                }
799            }
800            (Some(_), None) => q1.pop_front().unwrap(),
801            (None, Some(_)) => q2.pop_front().unwrap(),
802            (None, None) => unreachable!(),
803        }
804    };
805
806    for _ in 0..(n - 1) {
807        let (f1, idx1) = pick_min(&mut q1, &mut q2);
808        let (f2, idx2) = pick_min(&mut q1, &mut q2);
809        parent[idx1] = next_internal;
810        parent[idx2] = next_internal;
811        q2.push_back((f1 + f2, next_internal));
812        next_internal += 1;
813    }
814
815    // Compute code lengths by walking from each leaf to the root.
816    let root = total_nodes - 1;
817    let mut code_lengths = vec![0u8; n];
818    for i in 0..n {
819        let mut depth = 0u8;
820        let mut node = i;
821        while node != root {
822            node = parent[node];
823            depth += 1;
824        }
825        code_lengths[i] = depth;
826    }
827
828    // Limit code lengths to 16 bits (JPEG Annex K.3 — Adjust_BITS procedure).
829    let max_len = code_lengths.iter().copied().max().unwrap_or(0) as usize;
830
831    let mut bits_count = vec![0u32; max_len + 1];
832    for &len in &code_lengths {
833        bits_count[len as usize] += 1;
834    }
835
836    if max_len > 16 {
837        let mut i = max_len;
838        while i > 16 {
839            while bits_count[i] > 0 {
840                // Find a donor level j (j <= i-2) that has codes to split.
841                let mut j = i - 2;
842                while j > 0 && bits_count[j] == 0 {
843                    j -= 1;
844                }
845                debug_assert!(j > 0, "Annex K.3: no donor found (pseudo-symbol should prevent this)");
846                if j == 0 {
847                    // Safety fallback (should never happen with pseudo-symbol).
848                    bits_count[16] += bits_count[i];
849                    bits_count[i] = 0;
850                    break;
851                }
852                bits_count[i] -= 2;
853                bits_count[i - 1] += 1;
854                bits_count[j + 1] += 2;
855                bits_count[j] -= 1;
856            }
857            i -= 1;
858        }
859
860        // Reassign code_lengths from the adjusted bits_count[].
861        // Longest codes go to least-frequent symbols (lowest indices).
862        let mut pos = 0;
863        for len in (1..=16u8).rev() {
864            let count = bits_count[len as usize] as usize;
865            for _ in 0..count {
866                code_lengths[pos] = len;
867                pos += 1;
868            }
869        }
870    }
871
872    // Build bits[] and huffval[] arrays, excluding pseudo-symbol 256.
873    // Sort by (code_length, symbol_value) for canonical Huffman ordering.
874    let mut sym_len: Vec<(u16, u8)> = symbols
875        .iter()
876        .zip(code_lengths.iter())
877        .map(|(&(sym, _), &len)| (sym, len))
878        .collect();
879    sym_len.sort_by_key(|&(sym, len)| (len, sym));
880
881    let mut bits = [0u8; 16];
882    let mut huffval = Vec::with_capacity(n);
883    for &(sym, len) in &sym_len {
884        // Skip pseudo-symbol 256 — it served its purpose in tree construction.
885        if sym == 256 {
886            continue;
887        }
888        if len > 0 && len <= 16 {
889            bits[(len - 1) as usize] += 1;
890            huffval.push(sym as u8);
891        }
892    }
893
894    HuffmanSpec {
895        class,
896        id,
897        bits,
898        huffval,
899    }
900}