Skip to main content

openjph_core/transform/
mod.rs

1//! Wavelet and color transforms (DWT 5/3, 9/7, RCT, ICT).
2//!
3//! Port of `ojph_transform.h/cpp` and `ojph_colour.h/cpp`.
4
5pub(crate) mod colour;
6pub(crate) mod simd;
7pub(crate) mod wavelet;
8
9use std::sync::OnceLock;
10
11use crate::mem::LineBuf;
12
13// =========================================================================
14// Lifting step — port of C++ `union lifting_step`
15// =========================================================================
16
17/// Reversible lifting step parameters (5/3 DWT).
18#[derive(Debug, Clone, Copy, Default)]
19pub struct RevLiftingStep {
20    /// Lifting coefficient (Aatk).
21    pub a: i16,
22    /// Additive residue (Batk).
23    pub b: i16,
24    /// Power-of-2 shift (Eatk).
25    pub e: u8,
26}
27
28/// Irreversible lifting step parameters (9/7 DWT).
29#[derive(Debug, Clone, Copy, Default)]
30pub struct IrvLiftingStep {
31    /// Lifting coefficient (Aatk).
32    pub a: f32,
33}
34
35/// A single lifting step — either reversible (integer) or irreversible (float).
36#[derive(Debug, Clone, Copy)]
37pub enum LiftingStep {
38    Reversible(RevLiftingStep),
39    Irreversible(IrvLiftingStep),
40}
41
42impl Default for LiftingStep {
43    fn default() -> Self {
44        LiftingStep::Reversible(RevLiftingStep::default())
45    }
46}
47
48impl LiftingStep {
49    /// Access as reversible step. Panics if irreversible.
50    #[inline]
51    pub fn rev(&self) -> &RevLiftingStep {
52        match self {
53            LiftingStep::Reversible(r) => r,
54            _ => panic!("expected reversible lifting step"),
55        }
56    }
57
58    /// Access as irreversible step. Panics if reversible.
59    #[inline]
60    pub fn irv(&self) -> &IrvLiftingStep {
61        match self {
62            LiftingStep::Irreversible(i) => i,
63            _ => panic!("expected irreversible lifting step"),
64        }
65    }
66}
67
68// =========================================================================
69// ParamAtk — port of C++ `struct param_atk`
70// =========================================================================
71
72/// Maximum number of inline lifting steps (matches C++ d_store[6]).
73const MAX_INLINE_STEPS: usize = 6;
74
75/// Arbitrary Transformation Kernel parameters.
76///
77/// Stores the lifting steps for one wavelet kernel (e.g., the standard 5/3 or
78/// 9/7 filter).
79#[derive(Debug, Clone)]
80pub struct ParamAtk {
81    /// ATK marker segment length.
82    pub latk: u16,
83    /// Satk — carries filter type information.
84    pub satk: u16,
85    /// Scaling factor K (irreversible only).
86    pub katk: f32,
87    /// Number of lifting steps.
88    pub natk: u8,
89    /// The lifting step coefficients.
90    pub steps: Vec<LiftingStep>,
91}
92
93impl Default for ParamAtk {
94    fn default() -> Self {
95        Self {
96            latk: 0,
97            satk: 0,
98            katk: 0.0,
99            natk: 0,
100            steps: Vec::with_capacity(MAX_INLINE_STEPS),
101        }
102    }
103}
104
105impl ParamAtk {
106    /// Returns the number of lifting steps.
107    #[inline]
108    pub fn get_num_steps(&self) -> u32 {
109        self.natk as u32
110    }
111
112    /// Returns a reference to the `s`-th lifting step.
113    #[inline]
114    pub fn get_step(&self, s: u32) -> &LiftingStep {
115        debug_assert!((s as u8) < self.natk);
116        &self.steps[s as usize]
117    }
118
119    /// Returns the scaling factor K (irreversible kernels).
120    #[inline]
121    pub fn get_k(&self) -> f32 {
122        self.katk
123    }
124
125    /// Initializes for the standard irreversible 9/7 wavelet.
126    #[allow(clippy::excessive_precision)]
127    pub fn init_irv97(&mut self) {
128        // Match OpenJPH's stored step order in param_atk::init_irv97().
129        const DELTA: f32 = 0.443_506_85; // step 0
130        const GAMMA: f32 = 0.882_911_08; // step 1
131        const BETA: f32 = -0.052_980_118; // step 2
132        const ALPHA: f32 = -1.586_134_3; // step 3
133        const K: f32 = 1.230_174_1;
134
135        self.natk = 4;
136        self.katk = K;
137        self.steps.clear();
138        self.steps
139            .push(LiftingStep::Irreversible(IrvLiftingStep { a: DELTA }));
140        self.steps
141            .push(LiftingStep::Irreversible(IrvLiftingStep { a: GAMMA }));
142        self.steps
143            .push(LiftingStep::Irreversible(IrvLiftingStep { a: BETA }));
144        self.steps
145            .push(LiftingStep::Irreversible(IrvLiftingStep { a: ALPHA }));
146    }
147
148    /// Initializes for the standard reversible 5/3 wavelet.
149    pub fn init_rev53(&mut self) {
150        // Match OpenJPH's stored step order in param_atk::init_rev53().
151        self.natk = 2;
152        self.katk = 0.0;
153        self.steps.clear();
154        self.steps
155            .push(LiftingStep::Reversible(RevLiftingStep { a: 1, b: 2, e: 2 }));
156        self.steps.push(LiftingStep::Reversible(RevLiftingStep {
157            a: -1,
158            b: 1,
159            e: 1,
160        }));
161    }
162}
163
164// =========================================================================
165// Function pointer types — wavelet transforms
166// =========================================================================
167
168/// Reversible / irreversible vertical lifting step.
169pub type RevVertStepFn = fn(
170    s: &LiftingStep,
171    sig: &LineBuf,
172    other: &LineBuf,
173    aug: &mut LineBuf,
174    repeat: u32,
175    synthesis: bool,
176);
177
178/// Reversible horizontal analysis (forward DWT, split into low/high).
179pub type RevHorzAnaFn = fn(
180    atk: &ParamAtk,
181    ldst: &mut LineBuf,
182    hdst: &mut LineBuf,
183    src: &LineBuf,
184    width: u32,
185    even: bool,
186);
187
188/// Reversible horizontal synthesis (inverse DWT, merge low/high).
189pub type RevHorzSynFn = fn(
190    atk: &ParamAtk,
191    dst: &mut LineBuf,
192    lsrc: &mut LineBuf,
193    hsrc: &mut LineBuf,
194    width: u32,
195    even: bool,
196);
197
198/// Irreversible vertical lifting step (same shape as reversible).
199pub type IrvVertStepFn = fn(
200    s: &LiftingStep,
201    sig: &LineBuf,
202    other: &LineBuf,
203    aug: &mut LineBuf,
204    repeat: u32,
205    synthesis: bool,
206);
207
208/// Multiply line by normalization constant K.
209pub type IrvVertTimesKFn = fn(k: f32, aug: &mut LineBuf, repeat: u32);
210
211/// Irreversible horizontal analysis.
212pub type IrvHorzAnaFn = fn(
213    atk: &ParamAtk,
214    ldst: &mut LineBuf,
215    hdst: &mut LineBuf,
216    src: &LineBuf,
217    width: u32,
218    even: bool,
219);
220
221/// Irreversible horizontal synthesis.
222pub type IrvHorzSynFn = fn(
223    atk: &ParamAtk,
224    dst: &mut LineBuf,
225    lsrc: &mut LineBuf,
226    hsrc: &mut LineBuf,
227    width: u32,
228    even: bool,
229);
230
231/// Runtime-dispatched wavelet transform function table.
232pub struct WaveletTransformFns {
233    pub rev_vert_step: RevVertStepFn,
234    pub rev_horz_ana: RevHorzAnaFn,
235    pub rev_horz_syn: RevHorzSynFn,
236    pub irv_vert_step: IrvVertStepFn,
237    pub irv_vert_times_k: IrvVertTimesKFn,
238    pub irv_horz_ana: IrvHorzAnaFn,
239    pub irv_horz_syn: IrvHorzSynFn,
240}
241
242// =========================================================================
243// Function pointer types — colour transforms
244// =========================================================================
245
246/// Reversible sample conversion (integer shift).
247pub type RevConvertFn = fn(
248    src_line: &LineBuf,
249    src_line_offset: u32,
250    dst_line: &mut LineBuf,
251    dst_line_offset: u32,
252    shift: i64,
253    width: u32,
254);
255
256/// Irreversible: float → integer quantization.
257pub type IrvConvertToIntegerFn = fn(
258    src_line: &LineBuf,
259    dst_line: &mut LineBuf,
260    dst_line_offset: u32,
261    bit_depth: u32,
262    is_signed: bool,
263    width: u32,
264);
265
266/// Irreversible: integer → float dequantization.
267pub type IrvConvertToFloatFn = fn(
268    src_line: &LineBuf,
269    src_line_offset: u32,
270    dst_line: &mut LineBuf,
271    bit_depth: u32,
272    is_signed: bool,
273    width: u32,
274);
275
276/// RCT forward/backward (integer buffers).
277pub type RctFn = fn(
278    c0: &LineBuf,
279    c1: &LineBuf,
280    c2: &LineBuf,
281    d0: &mut LineBuf,
282    d1: &mut LineBuf,
283    d2: &mut LineBuf,
284    repeat: u32,
285);
286
287/// ICT forward/backward (float buffers).
288pub type IctFn = fn(
289    c0: &[f32],
290    c1: &[f32],
291    c2: &[f32],
292    d0: &mut [f32],
293    d1: &mut [f32],
294    d2: &mut [f32],
295    repeat: u32,
296);
297
298/// Runtime-dispatched colour transform function table.
299pub struct ColourTransformFns {
300    pub rev_convert: RevConvertFn,
301    pub rev_convert_nlt_type3: RevConvertFn,
302    pub irv_convert_to_integer: IrvConvertToIntegerFn,
303    pub irv_convert_to_float: IrvConvertToFloatFn,
304    pub irv_convert_to_integer_nlt_type3: IrvConvertToIntegerFn,
305    pub irv_convert_to_float_nlt_type3: IrvConvertToFloatFn,
306    pub rct_forward: RctFn,
307    pub rct_backward: RctFn,
308    pub ict_forward: IctFn,
309    pub ict_backward: IctFn,
310}
311
312// =========================================================================
313// Runtime dispatch — OnceLock singletons
314// =========================================================================
315
316static WAVELET_FNS: OnceLock<WaveletTransformFns> = OnceLock::new();
317static COLOUR_FNS: OnceLock<ColourTransformFns> = OnceLock::new();
318
319/// Initializes wavelet transform function pointers (called once, lazily).
320pub fn init_wavelet_transform_functions() -> &'static WaveletTransformFns {
321    WAVELET_FNS.get_or_init(|| {
322        // Start with generic implementations.
323        let mut fns = WaveletTransformFns {
324            rev_vert_step: wavelet::gen_rev_vert_step,
325            rev_horz_ana: wavelet::gen_rev_horz_ana,
326            rev_horz_syn: wavelet::gen_rev_horz_syn,
327            irv_vert_step: wavelet::gen_irv_vert_step,
328            irv_vert_times_k: wavelet::gen_irv_vert_times_k,
329            irv_horz_ana: wavelet::gen_irv_horz_ana,
330            irv_horz_syn: wavelet::gen_irv_horz_syn,
331        };
332
333        // SIMD dispatch: select the best available implementation.
334        #[cfg(target_arch = "aarch64")]
335        {
336            // aarch64 always has NEON
337            fns.rev_vert_step = simd::neon::neon_rev_vert_step;
338            fns.irv_vert_step = simd::neon::neon_irv_vert_step;
339            fns.irv_vert_times_k = simd::neon::neon_irv_vert_times_k;
340        }
341
342        #[cfg(target_arch = "x86_64")]
343        {
344            if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
345                fns.rev_vert_step = simd::x86::avx2_rev_vert_step;
346                fns.irv_vert_step = simd::x86::avx2_irv_vert_step;
347                fns.irv_vert_times_k = simd::x86::avx2_irv_vert_times_k;
348            } else if is_x86_feature_detected!("sse2") {
349                fns.rev_vert_step = simd::x86::sse2_rev_vert_step;
350                fns.irv_vert_step = simd::x86::sse2_irv_vert_step;
351                fns.irv_vert_times_k = simd::x86::sse2_irv_vert_times_k;
352            }
353        }
354
355        fns
356    })
357}
358
359/// Initializes colour transform function pointers (called once, lazily).
360pub fn init_colour_transform_functions() -> &'static ColourTransformFns {
361    COLOUR_FNS.get_or_init(|| {
362        // Start with generic implementations.
363        let mut fns = ColourTransformFns {
364            rev_convert: colour::gen_rev_convert,
365            rev_convert_nlt_type3: colour::gen_rev_convert_nlt_type3,
366            irv_convert_to_integer: colour::gen_irv_convert_to_integer,
367            irv_convert_to_float: colour::gen_irv_convert_to_float,
368            irv_convert_to_integer_nlt_type3: colour::gen_irv_convert_to_integer_nlt_type3,
369            irv_convert_to_float_nlt_type3: colour::gen_irv_convert_to_float_nlt_type3,
370            rct_forward: colour::gen_rct_forward,
371            rct_backward: colour::gen_rct_backward,
372            ict_forward: colour::gen_ict_forward,
373            ict_backward: colour::gen_ict_backward,
374        };
375
376        // SIMD dispatch for colour transforms.
377        #[cfg(target_arch = "aarch64")]
378        {
379            fns.rct_forward = simd::neon_colour::neon_rct_forward;
380            fns.rct_backward = simd::neon_colour::neon_rct_backward;
381            fns.ict_forward = simd::neon_colour::neon_ict_forward;
382            fns.ict_backward = simd::neon_colour::neon_ict_backward;
383        }
384
385        #[cfg(target_arch = "x86_64")]
386        {
387            if is_x86_feature_detected!("sse2") {
388                fns.rct_forward = simd::x86_colour::sse2_rct_forward;
389                fns.rct_backward = simd::x86_colour::sse2_rct_backward;
390                fns.ict_forward = simd::x86_colour::sse2_ict_forward;
391                fns.ict_backward = simd::x86_colour::sse2_ict_backward;
392            }
393        }
394
395        fns
396    })
397}
398
399/// Returns a reference to the lazily-initialized wavelet function table.
400#[inline]
401pub fn wavelet_fns() -> &'static WaveletTransformFns {
402    init_wavelet_transform_functions()
403}
404
405/// Returns a reference to the lazily-initialized colour function table.
406#[inline]
407pub fn colour_fns() -> &'static ColourTransformFns {
408    init_colour_transform_functions()
409}
410
411#[cfg(test)]
412mod tests {
413    use super::{LiftingStep, ParamAtk};
414
415    #[test]
416    fn rev53_step_order_matches_openjph() {
417        let mut atk = ParamAtk::default();
418        atk.init_rev53();
419        assert_eq!(atk.get_num_steps(), 2);
420
421        match atk.get_step(0) {
422            LiftingStep::Reversible(step) => assert_eq!((step.a, step.b, step.e), (1, 2, 2)),
423            _ => panic!("expected reversible step 0"),
424        }
425        match atk.get_step(1) {
426            LiftingStep::Reversible(step) => assert_eq!((step.a, step.b, step.e), (-1, 1, 1)),
427            _ => panic!("expected reversible step 1"),
428        }
429    }
430
431    #[test]
432    fn irv97_step_order_matches_openjph() {
433        let mut atk = ParamAtk::default();
434        atk.init_irv97();
435        assert_eq!(atk.get_num_steps(), 4);
436
437        let mut got = Vec::new();
438        for idx in 0..atk.get_num_steps() {
439            match atk.get_step(idx) {
440                LiftingStep::Irreversible(step) => got.push(step.a),
441                _ => panic!("expected irreversible step"),
442            }
443        }
444
445        let expected = [0.443_506_85, 0.882_911_1, -0.052_980_118, -1.586_134_3];
446        for (actual, expected) in got.into_iter().zip(expected) {
447            assert!((actual - expected).abs() < 1e-7);
448        }
449    }
450}