Skip to main content

j2k_cuda/
direct_plan.rs

1use std::collections::HashMap;
2
3use j2k_core::PixelFormat;
4use j2k_native::{
5    idwt_band_index, J2kDirectGrayscalePlan, J2kDirectGrayscaleStep, J2kDirectIdwtStep,
6    J2kDirectStoreStep, J2kRect, J2kWaveletTransform,
7};
8
9use crate::Error;
10
11const CLASSIC_J2K_NOT_CUDA_HTJ2K: &str =
12    "strict CUDA codestream decode only accepts HTJ2K direct-plan subbands";
13const EMPTY_HTJ2K_PLAN: &str = "strict CUDA HTJ2K plan contains no HT code blocks";
14const MIXED_TRANSFORMS_UNSUPPORTED: &str = "strict CUDA HTJ2K plan contains mixed DWT transforms";
15const PLAN_PAYLOAD_TOO_LARGE: &str = "strict CUDA HTJ2K plan payload is too large";
16const PLAN_BLOCK_LENGTH_MISMATCH: &str =
17    "strict CUDA HTJ2K plan block lengths do not match payload bytes";
18const PLAN_OUTPUT_RECT_MISMATCH: &str =
19    "strict CUDA HTJ2K plan store does not fit the requested output rectangle";
20const ROI_MAXSHIFT_UNSUPPORTED: &str =
21    "strict CUDA HTJ2K plan does not support ROI maxshift decode";
22
23/// CUDA-side DWT transform selector for a flat HTJ2K plan.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(u32)]
26pub enum CudaHtj2kTransform {
27    /// Reversible 5/3 transform.
28    Reversible53,
29    /// Irreversible 9/7 transform.
30    Irreversible97,
31}
32
33/// Stable CUDA-side identifier for a direct-plan coefficient band.
34pub type CudaHtj2kBandId = u32;
35
36impl CudaHtj2kTransform {
37    pub(crate) fn from_native(value: J2kWaveletTransform) -> Self {
38        match value {
39            J2kWaveletTransform::Reversible53 => Self::Reversible53,
40            J2kWaveletTransform::Irreversible97 => Self::Irreversible97,
41        }
42    }
43}
44
45/// Flat POD HTJ2K code-block metadata consumed by CUDA kernels.
46#[derive(Debug, Clone, Copy, PartialEq)]
47#[repr(C)]
48pub struct CudaHtj2kCodeBlock {
49    /// Index of the parent sub-band in [`CudaHtj2kDecodePlan::subbands`].
50    pub subband_index: u32,
51    /// Byte offset into [`CudaHtj2kDecodePlan::payload`].
52    pub payload_offset: u64,
53    /// Total payload byte length for this code block.
54    pub payload_len: u32,
55    /// Cleanup segment length in bytes.
56    pub cleanup_length: u32,
57    /// Refinement segment length in bytes.
58    pub refinement_length: u32,
59    /// X offset within the target sub-band coefficient buffer.
60    pub output_x: u32,
61    /// Y offset within the target sub-band coefficient buffer.
62    pub output_y: u32,
63    /// Code-block width in samples.
64    pub width: u32,
65    /// Code-block height in samples.
66    pub height: u32,
67    /// Output row stride, in samples.
68    pub output_stride: u32,
69    /// Missing most-significant bit planes.
70    pub missing_bit_planes: u8,
71    /// Number of coding passes present.
72    pub number_of_coding_passes: u8,
73    /// Total coded bitplanes for the parent sub-band.
74    pub num_bitplanes: u8,
75    /// Nonzero when vertically causal context was enabled.
76    pub stripe_causal: u8,
77    /// Dequantization step to apply to decoded coefficients.
78    pub dequantization_step: f32,
79}
80
81/// Flat POD sub-band geometry consumed by CUDA kernels.
82#[derive(Debug, Clone, Copy, PartialEq)]
83#[repr(C)]
84pub struct CudaHtj2kSubband {
85    /// Stable CUDA direct-plan band id.
86    pub band_id: CudaHtj2kBandId,
87    /// Absolute x0 coordinate in component space.
88    pub x0: u32,
89    /// Absolute y0 coordinate in component space.
90    pub y0: u32,
91    /// Absolute x1 coordinate in component space.
92    pub x1: u32,
93    /// Absolute y1 coordinate in component space.
94    pub y1: u32,
95    /// Sub-band width in samples.
96    pub width: u32,
97    /// Sub-band height in samples.
98    pub height: u32,
99    /// First code-block index for this sub-band.
100    pub code_block_start: u32,
101    /// Number of code blocks for this sub-band.
102    pub code_block_count: u32,
103}
104
105/// Flat POD IDWT step consumed by CUDA kernels.
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107#[repr(C)]
108pub struct CudaHtj2kIdwtStep {
109    /// Stable identifier of the output coefficient band produced by this step.
110    pub output_band_id: CudaHtj2kBandId,
111    /// DWT transform to apply.
112    pub transform: CudaHtj2kTransform,
113    /// Output rectangle.
114    pub rect: CudaHtj2kRect,
115    /// LL input band id.
116    pub ll_band_id: CudaHtj2kBandId,
117    /// LL input rectangle.
118    pub ll_rect: CudaHtj2kRect,
119    /// HL input band id.
120    pub hl_band_id: CudaHtj2kBandId,
121    /// HL input rectangle.
122    pub hl_rect: CudaHtj2kRect,
123    /// LH input band id.
124    pub lh_band_id: CudaHtj2kBandId,
125    /// LH input rectangle.
126    pub lh_rect: CudaHtj2kRect,
127    /// HH input band id.
128    pub hh_band_id: CudaHtj2kBandId,
129    /// HH input rectangle.
130    pub hh_rect: CudaHtj2kRect,
131}
132
133/// Flat POD store step consumed by CUDA kernels.
134#[derive(Debug, Clone, Copy, PartialEq)]
135#[repr(C)]
136pub struct CudaHtj2kStoreStep {
137    /// Stable identifier of the input coefficient band.
138    pub input_band_id: CudaHtj2kBandId,
139    /// Source rectangle.
140    pub input_rect: CudaHtj2kRect,
141    /// Source x offset.
142    pub source_x: u32,
143    /// Source y offset.
144    pub source_y: u32,
145    /// Number of samples copied per row.
146    pub copy_width: u32,
147    /// Number of rows copied.
148    pub copy_height: u32,
149    /// Destination row width.
150    pub output_width: u32,
151    /// Destination height.
152    pub output_height: u32,
153    /// Destination x offset.
154    pub output_x: u32,
155    /// Destination y offset.
156    pub output_y: u32,
157    /// Constant level-shift addend.
158    pub addend: f32,
159}
160
161/// Flat POD rectangle used inside CUDA HTJ2K plan metadata.
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
163#[repr(C)]
164pub struct CudaHtj2kRect {
165    /// Inclusive left coordinate.
166    pub x0: u32,
167    /// Inclusive top coordinate.
168    pub y0: u32,
169    /// Exclusive right coordinate.
170    pub x1: u32,
171    /// Exclusive bottom coordinate.
172    pub y1: u32,
173}
174
175/// Flat CUDA HTJ2K decode plan.
176#[derive(Debug, Clone)]
177pub struct CudaHtj2kDecodePlan {
178    dimensions: (u32, u32),
179    bit_depth: u8,
180    output_format: PixelFormat,
181    output_origin: (u32, u32),
182    transform: CudaHtj2kTransform,
183    payload: Vec<u8>,
184    code_blocks: Vec<CudaHtj2kCodeBlock>,
185    subbands: Vec<CudaHtj2kSubband>,
186    idwt_steps: Vec<CudaHtj2kIdwtStep>,
187    store_steps: Vec<CudaHtj2kStoreStep>,
188}
189
190impl CudaHtj2kDecodePlan {
191    pub(crate) fn from_grayscale_direct_plan(
192        plan: &J2kDirectGrayscalePlan,
193        output_format: PixelFormat,
194        output_origin: (u32, u32),
195    ) -> Result<Self, Error> {
196        Self::from_grayscale_direct_plan_region(plan, output_format, output_origin, plan.dimensions)
197    }
198
199    pub(crate) fn from_grayscale_direct_plan_region(
200        plan: &J2kDirectGrayscalePlan,
201        output_format: PixelFormat,
202        output_origin: (u32, u32),
203        output_dimensions: (u32, u32),
204    ) -> Result<Self, Error> {
205        let capacity_hint = cuda_plan_capacity_hint(plan)?;
206        let mut payload = Vec::with_capacity(capacity_hint.payload_bytes);
207        let mut code_blocks = Vec::with_capacity(capacity_hint.code_blocks);
208        let mut subbands = Vec::with_capacity(capacity_hint.subbands);
209        let mut idwt_steps = Vec::with_capacity(capacity_hint.idwt_steps);
210        let mut store_steps = Vec::with_capacity(capacity_hint.store_steps);
211        let mut transform = None;
212        let mut saw_classic = false;
213        let required_regions = if output_origin == (0, 0) && output_dimensions == plan.dimensions {
214            None
215        } else {
216            Some(required_regions_for_direct_plan(plan)?)
217        };
218
219        for step in &plan.steps {
220            match step {
221                J2kDirectGrayscaleStep::HtSubBand(subband) => {
222                    let subband_index = u32::try_from(subbands.len()).map_err(|_| {
223                        Error::UnsupportedCudaRequest {
224                            reason: PLAN_PAYLOAD_TOO_LARGE,
225                        }
226                    })?;
227                    let code_block_start = u32::try_from(code_blocks.len()).map_err(|_| {
228                        Error::UnsupportedCudaRequest {
229                            reason: PLAN_PAYLOAD_TOO_LARGE,
230                        }
231                    })?;
232                    for job in &subband.jobs {
233                        let payload_offset = u64::try_from(payload.len()).map_err(|_| {
234                            Error::UnsupportedCudaRequest {
235                                reason: PLAN_PAYLOAD_TOO_LARGE,
236                            }
237                        })?;
238                        let payload_len = u32::try_from(job.data.len()).map_err(|_| {
239                            Error::UnsupportedCudaRequest {
240                                reason: PLAN_PAYLOAD_TOO_LARGE,
241                            }
242                        })?;
243                        let expected_len = job
244                            .cleanup_length
245                            .checked_add(job.refinement_length)
246                            .ok_or(Error::UnsupportedCudaRequest {
247                                reason: PLAN_BLOCK_LENGTH_MISMATCH,
248                            })?;
249                        if expected_len != payload_len {
250                            return Err(Error::UnsupportedCudaRequest {
251                                reason: PLAN_BLOCK_LENGTH_MISMATCH,
252                            });
253                        }
254                        let output_stride = u32::try_from(job.output_stride).map_err(|_| {
255                            Error::UnsupportedCudaRequest {
256                                reason: PLAN_PAYLOAD_TOO_LARGE,
257                            }
258                        })?;
259                        if let Some(required_regions) = &required_regions {
260                            if !required_regions
261                                .get(&subband.band_id)
262                                .is_some_and(|required| {
263                                    required.intersects(
264                                        job.output_x,
265                                        job.output_y,
266                                        job.width,
267                                        job.height,
268                                    )
269                                })
270                            {
271                                continue;
272                            }
273                        }
274                        if job.roi_shift != 0 {
275                            return Err(Error::UnsupportedCudaRequest {
276                                reason: ROI_MAXSHIFT_UNSUPPORTED,
277                            });
278                        }
279                        payload.extend_from_slice(&job.data);
280                        code_blocks.push(CudaHtj2kCodeBlock {
281                            subband_index,
282                            payload_offset,
283                            payload_len,
284                            cleanup_length: job.cleanup_length,
285                            refinement_length: job.refinement_length,
286                            output_x: job.output_x,
287                            output_y: job.output_y,
288                            width: job.width,
289                            height: job.height,
290                            output_stride,
291                            missing_bit_planes: job.missing_bit_planes,
292                            number_of_coding_passes: job.number_of_coding_passes,
293                            num_bitplanes: job.num_bitplanes,
294                            stripe_causal: u8::from(job.stripe_causal),
295                            dequantization_step: job.dequantization_step,
296                        });
297                    }
298                    let code_block_count = u32::try_from(
299                        code_blocks.len() - code_block_start as usize,
300                    )
301                    .map_err(|_| Error::UnsupportedCudaRequest {
302                        reason: PLAN_PAYLOAD_TOO_LARGE,
303                    })?;
304                    subbands.push(CudaHtj2kSubband {
305                        band_id: subband.band_id,
306                        x0: subband.rect.x0,
307                        y0: subband.rect.y0,
308                        x1: subband.rect.x1,
309                        y1: subband.rect.y1,
310                        width: subband.width,
311                        height: subband.height,
312                        code_block_start,
313                        code_block_count,
314                    });
315                }
316                J2kDirectGrayscaleStep::ClassicSubBand(_) => saw_classic = true,
317                J2kDirectGrayscaleStep::Idwt(step) => {
318                    let step_transform = CudaHtj2kTransform::from_native(step.transform);
319                    match transform {
320                        Some(existing) if existing != step_transform => {
321                            return Err(Error::UnsupportedCudaRequest {
322                                reason: MIXED_TRANSFORMS_UNSUPPORTED,
323                            });
324                        }
325                        Some(_) => {}
326                        None => transform = Some(step_transform),
327                    }
328                    idwt_steps.push(convert_idwt_step(*step));
329                }
330                J2kDirectGrayscaleStep::Store(step) => {
331                    store_steps.push(convert_store_step(*step, output_origin, output_dimensions)?);
332                }
333            }
334        }
335
336        if saw_classic {
337            return Err(Error::UnsupportedCudaRequest {
338                reason: CLASSIC_J2K_NOT_CUDA_HTJ2K,
339            });
340        }
341        if code_blocks.is_empty() {
342            return Err(Error::UnsupportedCudaRequest {
343                reason: EMPTY_HTJ2K_PLAN,
344            });
345        }
346
347        Ok(Self {
348            dimensions: output_dimensions,
349            bit_depth: plan.bit_depth,
350            output_format,
351            output_origin,
352            transform: transform.unwrap_or(CudaHtj2kTransform::Reversible53),
353            payload,
354            code_blocks,
355            subbands,
356            idwt_steps,
357            store_steps,
358        })
359    }
360
361    /// Output dimensions of the decoded surface.
362    pub fn dimensions(&self) -> (u32, u32) {
363        self.dimensions
364    }
365
366    /// Source component bit depth.
367    pub fn bit_depth(&self) -> u8 {
368        self.bit_depth
369    }
370
371    /// Output pixel format requested by the caller.
372    pub fn output_format(&self) -> PixelFormat {
373        self.output_format
374    }
375
376    /// Destination origin in the caller-visible output surface.
377    pub fn output_origin(&self) -> (u32, u32) {
378        self.output_origin
379    }
380
381    /// DWT transform used by IDWT kernels.
382    pub fn transform(&self) -> CudaHtj2kTransform {
383        self.transform
384    }
385
386    /// Contiguous cleanup/refinement payload bytes.
387    pub fn payload(&self) -> &[u8] {
388        &self.payload
389    }
390
391    #[cfg_attr(not(feature = "cuda-runtime"), allow(dead_code))]
392    pub(crate) fn append_payload_to_shared(
393        &mut self,
394        shared_payload: &mut Vec<u8>,
395    ) -> Result<(), Error> {
396        let base =
397            u64::try_from(shared_payload.len()).map_err(|_| Error::UnsupportedCudaRequest {
398                reason: PLAN_PAYLOAD_TOO_LARGE,
399            })?;
400        shared_payload
401            .try_reserve(self.payload.len())
402            .map_err(|_| Error::UnsupportedCudaRequest {
403                reason: PLAN_PAYLOAD_TOO_LARGE,
404            })?;
405        for block in &mut self.code_blocks {
406            block.payload_offset =
407                block
408                    .payload_offset
409                    .checked_add(base)
410                    .ok_or(Error::UnsupportedCudaRequest {
411                        reason: PLAN_PAYLOAD_TOO_LARGE,
412                    })?;
413        }
414        shared_payload.append(&mut self.payload);
415        Ok(())
416    }
417
418    #[cfg_attr(not(feature = "cuda-runtime"), allow(dead_code))]
419    pub(crate) fn rebase_payload_offsets(&mut self, base: u64) -> Result<(), Error> {
420        for block in &mut self.code_blocks {
421            block.payload_offset =
422                block
423                    .payload_offset
424                    .checked_add(base)
425                    .ok_or(Error::UnsupportedCudaRequest {
426                        reason: PLAN_PAYLOAD_TOO_LARGE,
427                    })?;
428        }
429        Ok(())
430    }
431
432    /// Flat code-block metadata.
433    pub fn code_blocks(&self) -> &[CudaHtj2kCodeBlock] {
434        &self.code_blocks
435    }
436
437    /// Flat sub-band metadata.
438    pub fn subbands(&self) -> &[CudaHtj2kSubband] {
439        &self.subbands
440    }
441
442    /// Flat IDWT step metadata.
443    pub fn idwt_steps(&self) -> &[CudaHtj2kIdwtStep] {
444        &self.idwt_steps
445    }
446
447    /// Flat store step metadata.
448    pub fn store_steps(&self) -> &[CudaHtj2kStoreStep] {
449        &self.store_steps
450    }
451
452    /// Number of per-code-block decode dispatches implied by the plan.
453    pub fn dispatch_count_hint(&self) -> usize {
454        self.code_blocks.len()
455    }
456}
457
458#[derive(Debug, Default)]
459struct CudaPlanCapacityHint {
460    payload_bytes: usize,
461    code_blocks: usize,
462    subbands: usize,
463    idwt_steps: usize,
464    store_steps: usize,
465}
466
467fn cuda_plan_capacity_hint(plan: &J2kDirectGrayscalePlan) -> Result<CudaPlanCapacityHint, Error> {
468    let mut hint = CudaPlanCapacityHint::default();
469    for step in &plan.steps {
470        match step {
471            J2kDirectGrayscaleStep::HtSubBand(subband) => {
472                hint.subbands = hint.subbands.saturating_add(1);
473                hint.code_blocks = hint.code_blocks.checked_add(subband.jobs.len()).ok_or(
474                    Error::UnsupportedCudaRequest {
475                        reason: PLAN_PAYLOAD_TOO_LARGE,
476                    },
477                )?;
478                for job in &subband.jobs {
479                    hint.payload_bytes = hint.payload_bytes.checked_add(job.data.len()).ok_or(
480                        Error::UnsupportedCudaRequest {
481                            reason: PLAN_PAYLOAD_TOO_LARGE,
482                        },
483                    )?;
484                }
485            }
486            J2kDirectGrayscaleStep::ClassicSubBand(_) => {}
487            J2kDirectGrayscaleStep::Idwt(_) => {
488                hint.idwt_steps = hint.idwt_steps.saturating_add(1);
489            }
490            J2kDirectGrayscaleStep::Store(_) => {
491                hint.store_steps = hint.store_steps.saturating_add(1);
492            }
493        }
494    }
495    Ok(hint)
496}
497
498fn convert_idwt_step(step: J2kDirectIdwtStep) -> CudaHtj2kIdwtStep {
499    CudaHtj2kIdwtStep {
500        output_band_id: step.output_band_id,
501        transform: CudaHtj2kTransform::from_native(step.transform),
502        rect: convert_rect(step.rect),
503        ll_band_id: step.ll_band_id,
504        ll_rect: convert_rect(step.ll),
505        hl_band_id: step.hl_band_id,
506        hl_rect: convert_rect(step.hl),
507        lh_band_id: step.lh_band_id,
508        lh_rect: convert_rect(step.lh),
509        hh_band_id: step.hh_band_id,
510        hh_rect: convert_rect(step.hh),
511    }
512}
513
514#[derive(Clone, Copy, Debug)]
515struct RequiredBandRegion {
516    x0: u32,
517    y0: u32,
518    x1: u32,
519    y1: u32,
520}
521
522impl RequiredBandRegion {
523    fn new(x0: u32, y0: u32, x1: u32, y1: u32) -> Option<Self> {
524        (x0 < x1 && y0 < y1).then_some(Self { x0, y0, x1, y1 })
525    }
526
527    fn expanded(self, margin: u32, width: u32, height: u32) -> Self {
528        Self {
529            x0: self.x0.saturating_sub(margin),
530            y0: self.y0.saturating_sub(margin),
531            x1: self.x1.saturating_add(margin).min(width),
532            y1: self.y1.saturating_add(margin).min(height),
533        }
534    }
535
536    const fn union(self, other: Self) -> Self {
537        Self {
538            x0: if self.x0 < other.x0 {
539                self.x0
540            } else {
541                other.x0
542            },
543            y0: if self.y0 < other.y0 {
544                self.y0
545            } else {
546                other.y0
547            },
548            x1: if self.x1 > other.x1 {
549                self.x1
550            } else {
551                other.x1
552            },
553            y1: if self.y1 > other.y1 {
554                self.y1
555            } else {
556                other.y1
557            },
558        }
559    }
560
561    fn intersects(self, x0: u32, y0: u32, width: u32, height: u32) -> bool {
562        let x1 = x0.saturating_add(width);
563        let y1 = y0.saturating_add(height);
564        self.x0 < x1 && x0 < self.x1 && self.y0 < y1 && y0 < self.y1
565    }
566}
567
568fn required_regions_for_direct_plan(
569    plan: &J2kDirectGrayscalePlan,
570) -> Result<HashMap<CudaHtj2kBandId, RequiredBandRegion>, Error> {
571    let mut required = HashMap::<CudaHtj2kBandId, RequiredBandRegion>::new();
572    for step in &plan.steps {
573        let J2kDirectGrayscaleStep::Store(store) = step else {
574            continue;
575        };
576        let source_right =
577            store
578                .source_x
579                .checked_add(store.copy_width)
580                .ok_or(Error::UnsupportedCudaRequest {
581                    reason: PLAN_OUTPUT_RECT_MISMATCH,
582                })?;
583        let source_bottom =
584            store
585                .source_y
586                .checked_add(store.copy_height)
587                .ok_or(Error::UnsupportedCudaRequest {
588                    reason: PLAN_OUTPUT_RECT_MISMATCH,
589                })?;
590        if let Some(region) =
591            RequiredBandRegion::new(store.source_x, store.source_y, source_right, source_bottom)
592        {
593            add_required_region(&mut required, store.input_band_id, region);
594        }
595    }
596
597    for step in plan.steps.iter().rev() {
598        let J2kDirectGrayscaleStep::Idwt(idwt) = step else {
599            continue;
600        };
601        let Some(output_region) = required.get(&idwt.output_band_id).copied() else {
602            continue;
603        };
604        let expanded = output_region.expanded(
605            idwt_required_output_margin(idwt.transform),
606            idwt.rect.width(),
607            idwt.rect.height(),
608        );
609        add_idwt_input_required_regions(&mut required, idwt, expanded);
610    }
611    Ok(required)
612}
613
614fn add_required_region(
615    required: &mut HashMap<CudaHtj2kBandId, RequiredBandRegion>,
616    band_id: CudaHtj2kBandId,
617    region: RequiredBandRegion,
618) {
619    required
620        .entry(band_id)
621        .and_modify(|existing| *existing = existing.union(region))
622        .or_insert(region);
623}
624
625const fn idwt_required_output_margin(transform: J2kWaveletTransform) -> u32 {
626    match transform {
627        J2kWaveletTransform::Reversible53 => 16,
628        J2kWaveletTransform::Irreversible97 => 40,
629    }
630}
631
632fn add_idwt_input_required_regions(
633    required: &mut HashMap<CudaHtj2kBandId, RequiredBandRegion>,
634    idwt: &J2kDirectIdwtStep,
635    output_region: RequiredBandRegion,
636) {
637    add_required_region(
638        required,
639        idwt.ll_band_id,
640        idwt_input_required_region(
641            output_region,
642            idwt.rect.x0,
643            idwt.rect.y0,
644            true,
645            true,
646            idwt.ll.width(),
647            idwt.ll.height(),
648        ),
649    );
650    add_required_region(
651        required,
652        idwt.hl_band_id,
653        idwt_input_required_region(
654            output_region,
655            idwt.rect.x0,
656            idwt.rect.y0,
657            false,
658            true,
659            idwt.hl.width(),
660            idwt.hl.height(),
661        ),
662    );
663    add_required_region(
664        required,
665        idwt.lh_band_id,
666        idwt_input_required_region(
667            output_region,
668            idwt.rect.x0,
669            idwt.rect.y0,
670            true,
671            false,
672            idwt.lh.width(),
673            idwt.lh.height(),
674        ),
675    );
676    add_required_region(
677        required,
678        idwt.hh_band_id,
679        idwt_input_required_region(
680            output_region,
681            idwt.rect.x0,
682            idwt.rect.y0,
683            false,
684            false,
685            idwt.hh.width(),
686            idwt.hh.height(),
687        ),
688    );
689}
690
691#[allow(clippy::fn_params_excessive_bools)]
692fn idwt_input_required_region(
693    output_region: RequiredBandRegion,
694    output_origin_x: u32,
695    output_origin_y: u32,
696    low_x: bool,
697    low_y: bool,
698    band_width: u32,
699    band_height: u32,
700) -> RequiredBandRegion {
701    let x0 = idwt_band_index(output_origin_x, output_region.x0, low_x);
702    let x1 = idwt_band_index(output_origin_x, output_region.x1 - 1, low_x).saturating_add(1);
703    let y0 = idwt_band_index(output_origin_y, output_region.y0, low_y);
704    let y1 = idwt_band_index(output_origin_y, output_region.y1 - 1, low_y).saturating_add(1);
705    RequiredBandRegion {
706        x0: x0.min(band_width),
707        y0: y0.min(band_height),
708        x1: x1.min(band_width),
709        y1: y1.min(band_height),
710    }
711}
712
713fn convert_store_step(
714    step: J2kDirectStoreStep,
715    output_origin: (u32, u32),
716    output_dimensions: (u32, u32),
717) -> Result<CudaHtj2kStoreStep, Error> {
718    if output_dimensions.0 == 0 || output_dimensions.1 == 0 {
719        return Err(Error::UnsupportedCudaRequest {
720            reason: PLAN_OUTPUT_RECT_MISMATCH,
721        });
722    }
723    let region_end_x =
724        output_origin
725            .0
726            .checked_add(output_dimensions.0)
727            .ok_or(Error::UnsupportedCudaRequest {
728                reason: PLAN_OUTPUT_RECT_MISMATCH,
729            })?;
730    let region_end_y =
731        output_origin
732            .1
733            .checked_add(output_dimensions.1)
734            .ok_or(Error::UnsupportedCudaRequest {
735                reason: PLAN_OUTPUT_RECT_MISMATCH,
736            })?;
737    let store_end_x =
738        step.output_x
739            .checked_add(step.copy_width)
740            .ok_or(Error::UnsupportedCudaRequest {
741                reason: PLAN_OUTPUT_RECT_MISMATCH,
742            })?;
743    let store_end_y =
744        step.output_y
745            .checked_add(step.copy_height)
746            .ok_or(Error::UnsupportedCudaRequest {
747                reason: PLAN_OUTPUT_RECT_MISMATCH,
748            })?;
749    if output_origin.0 < step.output_x
750        || output_origin.1 < step.output_y
751        || region_end_x > store_end_x
752        || region_end_y > store_end_y
753    {
754        return Err(Error::UnsupportedCudaRequest {
755            reason: PLAN_OUTPUT_RECT_MISMATCH,
756        });
757    }
758    let source_x = step
759        .source_x
760        .checked_add(output_origin.0 - step.output_x)
761        .ok_or(Error::UnsupportedCudaRequest {
762            reason: PLAN_OUTPUT_RECT_MISMATCH,
763        })?;
764    let source_y = step
765        .source_y
766        .checked_add(output_origin.1 - step.output_y)
767        .ok_or(Error::UnsupportedCudaRequest {
768            reason: PLAN_OUTPUT_RECT_MISMATCH,
769        })?;
770    Ok(CudaHtj2kStoreStep {
771        input_band_id: step.input_band_id,
772        input_rect: convert_rect(step.input_rect),
773        source_x,
774        source_y,
775        copy_width: output_dimensions.0,
776        copy_height: output_dimensions.1,
777        output_width: output_dimensions.0,
778        output_height: output_dimensions.1,
779        output_x: 0,
780        output_y: 0,
781        addend: step.addend,
782    })
783}
784
785fn convert_rect(rect: J2kRect) -> CudaHtj2kRect {
786    CudaHtj2kRect {
787        x0: rect.x0,
788        y0: rect.y0,
789        x1: rect.x1,
790        y1: rect.y1,
791    }
792}
793
794#[cfg(test)]
795mod tests {
796    use super::*;
797    use j2k_core::CodecError;
798    use j2k_native::{HtOwnedCodeBlockBatchJob, HtOwnedSubBandPlan};
799
800    fn one_block_direct_plan(
801        cleanup_length: u32,
802        refinement_length: u32,
803        data: Vec<u8>,
804        output_stride: usize,
805    ) -> J2kDirectGrayscalePlan {
806        J2kDirectGrayscalePlan {
807            dimensions: (1, 1),
808            bit_depth: 8,
809            steps: vec![
810                J2kDirectGrayscaleStep::HtSubBand(HtOwnedSubBandPlan {
811                    band_id: 0,
812                    rect: J2kRect {
813                        x0: 0,
814                        y0: 0,
815                        x1: 1,
816                        y1: 1,
817                    },
818                    width: 1,
819                    height: 1,
820                    jobs: vec![HtOwnedCodeBlockBatchJob {
821                        output_x: 0,
822                        output_y: 0,
823                        data,
824                        cleanup_length,
825                        refinement_length,
826                        width: 1,
827                        height: 1,
828                        output_stride,
829                        missing_bit_planes: 0,
830                        number_of_coding_passes: 1,
831                        num_bitplanes: 8,
832                        roi_shift: 0,
833                        stripe_causal: false,
834                        strict: true,
835                        dequantization_step: 1.0,
836                    }],
837                }),
838                J2kDirectGrayscaleStep::Store(J2kDirectStoreStep {
839                    input_band_id: 0,
840                    input_rect: J2kRect {
841                        x0: 0,
842                        y0: 0,
843                        x1: 1,
844                        y1: 1,
845                    },
846                    source_x: 0,
847                    source_y: 0,
848                    copy_width: 1,
849                    copy_height: 1,
850                    output_width: 1,
851                    output_height: 1,
852                    output_x: 0,
853                    output_y: 0,
854                    addend: 128.0,
855                }),
856            ],
857        }
858    }
859
860    fn one_block_plan(data: Vec<u8>) -> CudaHtj2kDecodePlan {
861        let payload_len = u32::try_from(data.len()).expect("fixture payload length");
862        let direct = one_block_direct_plan(payload_len, 0, data, 1);
863        CudaHtj2kDecodePlan::from_grayscale_direct_plan(&direct, PixelFormat::Gray8, (0, 0))
864            .expect("CUDA plan")
865    }
866
867    fn two_block_direct_plan() -> J2kDirectGrayscalePlan {
868        J2kDirectGrayscalePlan {
869            dimensions: (2, 1),
870            bit_depth: 8,
871            steps: vec![
872                J2kDirectGrayscaleStep::HtSubBand(HtOwnedSubBandPlan {
873                    band_id: 0,
874                    rect: J2kRect {
875                        x0: 0,
876                        y0: 0,
877                        x1: 2,
878                        y1: 1,
879                    },
880                    width: 2,
881                    height: 1,
882                    jobs: vec![
883                        HtOwnedCodeBlockBatchJob {
884                            output_x: 0,
885                            output_y: 0,
886                            data: vec![1],
887                            cleanup_length: 1,
888                            refinement_length: 0,
889                            width: 1,
890                            height: 1,
891                            output_stride: 2,
892                            missing_bit_planes: 0,
893                            number_of_coding_passes: 1,
894                            num_bitplanes: 8,
895                            roi_shift: 0,
896                            stripe_causal: false,
897                            strict: true,
898                            dequantization_step: 1.0,
899                        },
900                        HtOwnedCodeBlockBatchJob {
901                            output_x: 1,
902                            output_y: 0,
903                            data: vec![2],
904                            cleanup_length: 1,
905                            refinement_length: 0,
906                            width: 1,
907                            height: 1,
908                            output_stride: 2,
909                            missing_bit_planes: 0,
910                            number_of_coding_passes: 1,
911                            num_bitplanes: 8,
912                            roi_shift: 0,
913                            stripe_causal: false,
914                            strict: true,
915                            dequantization_step: 1.0,
916                        },
917                    ],
918                }),
919                J2kDirectGrayscaleStep::Store(J2kDirectStoreStep {
920                    input_band_id: 0,
921                    input_rect: J2kRect {
922                        x0: 0,
923                        y0: 0,
924                        x1: 2,
925                        y1: 1,
926                    },
927                    source_x: 0,
928                    source_y: 0,
929                    copy_width: 2,
930                    copy_height: 1,
931                    output_width: 2,
932                    output_height: 1,
933                    output_x: 0,
934                    output_y: 0,
935                    addend: 128.0,
936                }),
937            ],
938        }
939    }
940
941    #[test]
942    fn append_payload_to_shared_offsets_blocks_and_drains_local_payload() {
943        let mut first = one_block_plan(vec![1, 2]);
944        let mut second = one_block_plan(vec![3, 4, 5]);
945        let mut shared = Vec::new();
946
947        first
948            .append_payload_to_shared(&mut shared)
949            .expect("append first payload");
950        second
951            .append_payload_to_shared(&mut shared)
952            .expect("append second payload");
953
954        assert_eq!(shared, vec![1, 2, 3, 4, 5]);
955        assert!(first.payload().is_empty());
956        assert!(second.payload().is_empty());
957        assert_eq!(first.code_blocks()[0].payload_offset, 0);
958        assert_eq!(second.code_blocks()[0].payload_offset, 2);
959    }
960
961    #[test]
962    fn rebase_payload_offsets_preserves_shared_payload_for_larger_batch() {
963        let mut plan = one_block_plan(vec![7, 8]);
964        let mut shared = Vec::new();
965        plan.append_payload_to_shared(&mut shared)
966            .expect("append local payload");
967
968        plan.rebase_payload_offsets(4096).expect("rebase payload");
969
970        assert_eq!(shared, vec![7, 8]);
971        assert_eq!(plan.code_blocks()[0].payload_offset, 4096);
972    }
973
974    #[test]
975    fn full_frame_plan_keeps_all_blocks_while_region_plan_prunes() {
976        let direct = two_block_direct_plan();
977        let full =
978            CudaHtj2kDecodePlan::from_grayscale_direct_plan(&direct, PixelFormat::Gray8, (0, 0))
979                .expect("full CUDA plan");
980        let mut region_direct = two_block_direct_plan();
981        let J2kDirectGrayscaleStep::Store(store) = &mut region_direct.steps[1] else {
982            panic!("expected store fixture");
983        };
984        store.source_x = 1;
985        store.copy_width = 1;
986        store.output_x = 1;
987        let region = CudaHtj2kDecodePlan::from_grayscale_direct_plan_region(
988            &region_direct,
989            PixelFormat::Gray8,
990            (1, 0),
991            (1, 1),
992        )
993        .expect("region CUDA plan");
994
995        assert_eq!(full.code_blocks().len(), 2);
996        assert_eq!(region.code_blocks().len(), 1);
997        assert_eq!(region.code_blocks()[0].output_x, 1);
998    }
999
1000    #[test]
1001    fn rejects_block_length_mismatch() {
1002        let direct = one_block_direct_plan(1, 2, vec![0xAA, 0xBB], 1);
1003
1004        let error =
1005            CudaHtj2kDecodePlan::from_grayscale_direct_plan(&direct, PixelFormat::Gray8, (0, 0))
1006                .expect_err("mismatched cleanup/refinement lengths must be rejected");
1007
1008        assert!(error.is_unsupported());
1009        assert!(
1010            error
1011                .to_string()
1012                .contains("block lengths do not match payload bytes"),
1013            "unexpected error: {error}"
1014        );
1015    }
1016
1017    #[test]
1018    fn rejects_roi_maxshift_jobs() {
1019        let mut direct = one_block_direct_plan(1, 0, vec![0xAA], 1);
1020        let J2kDirectGrayscaleStep::HtSubBand(subband) = &mut direct.steps[0] else {
1021            panic!("fixture starts with one HT sub-band");
1022        };
1023        subband.jobs[0].roi_shift = 7;
1024
1025        let error =
1026            CudaHtj2kDecodePlan::from_grayscale_direct_plan(&direct, PixelFormat::Gray8, (0, 0))
1027                .expect_err("ROI maxshift jobs must be rejected");
1028
1029        assert!(error.is_unsupported());
1030        assert!(
1031            error.to_string().contains("ROI maxshift decode"),
1032            "unexpected error: {error}"
1033        );
1034    }
1035
1036    #[test]
1037    fn rejects_output_stride_overflow() {
1038        let direct = one_block_direct_plan(1, 0, vec![0xAA], usize::MAX);
1039
1040        let error =
1041            CudaHtj2kDecodePlan::from_grayscale_direct_plan(&direct, PixelFormat::Gray8, (0, 0))
1042                .expect_err("unrepresentable output stride must be rejected");
1043
1044        assert!(error.is_unsupported());
1045    }
1046
1047    #[test]
1048    fn rejects_mixed_idwt_transforms() {
1049        let mut direct = one_block_direct_plan(1, 0, vec![0xAA], 1);
1050        let rect = J2kRect {
1051            x0: 0,
1052            y0: 0,
1053            x1: 1,
1054            y1: 1,
1055        };
1056        direct.steps.insert(
1057            1,
1058            J2kDirectGrayscaleStep::Idwt(J2kDirectIdwtStep {
1059                output_band_id: 4,
1060                rect,
1061                transform: J2kWaveletTransform::Reversible53,
1062                ll_band_id: 0,
1063                ll: rect,
1064                hl_band_id: 1,
1065                hl: rect,
1066                lh_band_id: 2,
1067                lh: rect,
1068                hh_band_id: 3,
1069                hh: rect,
1070            }),
1071        );
1072        direct.steps.insert(
1073            2,
1074            J2kDirectGrayscaleStep::Idwt(J2kDirectIdwtStep {
1075                output_band_id: 8,
1076                rect,
1077                transform: J2kWaveletTransform::Irreversible97,
1078                ll_band_id: 4,
1079                ll: rect,
1080                hl_band_id: 5,
1081                hl: rect,
1082                lh_band_id: 6,
1083                lh: rect,
1084                hh_band_id: 7,
1085                hh: rect,
1086            }),
1087        );
1088
1089        let error =
1090            CudaHtj2kDecodePlan::from_grayscale_direct_plan(&direct, PixelFormat::Gray8, (0, 0))
1091                .expect_err("mixed transforms must be rejected");
1092
1093        assert!(error.is_unsupported());
1094        assert!(
1095            error.to_string().contains("mixed DWT transforms"),
1096            "unexpected error: {error}"
1097        );
1098    }
1099
1100    #[test]
1101    fn region_plan_rejects_store_outside_output_rect() {
1102        let direct = one_block_direct_plan(1, 0, vec![0xAA], 1);
1103
1104        let error = CudaHtj2kDecodePlan::from_grayscale_direct_plan_region(
1105            &direct,
1106            PixelFormat::Gray8,
1107            (1, 1),
1108            (0, 0),
1109        )
1110        .expect_err("store outside compact output rectangle must be rejected");
1111
1112        assert!(error.is_unsupported());
1113        assert!(
1114            error
1115                .to_string()
1116                .contains("store does not fit the requested output rectangle"),
1117            "unexpected error: {error}"
1118        );
1119    }
1120}