Skip to main content

phasm_core/codec/jpeg/
scan.rs

1// Copyright (c) 2026 Christoph Gaffga
2// SPDX-License-Identifier: GPL-3.0-only
3// https://github.com/cgaffga/phasmcore
4
5//! JPEG scan data encoding and decoding.
6//!
7//! Decodes entropy-coded scan data into [`DctGrid`]s (one per component)
8//! and re-encodes modified grids back to entropy-coded bytes. Handles
9//! interleaved MCU ordering, restart markers, and DC prediction.
10
11use super::bitio::{BitReader, BitWriter};
12use super::dct::DctGrid;
13use super::error::{JpegError, Result};
14use super::frame::FrameInfo;
15use super::huffman::{
16    encode_value, extend_sign, HuffmanDecodeTable, HuffmanEncodeTable,
17};
18use super::marker::SosParams;
19use super::tables::HuffmanSpec;
20use super::zigzag::{NATURAL_TO_ZIGZAG, ZIGZAG_TO_NATURAL};
21
22/// Component selector for one scan component.
23#[derive(Clone)]
24pub struct ScanComponent {
25    /// Index into FrameInfo.components.
26    pub comp_idx: usize,
27    /// DC Huffman table index.
28    pub dc_table: usize,
29    /// AC Huffman table index.
30    pub ac_table: usize,
31}
32
33/// Decode the entropy-coded scan data into DctGrids.
34///
35/// - `data`: full JPEG file bytes
36/// - `scan_start`: byte offset of the first entropy-coded byte (right after SOS header)
37/// - `frame`: parsed frame info
38/// - `scan_components`: component selectors from SOS
39/// - `dc_specs`/`ac_specs`: Huffman table specs, indexed by table ID
40/// - `restart_interval`: from DRI marker, 0 = no restarts
41///
42/// Returns (grids, end_position). `end_position` is the byte after the last scan byte.
43pub fn decode_scan(
44    data: &[u8],
45    scan_start: usize,
46    frame: &FrameInfo,
47    scan_components: &[ScanComponent],
48    dc_specs: &[Option<HuffmanSpec>; 4],
49    ac_specs: &[Option<HuffmanSpec>; 4],
50    restart_interval: u16,
51) -> Result<(Vec<DctGrid>, usize)> {
52    // Build Huffman decode tables
53    let mut dc_tables: [Option<HuffmanDecodeTable>; 4] = [None, None, None, None];
54    let mut ac_tables: [Option<HuffmanDecodeTable>; 4] = [None, None, None, None];
55
56    for sc in scan_components {
57        if dc_tables[sc.dc_table].is_none() {
58            let spec = dc_specs[sc.dc_table]
59                .as_ref()
60                .ok_or(JpegError::InvalidHuffmanTableId(sc.dc_table as u8))?;
61            dc_tables[sc.dc_table] = Some(HuffmanDecodeTable::build(&spec.bits, &spec.huffval)?);
62        }
63        if ac_tables[sc.ac_table].is_none() {
64            let spec = ac_specs[sc.ac_table]
65                .as_ref()
66                .ok_or(JpegError::InvalidHuffmanTableId(sc.ac_table as u8))?;
67            ac_tables[sc.ac_table] = Some(HuffmanDecodeTable::build(&spec.bits, &spec.huffval)?);
68        }
69    }
70
71    // Allocate DctGrids
72    let mut grids: Vec<DctGrid> = Vec::with_capacity(scan_components.len());
73    for sc in scan_components {
74        let bw = frame.blocks_wide(sc.comp_idx);
75        let bt = frame.blocks_tall(sc.comp_idx);
76        grids.push(DctGrid::new(bw, bt));
77    }
78
79    // Initialize DC predictors (i32 to prevent overflow from accumulated diffs)
80    let mut dc_pred = vec![0i32; scan_components.len()];
81
82    let mut reader = BitReader::new(data, scan_start);
83    let mut mcu_count = 0usize;
84
85    for mcu_row in 0..frame.mcus_tall as usize {
86        for mcu_col in 0..frame.mcus_wide as usize {
87            // Check for restart
88            if restart_interval > 0 && mcu_count > 0 && mcu_count.is_multiple_of(restart_interval as usize) {
89                reader.byte_align();
90                // Look for RST marker — accept any RST without strict sequence
91                // validation, matching libjpeg/libjpeg-turbo behavior.
92                let _rst = reader.check_restart_marker()?;
93                // Reset DC predictors
94                dc_pred.fill(0);
95            }
96
97            // Decode blocks for each component in this MCU
98            for (sci, sc) in scan_components.iter().enumerate() {
99                let comp = &frame.components[sc.comp_idx];
100                let dc_tab = dc_tables[sc.dc_table].as_ref().unwrap();
101                let ac_tab = ac_tables[sc.ac_table].as_ref().unwrap();
102
103                for v in 0..comp.v_sampling as usize {
104                    for h in 0..comp.h_sampling as usize {
105                        let block_row = mcu_row * (comp.v_sampling as usize) + v;
106                        let block_col = mcu_col * (comp.h_sampling as usize) + h;
107
108                        // Bounds check: skip blocks outside the grid (malformed JPEGs)
109                        let blocks_tall = grids[sci].blocks_tall();
110                        let blocks_wide = grids[sci].blocks_wide();
111                        if block_row >= blocks_tall || block_col >= blocks_wide {
112                            // Still need to consume the entropy-coded data for this block
113                            // to keep the bitstream in sync, so decode but discard.
114                            let dc_size = dc_tab.decode(&mut reader)?;
115                            if dc_size > 0 {
116                                let dc_bits = reader.read_bits(dc_size)?;
117                                let dc_diff = extend_sign(dc_bits, dc_size);
118                                dc_pred[sci] += dc_diff as i32;
119                            }
120                            let mut k = 1;
121                            while k < 64 {
122                                let rs = ac_tab.decode(&mut reader)?;
123                                let run = (rs >> 4) as usize;
124                                let size = rs & 0x0F;
125                                if size == 0 {
126                                    if run == 0 || run != 15 { break; }
127                                    k += 16;
128                                    continue;
129                                }
130                                k += run;
131                                if k >= 64 { return Err(JpegError::HuffmanDecode); }
132                                let _ac_bits = reader.read_bits(size)?;
133                                k += 1;
134                            }
135                            continue;
136                        }
137
138                        let mut zz = [0i16; 64];
139
140                        // Decode DC coefficient
141                        let dc_size = dc_tab.decode(&mut reader)?;
142                        if dc_size > 0 {
143                            let dc_bits = reader.read_bits(dc_size)?;
144                            let dc_diff = extend_sign(dc_bits, dc_size);
145                            dc_pred[sci] += dc_diff as i32;
146                        }
147                        zz[0] = dc_pred[sci].clamp(i16::MIN as i32, i16::MAX as i32) as i16;
148
149                        // Decode AC coefficients
150                        let mut k = 1;
151                        while k < 64 {
152                            let rs = ac_tab.decode(&mut reader)?;
153                            let run = (rs >> 4) as usize;
154                            let size = rs & 0x0F;
155
156                            if size == 0 {
157                                if run == 0 {
158                                    // EOB — remaining ACs are zero
159                                    break;
160                                } else if run == 15 {
161                                    // ZRL — skip 16 zeros
162                                    k += 16;
163                                    continue;
164                                } else {
165                                    break;
166                                }
167                            }
168
169                            k += run;
170                            if k >= 64 {
171                                return Err(JpegError::HuffmanDecode);
172                            }
173                            let ac_bits = reader.read_bits(size)?;
174                            zz[k] = extend_sign(ac_bits, size);
175                            k += 1;
176                        }
177
178                        // Convert zigzag to natural order and store
179                        let block = grids[sci].block_mut(block_row, block_col);
180                        for zi in 0..64 {
181                            block[ZIGZAG_TO_NATURAL[zi]] = zz[zi];
182                        }
183                    }
184                }
185            }
186
187            mcu_count += 1;
188        }
189    }
190
191    // Find end position: align to byte, then skip past any trailing marker
192    let end_pos = reader.position();
193
194    Ok((grids, end_pos))
195}
196
197/// Encode DctGrids back to entropy-coded scan data.
198///
199/// Returns the raw entropy-coded bytes (without SOS header, but including
200/// restart markers if restart_interval > 0).
201pub fn encode_scan(
202    frame: &FrameInfo,
203    scan_components: &[ScanComponent],
204    grids: &[DctGrid],
205    dc_specs: &[Option<HuffmanSpec>; 4],
206    ac_specs: &[Option<HuffmanSpec>; 4],
207    restart_interval: u16,
208) -> Result<Vec<u8>> {
209    encode_scan_with_progress(frame, scan_components, grids, dc_specs, ac_specs, restart_interval, None)
210}
211
212/// Encode scan data with optional per-row progress callback.
213///
214/// The callback is invoked periodically during MCU row encoding (approximately
215/// `JPEG_WRITE_STEPS` times total) so callers can report progress during
216/// what would otherwise be a long blocking operation.
217pub const JPEG_WRITE_STEPS: u32 = 20;
218
219pub fn encode_scan_with_progress(
220    frame: &FrameInfo,
221    scan_components: &[ScanComponent],
222    grids: &[DctGrid],
223    dc_specs: &[Option<HuffmanSpec>; 4],
224    ac_specs: &[Option<HuffmanSpec>; 4],
225    restart_interval: u16,
226    on_progress: Option<&dyn Fn()>,
227) -> Result<Vec<u8>> {
228    // Build Huffman encode tables
229    let mut dc_tables: [Option<HuffmanEncodeTable>; 4] = [None, None, None, None];
230    let mut ac_tables: [Option<HuffmanEncodeTable>; 4] = [None, None, None, None];
231
232    for sc in scan_components {
233        if dc_tables[sc.dc_table].is_none() {
234            let spec = dc_specs[sc.dc_table]
235                .as_ref()
236                .ok_or(JpegError::InvalidHuffmanTableId(sc.dc_table as u8))?;
237            dc_tables[sc.dc_table] = Some(HuffmanEncodeTable::build(&spec.bits, &spec.huffval));
238        }
239        if ac_tables[sc.ac_table].is_none() {
240            let spec = ac_specs[sc.ac_table]
241                .as_ref()
242                .ok_or(JpegError::InvalidHuffmanTableId(sc.ac_table as u8))?;
243            ac_tables[sc.ac_table] = Some(HuffmanEncodeTable::build(&spec.bits, &spec.huffval));
244        }
245    }
246
247    // Use a byte accumulator so we can insert restart markers between segments
248    let mut output = Vec::new();
249    let mut writer = BitWriter::new();
250    let mut dc_pred = vec![0i32; scan_components.len()];
251    let mut mcu_count = 0usize;
252    let mut restart_count = 0u16;
253    let mcus_tall = frame.mcus_tall as usize;
254    let row_interval = if mcus_tall > 0 { (mcus_tall / JPEG_WRITE_STEPS as usize).max(1) } else { 1 };
255
256    for mcu_row in 0..mcus_tall {
257        if let Some(ref cb) = on_progress
258            && mcu_row > 0 && mcu_row % row_interval == 0 {
259                cb();
260            }
261        for mcu_col in 0..frame.mcus_wide as usize {
262            // Insert restart marker if needed
263            if restart_interval > 0 && mcu_count > 0 && mcu_count.is_multiple_of(restart_interval as usize) {
264                // Flush current segment
265                let segment = std::mem::take(&mut writer).flush();
266                output.extend_from_slice(&segment);
267
268                // Write RST marker (not byte-stuffed — markers are outside entropy data)
269                let rst_marker = 0xD0 + (restart_count % 8) as u8;
270                output.push(0xFF);
271                output.push(rst_marker);
272                restart_count += 1;
273
274                // Reset DC predictors
275                dc_pred.fill(0);
276            }
277
278            // Encode blocks for each component in this MCU
279            for (sci, sc) in scan_components.iter().enumerate() {
280                let comp = &frame.components[sc.comp_idx];
281                let dc_tab = dc_tables[sc.dc_table].as_ref().unwrap();
282                let ac_tab = ac_tables[sc.ac_table].as_ref().unwrap();
283
284                for v in 0..comp.v_sampling as usize {
285                    for h in 0..comp.h_sampling as usize {
286                        let block_row = mcu_row * (comp.v_sampling as usize) + v;
287                        let block_col = mcu_col * (comp.h_sampling as usize) + h;
288
289                        // Read block and convert natural → zigzag
290                        let block = grids[sci].block(block_row, block_col);
291                        let mut zz = [0i16; 64];
292                        for ni in 0..64 {
293                            zz[NATURAL_TO_ZIGZAG[ni]] = block[ni];
294                        }
295
296                        // Encode DC
297                        let dc_diff = (zz[0] as i32 - dc_pred[sci]) as i16;
298                        dc_pred[sci] = zz[0] as i32;
299                        let (dc_bits, dc_size) = encode_value(dc_diff);
300                        let (dc_code, dc_code_len) = dc_tab.encode(dc_size)?;
301                        writer.write_bits(dc_code, dc_code_len);
302                        if dc_size > 0 {
303                            writer.write_bits(dc_bits, dc_size);
304                        }
305
306                        // Encode AC
307                        let mut k = 1;
308                        while k < 64 {
309                            // Find run of zeros
310                            let mut run = 0usize;
311                            while k + run < 64 && zz[k + run] == 0 {
312                                run += 1;
313                            }
314
315                            if k + run >= 64 {
316                                // EOB
317                                let (eob_code, eob_len) = ac_tab.encode(0x00)?;
318                                writer.write_bits(eob_code, eob_len);
319                                break;
320                            }
321
322                            // Emit ZRL (16 zeros) as needed
323                            while run >= 16 {
324                                let (zrl_code, zrl_len) = ac_tab.encode(0xF0)?;
325                                writer.write_bits(zrl_code, zrl_len);
326                                run -= 16;
327                                k += 16;
328                            }
329
330                            k += run;
331                            let (ac_bits, ac_size) = encode_value(zz[k]);
332                            let rs = ((run as u8) << 4) | ac_size;
333                            let (ac_code, ac_code_len) = ac_tab.encode(rs)?;
334                            writer.write_bits(ac_code, ac_code_len);
335                            if ac_size > 0 {
336                                writer.write_bits(ac_bits, ac_size);
337                            }
338                            k += 1;
339                        }
340                    }
341                }
342            }
343
344            mcu_count += 1;
345        }
346    }
347
348    // Flush final segment
349    output.extend_from_slice(&writer.flush());
350    Ok(output)
351}
352
353/// Decode a single progressive scan into existing DctGrids.
354///
355/// Progressive JPEG has multiple scans, each contributing partial coefficient data.
356/// This function decodes one scan (identified by its SOS parameters) and accumulates
357/// the results into the provided grids.
358///
359/// The four scan types are:
360/// - **DC first** (Ss=0, Se=0, Ah=0): Initial DC coefficients, shifted left by Al
361/// - **DC refining** (Ss=0, Se=0, Ah>0): One correction bit per DC coefficient
362/// - **AC first** (Ss>0, Ah=0): Initial AC coefficients for a spectral band
363/// - **AC refining** (Ss>0, Ah>0): Correction bits for previously-decoded AC coefficients
364///
365/// Returns the byte position after the scan data.
366#[allow(unused_assignments)]
367pub fn decode_progressive_scan(
368    data: &[u8],
369    scan_start: usize,
370    frame: &FrameInfo,
371    scan_components: &[ScanComponent],
372    dc_specs: &[Option<HuffmanSpec>; 4],
373    ac_specs: &[Option<HuffmanSpec>; 4],
374    restart_interval: u16,
375    params: &SosParams,
376    grids: &mut [DctGrid],
377) -> Result<usize> {
378    let ss = params.ss as usize;
379    let se = params.se as usize;
380    let ah = params.ah;
381    let al = params.al;
382
383    // Validate parameters
384    if ss > 63 || se > 63 || ss > se {
385        return Err(JpegError::InvalidMarkerData("invalid spectral selection"));
386    }
387
388    // Build Huffman decode tables (only for tables actually needed by this scan)
389    let mut dc_tables: [Option<HuffmanDecodeTable>; 4] = [None, None, None, None];
390    let mut ac_tables: [Option<HuffmanDecodeTable>; 4] = [None, None, None, None];
391
392    for sc in scan_components {
393        // DC tables needed for DC scans (ss == 0)
394        if ss == 0 && ah == 0 && dc_tables[sc.dc_table].is_none() {
395            let spec = dc_specs[sc.dc_table]
396                .as_ref()
397                .ok_or(JpegError::InvalidHuffmanTableId(sc.dc_table as u8))?;
398            dc_tables[sc.dc_table] = Some(HuffmanDecodeTable::build(&spec.bits, &spec.huffval)?);
399        }
400        // AC tables needed for AC scans (ss > 0)
401        if ss > 0 && ac_tables[sc.ac_table].is_none() {
402            let spec = ac_specs[sc.ac_table]
403                .as_ref()
404                .ok_or(JpegError::InvalidHuffmanTableId(sc.ac_table as u8))?;
405            ac_tables[sc.ac_table] = Some(HuffmanDecodeTable::build(&spec.bits, &spec.huffval)?);
406        }
407    }
408
409    let mut reader = BitReader::new(data, scan_start);
410    let mut dc_pred = vec![0i32; scan_components.len()];
411    let mut mcu_count = 0usize;
412
413    // For AC scans, track the End-of-Band run counter
414    let mut eob_run: u32 = 0;
415
416    // Determine if this is a non-interleaved scan (single component)
417    // Non-interleaved scans process blocks in raster order within that component.
418    let non_interleaved = scan_components.len() == 1 && (ss > 0 || se > 0 || frame.components.len() == 1);
419
420    if non_interleaved && scan_components.len() == 1 {
421        // Non-interleaved scan: iterate blocks in raster order for the single component
422        let sc = &scan_components[0];
423        let bw = frame.blocks_wide(sc.comp_idx);
424        let bt = frame.blocks_tall(sc.comp_idx);
425
426        // For non-interleaved scans, MCU = 1 block
427        // Restart interval counts blocks, not MCUs
428        let mut block_count = 0usize;
429
430        for block_row in 0..bt {
431            for block_col in 0..bw {
432                // Handle restart markers
433                if restart_interval > 0 && block_count > 0 && block_count.is_multiple_of(restart_interval as usize) {
434                    reader.byte_align();
435                    let _rst = reader.check_restart_marker()?;
436                    dc_pred[0] = 0;
437                    eob_run = 0;
438                }
439
440                let grid = &mut grids[sc.comp_idx];
441
442                if ss == 0 {
443                    // DC scan
444                    if ah == 0 {
445                        // DC first scan
446                        decode_dc_first(&mut reader, &dc_tables, sc, &mut dc_pred[0], al, grid, block_row, block_col)?;
447                    } else {
448                        // DC refining scan
449                        decode_dc_refine(&mut reader, al, grid, block_row, block_col)?;
450                    }
451                }
452
453                if se > 0 {
454                    // AC scan (might also include DC if ss == 0, but for progressive
455                    // AC and DC are always in separate scans)
456                    let ac_start = if ss == 0 { 1 } else { ss };
457                    if ah == 0 {
458                        // AC first scan
459                        decode_ac_first(&mut reader, &ac_tables, sc, al, ac_start, se, &mut eob_run, grid, block_row, block_col)?;
460                    } else {
461                        // AC refining scan
462                        decode_ac_refine(&mut reader, &ac_tables, sc, al, ac_start, se, &mut eob_run, grid, block_row, block_col)?;
463                    }
464                }
465
466                block_count += 1;
467            }
468        }
469    } else {
470        // Interleaved scan (DC scans with multiple components)
471        for mcu_row in 0..frame.mcus_tall as usize {
472            for mcu_col in 0..frame.mcus_wide as usize {
473                // Handle restart markers
474                if restart_interval > 0 && mcu_count > 0 && mcu_count.is_multiple_of(restart_interval as usize) {
475                    reader.byte_align();
476                    let _rst = reader.check_restart_marker()?;
477                    dc_pred.fill(0);
478                    eob_run = 0;
479                }
480
481                for (sci, sc) in scan_components.iter().enumerate() {
482                    let comp = &frame.components[sc.comp_idx];
483
484                    for v in 0..comp.v_sampling as usize {
485                        for h in 0..comp.h_sampling as usize {
486                            let block_row = mcu_row * (comp.v_sampling as usize) + v;
487                            let block_col = mcu_col * (comp.h_sampling as usize) + h;
488
489                            let grid = &mut grids[sc.comp_idx];
490
491                            if ss == 0 {
492                                if ah == 0 {
493                                    decode_dc_first(&mut reader, &dc_tables, sc, &mut dc_pred[sci], al, grid, block_row, block_col)?;
494                                } else {
495                                    decode_dc_refine(&mut reader, al, grid, block_row, block_col)?;
496                                }
497                            }
498                            // Interleaved scans are DC-only in progressive JPEG
499                        }
500                    }
501                }
502
503                mcu_count += 1;
504            }
505        }
506    }
507
508    let end_pos = reader.position();
509    Ok(end_pos)
510}
511
512/// DC first scan: Huffman-decode the DC difference and shift left by Al.
513fn decode_dc_first(
514    reader: &mut BitReader,
515    dc_tables: &[Option<HuffmanDecodeTable>; 4],
516    sc: &ScanComponent,
517    dc_pred: &mut i32,
518    al: u8,
519    grid: &mut DctGrid,
520    block_row: usize,
521    block_col: usize,
522) -> Result<()> {
523    let dc_tab = dc_tables[sc.dc_table]
524        .as_ref()
525        .ok_or(JpegError::InvalidHuffmanTableId(sc.dc_table as u8))?;
526    let dc_size = dc_tab.decode(reader)?;
527    if dc_size > 0 {
528        let dc_bits = reader.read_bits(dc_size)?;
529        let dc_diff = extend_sign(dc_bits, dc_size);
530        *dc_pred += dc_diff as i32;
531    }
532    // Store DC coefficient shifted left by Al (successive approximation), clamped to i16 range
533    let block = grid.block_mut(block_row, block_col);
534    block[0] = ((*dc_pred).clamp(i16::MIN as i32, i16::MAX as i32) as i16) << al;
535    Ok(())
536}
537
538/// DC refining scan: read one correction bit per DC coefficient.
539fn decode_dc_refine(
540    reader: &mut BitReader,
541    al: u8,
542    grid: &mut DctGrid,
543    block_row: usize,
544    block_col: usize,
545) -> Result<()> {
546    let bit = reader.read_bits(1)?;
547    let block = grid.block_mut(block_row, block_col);
548    // Set bit Al of the DC coefficient
549    if bit != 0 {
550        block[0] |= 1i16 << al;
551    }
552    Ok(())
553}
554
555/// AC first scan: decode AC coefficients for the spectral band [ss..se].
556///
557/// Similar to baseline AC decoding but operates on a sub-range of zigzag positions
558/// and supports EOBn (End of Band) run-length coding across multiple blocks.
559fn decode_ac_first(
560    reader: &mut BitReader,
561    ac_tables: &[Option<HuffmanDecodeTable>; 4],
562    sc: &ScanComponent,
563    al: u8,
564    ss: usize,
565    se: usize,
566    eob_run: &mut u32,
567    grid: &mut DctGrid,
568    block_row: usize,
569    block_col: usize,
570) -> Result<()> {
571    let block = grid.block_mut(block_row, block_col);
572
573    if *eob_run > 0 {
574        // We are in an EOB run — this block's coefficients in [ss..se] are all zero.
575        *eob_run -= 1;
576        return Ok(());
577    }
578
579    let ac_tab = ac_tables[sc.ac_table]
580        .as_ref()
581        .ok_or(JpegError::InvalidHuffmanTableId(sc.ac_table as u8))?;
582
583    let mut k = ss;
584    while k <= se {
585        let rs = ac_tab.decode(reader)?;
586        let run = (rs >> 4) as usize;
587        let size = rs & 0x0F;
588
589        if size == 0 {
590            if run == 15 {
591                // ZRL: skip 16 zero positions
592                k += 16;
593                continue;
594            } else {
595                // EOBn: End of Band for 2^run + extra blocks
596                // run=0 means EOB for this block only (EOB0 = 1 block)
597                // run=1..14 means read `run` extra bits to get the EOB run length
598                *eob_run = 1u32 << (run as u32);
599                if run > 0 {
600                    let extra = reader.read_bits(run as u8)? as u32;
601                    *eob_run += extra;
602                }
603                *eob_run -= 1; // Current block is part of the run
604                return Ok(());
605            }
606        }
607
608        k += run;
609        if k > se {
610            return Err(JpegError::HuffmanDecode);
611        }
612
613        let ac_bits = reader.read_bits(size)?;
614        let val = extend_sign(ac_bits, size);
615        // Store in natural order, shifted left by Al
616        block[ZIGZAG_TO_NATURAL[k]] = val << al;
617        k += 1;
618    }
619
620    Ok(())
621}
622
623/// AC refining scan: read correction bits for previously-nonzero coefficients
624/// and new nonzero coefficients in the spectral band [ss..se].
625///
626/// This is the most complex part of progressive JPEG decoding.
627/// The algorithm from ITU-T T.81 Figure G.7 interleaves:
628/// - Correction bits for coefficients that were already nonzero
629/// - New coefficients (with zero-run and EOBn coding)
630fn decode_ac_refine(
631    reader: &mut BitReader,
632    ac_tables: &[Option<HuffmanDecodeTable>; 4],
633    sc: &ScanComponent,
634    al: u8,
635    ss: usize,
636    se: usize,
637    eob_run: &mut u32,
638    grid: &mut DctGrid,
639    block_row: usize,
640    block_col: usize,
641) -> Result<()> {
642    let block = grid.block_mut(block_row, block_col);
643    let p1 = 1i16 << al;  // 1 in the bit position being corrected
644    let m1 = (-1i16) << al; // -1 in the bit position being corrected (= -(1 << al))
645
646    let ac_tab = ac_tables[sc.ac_table]
647        .as_ref()
648        .ok_or(JpegError::InvalidHuffmanTableId(sc.ac_table as u8))?;
649
650    let mut k = ss;
651
652    if *eob_run > 0 {
653        // In an EOB run: just apply correction bits to nonzero coefficients
654        while k <= se {
655            let ni = ZIGZAG_TO_NATURAL[k];
656            if block[ni] != 0 {
657                let bit = reader.read_bits(1)?;
658                if bit != 0 {
659                    if block[ni] > 0 {
660                        block[ni] += p1;
661                    } else {
662                        block[ni] += m1;
663                    }
664                }
665            }
666            k += 1;
667        }
668        *eob_run -= 1;
669        return Ok(());
670    }
671
672    while k <= se {
673        let rs = ac_tab.decode(reader)?;
674        let run = (rs >> 4) as usize; // Number of zero-valued coefficients to skip
675        let size = rs & 0x0F;
676
677        if size == 0 {
678            if run == 15 {
679                // ZRL: skip 16 zero-valued positions, applying correction bits to nonzero
680                let mut zeros_to_skip = 16usize;
681                while k <= se && zeros_to_skip > 0 {
682                    let ni = ZIGZAG_TO_NATURAL[k];
683                    if block[ni] != 0 {
684                        // Apply correction bit
685                        let bit = reader.read_bits(1)?;
686                        if bit != 0 {
687                            if block[ni] > 0 {
688                                block[ni] += p1;
689                            } else {
690                                block[ni] += m1;
691                            }
692                        }
693                    } else {
694                        zeros_to_skip -= 1;
695                    }
696                    k += 1;
697                }
698                continue;
699            } else {
700                // EOBn: remaining coefficients in this band get correction bits only
701                *eob_run = 1u32 << (run as u32);
702                if run > 0 {
703                    let extra = reader.read_bits(run as u8)? as u32;
704                    *eob_run += extra;
705                }
706                // Apply correction bits to remaining nonzero coefficients in this block
707                while k <= se {
708                    let ni = ZIGZAG_TO_NATURAL[k];
709                    if block[ni] != 0 {
710                        let bit = reader.read_bits(1)?;
711                        if bit != 0 {
712                            if block[ni] > 0 {
713                                block[ni] += p1;
714                            } else {
715                                block[ni] += m1;
716                            }
717                        }
718                    }
719                    k += 1;
720                }
721                *eob_run -= 1;
722                return Ok(());
723            }
724        } else if size == 1 {
725            // New nonzero coefficient after skipping `run` zero-valued positions
726            // Read the sign bit
727            let sign_bit = reader.read_bits(1)?;
728            let new_val = if sign_bit != 0 { p1 } else { m1 };
729
730            // Skip `run` zero-valued coefficients, applying correction bits to nonzero
731            let mut zeros_to_skip = run;
732            while k <= se {
733                let ni = ZIGZAG_TO_NATURAL[k];
734                if block[ni] != 0 {
735                    // Apply correction bit to existing nonzero
736                    let bit = reader.read_bits(1)?;
737                    if bit != 0 {
738                        if block[ni] > 0 {
739                            block[ni] += p1;
740                        } else {
741                            block[ni] += m1;
742                        }
743                    }
744                } else {
745                    if zeros_to_skip == 0 {
746                        // Place the new coefficient here
747                        block[ni] = new_val;
748                        k += 1;
749                        break;
750                    }
751                    zeros_to_skip -= 1;
752                }
753                k += 1;
754            }
755            continue;
756        } else {
757            // Invalid: size must be 0 or 1 in AC refining scans
758            return Err(JpegError::HuffmanDecode);
759        }
760    }
761
762    Ok(())
763}
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768
769    #[test]
770    fn encode_value_dc_diff() {
771        // DC difference of 5: category 3 (size=3), bits = 5 = 0b101
772        let (bits, size) = encode_value(5);
773        assert_eq!(size, 3);
774        assert_eq!(bits, 5);
775
776        // DC difference of -3: category 2 (size=2), bits = 0 (one's complement of 3 in 2 bits)
777        let (bits, size) = encode_value(-3);
778        assert_eq!(size, 2);
779        let recovered = extend_sign(bits, size);
780        assert_eq!(recovered, -3);
781    }
782}