Skip to main content

dicom_toolkit_codec/jpeg_ls/
scan.rs

1//! Core JPEG-LS scan decoder and encoder.
2//!
3//! Processes a scan line-by-line using the JPEG-LS algorithm:
4//! prediction → context modeling → Golomb-Rice decoding/encoding.
5
6use dicom_toolkit_core::error::{DcmError, DcmResult};
7
8use super::bitstream::{BitReader, BitWriter};
9use super::context::{JlsContext, RunModeContext};
10use super::golomb;
11use super::params::DerivedTraits;
12use super::prediction::{
13    apply_sign, bitwise_sign, build_quantization_lut, compute_context_id, get_predicted_value,
14    quantize_from_lut, sign, J,
15};
16
17/// Number of regular-mode contexts (365: half of 729 after sign normalization + context 0).
18const NUM_CONTEXTS: usize = 365;
19
20/// Scan decoder: decodes a single JPEG-LS scan to pixel data.
21pub struct ScanDecoder<'a> {
22    reader: BitReader<'a>,
23    traits: DerivedTraits,
24    #[allow(dead_code)]
25    t1: i32,
26    #[allow(dead_code)]
27    t2: i32,
28    #[allow(dead_code)]
29    t3: i32,
30    width: usize,
31    height: usize,
32    contexts: Vec<JlsContext>,
33    run_contexts: [RunModeContext; 2],
34    run_index: usize,
35    quant_lut: Vec<i8>,
36    quant_range: i32,
37}
38
39impl<'a> ScanDecoder<'a> {
40    pub fn new(
41        data: &'a [u8],
42        traits: DerivedTraits,
43        t1: i32,
44        t2: i32,
45        t3: i32,
46        width: usize,
47        height: usize,
48    ) -> Self {
49        let a_init = (traits.range + 32).max(2) / 64;
50        let a_init = a_init.max(2);
51
52        let contexts = vec![JlsContext::new(a_init); NUM_CONTEXTS];
53        let run_contexts = [
54            RunModeContext::new(a_init, traits.reset),
55            RunModeContext::new(a_init, traits.reset),
56        ];
57
58        let quant_lut = build_quantization_lut(traits.bpp, t1, t2, t3, traits.near);
59        let quant_range = 1i32 << traits.bpp;
60
61        Self {
62            reader: BitReader::new(data),
63            traits,
64            t1,
65            t2,
66            t3,
67            width,
68            height,
69            contexts,
70            run_contexts,
71            run_index: 0,
72            quant_lut,
73            quant_range,
74        }
75    }
76
77    /// Decode the entire scan, returning pixel data as i32 values.
78    pub fn decode(&mut self) -> DcmResult<Vec<i32>> {
79        let w = self.width;
80        let h = self.height;
81        let stride = w + 2; // 1 extra on each side for edge pixels
82
83        // Line buffers: previous and current (with 1-pixel padding on each side).
84        let mut prev_line = vec![0i32; stride];
85        let mut curr_line = vec![0i32; stride];
86
87        let mut output = Vec::with_capacity(w * h);
88
89        for _line in 0..h {
90            // Edge initialization: left edge of current = first pixel of previous.
91            curr_line[0] = prev_line[1];
92            // Right edge of previous = last pixel of previous.
93            prev_line[w + 1] = prev_line[w];
94
95            self.decode_line(&prev_line, &mut curr_line, w)?;
96
97            // Copy decoded pixels to output (indices 1..=w).
98            for val in curr_line.iter().take(w + 1).skip(1) {
99                output.push(*val);
100            }
101
102            // Swap lines.
103            std::mem::swap(&mut prev_line, &mut curr_line);
104        }
105
106        Ok(output)
107    }
108
109    /// Decode a single scan line.
110    fn decode_line(&mut self, prev: &[i32], curr: &mut [i32], width: usize) -> DcmResult<()> {
111        let mut index = 0usize;
112        // Rb (above) and Rd (above-right) are tracked across pixels.
113        let mut rb = prev[index]; // prev[0] for index=0
114        let mut rd = prev[index + 1]; // prev[1]
115
116        while index < width {
117            let ra = curr[index]; // left pixel (curr[0] for first pixel = edge)
118            let rc = rb;
119            rb = rd;
120            rd = prev[index + 2]; // above-right
121
122            let d1 = rd - rb;
123            let d2 = rb - rc;
124            let d3 = rc - ra;
125
126            let q1 = quantize_from_lut(&self.quant_lut, d1, self.quant_range);
127            let q2 = quantize_from_lut(&self.quant_lut, d2, self.quant_range);
128            let q3 = quantize_from_lut(&self.quant_lut, d3, self.quant_range);
129
130            let qs = compute_context_id(q1, q2, q3);
131
132            if qs != 0 {
133                // Regular mode.
134                let val = self.do_regular_decode(qs, ra, rb, rc)?;
135                curr[index + 1] = val;
136                index += 1;
137            } else {
138                // Run mode.
139                let count = self.do_run_mode_decode(curr, prev, index, width)?;
140                index += count;
141                if index < width {
142                    rb = prev[index];
143                    rd = prev[index + 1];
144                }
145            }
146        }
147
148        Ok(())
149    }
150
151    /// Decode a regular-mode sample.
152    fn do_regular_decode(&mut self, qs: i32, ra: i32, rb: i32, rc: i32) -> DcmResult<i32> {
153        let sign_val = bitwise_sign(qs);
154        let ctx_idx = apply_sign(qs, sign_val) as usize;
155        let ctx = &mut self.contexts[ctx_idx];
156
157        let k = ctx.get_golomb();
158        let px = self.traits.correct_prediction(
159            get_predicted_value(ra, rb, rc) + apply_sign(ctx.c as i32, sign_val),
160        );
161
162        // Decode the error value.
163        let mapped_err =
164            golomb::decode_mapped_value(&mut self.reader, k, self.traits.limit, self.traits.qbpp)?;
165        let mut err_val = golomb::unmap_err_val(mapped_err);
166
167        if err_val.abs() > 65535 {
168            return Err(DcmError::DecompressionError {
169                reason: "JPEG-LS: error value out of range".into(),
170            });
171        }
172
173        // Apply error correction for lossless mode.
174        if self.traits.near == 0 {
175            err_val ^= ctx.get_error_correction(k);
176        }
177
178        ctx.update_variables(err_val, self.traits.near, self.traits.reset);
179        let err_val = apply_sign(err_val, sign_val);
180
181        Ok(self.traits.compute_reconstructed(px, err_val))
182    }
183
184    /// Decode run mode: decode run length + optional run interruption sample.
185    fn do_run_mode_decode(
186        &mut self,
187        curr: &mut [i32],
188        prev: &[i32],
189        start_index: usize,
190        width: usize,
191    ) -> DcmResult<usize> {
192        let ra = curr[start_index]; // left pixel
193
194        // Decode run pixels.
195        let run_length = self.decode_run_pixels(ra, curr, start_index, width)?;
196        let end_index = start_index + run_length;
197
198        if end_index == width {
199            return Ok(run_length);
200        }
201
202        // Run interruption: decode the interruption sample.
203        let rb = prev[end_index + 1]; // above pixel at interruption point
204        let val = self.decode_ri_pixel(ra, rb)?;
205        curr[end_index + 1] = val;
206        self.run_index = self.run_index.saturating_sub(1);
207
208        Ok(run_length + 1)
209    }
210
211    /// Decode run-length encoded pixels.
212    fn decode_run_pixels(
213        &mut self,
214        ra: i32,
215        curr: &mut [i32],
216        start: usize,
217        width: usize,
218    ) -> DcmResult<usize> {
219        let max_run = width - start;
220        let mut count = 0usize;
221
222        while self.reader.read_bit()? {
223            let j_val = J[self.run_index] as usize;
224            let run_len = (1usize << j_val).min(max_run - count);
225            count += run_len;
226
227            if run_len == (1usize << j_val) {
228                self.run_index = (self.run_index + 1).min(31);
229            }
230
231            if count == max_run {
232                break;
233            }
234        }
235
236        if count < max_run {
237            // Incomplete run — read remaining length.
238            let j_val = J[self.run_index];
239            if j_val > 0 {
240                count += self.reader.read_value(j_val)? as usize;
241            }
242        }
243
244        if count > max_run {
245            return Err(DcmError::DecompressionError {
246                reason: "JPEG-LS: run length exceeds line width".into(),
247            });
248        }
249
250        // Fill pixels with Ra.
251        for i in 0..count {
252            curr[start + 1 + i] = ra;
253        }
254
255        Ok(count)
256    }
257
258    /// Decode a run-interruption pixel.
259    fn decode_ri_pixel(&mut self, ra: i32, rb: i32) -> DcmResult<i32> {
260        let ctx_idx = if (ra - rb).abs() <= self.traits.near {
261            1
262        } else {
263            0
264        };
265
266        let err_val = self.decode_ri_error(ctx_idx)?;
267
268        if ctx_idx == 1 {
269            Ok(self.traits.compute_reconstructed(ra, err_val))
270        } else {
271            Ok(self
272                .traits
273                .compute_reconstructed(rb, err_val * sign(rb - ra)))
274        }
275    }
276
277    /// Decode a run-interruption error value.
278    fn decode_ri_error(&mut self, ctx_idx: usize) -> DcmResult<i32> {
279        let ctx = &self.run_contexts[ctx_idx];
280        let k = ctx.get_golomb();
281        let limit = self.traits.limit - J[self.run_index] - 1;
282
283        let em_err_val = golomb::decode_mapped_value(&mut self.reader, k, limit, self.traits.qbpp)?;
284
285        let ctx = &mut self.run_contexts[ctx_idx];
286        let ri_type = ctx.ri_type;
287        let err_val = compute_ri_err_val(em_err_val + ri_type, k, ctx);
288        ctx.update_variables(err_val, em_err_val);
289
290        Ok(err_val)
291    }
292}
293
294/// Compute the actual error value from the mapped run-interruption error value.
295fn compute_ri_err_val(temp: i32, k: i32, ctx: &RunModeContext) -> i32 {
296    let map = temp & 1;
297    let err_abs = (temp + map) / 2;
298
299    let condition = if k != 0 || 2 * ctx.nn >= ctx.n { 1 } else { 0 };
300
301    if condition == map {
302        -err_abs
303    } else {
304        err_abs
305    }
306}
307
308// ── Scan encoder ──────────────────────────────────────────────────────────────
309
310/// Scan encoder: encodes pixel data into a JPEG-LS scan.
311pub struct ScanEncoder {
312    writer: BitWriter,
313    traits: DerivedTraits,
314    #[allow(dead_code)]
315    t1: i32,
316    #[allow(dead_code)]
317    t2: i32,
318    #[allow(dead_code)]
319    t3: i32,
320    width: usize,
321    height: usize,
322    contexts: Vec<JlsContext>,
323    run_contexts: [RunModeContext; 2],
324    run_index: usize,
325    quant_lut: Vec<i8>,
326    quant_range: i32,
327}
328
329impl ScanEncoder {
330    pub fn new(
331        traits: DerivedTraits,
332        t1: i32,
333        t2: i32,
334        t3: i32,
335        width: usize,
336        height: usize,
337    ) -> Self {
338        let a_init = (traits.range + 32).max(2) / 64;
339        let a_init = a_init.max(2);
340
341        let contexts = vec![JlsContext::new(a_init); NUM_CONTEXTS];
342        let run_contexts = [
343            RunModeContext::new(a_init, traits.reset),
344            RunModeContext::new(a_init, traits.reset),
345        ];
346
347        let quant_lut = build_quantization_lut(traits.bpp, t1, t2, t3, traits.near);
348        let quant_range = 1i32 << traits.bpp;
349
350        Self {
351            writer: BitWriter::new(),
352            traits,
353            t1,
354            t2,
355            t3,
356            width,
357            height,
358            contexts,
359            run_contexts,
360            run_index: 0,
361            quant_lut,
362            quant_range,
363        }
364    }
365
366    /// Encode pixel data (i32 values) into a JPEG-LS scan bitstream.
367    pub fn encode(&mut self, pixels: &[i32]) -> DcmResult<Vec<u8>> {
368        let w = self.width;
369        let h = self.height;
370
371        if pixels.len() != w * h {
372            return Err(DcmError::CompressionError {
373                reason: format!("expected {} pixels, got {}", w * h, pixels.len()),
374            });
375        }
376
377        let stride = w + 2;
378        let mut prev_line = vec![0i32; stride];
379        let mut curr_line = vec![0i32; stride];
380
381        for line in 0..h {
382            // Load pixel data into current line (indices 1..=w).
383            for i in 0..w {
384                curr_line[i + 1] = pixels[line * w + i];
385            }
386
387            // Edge initialization.
388            curr_line[0] = prev_line[1];
389            prev_line[w + 1] = prev_line[w];
390
391            self.encode_line(&prev_line, &mut curr_line, w)?;
392
393            std::mem::swap(&mut prev_line, &mut curr_line);
394        }
395
396        self.writer.end_scan();
397        // Take the writer output.
398        let mut result_writer = BitWriter::new();
399        std::mem::swap(&mut self.writer, &mut result_writer);
400        Ok(result_writer.into_bytes())
401    }
402
403    fn encode_line(&mut self, prev: &[i32], curr: &mut [i32], width: usize) -> DcmResult<()> {
404        let mut index = 0usize;
405        let mut rb = prev[index];
406        let mut rd = prev[index + 1];
407
408        while index < width {
409            let ra = curr[index];
410            let rc = rb;
411            rb = rd;
412            rd = prev[index + 2];
413
414            let d1 = rd - rb;
415            let d2 = rb - rc;
416            let d3 = rc - ra;
417
418            let q1 = quantize_from_lut(&self.quant_lut, d1, self.quant_range);
419            let q2 = quantize_from_lut(&self.quant_lut, d2, self.quant_range);
420            let q3 = quantize_from_lut(&self.quant_lut, d3, self.quant_range);
421
422            let qs = compute_context_id(q1, q2, q3);
423
424            if qs != 0 {
425                let x = curr[index + 1];
426                let recon = self.do_regular_encode(qs, x, ra, rb, rc)?;
427                curr[index + 1] = recon;
428                index += 1;
429            } else {
430                let count = self.do_run_mode_encode(curr, prev, index, width)?;
431                index += count;
432                if index < width {
433                    rb = prev[index];
434                    rd = prev[index + 1];
435                }
436            }
437        }
438
439        Ok(())
440    }
441
442    fn do_regular_encode(&mut self, qs: i32, x: i32, ra: i32, rb: i32, rc: i32) -> DcmResult<i32> {
443        let sign_val = bitwise_sign(qs);
444        let ctx_idx = apply_sign(qs, sign_val) as usize;
445        let ctx = &mut self.contexts[ctx_idx];
446
447        let k = ctx.get_golomb();
448        let px = self.traits.correct_prediction(
449            get_predicted_value(ra, rb, rc) + apply_sign(ctx.c as i32, sign_val),
450        );
451
452        let err_val = self.traits.compute_error_val(apply_sign(x - px, sign_val));
453
454        let mapped_err =
455            golomb::get_mapped_err_val(ctx.get_error_correction(k | self.traits.near) ^ err_val);
456        golomb::encode_mapped_value(
457            &mut self.writer,
458            k,
459            mapped_err,
460            self.traits.limit,
461            self.traits.qbpp,
462        );
463
464        ctx.update_variables(err_val, self.traits.near, self.traits.reset);
465
466        Ok(self
467            .traits
468            .compute_reconstructed(px, apply_sign(err_val, sign_val)))
469    }
470
471    fn do_run_mode_encode(
472        &mut self,
473        curr: &mut [i32],
474        prev: &[i32],
475        start_index: usize,
476        width: usize,
477    ) -> DcmResult<usize> {
478        let ra = curr[start_index];
479        let max_run = width - start_index;
480
481        // Count the run.
482        let mut run_length = 0usize;
483        while run_length < max_run {
484            let px = curr[start_index + 1 + run_length];
485            if self.traits.near == 0 {
486                if px != ra {
487                    break;
488                }
489            } else if (px - ra).abs() > self.traits.near {
490                break;
491            }
492            // Reconstruct as Ra for near-lossless.
493            curr[start_index + 1 + run_length] = ra;
494            run_length += 1;
495        }
496
497        let end_of_line = run_length == max_run;
498        self.encode_run_pixels(run_length as i32, end_of_line);
499
500        if end_of_line {
501            return Ok(run_length);
502        }
503
504        // Run interruption: encode the interruption sample.
505        let x = curr[start_index + 1 + run_length];
506        let rb = prev[start_index + 1 + run_length];
507        let recon = self.encode_ri_pixel(x, ra, rb)?;
508        curr[start_index + 1 + run_length] = recon;
509        self.run_index = self.run_index.saturating_sub(1);
510
511        Ok(run_length + 1)
512    }
513
514    fn encode_run_pixels(&mut self, mut run_length: i32, end_of_line: bool) {
515        while run_length >= (1 << J[self.run_index]) {
516            self.writer.append_ones(1);
517            run_length -= 1 << J[self.run_index];
518            self.run_index = (self.run_index + 1).min(31);
519        }
520
521        if end_of_line {
522            if run_length != 0 {
523                self.writer.append_ones(1);
524            }
525        } else {
526            self.writer.append(run_length, J[self.run_index] + 1);
527        }
528    }
529
530    fn encode_ri_pixel(&mut self, x: i32, ra: i32, rb: i32) -> DcmResult<i32> {
531        if (ra - rb).abs() <= self.traits.near {
532            let err_val = self.traits.compute_error_val(x - ra);
533            self.encode_ri_error(1, err_val);
534            Ok(self.traits.compute_reconstructed(ra, err_val))
535        } else {
536            let err_val = self.traits.compute_error_val((x - rb) * sign(rb - ra));
537            self.encode_ri_error(0, err_val);
538            Ok(self
539                .traits
540                .compute_reconstructed(rb, err_val * sign(rb - ra)))
541        }
542    }
543
544    fn encode_ri_error(&mut self, ctx_idx: usize, err_val: i32) {
545        let ctx = &self.run_contexts[ctx_idx];
546        let k = ctx.get_golomb();
547        let map = ctx.compute_map(err_val, k);
548        let em_err_val = 2 * err_val.abs() - ctx.ri_type - map;
549
550        let limit = self.traits.limit - J[self.run_index] - 1;
551        golomb::encode_mapped_value(&mut self.writer, k, em_err_val, limit, self.traits.qbpp);
552
553        let ctx = &mut self.run_contexts[ctx_idx];
554        ctx.update_variables(err_val, em_err_val);
555    }
556}
557
558// ── Tests ─────────────────────────────────────────────────────────────────────
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use crate::jpeg_ls::params::{compute_default, BASIC_RESET};
564
565    fn make_traits_8bit() -> (DerivedTraits, i32, i32, i32) {
566        let max_val = 255;
567        let near = 0;
568        let defaults = compute_default(max_val, near);
569        let traits = DerivedTraits::new(max_val, near, BASIC_RESET);
570        (traits, defaults.t1, defaults.t2, defaults.t3)
571    }
572
573    #[test]
574    fn encode_decode_roundtrip_constant() {
575        let (traits, t1, t2, t3) = make_traits_8bit();
576        let w = 8;
577        let h = 4;
578        let pixels: Vec<i32> = vec![128; w * h];
579
580        let mut encoder = ScanEncoder::new(traits, t1, t2, t3, w, h);
581        let encoded = encoder.encode(&pixels).unwrap();
582
583        let traits2 = DerivedTraits::new(255, 0, BASIC_RESET);
584        let mut decoder = ScanDecoder::new(&encoded, traits2, t1, t2, t3, w, h);
585        let decoded = decoder.decode().unwrap();
586
587        assert_eq!(decoded, pixels);
588    }
589
590    #[test]
591    fn encode_decode_roundtrip_gradient() {
592        let (traits, t1, t2, t3) = make_traits_8bit();
593        let w = 16;
594        let h = 8;
595        let mut pixels = Vec::with_capacity(w * h);
596        for y in 0..h {
597            for x in 0..w {
598                pixels.push(((x * 16 + y * 8) % 256) as i32);
599            }
600        }
601
602        let mut encoder = ScanEncoder::new(traits, t1, t2, t3, w, h);
603        let encoded = encoder.encode(&pixels).unwrap();
604
605        let traits2 = DerivedTraits::new(255, 0, BASIC_RESET);
606        let mut decoder = ScanDecoder::new(&encoded, traits2, t1, t2, t3, w, h);
607        let decoded = decoder.decode().unwrap();
608
609        assert_eq!(decoded, pixels);
610    }
611
612    #[test]
613    fn encode_decode_roundtrip_random_like() {
614        let (traits, t1, t2, t3) = make_traits_8bit();
615        let w = 32;
616        let h = 16;
617        let mut pixels = Vec::with_capacity(w * h);
618        // Pseudo-random deterministic data.
619        let mut val: u32 = 42;
620        for _ in 0..(w * h) {
621            val = val.wrapping_mul(1103515245).wrapping_add(12345);
622            pixels.push(((val >> 16) & 0xFF) as i32);
623        }
624
625        let mut encoder = ScanEncoder::new(traits, t1, t2, t3, w, h);
626        let encoded = encoder.encode(&pixels).unwrap();
627
628        let traits2 = DerivedTraits::new(255, 0, BASIC_RESET);
629        let mut decoder = ScanDecoder::new(&encoded, traits2, t1, t2, t3, w, h);
630        let decoded = decoder.decode().unwrap();
631
632        assert_eq!(decoded, pixels);
633    }
634
635    #[test]
636    fn encode_decode_roundtrip_1x1() {
637        let (traits, t1, t2, t3) = make_traits_8bit();
638        let pixels = vec![200i32];
639
640        let mut encoder = ScanEncoder::new(traits, t1, t2, t3, 1, 1);
641        let encoded = encoder.encode(&pixels).unwrap();
642
643        let traits2 = DerivedTraits::new(255, 0, BASIC_RESET);
644        let mut decoder = ScanDecoder::new(&encoded, traits2, t1, t2, t3, 1, 1);
645        let decoded = decoder.decode().unwrap();
646
647        assert_eq!(decoded, pixels);
648    }
649}