tx2_iff/
wavelet.rs

1//! Integer Wavelet Transform (CDF 5/3) for Layer 1 skeleton
2//!
3//! Implements the Cohen-Daubechies-Feauveau 5/3 biorthogonal wavelet
4//! using the lifting scheme with all operations in integers for perfect
5//! reconstruction and deterministic behavior.
6
7use crate::error::{IffError, Result};
8use crate::prime::QuantizationTable;
9use crate::compression::{compress_rle, decompress_rle};
10use serde::{Deserialize, Serialize};
11
12/// Wavelet subband type
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SubBand {
15    LL,
16    LH,
17    HL,
18    HH,
19}
20
21/// Wavelet decomposition structure (Compressed)
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct WaveletDecomposition {
24    /// Original image dimensions
25    pub width: u32,
26    pub height: u32,
27    /// Number of decomposition levels
28    pub levels: usize,
29    /// Number of channels
30    pub channels: u8,
31    /// Compressed coefficients (RLE encoded)
32    /// Layout: [Channel 0 Data] [Channel 1 Data] ...
33    pub data: Vec<u8>,
34}
35
36impl WaveletDecomposition {
37    /// Create a new empty decomposition
38    pub fn new(width: u32, height: u32, levels: usize, channels: u8) -> Self {
39        WaveletDecomposition {
40            width,
41            height,
42            levels,
43            channels,
44            data: Vec::new(),
45        }
46    }
47
48    /// Create from dense coefficient buffers
49    pub fn from_dense(
50        width: u32,
51        height: u32,
52        levels: usize,
53        channel_data: &[Vec<i32>],
54    ) -> Result<Self> {
55        let mut all_data = Vec::new();
56        for channel in channel_data {
57            all_data.extend_from_slice(channel);
58        }
59
60        let compressed = compress_rle(&all_data)?;
61
62        Ok(WaveletDecomposition {
63            width,
64            height,
65            levels,
66            channels: channel_data.len() as u8,
67            data: compressed,
68        })
69    }
70
71    /// Decompress to dense coefficient buffers
72    pub fn to_dense(&self) -> Result<Vec<Vec<i32>>> {
73        let total_pixels = (self.width * self.height) as usize;
74        let expected_len = total_pixels * self.channels as usize;
75
76        let all_data = decompress_rle(&self.data, Some(expected_len))?;
77
78        let mut channels = Vec::with_capacity(self.channels as usize);
79        for i in 0..self.channels as usize {
80            let start = i * total_pixels;
81            let end = start + total_pixels;
82            if end > all_data.len() {
83                return Err(IffError::Other("Insufficient data for channels".to_string()));
84            }
85            channels.push(all_data[start..end].to_vec());
86        }
87
88        Ok(channels)
89    }
90}
91
92/// CDF 5/3 Integer Wavelet Transform
93pub struct Cdf53Transform {
94    /// Number of decomposition levels
95    levels: usize,
96}
97
98impl Cdf53Transform {
99    /// Create a new CDF 5/3 transform with specified levels
100    pub fn new(levels: usize) -> Self {
101        Cdf53Transform { levels }
102    }
103
104    /// Forward transform: image → wavelet coefficients (dense)
105    pub fn forward(&self, image: &[i32], width: usize, height: usize) -> Result<Vec<i32>> {
106        if image.len() != width * height {
107            return Err(IffError::Other(
108                "Image dimensions don't match data length".to_string(),
109            ));
110        }
111
112        // Working buffer (will be modified in-place)
113        let mut buffer = image.to_vec();
114
115        // Apply transform levels
116        let mut current_width = width;
117        let mut current_height = height;
118
119        for _ in 0..self.levels {
120            if current_width < 2 || current_height < 2 {
121                break; // Can't decompose further
122            }
123
124            // Transform rows
125            for y in 0..current_height {
126                self.forward_1d_row(&mut buffer, y, current_width, width);
127            }
128
129            // Transform columns
130            for x in 0..current_width {
131                self.forward_1d_col(&mut buffer, x, current_height, width);
132            }
133
134            // Next level operates on LL subband
135            current_width /= 2;
136            current_height /= 2;
137        }
138
139        Ok(buffer)
140    }
141
142    /// Inverse transform: wavelet coefficients (dense) → image
143    pub fn inverse(&self, coefficients: &[i32], width: usize, height: usize) -> Result<Vec<i32>> {
144        if coefficients.len() != width * height {
145            return Err(IffError::Other(
146                "Coefficient dimensions don't match data length".to_string(),
147            ));
148        }
149
150        let mut buffer = coefficients.to_vec();
151
152        // Apply inverse transform levels (in reverse order)
153
154        // Adjust starting level if image is too small
155
156        // We need to reconstruct the correct sequence of widths/heights
157        // Or just iterate backwards
158        // The forward loop goes: (w,h) -> (w/2, h/2) ...
159        // The inverse loop should go: ... (w/2, h/2) -> (w,h)
160
161        // Let's pre-calculate the dimensions for each level
162        let mut dims = Vec::new();
163        let mut w = width;
164        let mut h = height;
165        for _ in 0..self.levels {
166            dims.push((w, h));
167            w /= 2;
168            h /= 2;
169        }
170
171        for (w, h) in dims.iter().rev() {
172             let current_width = *w;
173             let current_height = *h;
174             
175             if current_width < 2 || current_height < 2 {
176                 continue;
177             }
178
179            // Inverse transform columns
180            for x in 0..current_width {
181                self.inverse_1d_col(&mut buffer, x, current_height, width);
182            }
183
184            // Inverse transform rows
185            for y in 0..current_height {
186                self.inverse_1d_row(&mut buffer, y, current_width, width);
187            }
188        }
189
190        Ok(buffer)
191    }
192
193    /// Quantize coefficients using 360-prime pattern (in-place)
194    pub fn quantize(&self, buffer: &mut [i32], width: usize, height: usize, table: &QuantizationTable) {
195        // We need to iterate over subbands to know coordinates
196        // But the buffer is dense.
197        // We can iterate over levels and subbands similar to extract_coefficients
198        
199        let mut current_width = width;
200        let mut current_height = height;
201
202        for level in 0..self.levels {
203            let half_w = current_width / 2;
204            let half_h = current_height / 2;
205            
206            if half_w == 0 || half_h == 0 { break; }
207
208            // Process subbands
209            // LL is skipped (it's processed in next level, or at the end)
210            // But wait, LL of level N is the input to level N+1.
211            // We only quantize the details (LH, HL, HH) of each level.
212            // And the final LL.
213
214            // LH subband
215            for y in 0..half_h {
216                for x in half_w..current_width {
217                    self.quantize_pixel(buffer, x, y, width, table);
218                }
219            }
220
221            // HL subband
222            for y in half_h..current_height {
223                for x in 0..half_w {
224                    self.quantize_pixel(buffer, x, y, width, table);
225                }
226            }
227
228            // HH subband
229            for y in half_h..current_height {
230                for x in half_w..current_width {
231                    self.quantize_pixel(buffer, x, y, width, table);
232                }
233            }
234            
235            // If this is the last level, also quantize the LL band
236            if level == self.levels - 1 {
237                for y in 0..half_h {
238                    for x in 0..half_w {
239                        self.quantize_pixel(buffer, x, y, width, table);
240                    }
241                }
242            }
243
244            current_width = half_w;
245            current_height = half_h;
246        }
247    }
248    
249    fn quantize_pixel(&self, buffer: &mut [i32], x: usize, y: usize, width: usize, table: &QuantizationTable) {
250        let idx = y * width + x;
251        let val = buffer[idx];
252        let step = table.get_step(x, y);
253        
254        let quantized = val / step as i32;
255        
256        // Aggressive sparsification
257        if quantized.abs() < 2 { // Reduced threshold from 3 to 2 to be less aggressive?
258             buffer[idx] = 0;
259        } else {
260             buffer[idx] = quantized * step as i32;
261        }
262    }
263
264    /// Forward 1D transform on a row (in-place lifting scheme)
265    fn forward_1d_row(&self, data: &mut [i32], y: usize, width: usize, stride: usize) {
266        if width < 2 {
267            return;
268        }
269
270        let offset = y * stride;
271        let mut temp = vec![0i32; width];
272
273        // Predict step: d[i] = s[2i+1] - floor((s[2i] + s[2i+2]) / 2)
274        for i in 0..(width / 2) {
275            let left = data[offset + 2 * i];
276            let right = if 2 * i + 2 < width {
277                data[offset + 2 * i + 2]
278            } else {
279                data[offset + 2 * i] // Mirror boundary
280            };
281            temp[width / 2 + i] = data[offset + 2 * i + 1] - ((left + right) / 2);
282        }
283
284        // Update step: s[i] = s[2i] + floor((d[i-1] + d[i] + 2) / 4)
285        for i in 0..(width / 2) {
286            let d_left = if i > 0 {
287                temp[width / 2 + i - 1]
288            } else {
289                temp[width / 2] // Mirror boundary
290            };
291            let d_right = if i < width / 2 {
292                temp[width / 2 + i]
293            } else {
294                temp[width / 2 + i - 1] // Mirror boundary
295            };
296            temp[i] = data[offset + 2 * i] + ((d_left + d_right + 2) / 4);
297        }
298
299        // Copy back
300        for i in 0..width {
301            data[offset + i] = temp[i];
302        }
303    }
304
305    /// Forward 1D transform on a column
306    fn forward_1d_col(&self, data: &mut [i32], x: usize, height: usize, stride: usize) {
307        if height < 2 {
308            return;
309        }
310
311        let mut temp = vec![0i32; height];
312
313        // Predict step
314        for i in 0..(height / 2) {
315            let top = data[2 * i * stride + x];
316            let bottom = if 2 * i + 2 < height {
317                data[(2 * i + 2) * stride + x]
318            } else {
319                data[2 * i * stride + x] // Mirror boundary
320            };
321            temp[height / 2 + i] = data[(2 * i + 1) * stride + x] - ((top + bottom) / 2);
322        }
323
324        // Update step
325        for i in 0..(height / 2) {
326            let d_top = if i > 0 {
327                temp[height / 2 + i - 1]
328            } else {
329                temp[height / 2] // Mirror boundary
330            };
331            let d_bottom = if i < height / 2 {
332                temp[height / 2 + i]
333            } else {
334                temp[height / 2 + i - 1] // Mirror boundary
335            };
336            temp[i] = data[2 * i * stride + x] + ((d_top + d_bottom + 2) / 4);
337        }
338
339        // Copy back
340        for i in 0..height {
341            data[i * stride + x] = temp[i];
342        }
343    }
344
345    /// Inverse 1D transform on a row
346    fn inverse_1d_row(&self, data: &mut [i32], y: usize, width: usize, stride: usize) {
347        if width < 2 {
348            return;
349        }
350
351        let offset = y * stride;
352        let mut temp = vec![0i32; width];
353
354        // Copy to temp
355        for i in 0..width {
356            temp[i] = data[offset + i];
357        }
358
359        // Undo update step
360        for i in 0..(width / 2) {
361            let d_left = if i > 0 {
362                temp[width / 2 + i - 1]
363            } else {
364                temp[width / 2]
365            };
366            let d_right = if i < width / 2 {
367                temp[width / 2 + i]
368            } else {
369                temp[width / 2 + i - 1]
370            };
371            data[offset + 2 * i] = temp[i] - ((d_left + d_right + 2) / 4);
372        }
373
374        // Undo predict step
375        for i in 0..(width / 2) {
376            let left = data[offset + 2 * i];
377            let right = if 2 * i + 2 < width {
378                data[offset + 2 * i + 2]
379            } else {
380                data[offset + 2 * i]
381            };
382            data[offset + 2 * i + 1] = temp[width / 2 + i] + ((left + right) / 2);
383        }
384    }
385
386    /// Inverse 1D transform on a column
387    fn inverse_1d_col(&self, data: &mut [i32], x: usize, height: usize, stride: usize) {
388        if height < 2 {
389            return;
390        }
391
392        let mut temp = vec![0i32; height];
393
394        // Copy to temp
395        for i in 0..height {
396            temp[i] = data[i * stride + x];
397        }
398
399        // Undo update step
400        for i in 0..(height / 2) {
401            let d_top = if i > 0 {
402                temp[height / 2 + i - 1]
403            } else {
404                temp[height / 2]
405            };
406            let d_bottom = if i < height / 2 {
407                temp[height / 2 + i]
408            } else {
409                temp[height / 2 + i - 1]
410            };
411            data[2 * i * stride + x] = temp[i] - ((d_top + d_bottom + 2) / 4);
412        }
413
414        // Undo predict step
415        for i in 0..(height / 2) {
416            let top = data[2 * i * stride + x];
417            let bottom = if 2 * i + 2 < height {
418                data[(2 * i + 2) * stride + x]
419            } else {
420                data[2 * i * stride + x]
421            };
422            data[(2 * i + 1) * stride + x] = temp[height / 2 + i] + ((top + bottom) / 2);
423        }
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_perfect_reconstruction() {
433        // Create a small test image (8x8)
434        let width = 8;
435        let height = 8;
436        let image: Vec<i32> = (0..(width * height))
437            .map(|i| ((i % 256) as i32 - 128))
438            .collect();
439
440        // Forward transform
441        let transform = Cdf53Transform::new(2);
442        let coeffs = transform.forward(&image, width, height).unwrap();
443
444        // Inverse transform
445        let reconstructed = transform.inverse(&coeffs, width, height).unwrap();
446
447        // Check perfect reconstruction
448        for (orig, recon) in image.iter().zip(reconstructed.iter()) {
449            assert_eq!(orig, recon, "Perfect reconstruction failed");
450        }
451    }
452}