gistools/readers/geotiff/
predictor.rs

1use alloc::vec::Vec;
2use core::ops::AddAssign;
3
4/// Decode a block using the specified predictor (u8 version)
5pub fn decode_row_acc_u8(row: &mut [u8], stride: usize) {
6    let mut length = row.len().saturating_sub(stride);
7    let mut offset = 0;
8    while length > 0 {
9        for _ in 0..stride {
10            row[offset + stride] = row[offset + stride].wrapping_add(row[offset]);
11            offset += 1;
12        }
13        length -= stride;
14    }
15}
16
17/// Decode a block using the specified predictor (u16 version)
18pub fn decode_row_acc_u16(row: &mut [u16], stride: usize) {
19    let mut length = row.len().saturating_sub(stride);
20    let mut offset = 0;
21    while length > 0 {
22        for _ in 0..stride {
23            row[offset + stride] = row[offset + stride].wrapping_add(row[offset]);
24            offset += 1;
25        }
26        length -= stride;
27    }
28}
29
30/// Decode a block using the specified predictor (u32 version)
31pub fn decode_row_acc_u32(row: &mut [u32], stride: usize) {
32    let mut length = row.len().saturating_sub(stride);
33    let mut offset = 0;
34    while length > 0 {
35        for _ in 0..stride {
36            row[offset + stride] = row[offset + stride].wrapping_add(row[offset]);
37            offset += 1;
38        }
39        length -= stride;
40    }
41}
42
43/// Decode a floating point block using the specified predictor
44///
45/// ## Parameters
46/// - `row`: the row to decode
47/// - `stride`: the number of bytes per row
48/// - `bytes_per_sample`: the number of bytes per sample
49pub fn decode_row_floating_point<T>(row: &mut [T], stride: usize, bytes_per_sample: usize)
50where
51    T: AddAssign + Copy,
52{
53    let mut index = 0;
54    let mut count = row.len();
55    let wc = count / bytes_per_sample;
56
57    while count > stride {
58        let mut i = stride;
59        while i > 0 {
60            row[index + stride] += row[index];
61            index += 1;
62            i -= 1;
63        }
64        count -= stride;
65    }
66
67    let copy: Vec<T> = row.to_vec();
68    for i in 0..wc {
69        for b in 0..bytes_per_sample {
70            row[bytes_per_sample * i + b] = copy[(bytes_per_sample - b - 1) * wc + i];
71        }
72    }
73}
74
75/// Apply the specified predictor to a block
76///
77/// ## Parameters
78/// - `block`: the block to modify
79/// - `predictor`: the predictor
80/// - `width`: the block width
81/// - `height`: the block height
82/// - `bits_per_sample`: the number of bits per sample
83/// - `planar_configuration`: the planar configuration
84///
85/// ## Returns
86/// The modified block
87pub fn apply_predictor(
88    mut block: Vec<u8>,
89    predictor: i16,
90    width: usize,
91    height: usize,
92    bits_per_sample: Vec<u16>,
93    planar_configuration: i16,
94) -> Vec<u8> {
95    if predictor == 0 || predictor == 1 {
96        return block;
97    }
98
99    for i in 0..bits_per_sample.len() {
100        if !bits_per_sample[i].is_multiple_of(8) {
101            panic!("When decoding with predictor, only multiple of 8 bits are supported.");
102        }
103        if bits_per_sample[i] != bits_per_sample[0] {
104            panic!("When decoding with predictor, all samples must have the same size.");
105        }
106    }
107
108    let bytes_per_sample = (bits_per_sample[0] / 8) as usize;
109    let stride = if planar_configuration == 2 { 1 } else { bits_per_sample.len() };
110
111    for i in 0..height {
112        // Last strip will be truncated if height % stripHeight != 0
113        if i * stride * width * bytes_per_sample >= block.len() {
114            break;
115        }
116        if predictor == 2 {
117            // horizontal prediction
118            let row = &mut block[i * stride * width * bytes_per_sample
119                ..(i + 1) * stride * width * bytes_per_sample];
120            match bits_per_sample[0] {
121                8 => {
122                    decode_row_acc_u8(row, stride);
123                }
124                16 => {
125                    decode_row_acc_u16(as_u16_slice_mut(row), stride);
126                }
127                32 => {
128                    decode_row_acc_u32(as_u32_slice_mut(row), stride);
129                }
130                _ => panic!("Predictor 2 not allowed with {} bits per sample.", bits_per_sample[0]),
131            }
132        } else if predictor == 3 {
133            // horizontal floating point
134            let row = &mut block[i * stride * width * bytes_per_sample
135                ..(i + 1) * stride * width * bytes_per_sample];
136            decode_row_floating_point(row, stride, bytes_per_sample);
137        }
138    }
139
140    block
141}
142
143fn as_u16_slice_mut(data: &mut [u8]) -> &mut [u16] {
144    assert_eq!(data.len() % 2, 0);
145    let ptr = data.as_mut_ptr() as *mut u16;
146    let len = data.len() / 2;
147    assert_eq!(ptr.align_offset(core::mem::align_of::<u16>()), 0);
148    unsafe { core::slice::from_raw_parts_mut(ptr, len) }
149}
150
151fn as_u32_slice_mut(data: &mut [u8]) -> &mut [u32] {
152    assert_eq!(data.len() % 4, 0);
153    let ptr = data.as_mut_ptr() as *mut u32;
154    let len = data.len() / 4;
155    assert_eq!(ptr.align_offset(core::mem::align_of::<u32>()), 0);
156    unsafe { core::slice::from_raw_parts_mut(ptr, len) }
157}