Skip to main content

j2k_cuda/
decoder.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#[cfg(all(test, feature = "cuda-runtime"))]
4use core::cell::Cell;
5use core::convert::Infallible;
6#[cfg(feature = "cuda-runtime")]
7use std::sync::Arc;
8
9use j2k::{
10    adapter::device_plan::{DeviceDecodePlan, DeviceDecodeRequest},
11    J2kDecoder as CpuDecoder, J2kError, J2kScratchPool as CpuJ2kScratchPool, J2kView,
12};
13#[cfg(feature = "cuda-runtime")]
14use j2k_core::BackendKind;
15use j2k_core::{
16    submit_ready_device, BackendRequest, DecodeOutcome, Downscale, ImageCodec, ImageDecode,
17    ImageDecodeDevice, ImageDecodeSubmit, PixelFormat, ReadySubmission, Rect,
18};
19#[cfg(feature = "cuda-runtime")]
20use j2k_cuda_runtime::{
21    CudaBufferPool, CudaBufferPoolTakeTrace, CudaDeviceBuffer, CudaError, CudaHtj2kCleanupTarget,
22    CudaHtj2kCodeBlockJob, CudaHtj2kDecodeResources, CudaHtj2kDecodeTableResources,
23    CudaHtj2kDequantizeTarget, CudaJ2kIdwtJob, CudaJ2kIdwtTarget, CudaJ2kInverseMctJob,
24    CudaJ2kRect, CudaJ2kStoreGray16Job, CudaJ2kStoreGray8Job, CudaJ2kStoreRgb16Job,
25    CudaJ2kStoreRgb16MctJob, CudaJ2kStoreRgb8Job, CudaJ2kStoreRgb8MctJob,
26    CudaJ2kStoreRgb8MctTarget, CudaPooledDeviceBuffer, CudaQueuedExecution, CudaQueuedHtj2kCleanup,
27};
28use j2k_native::{DecodeSettings, DecoderContext as NativeDecoderContext, Image as NativeImage};
29
30#[cfg(feature = "cuda-runtime")]
31use crate::runtime::cuda_error;
32use crate::runtime::{validate_surface_request, wrap_cpu_staged_cuda_surface, wrap_surface};
33#[cfg(feature = "cuda-runtime")]
34use crate::surface::{cuda_range_storage, Storage};
35use crate::{
36    profile, CudaHtj2kDecodePlan, CudaHtj2kDecodeProfileDetail, CudaHtj2kProfileReport,
37    CudaSession, Error, Surface,
38};
39#[cfg(feature = "cuda-runtime")]
40use crate::{
41    CudaHtj2kBandId, CudaHtj2kIdwtStep, CudaHtj2kStoreStep, CudaHtj2kTransform, CudaSurfaceStats,
42    SurfaceResidency,
43};
44
45#[cfg(feature = "cuda-runtime")]
46const CUDA_HTJ2K_KERNELS_NOT_READY: &str =
47    "strict CUDA HTJ2K resident codestream decode kernels are not available in this build";
48#[cfg(feature = "cuda-runtime")]
49const CUDA_HTJ2K_OUTPUT_FORMAT_UNSUPPORTED: &str =
50    "strict CUDA HTJ2K resident decode currently accepts Gray8, Gray16, Rgb8, Rgba8, Rgb16, and Rgba16 output";
51#[cfg(feature = "cuda-runtime")]
52const CUDA_HTJ2K_PLAN_INVARIANT_FAILED: &str =
53    "strict CUDA HTJ2K resident decode plan has invalid internal ranges";
54#[cfg(feature = "cuda-runtime")]
55const CUDA_HTJ2K_STORE_UNSUPPORTED: &str =
56    "strict CUDA HTJ2K resident decode requires a single grayscale store step";
57#[cfg(feature = "cuda-runtime")]
58const CUDA_HTJ2K_BATCH_PAYLOAD_TOO_LARGE: &str =
59    "strict CUDA HTJ2K resident batch decode payload is too large";
60#[cfg(feature = "cuda-runtime")]
61const CUDA_IDWT_TRACE_ENV_VAR: &str = "J2K_CUDA_IDWT_TRACE";
62
63#[cfg(all(test, feature = "cuda-runtime"))]
64std::thread_local! {
65    static CUDA_HTJ2K_BATCH_DECODE_CALLS: Cell<usize> = const { Cell::new(0) };
66}
67
68#[cfg(all(test, feature = "cuda-runtime"))]
69pub(crate) fn testing_reset_cuda_htj2k_batch_decode_calls() {
70    CUDA_HTJ2K_BATCH_DECODE_CALLS.with(|calls| calls.set(0));
71}
72
73#[cfg(all(test, feature = "cuda-runtime"))]
74pub(crate) fn testing_cuda_htj2k_batch_decode_calls() -> usize {
75    CUDA_HTJ2K_BATCH_DECODE_CALLS.with(Cell::get)
76}
77
78#[cfg(any(test, feature = "cuda-runtime"))]
79#[derive(Clone, Copy, Debug, Eq, PartialEq)]
80struct CudaIdwtBatchHostTraceRow {
81    component_count: usize,
82    step_count: usize,
83    output_alloc_us: u128,
84    target_build_us: u128,
85    enqueue_us: u128,
86    output_take_count: usize,
87    output_pool_reuse_count: usize,
88    output_pool_alloc_count: usize,
89    output_pool_scanned_count: usize,
90    output_pool_max_free_count: usize,
91    output_requested_bytes: usize,
92}
93
94#[cfg(any(test, feature = "cuda-runtime"))]
95fn format_cuda_idwt_batch_host_trace_row(row: CudaIdwtBatchHostTraceRow) -> String {
96    format!(
97        "j2k_profile codec=j2k op=cuda_idwt_batch_host path=decode \
98         component_count={} step_count={} output_alloc_us={} target_build_us={} enqueue_us={} \
99         output_take_count={} output_pool_reuse_count={} output_pool_alloc_count={} \
100         output_pool_scanned_count={} output_pool_max_free_count={} output_requested_bytes={}",
101        row.component_count,
102        row.step_count,
103        row.output_alloc_us,
104        row.target_build_us,
105        row.enqueue_us,
106        row.output_take_count,
107        row.output_pool_reuse_count,
108        row.output_pool_alloc_count,
109        row.output_pool_scanned_count,
110        row.output_pool_max_free_count,
111        row.output_requested_bytes
112    )
113}
114
115#[cfg(feature = "cuda-runtime")]
116#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
117struct CudaIdwtOutputPoolTraceTotals {
118    take_count: usize,
119    reuse_count: usize,
120    alloc_count: usize,
121    scanned_count: usize,
122    max_free_count: usize,
123    requested_bytes: usize,
124}
125
126#[cfg(feature = "cuda-runtime")]
127impl CudaIdwtOutputPoolTraceTotals {
128    fn add_take(&mut self, trace: CudaBufferPoolTakeTrace) {
129        self.take_count = self.take_count.saturating_add(1);
130        if trace.reused {
131            self.reuse_count = self.reuse_count.saturating_add(1);
132        } else {
133            self.alloc_count = self.alloc_count.saturating_add(1);
134        }
135        self.scanned_count = self.scanned_count.saturating_add(trace.scanned_count);
136        self.max_free_count = self.max_free_count.max(trace.free_count_before);
137        self.requested_bytes = self.requested_bytes.saturating_add(trace.requested_len);
138    }
139}
140
141#[cfg(feature = "cuda-runtime")]
142fn cuda_idwt_trace_enabled() -> bool {
143    std::env::var_os(CUDA_IDWT_TRACE_ENV_VAR).is_some()
144}
145
146#[cfg(feature = "cuda-runtime")]
147fn elapsed_host_us(start: Option<std::time::Instant>) -> u128 {
148    start.map_or(0, |start| start.elapsed().as_micros())
149}
150
151/// CUDA-facing JPEG 2000 decoder wrapper.
152pub struct J2kDecoder<'a> {
153    inner: CpuDecoder<'a>,
154    pool: CpuJ2kScratchPool,
155}
156
157impl<'a> J2kDecoder<'a> {
158    /// Create a CUDA-facing decoder from compressed bytes.
159    pub fn new(input: &'a [u8]) -> Result<Self, Error> {
160        Ok(Self {
161            inner: CpuDecoder::new(input)?,
162            pool: CpuJ2kScratchPool::new(),
163        })
164    }
165
166    fn decode_to_surface_impl(
167        &mut self,
168        session: &mut CudaSession,
169        fmt: PixelFormat,
170        backend: BackendRequest,
171    ) -> Result<Surface, Error> {
172        validate_surface_request(backend)?;
173        if matches!(backend, BackendRequest::Cuda) {
174            return self.decode_to_cuda_resident_surface_impl(session, fmt);
175        }
176        let dims = self.inner.info().dimensions;
177        let stride = dims.0 as usize * fmt.bytes_per_pixel();
178        let mut out = vec![0u8; stride * dims.1 as usize];
179        if j2k_profile::gpu_route_profile_enabled() {
180            let request_s = format!("{backend:?}");
181            let fmt_s = format!("{fmt:?}");
182            let width_s = dims.0.to_string();
183            let height_s = dims.1.to_string();
184            j2k_profile::emit_gpu_route_profile(
185                "j2k",
186                "cuda",
187                &[
188                    ("op", "full"),
189                    ("request", request_s.as_str()),
190                    ("fmt", fmt_s.as_str()),
191                    ("width", width_s.as_str()),
192                    ("height", height_s.as_str()),
193                    ("decision", "cpu_decode_then_wrap"),
194                ],
195            );
196        }
197        self.inner
198            .decode_into_with_scratch(&mut self.pool, &mut out, stride, fmt)?;
199        wrap_surface(out, dims, fmt, backend, session)
200    }
201
202    fn decode_to_cuda_resident_surface_impl(
203        &mut self,
204        session: &mut CudaSession,
205        fmt: PixelFormat,
206    ) -> Result<Surface, Error> {
207        decode_to_cuda_resident_surface_impl(self, session, fmt)
208    }
209
210    fn decode_region_to_cuda_resident_surface_impl(
211        &mut self,
212        session: &mut CudaSession,
213        fmt: PixelFormat,
214        roi: Rect,
215    ) -> Result<Surface, Error> {
216        decode_region_to_cuda_resident_surface_impl(self, session, fmt, roi)
217    }
218
219    fn decode_scaled_to_cuda_resident_surface_impl(
220        &mut self,
221        session: &mut CudaSession,
222        fmt: PixelFormat,
223        scale: Downscale,
224    ) -> Result<Surface, Error> {
225        decode_scaled_to_cuda_resident_surface_impl(self, session, fmt, scale)
226    }
227
228    fn decode_region_scaled_to_cuda_resident_surface_impl(
229        &mut self,
230        session: &mut CudaSession,
231        fmt: PixelFormat,
232        roi: Rect,
233        scale: Downscale,
234    ) -> Result<Surface, Error> {
235        decode_region_scaled_to_cuda_resident_surface_impl(self, session, fmt, roi, scale)
236    }
237
238    fn decode_region_to_surface_impl(
239        &mut self,
240        session: &mut CudaSession,
241        fmt: PixelFormat,
242        roi: Rect,
243        backend: BackendRequest,
244    ) -> Result<Surface, Error> {
245        validate_surface_request(backend)?;
246        if matches!(backend, BackendRequest::Cuda) {
247            return self.decode_region_to_cuda_resident_surface_impl(session, fmt, roi);
248        }
249        let plan = DeviceDecodePlan::for_image(
250            self.inner.info().dimensions,
251            DeviceDecodeRequest::Region { roi },
252        )?;
253        let dims = plan.output_dims();
254        let stride = dims.0 as usize * fmt.bytes_per_pixel();
255        let mut out = vec![0u8; stride * dims.1 as usize];
256        self.inner
257            .decode_region_into(&mut self.pool, &mut out, stride, fmt, plan.source_rect())?;
258        wrap_surface(out, dims, fmt, backend, session)
259    }
260
261    fn decode_scaled_to_surface_impl(
262        &mut self,
263        session: &mut CudaSession,
264        fmt: PixelFormat,
265        scale: Downscale,
266        backend: BackendRequest,
267    ) -> Result<Surface, Error> {
268        validate_surface_request(backend)?;
269        if matches!(backend, BackendRequest::Cuda) {
270            return self.decode_scaled_to_cuda_resident_surface_impl(session, fmt, scale);
271        }
272        let dims = DeviceDecodePlan::for_image(
273            self.inner.info().dimensions,
274            DeviceDecodeRequest::Scaled { scale },
275        )?
276        .output_dims();
277        let stride = dims.0 as usize * fmt.bytes_per_pixel();
278        let mut out = vec![0u8; stride * dims.1 as usize];
279        self.inner
280            .decode_scaled_into(&mut self.pool, &mut out, stride, fmt, scale)?;
281        wrap_surface(out, dims, fmt, backend, session)
282    }
283
284    fn decode_region_scaled_to_surface_impl(
285        &mut self,
286        session: &mut CudaSession,
287        fmt: PixelFormat,
288        roi: Rect,
289        scale: Downscale,
290        backend: BackendRequest,
291    ) -> Result<Surface, Error> {
292        validate_surface_request(backend)?;
293        if matches!(backend, BackendRequest::Cuda) {
294            return self
295                .decode_region_scaled_to_cuda_resident_surface_impl(session, fmt, roi, scale);
296        }
297        let plan = DeviceDecodePlan::for_image(
298            self.inner.info().dimensions,
299            DeviceDecodeRequest::RegionScaled { roi, scale },
300        )?;
301        let dims = plan.output_dims();
302        let stride = dims.0 as usize * fmt.bytes_per_pixel();
303        let mut out = vec![0u8; stride * dims.1 as usize];
304        self.inner.decode_region_scaled_into(
305            &mut self.pool,
306            &mut out,
307            stride,
308            fmt,
309            plan.source_rect(),
310            scale,
311        )?;
312        wrap_surface(out, dims, fmt, backend, session)
313    }
314
315    /// Strictly decode a full HTJ2K image into a CUDA-backed surface using an
316    /// existing backend session.
317    pub fn decode_to_device_with_session(
318        &mut self,
319        fmt: PixelFormat,
320        session: &mut CudaSession,
321    ) -> Result<Surface, Error> {
322        self.decode_to_surface_impl(session, fmt, BackendRequest::Cuda)
323    }
324
325    /// Strictly decode a full HTJ2K image into a CUDA-backed surface and return
326    /// a structured profile report for CPU planning and CUDA stages.
327    pub fn decode_to_device_with_session_and_profile(
328        &mut self,
329        fmt: PixelFormat,
330        session: &mut CudaSession,
331    ) -> Result<(Surface, CudaHtj2kProfileReport), Error> {
332        decode_to_cuda_resident_surface_with_profile_impl(self, session, fmt)
333    }
334
335    /// Strictly decode a batch of full HTJ2K images into CUDA-backed surfaces
336    /// using an existing backend session.
337    pub fn decode_batch_to_device_with_session(
338        inputs: &[&[u8]],
339        fmt: PixelFormat,
340        session: &mut CudaSession,
341    ) -> Result<Vec<Surface>, Error> {
342        decode_batch_to_cuda_resident_surface_with_profile_control(inputs, session, fmt, false)
343            .map(|(surfaces, _report)| surfaces)
344    }
345
346    /// Strictly decode a batch of full HTJ2K images into CUDA-backed surfaces
347    /// and return one aggregate profile report for the shared batch.
348    pub fn decode_batch_to_device_with_session_and_profile(
349        inputs: &[&[u8]],
350        fmt: PixelFormat,
351        session: &mut CudaSession,
352    ) -> Result<(Vec<Surface>, CudaHtj2kProfileReport), Error> {
353        decode_batch_to_cuda_resident_surface_with_profile_control(inputs, session, fmt, true)
354    }
355
356    /// Strictly decode a full-resolution HTJ2K region into a CUDA-backed
357    /// surface using an existing backend session.
358    pub(crate) fn decode_region_to_device_with_session(
359        &mut self,
360        fmt: PixelFormat,
361        roi: Rect,
362        session: &mut CudaSession,
363    ) -> Result<Surface, Error> {
364        self.decode_region_to_surface_impl(session, fmt, roi, BackendRequest::Cuda)
365    }
366
367    /// Strictly decode a reduced-resolution HTJ2K image into a CUDA-backed
368    /// surface using an existing backend session.
369    pub(crate) fn decode_scaled_to_device_with_session(
370        &mut self,
371        fmt: PixelFormat,
372        scale: Downscale,
373        session: &mut CudaSession,
374    ) -> Result<Surface, Error> {
375        self.decode_scaled_to_surface_impl(session, fmt, scale, BackendRequest::Cuda)
376    }
377
378    /// Strictly decode a reduced-resolution HTJ2K region into a CUDA-backed
379    /// surface using an existing backend session.
380    pub(crate) fn decode_region_scaled_to_device_with_session(
381        &mut self,
382        fmt: PixelFormat,
383        roi: Rect,
384        scale: Downscale,
385        session: &mut CudaSession,
386    ) -> Result<Surface, Error> {
387        self.decode_region_scaled_to_surface_impl(session, fmt, roi, scale, BackendRequest::Cuda)
388    }
389
390    /// Decode a full image through the CPU path and wrap it as a host surface.
391    pub fn decode_to_host_surface(&mut self, fmt: PixelFormat) -> Result<Surface, Error> {
392        let mut session = CudaSession::default();
393        self.decode_to_surface_impl(&mut session, fmt, BackendRequest::Cpu)
394    }
395
396    /// Build a flat CUDA HTJ2K grayscale decode plan and return stage timings.
397    pub fn build_cuda_htj2k_grayscale_plan_with_profile(
398        &mut self,
399        fmt: PixelFormat,
400    ) -> Result<(CudaHtj2kDecodePlan, CudaHtj2kProfileReport), Error> {
401        let total_start = profile::profile_now(true);
402
403        let parse_start = profile::profile_now(true);
404        let image = NativeImage::new(self.inner.bytes(), &DecodeSettings::default())
405            .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
406        let parse_us = profile::elapsed_us(parse_start);
407
408        let plan_start = profile::profile_now(true);
409        let mut native_context = NativeDecoderContext::default();
410        let native_plan = image
411            .build_direct_grayscale_plan_with_context(&mut native_context)
412            .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
413        let plan_us = profile::elapsed_us(plan_start);
414
415        let flatten_start = profile::profile_now(true);
416        let cuda_plan = CudaHtj2kDecodePlan::from_grayscale_direct_plan(&native_plan, fmt, (0, 0))?;
417        let flatten_us = profile::elapsed_us(flatten_start);
418
419        let report = CudaHtj2kProfileReport {
420            parse_us,
421            plan_us,
422            flatten_us,
423            total_us: profile::elapsed_us(total_start),
424            block_count: cuda_plan.code_blocks().len(),
425            payload_bytes: cuda_plan.payload().len(),
426            dispatch_count: 0,
427            residency: crate::SurfaceResidency::CudaResidentDecode,
428            detail: CudaHtj2kDecodeProfileDetail::default(),
429            ..CudaHtj2kProfileReport::default()
430        };
431        report.emit("plan");
432        Ok((cuda_plan, report))
433    }
434
435    /// Build a flat CUDA HTJ2K grayscale region decode plan and return stage timings.
436    pub fn build_cuda_htj2k_grayscale_region_plan_with_profile(
437        &mut self,
438        fmt: PixelFormat,
439        roi: Rect,
440    ) -> Result<(CudaHtj2kDecodePlan, CudaHtj2kProfileReport), Error> {
441        let total_start = profile::profile_now(true);
442
443        let parse_start = profile::profile_now(true);
444        let image = NativeImage::new(self.inner.bytes(), &DecodeSettings::default())
445            .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
446        let parse_us = profile::elapsed_us(parse_start);
447
448        let plan_start = profile::profile_now(true);
449        let mut native_context = NativeDecoderContext::default();
450        let native_plan = image
451            .build_direct_grayscale_plan_region_with_context(
452                &mut native_context,
453                (roi.x, roi.y, roi.w, roi.h),
454            )
455            .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
456        let plan_us = profile::elapsed_us(plan_start);
457
458        let flatten_start = profile::profile_now(true);
459        let cuda_plan = CudaHtj2kDecodePlan::from_grayscale_direct_plan_region(
460            &native_plan,
461            fmt,
462            (roi.x, roi.y),
463            (roi.w, roi.h),
464        )?;
465        let flatten_us = profile::elapsed_us(flatten_start);
466
467        let report = CudaHtj2kProfileReport {
468            parse_us,
469            plan_us,
470            flatten_us,
471            total_us: profile::elapsed_us(total_start),
472            block_count: cuda_plan.code_blocks().len(),
473            payload_bytes: cuda_plan.payload().len(),
474            dispatch_count: 0,
475            residency: crate::SurfaceResidency::CudaResidentDecode,
476            detail: CudaHtj2kDecodeProfileDetail::default(),
477            ..CudaHtj2kProfileReport::default()
478        };
479        report.emit("plan");
480        Ok((cuda_plan, report))
481    }
482
483    /// Build a flat reduced-resolution CUDA HTJ2K grayscale decode plan and
484    /// return stage timings.
485    pub fn build_cuda_htj2k_grayscale_scaled_plan_with_profile(
486        &mut self,
487        fmt: PixelFormat,
488        output_dimensions: (u32, u32),
489    ) -> Result<(CudaHtj2kDecodePlan, CudaHtj2kProfileReport), Error> {
490        let total_start = profile::profile_now(true);
491
492        let parse_start = profile::profile_now(true);
493        let image = NativeImage::new(
494            self.inner.bytes(),
495            &DecodeSettings {
496                target_resolution: Some(output_dimensions),
497                ..DecodeSettings::default()
498            },
499        )
500        .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
501        let parse_us = profile::elapsed_us(parse_start);
502
503        let plan_start = profile::profile_now(true);
504        let mut native_context = NativeDecoderContext::default();
505        let native_plan = image
506            .build_direct_grayscale_plan_with_context(&mut native_context)
507            .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
508        let plan_us = profile::elapsed_us(plan_start);
509
510        let flatten_start = profile::profile_now(true);
511        let cuda_plan = CudaHtj2kDecodePlan::from_grayscale_direct_plan(&native_plan, fmt, (0, 0))?;
512        let flatten_us = profile::elapsed_us(flatten_start);
513
514        let report = CudaHtj2kProfileReport {
515            parse_us,
516            plan_us,
517            flatten_us,
518            total_us: profile::elapsed_us(total_start),
519            block_count: cuda_plan.code_blocks().len(),
520            payload_bytes: cuda_plan.payload().len(),
521            dispatch_count: 0,
522            residency: crate::SurfaceResidency::CudaResidentDecode,
523            detail: CudaHtj2kDecodeProfileDetail::default(),
524            ..CudaHtj2kProfileReport::default()
525        };
526        report.emit("plan");
527        Ok((cuda_plan, report))
528    }
529
530    /// Build a flat reduced-resolution CUDA HTJ2K grayscale region decode
531    /// plan and return stage timings.
532    pub fn build_cuda_htj2k_grayscale_region_scaled_plan_with_profile(
533        &mut self,
534        fmt: PixelFormat,
535        scaled_roi: Rect,
536        output_dimensions: (u32, u32),
537    ) -> Result<(CudaHtj2kDecodePlan, CudaHtj2kProfileReport), Error> {
538        let total_start = profile::profile_now(true);
539
540        let parse_start = profile::profile_now(true);
541        let image = NativeImage::new(
542            self.inner.bytes(),
543            &DecodeSettings {
544                target_resolution: Some(output_dimensions),
545                ..DecodeSettings::default()
546            },
547        )
548        .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
549        let parse_us = profile::elapsed_us(parse_start);
550
551        let plan_start = profile::profile_now(true);
552        let mut native_context = NativeDecoderContext::default();
553        let native_plan = image
554            .build_direct_grayscale_plan_region_with_context(
555                &mut native_context,
556                (scaled_roi.x, scaled_roi.y, scaled_roi.w, scaled_roi.h),
557            )
558            .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
559        let plan_us = profile::elapsed_us(plan_start);
560
561        let flatten_start = profile::profile_now(true);
562        let cuda_plan = CudaHtj2kDecodePlan::from_grayscale_direct_plan_region(
563            &native_plan,
564            fmt,
565            (scaled_roi.x, scaled_roi.y),
566            (scaled_roi.w, scaled_roi.h),
567        )?;
568        let flatten_us = profile::elapsed_us(flatten_start);
569
570        let report = CudaHtj2kProfileReport {
571            parse_us,
572            plan_us,
573            flatten_us,
574            total_us: profile::elapsed_us(total_start),
575            block_count: cuda_plan.code_blocks().len(),
576            payload_bytes: cuda_plan.payload().len(),
577            dispatch_count: 0,
578            residency: crate::SurfaceResidency::CudaResidentDecode,
579            detail: CudaHtj2kDecodeProfileDetail::default(),
580            ..CudaHtj2kProfileReport::default()
581        };
582        report.emit("plan");
583        Ok((cuda_plan, report))
584    }
585
586    /// Build flat CUDA HTJ2K RGB component plans and return stage timings.
587    #[cfg(feature = "cuda-runtime")]
588    fn build_cuda_htj2k_color_plans_with_profile(
589        &mut self,
590        fmt: PixelFormat,
591    ) -> Result<CudaHtj2kColorDecodePlans, Error> {
592        let mut native_context = NativeDecoderContext::default();
593        build_cuda_htj2k_color_plans_from_bytes_with_profile(
594            self.inner.bytes(),
595            fmt,
596            &mut native_context,
597        )
598    }
599
600    #[cfg(feature = "cuda-runtime")]
601    fn build_cuda_htj2k_color_scaled_plans_with_profile(
602        &mut self,
603        fmt: PixelFormat,
604        output_dimensions: (u32, u32),
605    ) -> Result<CudaHtj2kColorDecodePlans, Error> {
606        let total_start = profile::profile_now(true);
607
608        let parse_start = profile::profile_now(true);
609        let image = NativeImage::new(
610            self.inner.bytes(),
611            &DecodeSettings {
612                target_resolution: Some(output_dimensions),
613                ..DecodeSettings::default()
614            },
615        )
616        .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
617        let parse_us = profile::elapsed_us(parse_start);
618
619        let plan_start = profile::profile_now(true);
620        let mut native_context = NativeDecoderContext::default();
621        let native_plan = image
622            .build_direct_color_plan_with_context(&mut native_context)
623            .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
624        let plan_us = profile::elapsed_us(plan_start);
625
626        let flatten_start = profile::profile_now(true);
627        let mut payload = Vec::new();
628        let mut components = Vec::with_capacity(native_plan.component_plans.len());
629        for component_plan in &native_plan.component_plans {
630            let mut component =
631                CudaHtj2kDecodePlan::from_grayscale_direct_plan(component_plan, fmt, (0, 0))?;
632            component.append_payload_to_shared(&mut payload)?;
633            components.push(component);
634        }
635        let flatten_us = profile::elapsed_us(flatten_start);
636        let block_count = components
637            .iter()
638            .map(|plan| plan.code_blocks().len())
639            .sum::<usize>();
640        let payload_bytes = payload.len();
641        let report = CudaHtj2kProfileReport {
642            parse_us,
643            plan_us,
644            flatten_us,
645            total_us: profile::elapsed_us(total_start),
646            block_count,
647            payload_bytes,
648            dispatch_count: 0,
649            residency: crate::SurfaceResidency::CudaResidentDecode,
650            detail: CudaHtj2kDecodeProfileDetail::default(),
651            ..CudaHtj2kProfileReport::default()
652        };
653        report.emit("plan");
654
655        Ok(CudaHtj2kColorDecodePlans {
656            dimensions: native_plan.dimensions,
657            mct_dimensions: native_plan.dimensions,
658            bit_depths: native_plan.bit_depths,
659            mct: native_plan.mct,
660            transform: CudaHtj2kTransform::from_native(native_plan.transform),
661            payload,
662            components,
663            report,
664        })
665    }
666
667    #[cfg(feature = "cuda-runtime")]
668    fn build_cuda_htj2k_color_region_plans_with_profile(
669        &mut self,
670        fmt: PixelFormat,
671        roi: Rect,
672    ) -> Result<CudaHtj2kColorDecodePlans, Error> {
673        let total_start = profile::profile_now(true);
674
675        let parse_start = profile::profile_now(true);
676        let image = NativeImage::new(self.inner.bytes(), &DecodeSettings::default())
677            .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
678        let parse_us = profile::elapsed_us(parse_start);
679
680        let plan_start = profile::profile_now(true);
681        let mut native_context = NativeDecoderContext::default();
682        let native_plan = image
683            .build_direct_color_plan_with_context(&mut native_context)
684            .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
685        let plan_us = profile::elapsed_us(plan_start);
686
687        let flatten_start = profile::profile_now(true);
688        let mut payload = Vec::new();
689        let mut components = Vec::with_capacity(native_plan.component_plans.len());
690        for component_plan in &native_plan.component_plans {
691            let mut component = CudaHtj2kDecodePlan::from_grayscale_direct_plan_region(
692                component_plan,
693                fmt,
694                (roi.x, roi.y),
695                (roi.w, roi.h),
696            )?;
697            component.append_payload_to_shared(&mut payload)?;
698            components.push(component);
699        }
700        let flatten_us = profile::elapsed_us(flatten_start);
701        let block_count = components
702            .iter()
703            .map(|plan| plan.code_blocks().len())
704            .sum::<usize>();
705        let payload_bytes = payload.len();
706        let report = CudaHtj2kProfileReport {
707            parse_us,
708            plan_us,
709            flatten_us,
710            total_us: profile::elapsed_us(total_start),
711            block_count,
712            payload_bytes,
713            dispatch_count: 0,
714            residency: crate::SurfaceResidency::CudaResidentDecode,
715            detail: CudaHtj2kDecodeProfileDetail::default(),
716            ..CudaHtj2kProfileReport::default()
717        };
718        report.emit("plan");
719
720        Ok(CudaHtj2kColorDecodePlans {
721            dimensions: (roi.w, roi.h),
722            mct_dimensions: native_plan.dimensions,
723            bit_depths: native_plan.bit_depths,
724            mct: native_plan.mct,
725            transform: CudaHtj2kTransform::from_native(native_plan.transform),
726            payload,
727            components,
728            report,
729        })
730    }
731
732    #[cfg(feature = "cuda-runtime")]
733    fn build_cuda_htj2k_color_region_scaled_plans_with_profile(
734        &mut self,
735        fmt: PixelFormat,
736        scaled_roi: Rect,
737        output_dimensions: (u32, u32),
738    ) -> Result<CudaHtj2kColorDecodePlans, Error> {
739        let total_start = profile::profile_now(true);
740
741        let parse_start = profile::profile_now(true);
742        let image = NativeImage::new(
743            self.inner.bytes(),
744            &DecodeSettings {
745                target_resolution: Some(output_dimensions),
746                ..DecodeSettings::default()
747            },
748        )
749        .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
750        let parse_us = profile::elapsed_us(parse_start);
751
752        let plan_start = profile::profile_now(true);
753        let mut native_context = NativeDecoderContext::default();
754        let native_plan = image
755            .build_direct_color_plan_with_context(&mut native_context)
756            .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
757        let plan_us = profile::elapsed_us(plan_start);
758
759        let flatten_start = profile::profile_now(true);
760        let mut payload = Vec::new();
761        let mut components = Vec::with_capacity(native_plan.component_plans.len());
762        for component_plan in &native_plan.component_plans {
763            let mut component = CudaHtj2kDecodePlan::from_grayscale_direct_plan_region(
764                component_plan,
765                fmt,
766                (scaled_roi.x, scaled_roi.y),
767                (scaled_roi.w, scaled_roi.h),
768            )?;
769            component.append_payload_to_shared(&mut payload)?;
770            components.push(component);
771        }
772        let flatten_us = profile::elapsed_us(flatten_start);
773        let block_count = components
774            .iter()
775            .map(|plan| plan.code_blocks().len())
776            .sum::<usize>();
777        let payload_bytes = payload.len();
778        let report = CudaHtj2kProfileReport {
779            parse_us,
780            plan_us,
781            flatten_us,
782            total_us: profile::elapsed_us(total_start),
783            block_count,
784            payload_bytes,
785            dispatch_count: 0,
786            residency: crate::SurfaceResidency::CudaResidentDecode,
787            detail: CudaHtj2kDecodeProfileDetail::default(),
788            ..CudaHtj2kProfileReport::default()
789        };
790        report.emit("plan");
791
792        Ok(CudaHtj2kColorDecodePlans {
793            dimensions: (scaled_roi.w, scaled_roi.h),
794            mct_dimensions: native_plan.dimensions,
795            bit_depths: native_plan.bit_depths,
796            mct: native_plan.mct,
797            transform: CudaHtj2kTransform::from_native(native_plan.transform),
798            payload,
799            components,
800            report,
801        })
802    }
803
804    /// Decode a full image on CPU and upload it into a CUDA buffer using an
805    /// existing backend session.
806    pub fn decode_to_cpu_staged_cuda_surface_with_session(
807        &mut self,
808        fmt: PixelFormat,
809        session: &mut CudaSession,
810    ) -> Result<Surface, Error> {
811        let dims = self.inner.info().dimensions;
812        let stride = dims.0 as usize * fmt.bytes_per_pixel();
813        let mut out = vec![0u8; stride * dims.1 as usize];
814        self.inner
815            .decode_into_with_scratch(&mut self.pool, &mut out, stride, fmt)?;
816        wrap_cpu_staged_cuda_surface(&out, dims, fmt, session)
817    }
818
819    /// Decode a region on CPU and upload it into a CUDA buffer using an
820    /// existing backend session.
821    pub fn decode_region_to_cpu_staged_cuda_surface_with_session(
822        &mut self,
823        fmt: PixelFormat,
824        roi: Rect,
825        session: &mut CudaSession,
826    ) -> Result<Surface, Error> {
827        let plan = DeviceDecodePlan::for_image(
828            self.inner.info().dimensions,
829            DeviceDecodeRequest::Region { roi },
830        )?;
831        let dims = plan.output_dims();
832        let stride = dims.0 as usize * fmt.bytes_per_pixel();
833        let mut out = vec![0u8; stride * dims.1 as usize];
834        self.inner
835            .decode_region_into(&mut self.pool, &mut out, stride, fmt, plan.source_rect())?;
836        wrap_cpu_staged_cuda_surface(&out, dims, fmt, session)
837    }
838
839    /// Decode a scaled image on CPU and upload it into a CUDA buffer using an
840    /// existing backend session.
841    pub fn decode_scaled_to_cpu_staged_cuda_surface_with_session(
842        &mut self,
843        fmt: PixelFormat,
844        scale: Downscale,
845        session: &mut CudaSession,
846    ) -> Result<Surface, Error> {
847        let dims = DeviceDecodePlan::for_image(
848            self.inner.info().dimensions,
849            DeviceDecodeRequest::Scaled { scale },
850        )?
851        .output_dims();
852        let stride = dims.0 as usize * fmt.bytes_per_pixel();
853        let mut out = vec![0u8; stride * dims.1 as usize];
854        self.inner
855            .decode_scaled_into(&mut self.pool, &mut out, stride, fmt, scale)?;
856        wrap_cpu_staged_cuda_surface(&out, dims, fmt, session)
857    }
858
859    /// Decode a scaled region on CPU and upload it into a CUDA buffer using an
860    /// existing backend session.
861    pub fn decode_region_scaled_to_cpu_staged_cuda_surface_with_session(
862        &mut self,
863        fmt: PixelFormat,
864        roi: Rect,
865        scale: Downscale,
866        session: &mut CudaSession,
867    ) -> Result<Surface, Error> {
868        let plan = DeviceDecodePlan::for_image(
869            self.inner.info().dimensions,
870            DeviceDecodeRequest::RegionScaled { roi, scale },
871        )?;
872        let dims = plan.output_dims();
873        let stride = dims.0 as usize * fmt.bytes_per_pixel();
874        let mut out = vec![0u8; stride * dims.1 as usize];
875        self.inner.decode_region_scaled_into(
876            &mut self.pool,
877            &mut out,
878            stride,
879            fmt,
880            plan.source_rect(),
881            scale,
882        )?;
883        wrap_cpu_staged_cuda_surface(&out, dims, fmt, session)
884    }
885}
886
887#[cfg(feature = "cuda-runtime")]
888struct CudaCoefficientBand {
889    band_id: CudaHtj2kBandId,
890    buffer: CudaPooledDeviceBuffer,
891}
892
893#[cfg(feature = "cuda-runtime")]
894struct CudaPendingDequantBand {
895    band_index: usize,
896    jobs: Vec<CudaHtj2kCodeBlockJob>,
897    output_words: usize,
898}
899
900#[cfg(feature = "cuda-runtime")]
901struct CudaComponentDecodeWork {
902    bands: Vec<CudaCoefficientBand>,
903    pending_dequant_bands: Vec<CudaPendingDequantBand>,
904    store: CudaHtj2kStoreStep,
905    dispatches: usize,
906    decode_dispatches: usize,
907    timings: CudaDecodeStageTimings,
908}
909
910#[cfg(feature = "cuda-runtime")]
911struct CudaQueuedIdwtBatch {
912    queued: Vec<CudaQueuedExecution>,
913    kernel_dispatches: usize,
914    decode_dispatches: usize,
915}
916
917#[cfg(feature = "cuda-runtime")]
918struct CudaDecodedComponent {
919    buffer: CudaPooledDeviceBuffer,
920    store: CudaHtj2kStoreStep,
921    dispatches: usize,
922    decode_dispatches: usize,
923    timings: CudaDecodeStageTimings,
924}
925
926#[cfg(feature = "cuda-runtime")]
927struct CudaPreparedRgb8MctBatchStore {
928    color: CudaHtj2kColorDecodePlans,
929    decoded_components: Vec<CudaDecodedComponent>,
930    dispatches: usize,
931    decode_dispatches: usize,
932    job: CudaJ2kStoreRgb8MctJob,
933}
934
935#[cfg(feature = "cuda-runtime")]
936#[derive(Clone, Copy, Debug, Default)]
937struct CudaDecodeStageTimings {
938    h2d: u128,
939    payload_upload: u128,
940    status_d2h: u128,
941    ht_cleanup: u128,
942    ht_refine: u128,
943    dequant: u128,
944    ht_dispatch_count: usize,
945    idwt: u128,
946    dequant_dispatch_count: usize,
947    idwt_dispatch_count: usize,
948}
949
950#[cfg(feature = "cuda-runtime")]
951impl CudaDecodeStageTimings {
952    fn add_to_report(self, report: &mut CudaHtj2kProfileReport) {
953        report.h2d_us = report.h2d_us.saturating_add(self.h2d);
954        report.detail.payload_upload_us = report
955            .detail
956            .payload_upload_us
957            .saturating_add(self.payload_upload);
958        report.detail.status_d2h_us = report.detail.status_d2h_us.saturating_add(self.status_d2h);
959        report.ht_cleanup_us = report.ht_cleanup_us.saturating_add(self.ht_cleanup);
960        report.ht_refine_us = report.ht_refine_us.saturating_add(self.ht_refine);
961        report.dequant_us = report.dequant_us.saturating_add(self.dequant);
962        report.idwt_us = report.idwt_us.saturating_add(self.idwt);
963        report.detail.ht_dispatch_count = report
964            .detail
965            .ht_dispatch_count
966            .saturating_add(self.ht_dispatch_count);
967        report.detail.dequant_dispatch_count = report
968            .detail
969            .dequant_dispatch_count
970            .saturating_add(self.dequant_dispatch_count);
971        report.detail.idwt_dispatch_count = report
972            .detail
973            .idwt_dispatch_count
974            .saturating_add(self.idwt_dispatch_count);
975    }
976}
977
978#[cfg(feature = "cuda-runtime")]
979struct CudaHtj2kColorDecodePlans {
980    dimensions: (u32, u32),
981    mct_dimensions: (u32, u32),
982    bit_depths: [u8; 3],
983    mct: bool,
984    transform: CudaHtj2kTransform,
985    payload: Vec<u8>,
986    components: Vec<CudaHtj2kDecodePlan>,
987    report: CudaHtj2kProfileReport,
988}
989
990#[cfg(feature = "cuda-runtime")]
991fn decode_to_cuda_resident_surface_impl(
992    decoder: &mut J2kDecoder<'_>,
993    session: &mut CudaSession,
994    fmt: PixelFormat,
995) -> Result<Surface, Error> {
996    decode_to_cuda_resident_surface_with_profile_control(decoder, session, fmt, false)
997        .map(|(surface, _report)| surface)
998}
999
1000#[cfg(feature = "cuda-runtime")]
1001fn decode_to_cuda_resident_surface_with_profile_impl(
1002    decoder: &mut J2kDecoder<'_>,
1003    session: &mut CudaSession,
1004    fmt: PixelFormat,
1005) -> Result<(Surface, CudaHtj2kProfileReport), Error> {
1006    decode_to_cuda_resident_surface_with_profile_control(decoder, session, fmt, true)
1007}
1008
1009#[cfg(feature = "cuda-runtime")]
1010fn decode_to_cuda_resident_surface_with_profile_control(
1011    decoder: &mut J2kDecoder<'_>,
1012    session: &mut CudaSession,
1013    fmt: PixelFormat,
1014    collect_stage_timings: bool,
1015) -> Result<(Surface, CudaHtj2kProfileReport), Error> {
1016    let collect_stage_timings = collect_stage_timings || profile::profile_stages_enabled();
1017    let wall_started = profile::profile_now(collect_stage_timings);
1018    match fmt {
1019        PixelFormat::Gray8 | PixelFormat::Gray16 => {
1020            decode_grayscale_cuda_resident_surface_with_profile(
1021                decoder,
1022                session,
1023                fmt,
1024                wall_started,
1025                collect_stage_timings,
1026            )
1027        }
1028        PixelFormat::Rgb8 | PixelFormat::Rgba8 | PixelFormat::Rgb16 | PixelFormat::Rgba16 => {
1029            decode_color_cuda_resident_surface_with_profile(
1030                decoder,
1031                session,
1032                fmt,
1033                wall_started,
1034                collect_stage_timings,
1035            )
1036        }
1037        _ => Err(Error::UnsupportedCudaRequest {
1038            reason: CUDA_HTJ2K_OUTPUT_FORMAT_UNSUPPORTED,
1039        }),
1040    }
1041}
1042
1043#[cfg(feature = "cuda-runtime")]
1044fn decode_batch_to_cuda_resident_surface_with_profile_control(
1045    inputs: &[&[u8]],
1046    session: &mut CudaSession,
1047    fmt: PixelFormat,
1048    collect_stage_timings: bool,
1049) -> Result<(Vec<Surface>, CudaHtj2kProfileReport), Error> {
1050    #[cfg(all(test, feature = "cuda-runtime"))]
1051    CUDA_HTJ2K_BATCH_DECODE_CALLS.with(|calls| calls.set(calls.get().saturating_add(1)));
1052
1053    let collect_stage_timings = collect_stage_timings || profile::profile_stages_enabled();
1054    if inputs.is_empty() {
1055        return Ok((
1056            Vec::new(),
1057            CudaHtj2kProfileReport {
1058                residency: SurfaceResidency::CudaResidentDecode,
1059                ..CudaHtj2kProfileReport::default()
1060            },
1061        ));
1062    }
1063    match fmt {
1064        PixelFormat::Rgb8 | PixelFormat::Rgba8 | PixelFormat::Rgb16 | PixelFormat::Rgba16 => {
1065            decode_color_cuda_resident_batch_surfaces_with_profile(
1066                inputs,
1067                session,
1068                fmt,
1069                collect_stage_timings,
1070            )
1071        }
1072        _ => Err(Error::UnsupportedCudaRequest {
1073            reason: CUDA_HTJ2K_OUTPUT_FORMAT_UNSUPPORTED,
1074        }),
1075    }
1076}
1077
1078#[cfg(feature = "cuda-runtime")]
1079fn decode_region_to_cuda_resident_surface_impl(
1080    decoder: &mut J2kDecoder<'_>,
1081    session: &mut CudaSession,
1082    fmt: PixelFormat,
1083    roi: Rect,
1084) -> Result<Surface, Error> {
1085    let plan = DeviceDecodePlan::for_image(
1086        decoder.inner.info().dimensions,
1087        DeviceDecodeRequest::Region { roi },
1088    )?;
1089    if plan.is_full_frame() {
1090        return decode_to_cuda_resident_surface_impl(decoder, session, fmt);
1091    }
1092
1093    match fmt {
1094        PixelFormat::Gray8 | PixelFormat::Gray16 => {
1095            decode_grayscale_cuda_resident_region_surface(decoder, session, fmt, plan.source_rect())
1096        }
1097        PixelFormat::Rgb8 | PixelFormat::Rgba8 | PixelFormat::Rgb16 | PixelFormat::Rgba16 => {
1098            decode_color_cuda_resident_region_surface(decoder, session, fmt, plan.source_rect())
1099        }
1100        _ => Err(Error::UnsupportedCudaRequest {
1101            reason: CUDA_HTJ2K_OUTPUT_FORMAT_UNSUPPORTED,
1102        }),
1103    }
1104}
1105
1106#[cfg(feature = "cuda-runtime")]
1107fn decode_scaled_to_cuda_resident_surface_impl(
1108    decoder: &mut J2kDecoder<'_>,
1109    session: &mut CudaSession,
1110    fmt: PixelFormat,
1111    scale: Downscale,
1112) -> Result<Surface, Error> {
1113    if scale == Downscale::None {
1114        return decode_to_cuda_resident_surface_impl(decoder, session, fmt);
1115    }
1116    let output_dimensions = DeviceDecodePlan::for_image(
1117        decoder.inner.info().dimensions,
1118        DeviceDecodeRequest::Scaled { scale },
1119    )?
1120    .output_dims();
1121
1122    match fmt {
1123        PixelFormat::Gray8 | PixelFormat::Gray16 => {
1124            decode_grayscale_cuda_resident_scaled_surface(decoder, session, fmt, output_dimensions)
1125        }
1126        PixelFormat::Rgb8 | PixelFormat::Rgba8 | PixelFormat::Rgb16 | PixelFormat::Rgba16 => {
1127            decode_color_cuda_resident_scaled_surface(decoder, session, fmt, output_dimensions)
1128        }
1129        _ => Err(Error::UnsupportedCudaRequest {
1130            reason: CUDA_HTJ2K_OUTPUT_FORMAT_UNSUPPORTED,
1131        }),
1132    }
1133}
1134
1135#[cfg(feature = "cuda-runtime")]
1136fn decode_region_scaled_to_cuda_resident_surface_impl(
1137    decoder: &mut J2kDecoder<'_>,
1138    session: &mut CudaSession,
1139    fmt: PixelFormat,
1140    roi: Rect,
1141    scale: Downscale,
1142) -> Result<Surface, Error> {
1143    if scale == Downscale::None {
1144        return decode_region_to_cuda_resident_surface_impl(decoder, session, fmt, roi);
1145    }
1146    let source_dimensions = decoder.inner.info().dimensions;
1147    let scaled_dimensions =
1148        DeviceDecodePlan::for_image(source_dimensions, DeviceDecodeRequest::Scaled { scale })?
1149            .output_dims();
1150    let plan = DeviceDecodePlan::for_image(
1151        source_dimensions,
1152        DeviceDecodeRequest::RegionScaled { roi, scale },
1153    )?;
1154    let scaled_roi = plan.output_rect();
1155
1156    match fmt {
1157        PixelFormat::Gray8 | PixelFormat::Gray16 => {
1158            decode_grayscale_cuda_resident_region_scaled_surface(
1159                decoder,
1160                session,
1161                fmt,
1162                scaled_roi,
1163                scaled_dimensions,
1164            )
1165        }
1166        PixelFormat::Rgb8 | PixelFormat::Rgba8 | PixelFormat::Rgb16 | PixelFormat::Rgba16 => {
1167            decode_color_cuda_resident_region_scaled_surface(
1168                decoder,
1169                session,
1170                fmt,
1171                scaled_roi,
1172                scaled_dimensions,
1173            )
1174        }
1175        _ => Err(Error::UnsupportedCudaRequest {
1176            reason: CUDA_HTJ2K_OUTPUT_FORMAT_UNSUPPORTED,
1177        }),
1178    }
1179}
1180
1181#[cfg(feature = "cuda-runtime")]
1182fn decode_grayscale_cuda_resident_surface_with_profile(
1183    decoder: &mut J2kDecoder<'_>,
1184    session: &mut CudaSession,
1185    fmt: PixelFormat,
1186    wall_started: Option<profile::ProfileInstant>,
1187    collect_stage_timings: bool,
1188) -> Result<(Surface, CudaHtj2kProfileReport), Error> {
1189    let (plan, mut report) = decoder.build_cuda_htj2k_grayscale_plan_with_profile(fmt)?;
1190    decode_grayscale_cuda_resident_surface_with_plan_profile(
1191        session,
1192        fmt,
1193        &plan,
1194        &mut report,
1195        wall_started,
1196        collect_stage_timings,
1197    )
1198}
1199
1200#[cfg(feature = "cuda-runtime")]
1201fn decode_grayscale_cuda_resident_region_surface(
1202    decoder: &mut J2kDecoder<'_>,
1203    session: &mut CudaSession,
1204    fmt: PixelFormat,
1205    roi: Rect,
1206) -> Result<Surface, Error> {
1207    let collect_stage_timings = profile::profile_stages_enabled();
1208    let wall_started = profile::profile_now(collect_stage_timings);
1209    let (plan, mut report) =
1210        decoder.build_cuda_htj2k_grayscale_region_plan_with_profile(fmt, roi)?;
1211    decode_grayscale_cuda_resident_surface_with_plan_profile(
1212        session,
1213        fmt,
1214        &plan,
1215        &mut report,
1216        wall_started,
1217        collect_stage_timings,
1218    )
1219    .map(|(surface, _report)| surface)
1220}
1221
1222#[cfg(feature = "cuda-runtime")]
1223fn decode_grayscale_cuda_resident_scaled_surface(
1224    decoder: &mut J2kDecoder<'_>,
1225    session: &mut CudaSession,
1226    fmt: PixelFormat,
1227    output_dimensions: (u32, u32),
1228) -> Result<Surface, Error> {
1229    let collect_stage_timings = profile::profile_stages_enabled();
1230    let wall_started = profile::profile_now(collect_stage_timings);
1231    let (plan, mut report) =
1232        decoder.build_cuda_htj2k_grayscale_scaled_plan_with_profile(fmt, output_dimensions)?;
1233    decode_grayscale_cuda_resident_surface_with_plan_profile(
1234        session,
1235        fmt,
1236        &plan,
1237        &mut report,
1238        wall_started,
1239        collect_stage_timings,
1240    )
1241    .map(|(surface, _report)| surface)
1242}
1243
1244#[cfg(feature = "cuda-runtime")]
1245fn decode_grayscale_cuda_resident_region_scaled_surface(
1246    decoder: &mut J2kDecoder<'_>,
1247    session: &mut CudaSession,
1248    fmt: PixelFormat,
1249    scaled_roi: Rect,
1250    scaled_dimensions: (u32, u32),
1251) -> Result<Surface, Error> {
1252    let collect_stage_timings = profile::profile_stages_enabled();
1253    let wall_started = profile::profile_now(collect_stage_timings);
1254    let (plan, mut report) = decoder.build_cuda_htj2k_grayscale_region_scaled_plan_with_profile(
1255        fmt,
1256        scaled_roi,
1257        scaled_dimensions,
1258    )?;
1259    decode_grayscale_cuda_resident_surface_with_plan_profile(
1260        session,
1261        fmt,
1262        &plan,
1263        &mut report,
1264        wall_started,
1265        collect_stage_timings,
1266    )
1267    .map(|(surface, _report)| surface)
1268}
1269
1270#[cfg(feature = "cuda-runtime")]
1271fn decode_grayscale_cuda_resident_surface_with_plan_profile(
1272    session: &mut CudaSession,
1273    fmt: PixelFormat,
1274    plan: &CudaHtj2kDecodePlan,
1275    report: &mut CudaHtj2kProfileReport,
1276    wall_started: Option<profile::ProfileInstant>,
1277    collect_stage_timings: bool,
1278) -> Result<(Surface, CudaHtj2kProfileReport), Error> {
1279    let context = session.cuda_context()?;
1280    let table_upload_start = profile::profile_now(collect_stage_timings);
1281    let table_resources = session.htj2k_decode_table_resources()?;
1282    let table_upload_us = profile::elapsed_us(table_upload_start);
1283    report.h2d_us = report.h2d_us.saturating_add(table_upload_us);
1284    report.detail.table_upload_us = report
1285        .detail
1286        .table_upload_us
1287        .saturating_add(table_upload_us);
1288    let pool = session.decode_buffer_pool()?;
1289    let component = decode_cuda_component_plan(
1290        &context,
1291        plan,
1292        &table_resources,
1293        &pool,
1294        collect_stage_timings,
1295    )?;
1296    let input_width = component
1297        .store
1298        .input_rect
1299        .x1
1300        .saturating_sub(component.store.input_rect.x0);
1301    let component_buffer = pooled_cuda_buffer(&component.buffer)?;
1302    let (store_output, store_us) = context
1303        .time_default_stream_named_us_if(
1304            collect_stage_timings,
1305            "j2k.htj2k.decode.store.gray",
1306            || match fmt {
1307                PixelFormat::Gray8 => context.j2k_store_gray8_device(
1308                    component_buffer,
1309                    CudaJ2kStoreGray8Job {
1310                        input_width,
1311                        source_x: component.store.source_x,
1312                        source_y: component.store.source_y,
1313                        copy_width: component.store.copy_width,
1314                        copy_height: component.store.copy_height,
1315                        output_width: component.store.output_width,
1316                        output_height: component.store.output_height,
1317                        output_x: component.store.output_x,
1318                        output_y: component.store.output_y,
1319                        addend: component.store.addend,
1320                        bit_depth: u32::from(plan.bit_depth()),
1321                    },
1322                ),
1323                PixelFormat::Gray16 => context.j2k_store_gray16_device(
1324                    component_buffer,
1325                    CudaJ2kStoreGray16Job {
1326                        input_width,
1327                        source_x: component.store.source_x,
1328                        source_y: component.store.source_y,
1329                        copy_width: component.store.copy_width,
1330                        copy_height: component.store.copy_height,
1331                        output_width: component.store.output_width,
1332                        output_height: component.store.output_height,
1333                        output_x: component.store.output_x,
1334                        output_y: component.store.output_y,
1335                        addend: component.store.addend,
1336                        bit_depth: u32::from(plan.bit_depth()),
1337                    },
1338                ),
1339                _ => Err(CudaError::InvalidArgument {
1340                    message: CUDA_HTJ2K_OUTPUT_FORMAT_UNSUPPORTED.to_string(),
1341                }),
1342            },
1343        )
1344        .map_err(cuda_error)?;
1345    let (surface_buffer, store_stats) = store_output.into_parts();
1346    let dispatches = component
1347        .dispatches
1348        .saturating_add(store_stats.kernel_dispatches());
1349    let decode_dispatches = component
1350        .decode_dispatches
1351        .saturating_add(store_stats.decode_kernel_dispatches());
1352    report.dispatch_count = dispatches;
1353    component.timings.add_to_report(report);
1354    report.store_us = report.store_us.saturating_add(store_us);
1355    report.detail.store_dispatch_count = report
1356        .detail
1357        .store_dispatch_count
1358        .saturating_add(store_stats.kernel_dispatches());
1359    report.detail.wall_total_us = profile::elapsed_us(wall_started);
1360    profile::finalize_decode_total_us(report);
1361    report.emit("decode");
1362
1363    let dimensions = (component.store.output_width, component.store.output_height);
1364    let surface = Surface {
1365        backend: BackendKind::Cuda,
1366        residency: SurfaceResidency::CudaResidentDecode,
1367        dimensions,
1368        fmt,
1369        pitch_bytes: dimensions.0 as usize * fmt.bytes_per_pixel(),
1370        stats: CudaSurfaceStats {
1371            total: dispatches,
1372            copy: 0,
1373            decode: decode_dispatches,
1374        },
1375        storage: Storage::Cuda(surface_buffer),
1376    };
1377    Ok((surface, report.clone()))
1378}
1379
1380#[cfg(feature = "cuda-runtime")]
1381fn decode_color_cuda_resident_surface_with_profile(
1382    decoder: &mut J2kDecoder<'_>,
1383    session: &mut CudaSession,
1384    fmt: PixelFormat,
1385    wall_started: Option<profile::ProfileInstant>,
1386    collect_stage_timings: bool,
1387) -> Result<(Surface, CudaHtj2kProfileReport), Error> {
1388    let color = decoder.build_cuda_htj2k_color_plans_with_profile(fmt)?;
1389    decode_color_cuda_resident_surface_with_plans_profile(
1390        session,
1391        fmt,
1392        color,
1393        wall_started,
1394        collect_stage_timings,
1395    )
1396}
1397
1398#[cfg(feature = "cuda-runtime")]
1399fn decode_color_cuda_resident_scaled_surface(
1400    decoder: &mut J2kDecoder<'_>,
1401    session: &mut CudaSession,
1402    fmt: PixelFormat,
1403    output_dimensions: (u32, u32),
1404) -> Result<Surface, Error> {
1405    let collect_stage_timings = profile::profile_stages_enabled();
1406    let wall_started = profile::profile_now(collect_stage_timings);
1407    let color = decoder.build_cuda_htj2k_color_scaled_plans_with_profile(fmt, output_dimensions)?;
1408    decode_color_cuda_resident_surface_with_plans_profile(
1409        session,
1410        fmt,
1411        color,
1412        wall_started,
1413        collect_stage_timings,
1414    )
1415    .map(|(surface, _report)| surface)
1416}
1417
1418#[cfg(feature = "cuda-runtime")]
1419fn decode_color_cuda_resident_region_surface(
1420    decoder: &mut J2kDecoder<'_>,
1421    session: &mut CudaSession,
1422    fmt: PixelFormat,
1423    roi: Rect,
1424) -> Result<Surface, Error> {
1425    let collect_stage_timings = profile::profile_stages_enabled();
1426    let wall_started = profile::profile_now(collect_stage_timings);
1427    let color = decoder.build_cuda_htj2k_color_region_plans_with_profile(fmt, roi)?;
1428    decode_color_cuda_resident_surface_with_plans_profile(
1429        session,
1430        fmt,
1431        color,
1432        wall_started,
1433        collect_stage_timings,
1434    )
1435    .map(|(surface, _report)| surface)
1436}
1437
1438#[cfg(feature = "cuda-runtime")]
1439fn decode_color_cuda_resident_region_scaled_surface(
1440    decoder: &mut J2kDecoder<'_>,
1441    session: &mut CudaSession,
1442    fmt: PixelFormat,
1443    scaled_roi: Rect,
1444    scaled_dimensions: (u32, u32),
1445) -> Result<Surface, Error> {
1446    let collect_stage_timings = profile::profile_stages_enabled();
1447    let wall_started = profile::profile_now(collect_stage_timings);
1448    let color = decoder.build_cuda_htj2k_color_region_scaled_plans_with_profile(
1449        fmt,
1450        scaled_roi,
1451        scaled_dimensions,
1452    )?;
1453    decode_color_cuda_resident_surface_with_plans_profile(
1454        session,
1455        fmt,
1456        color,
1457        wall_started,
1458        collect_stage_timings,
1459    )
1460    .map(|(surface, _report)| surface)
1461}
1462
1463#[cfg(feature = "cuda-runtime")]
1464fn decode_color_cuda_resident_surface_with_plans_profile(
1465    session: &mut CudaSession,
1466    fmt: PixelFormat,
1467    mut color: CudaHtj2kColorDecodePlans,
1468    wall_started: Option<profile::ProfileInstant>,
1469    collect_stage_timings: bool,
1470) -> Result<(Surface, CudaHtj2kProfileReport), Error> {
1471    if color.components.len() != 3 {
1472        return Err(Error::UnsupportedCudaRequest {
1473            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1474        });
1475    }
1476    let context = session.cuda_context()?;
1477    let pool = session.decode_buffer_pool()?;
1478    let table_upload_start = profile::profile_now(collect_stage_timings);
1479    let table_resources = session.htj2k_decode_table_resources()?;
1480    let table_upload_us = profile::elapsed_us(table_upload_start);
1481    color.report.h2d_us = color.report.h2d_us.saturating_add(table_upload_us);
1482    color.report.detail.table_upload_us = color
1483        .report
1484        .detail
1485        .table_upload_us
1486        .saturating_add(table_upload_us);
1487    let payload_upload_start = profile::profile_now(collect_stage_timings);
1488    let decode_resources = context
1489        .upload_htj2k_decode_resources_with_tables(&color.payload, &table_resources)
1490        .map_err(cuda_error)?;
1491    let payload_upload_us = profile::elapsed_us(payload_upload_start);
1492    profile::add_payload_resource_upload_us(&mut color.report, payload_upload_us);
1493    let mut component_work = Vec::with_capacity(3);
1494    for plan in &color.components {
1495        component_work.push(decode_cuda_component_subbands_with_resources(
1496            &context,
1497            plan,
1498            &pool,
1499            collect_stage_timings,
1500        )?);
1501    }
1502    run_component_cleanup_dequant_batches(
1503        &context,
1504        &decode_resources,
1505        &mut component_work,
1506        &pool,
1507        collect_stage_timings,
1508    )?;
1509    finish_color_cuda_resident_surface_with_component_work(
1510        &context,
1511        &pool,
1512        fmt,
1513        color,
1514        component_work,
1515        wall_started,
1516        collect_stage_timings,
1517        true,
1518        true,
1519    )
1520}
1521
1522#[cfg(feature = "cuda-runtime")]
1523fn decode_color_cuda_resident_batch_surfaces_with_profile(
1524    inputs: &[&[u8]],
1525    session: &mut CudaSession,
1526    fmt: PixelFormat,
1527    collect_stage_timings: bool,
1528) -> Result<(Vec<Surface>, CudaHtj2kProfileReport), Error> {
1529    let batch_wall_started = profile::profile_now(collect_stage_timings);
1530    let mut colors = Vec::with_capacity(inputs.len());
1531    let mut shared_payload = Vec::new();
1532    let mut native_context = NativeDecoderContext::default();
1533    for input in inputs {
1534        let mut color =
1535            build_cuda_htj2k_color_plans_from_bytes_with_profile(input, fmt, &mut native_context)?;
1536        if color.components.len() != 3 {
1537            return Err(Error::UnsupportedCudaRequest {
1538                reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1539            });
1540        }
1541        append_color_payload_to_shared(&mut color, &mut shared_payload)?;
1542        colors.push(color);
1543    }
1544
1545    let context = session.cuda_context()?;
1546    let pool = session.decode_batch_buffer_pool()?;
1547    let table_upload_start = profile::profile_now(collect_stage_timings);
1548    let table_resources = session.htj2k_decode_table_resources()?;
1549    let table_upload_us = profile::elapsed_us(table_upload_start);
1550    let payload_upload_start = profile::profile_now(collect_stage_timings);
1551    let decode_resources = context
1552        .upload_htj2k_decode_resources_with_tables_and_pool(
1553            &shared_payload,
1554            &table_resources,
1555            &pool,
1556        )
1557        .map_err(cuda_error)?;
1558    let payload_upload_us = profile::elapsed_us(payload_upload_start);
1559
1560    let component_count = colors
1561        .iter()
1562        .map(|color| color.components.len())
1563        .sum::<usize>();
1564    let mut all_component_work = Vec::with_capacity(component_count);
1565    for color in &colors {
1566        for plan in &color.components {
1567            all_component_work.push(decode_cuda_component_subbands_with_resources(
1568                &context,
1569                plan,
1570                &pool,
1571                collect_stage_timings,
1572            )?);
1573        }
1574    }
1575    run_component_cleanup_dequant_batches(
1576        &context,
1577        &decode_resources,
1578        &mut all_component_work,
1579        &pool,
1580        collect_stage_timings,
1581    )?;
1582    let batch_components = colors
1583        .iter()
1584        .flat_map(|color| color.components.iter())
1585        .collect::<Vec<_>>();
1586    let idwt_batched = can_batch_color_idwt(&batch_components);
1587    let pending_idwt_batch = if idwt_batched {
1588        run_color_component_idwt_batches(
1589            &context,
1590            &batch_components,
1591            &mut all_component_work,
1592            &pool,
1593            collect_stage_timings,
1594        )?
1595    } else {
1596        None
1597    };
1598    drop(batch_components);
1599
1600    let can_use_batch_store =
1601        idwt_batched && can_batch_rgb8_mct_color_store(fmt, &colors, &all_component_work)?;
1602    let (surfaces, reports) = if can_use_batch_store {
1603        finish_color_cuda_resident_batch_surfaces_with_rgb8_mct_store(
1604            &context,
1605            fmt,
1606            colors,
1607            all_component_work,
1608            collect_stage_timings,
1609        )?
1610    } else {
1611        let mut surfaces = Vec::with_capacity(colors.len());
1612        let mut reports = Vec::with_capacity(colors.len());
1613        let mut work_iter = all_component_work.into_iter();
1614        for color in colors {
1615            let component_count = color.components.len();
1616            let component_work = work_iter.by_ref().take(component_count).collect::<Vec<_>>();
1617            if component_work.len() != component_count {
1618                return Err(Error::UnsupportedCudaRequest {
1619                    reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1620                });
1621            }
1622            let (surface, report) = finish_color_cuda_resident_surface_with_component_work(
1623                &context,
1624                &pool,
1625                fmt,
1626                color,
1627                component_work,
1628                None,
1629                collect_stage_timings,
1630                !idwt_batched,
1631                false,
1632            )?;
1633            surfaces.push(surface);
1634            reports.push(report);
1635        }
1636        (surfaces, reports)
1637    };
1638    drop(pending_idwt_batch);
1639
1640    let aggregate = finalize_color_batch_decode_report(
1641        &reports,
1642        table_upload_us,
1643        payload_upload_us,
1644        batch_wall_started,
1645    );
1646    aggregate.emit("decode_batch");
1647
1648    Ok((surfaces, aggregate))
1649}
1650
1651#[cfg(feature = "cuda-runtime")]
1652fn build_cuda_htj2k_color_plans_from_bytes_with_profile<'a>(
1653    input: &'a [u8],
1654    fmt: PixelFormat,
1655    native_context: &mut NativeDecoderContext<'a>,
1656) -> Result<CudaHtj2kColorDecodePlans, Error> {
1657    let total_start = profile::profile_now(true);
1658
1659    let parse_start = profile::profile_now(true);
1660    let image = NativeImage::new(input, &DecodeSettings::default())
1661        .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
1662    let parse_us = profile::elapsed_us(parse_start);
1663
1664    let plan_start = profile::profile_now(true);
1665    let native_plan = image
1666        .build_direct_color_plan_with_context(native_context)
1667        .map_err(|error| Error::Decode(J2kError::Backend(error.to_string())))?;
1668    let plan_us = profile::elapsed_us(plan_start);
1669
1670    let flatten_start = profile::profile_now(true);
1671    let mut payload = Vec::new();
1672    let mut components = Vec::with_capacity(native_plan.component_plans.len());
1673    for component_plan in &native_plan.component_plans {
1674        let mut component =
1675            CudaHtj2kDecodePlan::from_grayscale_direct_plan(component_plan, fmt, (0, 0))?;
1676        component.append_payload_to_shared(&mut payload)?;
1677        components.push(component);
1678    }
1679    let flatten_us = profile::elapsed_us(flatten_start);
1680    let block_count = components
1681        .iter()
1682        .map(|plan| plan.code_blocks().len())
1683        .sum::<usize>();
1684    let payload_bytes = payload.len();
1685    let report = CudaHtj2kProfileReport {
1686        parse_us,
1687        plan_us,
1688        flatten_us,
1689        total_us: profile::elapsed_us(total_start),
1690        block_count,
1691        payload_bytes,
1692        dispatch_count: 0,
1693        residency: crate::SurfaceResidency::CudaResidentDecode,
1694        detail: CudaHtj2kDecodeProfileDetail::default(),
1695        ..CudaHtj2kProfileReport::default()
1696    };
1697    report.emit("plan");
1698
1699    Ok(CudaHtj2kColorDecodePlans {
1700        dimensions: native_plan.dimensions,
1701        mct_dimensions: native_plan.dimensions,
1702        bit_depths: native_plan.bit_depths,
1703        mct: native_plan.mct,
1704        transform: CudaHtj2kTransform::from_native(native_plan.transform),
1705        payload,
1706        components,
1707        report,
1708    })
1709}
1710
1711#[cfg(feature = "cuda-runtime")]
1712fn finalize_color_batch_decode_report(
1713    reports: &[CudaHtj2kProfileReport],
1714    table_upload_us: u128,
1715    payload_upload_us: u128,
1716    batch_wall_started: Option<profile::ProfileInstant>,
1717) -> CudaHtj2kProfileReport {
1718    let mut aggregate = aggregate_decode_reports(reports);
1719    aggregate.h2d_us = aggregate
1720        .h2d_us
1721        .saturating_add(table_upload_us)
1722        .saturating_add(payload_upload_us);
1723    aggregate.detail.table_upload_us = aggregate
1724        .detail
1725        .table_upload_us
1726        .saturating_add(table_upload_us);
1727    aggregate.detail.payload_upload_us = aggregate
1728        .detail
1729        .payload_upload_us
1730        .saturating_add(payload_upload_us);
1731    aggregate.detail.wall_total_us = profile::elapsed_us(batch_wall_started);
1732    profile::finalize_decode_total_us(&mut aggregate);
1733    aggregate
1734}
1735
1736#[cfg(feature = "cuda-runtime")]
1737fn can_batch_rgb8_mct_color_store(
1738    fmt: PixelFormat,
1739    colors: &[CudaHtj2kColorDecodePlans],
1740    all_component_work: &[CudaComponentDecodeWork],
1741) -> Result<bool, Error> {
1742    if !matches!(fmt, PixelFormat::Rgb8 | PixelFormat::Rgba8) {
1743        return Ok(false);
1744    }
1745
1746    let mut offset = 0usize;
1747    for color in colors {
1748        let component_count = color.components.len();
1749        if component_count != 3 || offset.saturating_add(component_count) > all_component_work.len()
1750        {
1751            return Err(Error::UnsupportedCudaRequest {
1752                reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1753            });
1754        }
1755        if !color.mct {
1756            return Ok(false);
1757        }
1758        let component_work = &all_component_work[offset..offset + component_count];
1759        let stores = [
1760            &component_work[0].store,
1761            &component_work[1].store,
1762            &component_work[2].store,
1763        ];
1764        validate_color_stores(stores, color.dimensions)?;
1765        if !can_fuse_mct_store_for_stores(stores) {
1766            return Ok(false);
1767        }
1768        offset = offset.saturating_add(component_count);
1769    }
1770
1771    if offset != all_component_work.len() {
1772        return Err(Error::UnsupportedCudaRequest {
1773            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1774        });
1775    }
1776    Ok(!colors.is_empty())
1777}
1778
1779#[cfg(feature = "cuda-runtime")]
1780fn finish_color_cuda_resident_batch_surfaces_with_rgb8_mct_store(
1781    context: &j2k_cuda_runtime::CudaContext,
1782    fmt: PixelFormat,
1783    colors: Vec<CudaHtj2kColorDecodePlans>,
1784    all_component_work: Vec<CudaComponentDecodeWork>,
1785    collect_stage_timings: bool,
1786) -> Result<(Vec<Surface>, Vec<CudaHtj2kProfileReport>), Error> {
1787    let mut prepared = Vec::with_capacity(colors.len());
1788    let mut work_iter = all_component_work.into_iter();
1789    for color in colors {
1790        let component_count = color.components.len();
1791        let component_work = work_iter.by_ref().take(component_count).collect::<Vec<_>>();
1792        if component_work.len() != component_count {
1793            return Err(Error::UnsupportedCudaRequest {
1794                reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1795            });
1796        }
1797        prepared.push(prepare_rgb8_mct_batch_store(fmt, color, component_work)?);
1798    }
1799    if work_iter.next().is_some() {
1800        return Err(Error::UnsupportedCudaRequest {
1801            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1802        });
1803    }
1804
1805    let targets = prepared
1806        .iter()
1807        .map(rgb8_mct_batch_store_target)
1808        .collect::<Result<Vec<_>, Error>>()?;
1809    let (store_output, store_us) = context
1810        .time_default_stream_named_us_if(
1811            collect_stage_timings,
1812            "j2k.htj2k.decode.store.color.batch",
1813            || context.j2k_store_rgb8_mct_batch_contiguous_device(&targets),
1814        )
1815        .map_err(cuda_error)?;
1816    drop(targets);
1817    let (surface_buffer, surface_ranges, store_stats) = store_output.into_parts();
1818    if surface_ranges.len() != prepared.len() {
1819        return Err(Error::UnsupportedCudaRequest {
1820            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1821        });
1822    }
1823    let shared_surface_buffer = Arc::new(surface_buffer);
1824
1825    let mut surfaces = Vec::with_capacity(prepared.len());
1826    let mut reports = Vec::with_capacity(prepared.len());
1827    let store_dispatches = store_stats.kernel_dispatches();
1828    let store_decode_dispatches = store_stats.decode_kernel_dispatches();
1829    for (index, (mut prepared, surface_range)) in
1830        prepared.into_iter().zip(surface_ranges).enumerate()
1831    {
1832        let report_store_dispatches = if index == 0 { store_dispatches } else { 0 };
1833        let report_store_decode_dispatches = if index == 0 {
1834            store_decode_dispatches
1835        } else {
1836            0
1837        };
1838        let report_store_us = if index == 0 { store_us } else { 0 };
1839        let dispatches = prepared.dispatches.saturating_add(report_store_dispatches);
1840        let decode_dispatches = prepared
1841            .decode_dispatches
1842            .saturating_add(report_store_decode_dispatches);
1843        prepared.color.report.dispatch_count = dispatches;
1844        prepared.color.report.store_us = prepared
1845            .color
1846            .report
1847            .store_us
1848            .saturating_add(report_store_us);
1849        prepared.color.report.detail.store_dispatch_count = prepared
1850            .color
1851            .report
1852            .detail
1853            .store_dispatch_count
1854            .saturating_add(report_store_dispatches);
1855        profile::finalize_decode_total_us(&mut prepared.color.report);
1856
1857        let dimensions = prepared.color.dimensions;
1858        surfaces.push(Surface {
1859            backend: BackendKind::Cuda,
1860            residency: SurfaceResidency::CudaResidentDecode,
1861            dimensions,
1862            fmt,
1863            pitch_bytes: dimensions.0 as usize * fmt.bytes_per_pixel(),
1864            stats: CudaSurfaceStats {
1865                total: dispatches,
1866                copy: 0,
1867                decode: decode_dispatches,
1868            },
1869            storage: cuda_range_storage(
1870                shared_surface_buffer.clone(),
1871                surface_range.offset,
1872                surface_range.len,
1873            ),
1874        });
1875        reports.push(prepared.color.report);
1876    }
1877
1878    Ok((surfaces, reports))
1879}
1880
1881#[cfg(feature = "cuda-runtime")]
1882fn prepare_rgb8_mct_batch_store(
1883    fmt: PixelFormat,
1884    mut color: CudaHtj2kColorDecodePlans,
1885    component_work: Vec<CudaComponentDecodeWork>,
1886) -> Result<CudaPreparedRgb8MctBatchStore, Error> {
1887    let decoded_components = component_work
1888        .into_iter()
1889        .map(finish_cuda_component_decode)
1890        .collect::<Result<Vec<_>, Error>>()?;
1891    let [component0, component1, component2] = decoded_components.as_slice() else {
1892        return Err(Error::UnsupportedCudaRequest {
1893            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1894        });
1895    };
1896    let stores = [&component0.store, &component1.store, &component2.store];
1897    validate_color_stores(stores, color.dimensions)?;
1898    if !color.mct || !can_fuse_mct_store_for_stores(stores) {
1899        return Err(Error::UnsupportedCudaRequest {
1900            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1901        });
1902    }
1903
1904    let dispatches = decoded_components
1905        .iter()
1906        .map(|component| component.dispatches)
1907        .sum::<usize>();
1908    let decode_dispatches = decoded_components
1909        .iter()
1910        .map(|component| component.decode_dispatches)
1911        .sum::<usize>();
1912    for component in &decoded_components {
1913        component.timings.add_to_report(&mut color.report);
1914    }
1915
1916    let addends = [
1917        bit_depth_addend(color.bit_depths[0]),
1918        bit_depth_addend(color.bit_depths[1]),
1919        bit_depth_addend(color.bit_depths[2]),
1920    ];
1921    let job = CudaJ2kStoreRgb8MctJob {
1922        store: CudaJ2kStoreRgb8Job {
1923            input_width0: color_store_input_width(&component0.store),
1924            input_width1: color_store_input_width(&component1.store),
1925            input_width2: color_store_input_width(&component2.store),
1926            source_x0: component0.store.source_x,
1927            source_y0: component0.store.source_y,
1928            source_x1: component1.store.source_x,
1929            source_y1: component1.store.source_y,
1930            source_x2: component2.store.source_x,
1931            source_y2: component2.store.source_y,
1932            copy_width: component0.store.copy_width,
1933            copy_height: component0.store.copy_height,
1934            output_width: component0.store.output_width,
1935            output_height: component0.store.output_height,
1936            output_x: component0.store.output_x,
1937            output_y: component0.store.output_y,
1938            addend0: addends[0],
1939            addend1: addends[1],
1940            addend2: addends[2],
1941            bit_depth0: u32::from(color.bit_depths[0]),
1942            bit_depth1: u32::from(color.bit_depths[1]),
1943            bit_depth2: u32::from(color.bit_depths[2]),
1944            rgba: u32::from(fmt == PixelFormat::Rgba8),
1945        },
1946        irreversible97: u32::from(color.transform == CudaHtj2kTransform::Irreversible97),
1947    };
1948
1949    Ok(CudaPreparedRgb8MctBatchStore {
1950        color,
1951        decoded_components,
1952        dispatches,
1953        decode_dispatches,
1954        job,
1955    })
1956}
1957
1958#[cfg(feature = "cuda-runtime")]
1959fn rgb8_mct_batch_store_target(
1960    prepared: &CudaPreparedRgb8MctBatchStore,
1961) -> Result<CudaJ2kStoreRgb8MctTarget<'_>, Error> {
1962    let [component0, component1, component2] = prepared.decoded_components.as_slice() else {
1963        return Err(Error::UnsupportedCudaRequest {
1964            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
1965        });
1966    };
1967    Ok(CudaJ2kStoreRgb8MctTarget {
1968        plane0: pooled_cuda_buffer(&component0.buffer)?,
1969        plane1: pooled_cuda_buffer(&component1.buffer)?,
1970        plane2: pooled_cuda_buffer(&component2.buffer)?,
1971        job: prepared.job,
1972    })
1973}
1974
1975#[cfg(feature = "cuda-runtime")]
1976fn can_fuse_mct_store_for_stores(stores: [&CudaHtj2kStoreStep; 3]) -> bool {
1977    let input_width0 = color_store_input_width(stores[0]);
1978    let input_width1 = color_store_input_width(stores[1]);
1979    let input_width2 = color_store_input_width(stores[2]);
1980    input_width0 == input_width1
1981        && input_width0 == input_width2
1982        && stores[0].source_x == stores[1].source_x
1983        && stores[0].source_x == stores[2].source_x
1984        && stores[0].source_y == stores[1].source_y
1985        && stores[0].source_y == stores[2].source_y
1986}
1987
1988#[cfg(feature = "cuda-runtime")]
1989fn color_store_input_width(store: &CudaHtj2kStoreStep) -> u32 {
1990    store.input_rect.x1.saturating_sub(store.input_rect.x0)
1991}
1992
1993#[cfg(feature = "cuda-runtime")]
1994#[allow(clippy::too_many_arguments)]
1995fn finish_color_cuda_resident_surface_with_component_work(
1996    context: &j2k_cuda_runtime::CudaContext,
1997    pool: &CudaBufferPool,
1998    fmt: PixelFormat,
1999    mut color: CudaHtj2kColorDecodePlans,
2000    mut component_work: Vec<CudaComponentDecodeWork>,
2001    wall_started: Option<profile::ProfileInstant>,
2002    collect_stage_timings: bool,
2003    run_idwt: bool,
2004    emit_report: bool,
2005) -> Result<(Surface, CudaHtj2kProfileReport), Error> {
2006    let pending_idwt_batch = if run_idwt {
2007        let batch_components = color.components.iter().collect::<Vec<_>>();
2008        if can_batch_color_idwt(&batch_components) {
2009            run_color_component_idwt_batches(
2010                context,
2011                &batch_components,
2012                &mut component_work,
2013                pool,
2014                collect_stage_timings,
2015            )?
2016        } else {
2017            for (plan, work) in color.components.iter().zip(component_work.iter_mut()) {
2018                run_cuda_component_idwt_steps(
2019                    context,
2020                    plan.idwt_steps(),
2021                    work,
2022                    pool,
2023                    collect_stage_timings,
2024                )?;
2025            }
2026            None
2027        }
2028    } else {
2029        None
2030    };
2031    let decoded_components = component_work
2032        .into_iter()
2033        .map(finish_cuda_component_decode)
2034        .collect::<Result<Vec<_>, Error>>()?;
2035    let [component0, component1, component2] = decoded_components.as_slice() else {
2036        return Err(Error::UnsupportedCudaRequest {
2037            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
2038        });
2039    };
2040    validate_color_stores(
2041        [&component0.store, &component1.store, &component2.store],
2042        color.dimensions,
2043    )?;
2044
2045    let mut dispatches = decoded_components
2046        .iter()
2047        .map(|component| component.dispatches)
2048        .sum::<usize>();
2049    let mut decode_dispatches = decoded_components
2050        .iter()
2051        .map(|component| component.decode_dispatches)
2052        .sum::<usize>();
2053    for component in &decoded_components {
2054        component.timings.add_to_report(&mut color.report);
2055    }
2056    let component0_buffer = pooled_cuda_buffer(&component0.buffer)?;
2057    let component1_buffer = pooled_cuda_buffer(&component1.buffer)?;
2058    let component2_buffer = pooled_cuda_buffer(&component2.buffer)?;
2059    let input_width0 = component0
2060        .store
2061        .input_rect
2062        .x1
2063        .saturating_sub(component0.store.input_rect.x0);
2064    let input_width1 = component1
2065        .store
2066        .input_rect
2067        .x1
2068        .saturating_sub(component1.store.input_rect.x0);
2069    let input_width2 = component2
2070        .store
2071        .input_rect
2072        .x1
2073        .saturating_sub(component2.store.input_rect.x0);
2074    let irreversible97 = u32::from(color.transform == CudaHtj2kTransform::Irreversible97);
2075    let mct_store_addends = [
2076        bit_depth_addend(color.bit_depths[0]),
2077        bit_depth_addend(color.bit_depths[1]),
2078        bit_depth_addend(color.bit_depths[2]),
2079    ];
2080    let can_fuse_mct_store = color.mct
2081        && input_width0 == input_width1
2082        && input_width0 == input_width2
2083        && component0.store.source_x == component1.store.source_x
2084        && component0.store.source_x == component2.store.source_x
2085        && component0.store.source_y == component1.store.source_y
2086        && component0.store.source_y == component2.store.source_y;
2087    let addends = if color.mct && can_fuse_mct_store {
2088        mct_store_addends
2089    } else if color.mct {
2090        let mct_len = u32::try_from(checked_area(
2091            color.mct_dimensions.0,
2092            color.mct_dimensions.1,
2093        )?)
2094        .map_err(|_| Error::UnsupportedCudaRequest {
2095            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
2096        })?;
2097        let stats = context
2098            .time_default_stream_named_us_if(collect_stage_timings, "j2k.htj2k.decode.mct", || {
2099                context.j2k_inverse_mct_device(
2100                    component0_buffer,
2101                    component1_buffer,
2102                    component2_buffer,
2103                    CudaJ2kInverseMctJob {
2104                        len: mct_len,
2105                        irreversible97,
2106                        addend0: mct_store_addends[0],
2107                        addend1: mct_store_addends[1],
2108                        addend2: mct_store_addends[2],
2109                    },
2110                )
2111            })
2112            .map_err(cuda_error)?;
2113        let (stats, mct_us) = stats;
2114        dispatches = dispatches.saturating_add(stats.kernel_dispatches());
2115        decode_dispatches = decode_dispatches.saturating_add(stats.decode_kernel_dispatches());
2116        color.report.mct_us = color.report.mct_us.saturating_add(mct_us);
2117        color.report.detail.mct_dispatch_count = color
2118            .report
2119            .detail
2120            .mct_dispatch_count
2121            .saturating_add(stats.kernel_dispatches());
2122        [0.0, 0.0, 0.0]
2123    } else {
2124        [
2125            component0.store.addend,
2126            component1.store.addend,
2127            component2.store.addend,
2128        ]
2129    };
2130    let (store_output, store_us) = context
2131        .time_default_stream_named_us_if(
2132            collect_stage_timings,
2133            "j2k.htj2k.decode.store.color",
2134            || match fmt {
2135                PixelFormat::Rgb8 | PixelFormat::Rgba8 => {
2136                    let store_job = CudaJ2kStoreRgb8Job {
2137                        input_width0,
2138                        input_width1,
2139                        input_width2,
2140                        source_x0: component0.store.source_x,
2141                        source_y0: component0.store.source_y,
2142                        source_x1: component1.store.source_x,
2143                        source_y1: component1.store.source_y,
2144                        source_x2: component2.store.source_x,
2145                        source_y2: component2.store.source_y,
2146                        copy_width: component0.store.copy_width,
2147                        copy_height: component0.store.copy_height,
2148                        output_width: component0.store.output_width,
2149                        output_height: component0.store.output_height,
2150                        output_x: component0.store.output_x,
2151                        output_y: component0.store.output_y,
2152                        addend0: addends[0],
2153                        addend1: addends[1],
2154                        addend2: addends[2],
2155                        bit_depth0: u32::from(color.bit_depths[0]),
2156                        bit_depth1: u32::from(color.bit_depths[1]),
2157                        bit_depth2: u32::from(color.bit_depths[2]),
2158                        rgba: u32::from(fmt == PixelFormat::Rgba8),
2159                    };
2160                    if can_fuse_mct_store {
2161                        context.j2k_store_rgb8_mct_device(
2162                            component0_buffer,
2163                            component1_buffer,
2164                            component2_buffer,
2165                            CudaJ2kStoreRgb8MctJob {
2166                                store: store_job,
2167                                irreversible97,
2168                            },
2169                        )
2170                    } else {
2171                        context.j2k_store_rgb8_device(
2172                            component0_buffer,
2173                            component1_buffer,
2174                            component2_buffer,
2175                            store_job,
2176                        )
2177                    }
2178                }
2179                PixelFormat::Rgb16 | PixelFormat::Rgba16 => {
2180                    let store_job = CudaJ2kStoreRgb16Job {
2181                        input_width0,
2182                        input_width1,
2183                        input_width2,
2184                        source_x0: component0.store.source_x,
2185                        source_y0: component0.store.source_y,
2186                        source_x1: component1.store.source_x,
2187                        source_y1: component1.store.source_y,
2188                        source_x2: component2.store.source_x,
2189                        source_y2: component2.store.source_y,
2190                        copy_width: component0.store.copy_width,
2191                        copy_height: component0.store.copy_height,
2192                        output_width: component0.store.output_width,
2193                        output_height: component0.store.output_height,
2194                        output_x: component0.store.output_x,
2195                        output_y: component0.store.output_y,
2196                        addend0: addends[0],
2197                        addend1: addends[1],
2198                        addend2: addends[2],
2199                        bit_depth0: u32::from(color.bit_depths[0]),
2200                        bit_depth1: u32::from(color.bit_depths[1]),
2201                        bit_depth2: u32::from(color.bit_depths[2]),
2202                        rgba: u32::from(fmt == PixelFormat::Rgba16),
2203                    };
2204                    if can_fuse_mct_store {
2205                        context.j2k_store_rgb16_mct_device(
2206                            component0_buffer,
2207                            component1_buffer,
2208                            component2_buffer,
2209                            CudaJ2kStoreRgb16MctJob {
2210                                store: store_job,
2211                                irreversible97,
2212                            },
2213                        )
2214                    } else {
2215                        context.j2k_store_rgb16_device(
2216                            component0_buffer,
2217                            component1_buffer,
2218                            component2_buffer,
2219                            store_job,
2220                        )
2221                    }
2222                }
2223                _ => Err(CudaError::InvalidArgument {
2224                    message: CUDA_HTJ2K_OUTPUT_FORMAT_UNSUPPORTED.to_string(),
2225                }),
2226            },
2227        )
2228        .map_err(cuda_error)?;
2229    drop(pending_idwt_batch);
2230    let (surface_buffer, store_stats) = store_output.into_parts();
2231    dispatches = dispatches.saturating_add(store_stats.kernel_dispatches());
2232    decode_dispatches = decode_dispatches.saturating_add(store_stats.decode_kernel_dispatches());
2233    color.report.dispatch_count = dispatches;
2234    color.report.store_us = color.report.store_us.saturating_add(store_us);
2235    color.report.detail.store_dispatch_count = color
2236        .report
2237        .detail
2238        .store_dispatch_count
2239        .saturating_add(store_stats.kernel_dispatches());
2240    color.report.detail.wall_total_us = profile::elapsed_us(wall_started);
2241    profile::finalize_decode_total_us(&mut color.report);
2242    if emit_report {
2243        color.report.emit("decode");
2244    }
2245
2246    let surface = Surface {
2247        backend: BackendKind::Cuda,
2248        residency: SurfaceResidency::CudaResidentDecode,
2249        dimensions: color.dimensions,
2250        fmt,
2251        pitch_bytes: color.dimensions.0 as usize * fmt.bytes_per_pixel(),
2252        stats: CudaSurfaceStats {
2253            total: dispatches,
2254            copy: 0,
2255            decode: decode_dispatches,
2256        },
2257        storage: Storage::Cuda(surface_buffer),
2258    };
2259    Ok((surface, color.report))
2260}
2261
2262#[cfg(feature = "cuda-runtime")]
2263fn append_color_payload_to_shared(
2264    color: &mut CudaHtj2kColorDecodePlans,
2265    shared_payload: &mut Vec<u8>,
2266) -> Result<(), Error> {
2267    let base = u64::try_from(shared_payload.len()).map_err(|_| Error::UnsupportedCudaRequest {
2268        reason: CUDA_HTJ2K_BATCH_PAYLOAD_TOO_LARGE,
2269    })?;
2270    shared_payload
2271        .try_reserve(color.payload.len())
2272        .map_err(|_| Error::UnsupportedCudaRequest {
2273            reason: CUDA_HTJ2K_BATCH_PAYLOAD_TOO_LARGE,
2274        })?;
2275    for component in &mut color.components {
2276        component.rebase_payload_offsets(base)?;
2277    }
2278    shared_payload.append(&mut color.payload);
2279    Ok(())
2280}
2281
2282#[cfg(feature = "cuda-runtime")]
2283fn aggregate_decode_reports(reports: &[CudaHtj2kProfileReport]) -> CudaHtj2kProfileReport {
2284    let mut aggregate = CudaHtj2kProfileReport {
2285        residency: SurfaceResidency::CudaResidentDecode,
2286        ..CudaHtj2kProfileReport::default()
2287    };
2288    for report in reports {
2289        add_decode_report(&mut aggregate, report);
2290    }
2291    aggregate
2292}
2293
2294#[cfg(feature = "cuda-runtime")]
2295fn add_decode_report(aggregate: &mut CudaHtj2kProfileReport, report: &CudaHtj2kProfileReport) {
2296    aggregate.parse_us = aggregate.parse_us.saturating_add(report.parse_us);
2297    aggregate.plan_us = aggregate.plan_us.saturating_add(report.plan_us);
2298    aggregate.flatten_us = aggregate.flatten_us.saturating_add(report.flatten_us);
2299    aggregate.h2d_us = aggregate.h2d_us.saturating_add(report.h2d_us);
2300    aggregate.ht_cleanup_us = aggregate.ht_cleanup_us.saturating_add(report.ht_cleanup_us);
2301    aggregate.ht_refine_us = aggregate.ht_refine_us.saturating_add(report.ht_refine_us);
2302    aggregate.dequant_us = aggregate.dequant_us.saturating_add(report.dequant_us);
2303    aggregate.idwt_us = aggregate.idwt_us.saturating_add(report.idwt_us);
2304    aggregate.mct_us = aggregate.mct_us.saturating_add(report.mct_us);
2305    aggregate.store_us = aggregate.store_us.saturating_add(report.store_us);
2306    aggregate.block_count = aggregate.block_count.saturating_add(report.block_count);
2307    aggregate.payload_bytes = aggregate.payload_bytes.saturating_add(report.payload_bytes);
2308    aggregate.dispatch_count = aggregate
2309        .dispatch_count
2310        .saturating_add(report.dispatch_count);
2311    aggregate.detail.table_upload_us = aggregate
2312        .detail
2313        .table_upload_us
2314        .saturating_add(report.detail.table_upload_us);
2315    aggregate.detail.payload_upload_us = aggregate
2316        .detail
2317        .payload_upload_us
2318        .saturating_add(report.detail.payload_upload_us);
2319    aggregate.detail.job_upload_us = aggregate
2320        .detail
2321        .job_upload_us
2322        .saturating_add(report.detail.job_upload_us);
2323    aggregate.detail.status_d2h_us = aggregate
2324        .detail
2325        .status_d2h_us
2326        .saturating_add(report.detail.status_d2h_us);
2327    aggregate.detail.output_d2h_us = aggregate
2328        .detail
2329        .output_d2h_us
2330        .saturating_add(report.detail.output_d2h_us);
2331    aggregate.detail.ht_dispatch_count = aggregate
2332        .detail
2333        .ht_dispatch_count
2334        .saturating_add(report.detail.ht_dispatch_count);
2335    aggregate.detail.dequant_dispatch_count = aggregate
2336        .detail
2337        .dequant_dispatch_count
2338        .saturating_add(report.detail.dequant_dispatch_count);
2339    aggregate.detail.idwt_dispatch_count = aggregate
2340        .detail
2341        .idwt_dispatch_count
2342        .saturating_add(report.detail.idwt_dispatch_count);
2343    aggregate.detail.mct_dispatch_count = aggregate
2344        .detail
2345        .mct_dispatch_count
2346        .saturating_add(report.detail.mct_dispatch_count);
2347    aggregate.detail.store_dispatch_count = aggregate
2348        .detail
2349        .store_dispatch_count
2350        .saturating_add(report.detail.store_dispatch_count);
2351}
2352
2353#[cfg(not(feature = "cuda-runtime"))]
2354fn decode_to_cuda_resident_surface_impl(
2355    _decoder: &mut J2kDecoder<'_>,
2356    _session: &mut CudaSession,
2357    _fmt: PixelFormat,
2358) -> Result<Surface, Error> {
2359    Err(Error::CudaUnavailable)
2360}
2361
2362#[cfg(not(feature = "cuda-runtime"))]
2363fn decode_to_cuda_resident_surface_with_profile_impl(
2364    _decoder: &mut J2kDecoder<'_>,
2365    _session: &mut CudaSession,
2366    _fmt: PixelFormat,
2367) -> Result<(Surface, CudaHtj2kProfileReport), Error> {
2368    Err(Error::CudaUnavailable)
2369}
2370
2371#[cfg(not(feature = "cuda-runtime"))]
2372fn decode_region_to_cuda_resident_surface_impl(
2373    _decoder: &mut J2kDecoder<'_>,
2374    _session: &mut CudaSession,
2375    _fmt: PixelFormat,
2376    _roi: Rect,
2377) -> Result<Surface, Error> {
2378    Err(Error::CudaUnavailable)
2379}
2380
2381#[cfg(not(feature = "cuda-runtime"))]
2382fn decode_scaled_to_cuda_resident_surface_impl(
2383    _decoder: &mut J2kDecoder<'_>,
2384    _session: &mut CudaSession,
2385    _fmt: PixelFormat,
2386    _scale: Downscale,
2387) -> Result<Surface, Error> {
2388    Err(Error::CudaUnavailable)
2389}
2390
2391#[cfg(not(feature = "cuda-runtime"))]
2392fn decode_region_scaled_to_cuda_resident_surface_impl(
2393    _decoder: &mut J2kDecoder<'_>,
2394    _session: &mut CudaSession,
2395    _fmt: PixelFormat,
2396    _roi: Rect,
2397    _scale: Downscale,
2398) -> Result<Surface, Error> {
2399    Err(Error::CudaUnavailable)
2400}
2401
2402#[cfg(not(feature = "cuda-runtime"))]
2403fn decode_batch_to_cuda_resident_surface_with_profile_control(
2404    _inputs: &[&[u8]],
2405    _session: &mut CudaSession,
2406    _fmt: PixelFormat,
2407    _collect_stage_timings: bool,
2408) -> Result<(Vec<Surface>, CudaHtj2kProfileReport), Error> {
2409    Err(Error::CudaUnavailable)
2410}
2411
2412#[cfg(feature = "cuda-runtime")]
2413fn decode_cuda_component_plan(
2414    context: &j2k_cuda_runtime::CudaContext,
2415    plan: &CudaHtj2kDecodePlan,
2416    tables: &CudaHtj2kDecodeTableResources,
2417    pool: &CudaBufferPool,
2418    collect_stage_timings: bool,
2419) -> Result<CudaDecodedComponent, Error> {
2420    let resource_upload_start = profile::profile_now(collect_stage_timings);
2421    let decode_resources = context
2422        .upload_htj2k_decode_resources_with_tables(plan.payload(), tables)
2423        .map_err(cuda_error)?;
2424    let resource_upload_us = profile::elapsed_us(resource_upload_start);
2425    let mut component = decode_cuda_component_plan_with_resources(
2426        context,
2427        plan,
2428        &decode_resources,
2429        pool,
2430        collect_stage_timings,
2431    )?;
2432    component.timings.h2d = component.timings.h2d.saturating_add(resource_upload_us);
2433    component.timings.payload_upload = component
2434        .timings
2435        .payload_upload
2436        .saturating_add(resource_upload_us);
2437    Ok(component)
2438}
2439
2440#[cfg(test)]
2441fn split_htj2k_subband_decode_dispatches(kernel_dispatches: usize) -> (usize, usize) {
2442    if kernel_dispatches == 0 {
2443        return (0, 0);
2444    }
2445
2446    let dequant_dispatches = usize::from(kernel_dispatches > 1);
2447    (
2448        kernel_dispatches.saturating_sub(dequant_dispatches),
2449        dequant_dispatches,
2450    )
2451}
2452
2453#[cfg(feature = "cuda-runtime")]
2454fn htj2k_batched_cleanup_dispatches(target_count: usize) -> usize {
2455    usize::from(target_count > 0)
2456}
2457
2458#[cfg(any(feature = "cuda-runtime", test))]
2459fn htj2k_batched_dequant_dispatches(target_count: usize) -> usize {
2460    usize::from(target_count > 0)
2461}
2462
2463#[cfg(feature = "cuda-runtime")]
2464fn htj2k_batched_cleanup_dequant_dispatches(
2465    target_count: usize,
2466    fused_cleanup_dequant: bool,
2467) -> (usize, usize) {
2468    if target_count == 0 {
2469        return (0, 0);
2470    }
2471    if fused_cleanup_dequant {
2472        (1, 0)
2473    } else {
2474        (1, 1)
2475    }
2476}
2477
2478#[cfg(feature = "cuda-runtime")]
2479fn decode_cuda_component_plan_with_resources(
2480    context: &j2k_cuda_runtime::CudaContext,
2481    plan: &CudaHtj2kDecodePlan,
2482    decode_resources: &CudaHtj2kDecodeResources,
2483    pool: &CudaBufferPool,
2484    collect_stage_timings: bool,
2485) -> Result<CudaDecodedComponent, Error> {
2486    let mut work =
2487        decode_cuda_component_subbands_with_resources(context, plan, pool, collect_stage_timings)?;
2488    run_component_cleanup_dequant_batches(
2489        context,
2490        decode_resources,
2491        std::slice::from_mut(&mut work),
2492        pool,
2493        collect_stage_timings,
2494    )?;
2495    run_cuda_component_idwt_steps(
2496        context,
2497        plan.idwt_steps(),
2498        &mut work,
2499        pool,
2500        collect_stage_timings,
2501    )?;
2502    finish_cuda_component_decode(work)
2503}
2504
2505#[cfg(feature = "cuda-runtime")]
2506fn decode_cuda_component_subbands_with_resources(
2507    context: &j2k_cuda_runtime::CudaContext,
2508    plan: &CudaHtj2kDecodePlan,
2509    pool: &CudaBufferPool,
2510    collect_stage_timings: bool,
2511) -> Result<CudaComponentDecodeWork, Error> {
2512    let mut bands = Vec::with_capacity(plan.subbands().len() + plan.idwt_steps().len());
2513    let mut pending_dequant_bands = Vec::with_capacity(plan.subbands().len());
2514    let dispatches = 0usize;
2515    let decode_dispatches = 0usize;
2516    let mut timings = CudaDecodeStageTimings::default();
2517
2518    for subband in plan.subbands() {
2519        let start = subband.code_block_start as usize;
2520        let end = start.checked_add(subband.code_block_count as usize).ok_or(
2521            Error::UnsupportedCudaRequest {
2522                reason: CUDA_HTJ2K_PLAN_INVARIANT_FAILED,
2523            },
2524        )?;
2525        let code_blocks =
2526            plan.code_blocks()
2527                .get(start..end)
2528                .ok_or(Error::UnsupportedCudaRequest {
2529                    reason: CUDA_HTJ2K_PLAN_INVARIANT_FAILED,
2530                })?;
2531        let jobs = code_blocks
2532            .iter()
2533            .map(|block| cuda_code_block_job_from_plan_block(block, subband.width))
2534            .collect::<Result<Vec<_>, Error>>()?;
2535        let output_words = checked_area(subband.width, subband.height)?;
2536        let allocate_start = profile::profile_now(collect_stage_timings);
2537        let output = context
2538            .allocate_htj2k_codeblock_coefficients_with_pool(&jobs, output_words, pool)
2539            .map_err(cuda_error)?;
2540        let allocate_wall_us = profile::elapsed_us(allocate_start);
2541        timings.h2d = timings.h2d.saturating_add(allocate_wall_us);
2542        let (buffer, _, _) = output.into_parts();
2543        let band_index = bands.len();
2544        bands.push(CudaCoefficientBand {
2545            band_id: subband.band_id,
2546            buffer,
2547        });
2548        if !jobs.is_empty() {
2549            pending_dequant_bands.push(CudaPendingDequantBand {
2550                band_index,
2551                jobs,
2552                output_words,
2553            });
2554        }
2555    }
2556
2557    let [store] = plan.store_steps() else {
2558        return Err(Error::UnsupportedCudaRequest {
2559            reason: CUDA_HTJ2K_STORE_UNSUPPORTED,
2560        });
2561    };
2562
2563    Ok(CudaComponentDecodeWork {
2564        bands,
2565        pending_dequant_bands,
2566        store: *store,
2567        dispatches,
2568        decode_dispatches,
2569        timings,
2570    })
2571}
2572
2573#[cfg(feature = "cuda-runtime")]
2574fn run_component_cleanup_dequant_batches(
2575    context: &j2k_cuda_runtime::CudaContext,
2576    decode_resources: &CudaHtj2kDecodeResources,
2577    component_work: &mut [CudaComponentDecodeWork],
2578    pool: &CudaBufferPool,
2579    collect_stage_timings: bool,
2580) -> Result<(), Error> {
2581    let pending_count = component_work
2582        .iter()
2583        .map(|work| work.pending_dequant_bands.len())
2584        .sum::<usize>();
2585    if pending_count == 0 {
2586        return Ok(());
2587    }
2588    let accounting_index = component_work
2589        .iter()
2590        .position(|work| !work.pending_dequant_bands.is_empty())
2591        .ok_or(Error::UnsupportedCudaRequest {
2592            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
2593        })?;
2594
2595    let has_refinement = component_work.iter().any(|work| {
2596        work.pending_dequant_bands.iter().any(|pending| {
2597            pending
2598                .jobs
2599                .iter()
2600                .any(|job| job.refinement_length > 0 || u32::from(job.number_of_coding_passes) > 1)
2601        })
2602    });
2603    let cleanup_targets = component_work
2604        .iter()
2605        .flat_map(|work| {
2606            work.pending_dequant_bands
2607                .iter()
2608                .map(move |pending| (work, pending))
2609        })
2610        .map(|(work, pending)| {
2611            let coefficients = pooled_cuda_buffer(&work.bands[pending.band_index].buffer)?;
2612            Ok(CudaHtj2kCleanupTarget {
2613                coefficients,
2614                jobs: &pending.jobs,
2615                output_words: pending.output_words,
2616            })
2617        })
2618        .collect::<Result<Vec<_>, Error>>()?;
2619    if !has_refinement {
2620        let stage_start = profile::profile_now(collect_stage_timings);
2621        let ((stats, runtime_timings), fused_us) = context
2622            .time_default_stream_named_us_if(
2623                collect_stage_timings,
2624                "j2k.htj2k.decode.cleanup_dequantize.batch",
2625                || {
2626                    context
2627                        .decode_htj2k_codeblocks_cleanup_dequantize_multi_with_resources_and_pool_timed(
2628                            decode_resources,
2629                            &cleanup_targets,
2630                            pool,
2631                            collect_stage_timings,
2632                        )
2633                },
2634            )
2635            .map_err(cuda_error)?;
2636        let stage_wall_us = profile::elapsed_us(stage_start);
2637        let (cleanup_dispatches, dequant_dispatches) =
2638            htj2k_batched_cleanup_dequant_dispatches(pending_count, true);
2639        {
2640            let accounting = &mut component_work[accounting_index];
2641            accounting.timings.h2d = accounting
2642                .timings
2643                .h2d
2644                .saturating_add(stage_wall_us.saturating_sub(fused_us));
2645            accounting.timings.ht_cleanup = accounting.timings.ht_cleanup.saturating_add(fused_us);
2646            accounting.timings.status_d2h = accounting
2647                .timings
2648                .status_d2h
2649                .saturating_add(runtime_timings.status_d2h_us);
2650            accounting.timings.ht_dispatch_count = accounting
2651                .timings
2652                .ht_dispatch_count
2653                .saturating_add(cleanup_dispatches);
2654            accounting.timings.dequant_dispatch_count = accounting
2655                .timings
2656                .dequant_dispatch_count
2657                .saturating_add(dequant_dispatches);
2658            accounting.dispatches = accounting
2659                .dispatches
2660                .saturating_add(stats.kernel_dispatches());
2661            accounting.decode_dispatches = accounting
2662                .decode_dispatches
2663                .saturating_add(stats.decode_kernel_dispatches());
2664        }
2665
2666        for work in component_work {
2667            work.pending_dequant_bands.clear();
2668        }
2669        return Ok(());
2670    }
2671    let mut queued_cleanup: Option<CudaQueuedHtj2kCleanup> = None;
2672    let stage_start = profile::profile_now(collect_stage_timings);
2673    let (stats, cleanup_us, status_d2h_us) = if collect_stage_timings {
2674        let ((stats, runtime_timings), cleanup_us) = context
2675            .time_default_stream_named_us_if(
2676                collect_stage_timings,
2677                "j2k.htj2k.decode.cleanup.batch",
2678                || {
2679                    context.decode_htj2k_codeblocks_cleanup_multi_with_resources_and_pool_timed(
2680                        decode_resources,
2681                        &cleanup_targets,
2682                        pool,
2683                        collect_stage_timings,
2684                    )
2685                },
2686            )
2687            .map_err(cuda_error)?;
2688        (stats, cleanup_us, runtime_timings.status_d2h_us)
2689    } else {
2690        let (queued, cleanup_us) = context
2691            .time_default_stream_named_us_if(false, "j2k.htj2k.decode.cleanup.batch", || {
2692                context.decode_htj2k_codeblocks_cleanup_multi_enqueue_with_resources_and_pool(
2693                    decode_resources,
2694                    &cleanup_targets,
2695                    pool,
2696                )
2697            })
2698            .map_err(cuda_error)?;
2699        let stats = queued.execution();
2700        queued_cleanup = Some(queued);
2701        (stats, cleanup_us, 0)
2702    };
2703    drop(cleanup_targets);
2704    let stage_wall_us = profile::elapsed_us(stage_start);
2705    {
2706        let accounting = &mut component_work[accounting_index];
2707        accounting.timings.h2d = accounting
2708            .timings
2709            .h2d
2710            .saturating_add(stage_wall_us.saturating_sub(cleanup_us));
2711        accounting.timings.ht_cleanup = accounting.timings.ht_cleanup.saturating_add(cleanup_us);
2712        accounting.timings.status_d2h = accounting.timings.status_d2h.saturating_add(status_d2h_us);
2713        if has_refinement {
2714            accounting.timings.ht_refine = accounting.timings.ht_refine.saturating_add(cleanup_us);
2715        }
2716        accounting.timings.ht_dispatch_count = accounting
2717            .timings
2718            .ht_dispatch_count
2719            .saturating_add(htj2k_batched_cleanup_dispatches(pending_count));
2720        accounting.dispatches = accounting
2721            .dispatches
2722            .saturating_add(stats.kernel_dispatches());
2723        accounting.decode_dispatches = accounting
2724            .decode_dispatches
2725            .saturating_add(stats.decode_kernel_dispatches());
2726    }
2727
2728    let stage_start = profile::profile_now(collect_stage_timings);
2729    let (stats, dequant_us, dequant_target_count) = {
2730        let dequant_target_count = pending_count;
2731        let dequant_result = if let Some(queued) = queued_cleanup.as_ref() {
2732            context.time_default_stream_named_us_if(
2733                collect_stage_timings,
2734                "j2k.htj2k.decode.dequantize.batch",
2735                || context.j2k_dequantize_queued_htj2k_cleanup_with_pool(queued),
2736            )
2737        } else {
2738            let dequant_targets = component_work
2739                .iter()
2740                .flat_map(|work| {
2741                    work.pending_dequant_bands
2742                        .iter()
2743                        .map(move |pending| (work, pending))
2744                })
2745                .map(|(work, pending)| {
2746                    let coefficients = pooled_cuda_buffer(&work.bands[pending.band_index].buffer)?;
2747                    Ok(CudaHtj2kDequantizeTarget {
2748                        coefficients,
2749                        jobs: &pending.jobs,
2750                        output_words: pending.output_words,
2751                    })
2752                })
2753                .collect::<Result<Vec<_>, Error>>()?;
2754            context.time_default_stream_named_us_if(
2755                collect_stage_timings,
2756                "j2k.htj2k.decode.dequantize.batch",
2757                || {
2758                    context.j2k_dequantize_htj2k_codeblocks_multi_device_with_pool(
2759                        &dequant_targets,
2760                        pool,
2761                    )
2762                },
2763            )
2764        };
2765        let (stats, dequant_us) = match dequant_result {
2766            Ok(result) => result,
2767            Err(error) => {
2768                if let Some(queued) = queued_cleanup.take() {
2769                    queued.finish().map_err(cuda_error)?;
2770                }
2771                return Err(cuda_error(error));
2772            }
2773        };
2774        (stats, dequant_us, dequant_target_count)
2775    };
2776    let stage_wall_us = profile::elapsed_us(stage_start);
2777    {
2778        let accounting = &mut component_work[accounting_index];
2779        accounting.timings.h2d = accounting
2780            .timings
2781            .h2d
2782            .saturating_add(stage_wall_us.saturating_sub(dequant_us));
2783        accounting.timings.dequant = accounting.timings.dequant.saturating_add(dequant_us);
2784        accounting.timings.dequant_dispatch_count = accounting
2785            .timings
2786            .dequant_dispatch_count
2787            .saturating_add(htj2k_batched_dequant_dispatches(dequant_target_count));
2788        accounting.dispatches = accounting
2789            .dispatches
2790            .saturating_add(stats.kernel_dispatches());
2791        accounting.decode_dispatches = accounting
2792            .decode_dispatches
2793            .saturating_add(stats.decode_kernel_dispatches());
2794    }
2795    if let Some(queued) = queued_cleanup.take() {
2796        queued.finish().map_err(cuda_error)?;
2797    }
2798
2799    for work in component_work {
2800        work.pending_dequant_bands.clear();
2801    }
2802    Ok(())
2803}
2804
2805#[cfg(feature = "cuda-runtime")]
2806fn run_cuda_component_idwt_steps(
2807    context: &j2k_cuda_runtime::CudaContext,
2808    steps: &[CudaHtj2kIdwtStep],
2809    work: &mut CudaComponentDecodeWork,
2810    pool: &CudaBufferPool,
2811    collect_stage_timings: bool,
2812) -> Result<(), Error> {
2813    for step in steps {
2814        let ll = find_cuda_band(&work.bands, step.ll_band_id)?;
2815        let hl = find_cuda_band(&work.bands, step.hl_band_id)?;
2816        let lh = find_cuda_band(&work.bands, step.lh_band_id)?;
2817        let hh = find_cuda_band(&work.bands, step.hh_band_id)?;
2818        let low_low_device = pooled_cuda_buffer(&ll.buffer)?;
2819        let high_low_device = pooled_cuda_buffer(&hl.buffer)?;
2820        let low_high_device = pooled_cuda_buffer(&lh.buffer)?;
2821        let high_high_device = pooled_cuda_buffer(&hh.buffer)?;
2822        let job = cuda_idwt_job_from_step(step);
2823        let (output, idwt_us) = context
2824            .time_default_stream_named_us_if(collect_stage_timings, "j2k.htj2k.decode.idwt", || {
2825                if collect_stage_timings {
2826                    return context.j2k_inverse_dwt_single_device_with_pool(
2827                        low_low_device,
2828                        high_low_device,
2829                        low_high_device,
2830                        high_high_device,
2831                        job,
2832                        pool,
2833                    );
2834                }
2835                context.j2k_inverse_dwt_single_device_untimed_with_pool(
2836                    low_low_device,
2837                    high_low_device,
2838                    low_high_device,
2839                    high_high_device,
2840                    job,
2841                    pool,
2842                )
2843            })
2844            .map_err(cuda_error)?;
2845        work.timings.idwt = work.timings.idwt.saturating_add(idwt_us);
2846        let (buffer, stats) = output.into_parts();
2847        work.dispatches = work.dispatches.saturating_add(stats.kernel_dispatches());
2848        work.decode_dispatches = work
2849            .decode_dispatches
2850            .saturating_add(stats.decode_kernel_dispatches());
2851        work.timings.idwt_dispatch_count = work
2852            .timings
2853            .idwt_dispatch_count
2854            .saturating_add(stats.kernel_dispatches());
2855        work.bands.push(CudaCoefficientBand {
2856            band_id: step.output_band_id,
2857            buffer,
2858        });
2859    }
2860    Ok(())
2861}
2862
2863#[cfg(feature = "cuda-runtime")]
2864fn finish_cuda_component_decode(
2865    mut work: CudaComponentDecodeWork,
2866) -> Result<CudaDecodedComponent, Error> {
2867    let input_index = work
2868        .bands
2869        .iter()
2870        .position(|band| band.band_id == work.store.input_band_id)
2871        .ok_or(Error::UnsupportedCudaRequest {
2872            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
2873        })?;
2874    let input = work.bands.swap_remove(input_index);
2875    Ok(CudaDecodedComponent {
2876        buffer: input.buffer,
2877        store: work.store,
2878        dispatches: work.dispatches,
2879        decode_dispatches: work.decode_dispatches,
2880        timings: work.timings,
2881    })
2882}
2883
2884#[cfg(feature = "cuda-runtime")]
2885fn can_batch_color_idwt(components: &[&CudaHtj2kDecodePlan]) -> bool {
2886    let Some(first) = components.first() else {
2887        return false;
2888    };
2889    components
2890        .iter()
2891        .all(|component| component.idwt_steps().len() == first.idwt_steps().len())
2892}
2893
2894#[cfg(feature = "cuda-runtime")]
2895fn run_color_component_idwt_batches(
2896    context: &j2k_cuda_runtime::CudaContext,
2897    components: &[&CudaHtj2kDecodePlan],
2898    component_work: &mut [CudaComponentDecodeWork],
2899    pool: &CudaBufferPool,
2900    collect_stage_timings: bool,
2901) -> Result<Option<CudaQueuedIdwtBatch>, Error> {
2902    let (queued_batch, idwt_us) = context
2903        .time_default_stream_named_us_if(
2904            collect_stage_timings,
2905            "j2k.htj2k.decode.idwt.batch",
2906            || enqueue_color_component_idwt_batches(context, components, component_work, pool),
2907        )
2908        .map_err(cuda_error)?;
2909
2910    if let Some(accounting) = component_work.first_mut() {
2911        accounting.timings.idwt = accounting.timings.idwt.saturating_add(idwt_us);
2912        accounting.dispatches = accounting
2913            .dispatches
2914            .saturating_add(queued_batch.kernel_dispatches);
2915        accounting.decode_dispatches = accounting
2916            .decode_dispatches
2917            .saturating_add(queued_batch.decode_dispatches);
2918        accounting.timings.idwt_dispatch_count = accounting
2919            .timings
2920            .idwt_dispatch_count
2921            .saturating_add(queued_batch.kernel_dispatches);
2922    }
2923    let _queued_resource_count = queued_batch
2924        .queued
2925        .iter()
2926        .map(CudaQueuedExecution::resource_count)
2927        .sum::<usize>();
2928    if collect_stage_timings {
2929        drop(queued_batch);
2930        Ok(None)
2931    } else {
2932        Ok(Some(queued_batch))
2933    }
2934}
2935
2936#[cfg(feature = "cuda-runtime")]
2937fn enqueue_color_component_idwt_batches(
2938    context: &j2k_cuda_runtime::CudaContext,
2939    components: &[&CudaHtj2kDecodePlan],
2940    component_work: &mut [CudaComponentDecodeWork],
2941    pool: &CudaBufferPool,
2942) -> Result<CudaQueuedIdwtBatch, CudaError> {
2943    if components.len() != component_work.len() {
2944        return Err(CudaError::InvalidArgument {
2945            message: CUDA_HTJ2K_KERNELS_NOT_READY.to_string(),
2946        });
2947    }
2948    let Some(first) = components.first() else {
2949        return Ok(CudaQueuedIdwtBatch {
2950            queued: Vec::new(),
2951            kernel_dispatches: 0,
2952            decode_dispatches: 0,
2953        });
2954    };
2955
2956    let mut queued = Vec::with_capacity(first.idwt_steps().len());
2957    let mut kernel_dispatches = 0usize;
2958    let mut decode_dispatches = 0usize;
2959    let step_count = first.idwt_steps().len();
2960    let trace_enabled = cuda_idwt_trace_enabled();
2961    let enqueue_result = (|| -> Result<(), CudaError> {
2962        let mut output_pool_trace = CudaIdwtOutputPoolTraceTotals::default();
2963        let output_alloc_start = trace_enabled.then(std::time::Instant::now);
2964        for step_index in 0..step_count {
2965            for (component_index, component) in components.iter().enumerate() {
2966                let step = component.idwt_steps().get(step_index).ok_or_else(|| {
2967                    CudaError::InvalidArgument {
2968                        message: CUDA_HTJ2K_KERNELS_NOT_READY.to_string(),
2969                    }
2970                })?;
2971                let width = step.rect.x1.saturating_sub(step.rect.x0);
2972                let height = step.rect.y1.saturating_sub(step.rect.y0);
2973                let output_words = checked_area(width, height).map_err(cuda_invalid_decode_plan)?;
2974                let output_bytes = output_words
2975                    .checked_mul(std::mem::size_of::<f32>())
2976                    .ok_or_else(|| CudaError::InvalidArgument {
2977                        message: CUDA_HTJ2K_KERNELS_NOT_READY.to_string(),
2978                    })?;
2979                let buffer = if trace_enabled {
2980                    let (buffer, trace) = pool.take_with_trace(output_bytes)?;
2981                    output_pool_trace.add_take(trace);
2982                    buffer
2983                } else {
2984                    pool.take(output_bytes)?
2985                };
2986                component_work[component_index]
2987                    .bands
2988                    .push(CudaCoefficientBand {
2989                        band_id: step.output_band_id,
2990                        buffer,
2991                    });
2992            }
2993        }
2994        let output_alloc_us = elapsed_host_us(output_alloc_start);
2995
2996        let target_build_start = trace_enabled.then(std::time::Instant::now);
2997        let mut target_batches = Vec::with_capacity(step_count);
2998        for step_index in 0..step_count {
2999            let targets = components
3000                .iter()
3001                .enumerate()
3002                .map(|(component_index, component)| {
3003                    let step = component.idwt_steps().get(step_index).ok_or_else(|| {
3004                        CudaError::InvalidArgument {
3005                            message: CUDA_HTJ2K_KERNELS_NOT_READY.to_string(),
3006                        }
3007                    })?;
3008                    let work = &component_work[component_index];
3009                    let ll = find_cuda_band(&work.bands, step.ll_band_id)
3010                        .map_err(cuda_invalid_decode_plan)?;
3011                    let hl = find_cuda_band(&work.bands, step.hl_band_id)
3012                        .map_err(cuda_invalid_decode_plan)?;
3013                    let lh = find_cuda_band(&work.bands, step.lh_band_id)
3014                        .map_err(cuda_invalid_decode_plan)?;
3015                    let hh = find_cuda_band(&work.bands, step.hh_band_id)
3016                        .map_err(cuda_invalid_decode_plan)?;
3017                    let output = find_cuda_band(&work.bands, step.output_band_id)
3018                        .map_err(cuda_invalid_decode_plan)?;
3019                    Ok(CudaJ2kIdwtTarget {
3020                        ll: pooled_cuda_buffer(&ll.buffer).map_err(cuda_invalid_decode_plan)?,
3021                        hl: pooled_cuda_buffer(&hl.buffer).map_err(cuda_invalid_decode_plan)?,
3022                        lh: pooled_cuda_buffer(&lh.buffer).map_err(cuda_invalid_decode_plan)?,
3023                        hh: pooled_cuda_buffer(&hh.buffer).map_err(cuda_invalid_decode_plan)?,
3024                        output: pooled_cuda_buffer(&output.buffer)
3025                            .map_err(cuda_invalid_decode_plan)?,
3026                        job: cuda_idwt_job_from_step(step),
3027                    })
3028                })
3029                .collect::<Result<Vec<_>, CudaError>>()?;
3030            target_batches.push(targets);
3031        }
3032        let target_build_us = elapsed_host_us(target_build_start);
3033        let target_slices = target_batches.iter().map(Vec::as_slice).collect::<Vec<_>>();
3034        let enqueue_start = trace_enabled.then(std::time::Instant::now);
3035        let queued_execution =
3036            context.j2k_inverse_dwt_batch_sequence_enqueue_with_pool(&target_slices, pool)?;
3037        let enqueue_us = elapsed_host_us(enqueue_start);
3038        let execution = queued_execution.execution();
3039        kernel_dispatches = kernel_dispatches.saturating_add(execution.kernel_dispatches());
3040        decode_dispatches = decode_dispatches.saturating_add(execution.decode_kernel_dispatches());
3041        queued.push(queued_execution);
3042        if trace_enabled {
3043            let row = CudaIdwtBatchHostTraceRow {
3044                component_count: components.len(),
3045                step_count,
3046                output_alloc_us,
3047                target_build_us,
3048                enqueue_us,
3049                output_take_count: output_pool_trace.take_count,
3050                output_pool_reuse_count: output_pool_trace.reuse_count,
3051                output_pool_alloc_count: output_pool_trace.alloc_count,
3052                output_pool_scanned_count: output_pool_trace.scanned_count,
3053                output_pool_max_free_count: output_pool_trace.max_free_count,
3054                output_requested_bytes: output_pool_trace.requested_bytes,
3055            };
3056            eprintln!("{}", format_cuda_idwt_batch_host_trace_row(row));
3057        }
3058        Ok(())
3059    })();
3060    if let Err(error) = enqueue_result {
3061        if !queued.is_empty() {
3062            let _ = context.synchronize();
3063        }
3064        return Err(error);
3065    }
3066
3067    Ok(CudaQueuedIdwtBatch {
3068        queued,
3069        kernel_dispatches,
3070        decode_dispatches,
3071    })
3072}
3073
3074#[cfg(feature = "cuda-runtime")]
3075fn cuda_code_block_job_from_plan_block(
3076    block: &crate::CudaHtj2kCodeBlock,
3077    subband_width: u32,
3078) -> Result<CudaHtj2kCodeBlockJob, Error> {
3079    let output_offset = block
3080        .output_y
3081        .checked_mul(subband_width)
3082        .and_then(|base| base.checked_add(block.output_x))
3083        .ok_or(Error::UnsupportedCudaRequest {
3084            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
3085        })?;
3086    Ok(CudaHtj2kCodeBlockJob {
3087        payload_offset: block.payload_offset,
3088        width: block.width,
3089        height: block.height,
3090        payload_len: block.payload_len,
3091        cleanup_length: block.cleanup_length,
3092        refinement_length: block.refinement_length,
3093        missing_bit_planes: block.missing_bit_planes,
3094        num_bitplanes: block.num_bitplanes,
3095        number_of_coding_passes: block.number_of_coding_passes,
3096        output_stride: block.output_stride,
3097        output_offset,
3098        dequantization_step: block.dequantization_step,
3099        stripe_causal: block.stripe_causal != 0,
3100    })
3101}
3102
3103#[cfg(feature = "cuda-runtime")]
3104fn validate_color_stores(
3105    stores: [&CudaHtj2kStoreStep; 3],
3106    dimensions: (u32, u32),
3107) -> Result<(), Error> {
3108    let first = stores[0];
3109    for store in stores {
3110        let input_width = store.input_rect.x1.saturating_sub(store.input_rect.x0);
3111        let input_height = store.input_rect.y1.saturating_sub(store.input_rect.y0);
3112        let source_end_x =
3113            store
3114                .source_x
3115                .checked_add(store.copy_width)
3116                .ok_or(Error::UnsupportedCudaRequest {
3117                    reason: CUDA_HTJ2K_KERNELS_NOT_READY,
3118                })?;
3119        let source_end_y =
3120            store
3121                .source_y
3122                .checked_add(store.copy_height)
3123                .ok_or(Error::UnsupportedCudaRequest {
3124                    reason: CUDA_HTJ2K_KERNELS_NOT_READY,
3125                })?;
3126        if store.output_x != 0
3127            || store.output_y != 0
3128            || store.copy_width != dimensions.0
3129            || store.copy_height != dimensions.1
3130            || store.output_width != dimensions.0
3131            || store.output_height != dimensions.1
3132            || source_end_x > input_width
3133            || source_end_y > input_height
3134            || store.source_x != first.source_x
3135            || store.source_y != first.source_y
3136        {
3137            return Err(Error::UnsupportedCudaRequest {
3138                reason: CUDA_HTJ2K_KERNELS_NOT_READY,
3139            });
3140        }
3141    }
3142    Ok(())
3143}
3144
3145#[cfg(feature = "cuda-runtime")]
3146fn bit_depth_addend(bit_depth: u8) -> f32 {
3147    let shift = bit_depth.saturating_sub(1).min(15);
3148    f32::from(1_u16 << shift)
3149}
3150
3151#[cfg(feature = "cuda-runtime")]
3152fn checked_area(width: u32, height: u32) -> Result<usize, Error> {
3153    width
3154        .try_into()
3155        .ok()
3156        .and_then(|width: usize| width.checked_mul(height as usize))
3157        .ok_or(Error::UnsupportedCudaRequest {
3158            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
3159        })
3160}
3161
3162#[cfg(feature = "cuda-runtime")]
3163fn find_cuda_band(
3164    bands: &[CudaCoefficientBand],
3165    band_id: CudaHtj2kBandId,
3166) -> Result<&CudaCoefficientBand, Error> {
3167    bands
3168        .iter()
3169        .find(|band| band.band_id == band_id)
3170        .ok_or(Error::UnsupportedCudaRequest {
3171            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
3172        })
3173}
3174
3175#[cfg(feature = "cuda-runtime")]
3176fn pooled_cuda_buffer(buffer: &CudaPooledDeviceBuffer) -> Result<&CudaDeviceBuffer, Error> {
3177    buffer
3178        .as_device_buffer()
3179        .ok_or(Error::UnsupportedCudaRequest {
3180            reason: CUDA_HTJ2K_KERNELS_NOT_READY,
3181        })
3182}
3183
3184#[cfg(feature = "cuda-runtime")]
3185#[allow(clippy::needless_pass_by_value)]
3186fn cuda_invalid_decode_plan(error: Error) -> CudaError {
3187    CudaError::InvalidArgument {
3188        message: error.to_string(),
3189    }
3190}
3191
3192#[cfg(feature = "cuda-runtime")]
3193fn cuda_runtime_rect(rect: crate::CudaHtj2kRect) -> CudaJ2kRect {
3194    CudaJ2kRect {
3195        x0: rect.x0,
3196        y0: rect.y0,
3197        x1: rect.x1,
3198        y1: rect.y1,
3199    }
3200}
3201
3202#[cfg(feature = "cuda-runtime")]
3203fn cuda_idwt_job_from_step(step: &CudaHtj2kIdwtStep) -> CudaJ2kIdwtJob {
3204    CudaJ2kIdwtJob {
3205        rect: cuda_runtime_rect(step.rect),
3206        ll_rect: cuda_runtime_rect(step.ll_rect),
3207        hl_rect: cuda_runtime_rect(step.hl_rect),
3208        lh_rect: cuda_runtime_rect(step.lh_rect),
3209        hh_rect: cuda_runtime_rect(step.hh_rect),
3210        irreversible97: u32::from(step.transform == CudaHtj2kTransform::Irreversible97),
3211    }
3212}
3213
3214#[cfg(all(test, feature = "cuda-runtime"))]
3215mod tests {
3216    use super::{
3217        build_cuda_htj2k_color_plans_from_bytes_with_profile, can_batch_color_idwt,
3218        cuda_code_block_job_from_plan_block, htj2k_batched_cleanup_dequant_dispatches,
3219        htj2k_batched_cleanup_dispatches, htj2k_batched_dequant_dispatches, CudaDecodeStageTimings,
3220    };
3221    use j2k_core::PixelFormat;
3222    use j2k_native::{encode_htj2k, DecoderContext as NativeDecoderContext, EncodeOptions};
3223
3224    use crate::CudaHtj2kCodeBlock;
3225
3226    #[test]
3227    fn cuda_runtime_code_block_job_preserves_plan_output_stride() {
3228        let block = CudaHtj2kCodeBlock {
3229            subband_index: 0,
3230            payload_offset: 13,
3231            payload_len: 5,
3232            cleanup_length: 5,
3233            refinement_length: 0,
3234            output_x: 3,
3235            output_y: 2,
3236            width: 4,
3237            height: 5,
3238            output_stride: 99,
3239            missing_bit_planes: 1,
3240            number_of_coding_passes: 1,
3241            num_bitplanes: 8,
3242            stripe_causal: 0,
3243            dequantization_step: 1.0,
3244        };
3245
3246        let job = cuda_code_block_job_from_plan_block(&block, 64)
3247            .expect("valid CUDA code-block runtime job");
3248
3249        assert_eq!(job.output_offset, 131);
3250        assert_eq!(job.output_stride, 99);
3251    }
3252
3253    #[test]
3254    fn batched_cleanup_and_dequant_dispatch_helpers_count_one_shared_dispatch() {
3255        assert_eq!(htj2k_batched_cleanup_dispatches(0), 0);
3256        assert_eq!(htj2k_batched_cleanup_dispatches(1), 1);
3257        assert_eq!(htj2k_batched_cleanup_dispatches(3), 1);
3258        assert_eq!(htj2k_batched_dequant_dispatches(0), 0);
3259        assert_eq!(htj2k_batched_dequant_dispatches(1), 1);
3260        assert_eq!(htj2k_batched_dequant_dispatches(3), 1);
3261        assert_eq!(htj2k_batched_cleanup_dequant_dispatches(0, true), (0, 0));
3262        assert_eq!(htj2k_batched_cleanup_dequant_dispatches(1, true), (1, 0));
3263        assert_eq!(htj2k_batched_cleanup_dequant_dispatches(3, true), (1, 0));
3264        assert_eq!(htj2k_batched_cleanup_dequant_dispatches(1, false), (1, 1));
3265        assert_eq!(htj2k_batched_cleanup_dequant_dispatches(3, false), (1, 1));
3266    }
3267
3268    #[test]
3269    fn profiled_cuda_batch_decode_api_accepts_empty_batch() {
3270        let mut session = crate::CudaSession::default();
3271        let inputs: [&[u8]; 0] = [];
3272
3273        let (surfaces, report) =
3274            crate::J2kDecoder::decode_batch_to_device_with_session_and_profile(
3275                &inputs,
3276                PixelFormat::Rgb8,
3277                &mut session,
3278            )
3279            .expect("empty CUDA batch decode");
3280
3281        assert!(surfaces.is_empty());
3282        assert_eq!(report.block_count, 0);
3283        assert_eq!(report.payload_bytes, 0);
3284    }
3285
3286    #[test]
3287    fn cuda_batch_decode_two_color_images_matches_single_when_runtime_required() {
3288        let pixels_a: Vec<u8> = (0u16..16 * 16 * 3)
3289            .map(|idx| u8::try_from((idx * 7 + idx / 5) & 0xff).expect("masked byte"))
3290            .collect();
3291        let pixels_b: Vec<u8> = (0u16..16 * 16 * 3)
3292            .map(|idx| u8::try_from((idx * 11 + 23) & 0xff).expect("masked byte"))
3293            .collect();
3294        let options = EncodeOptions {
3295            reversible: true,
3296            num_decomposition_levels: 1,
3297            ..EncodeOptions::default()
3298        };
3299        let codestream_a =
3300            encode_htj2k(&pixels_a, 16, 16, 3, 8, false, &options).expect("encode fixture A");
3301        let codestream_b =
3302            encode_htj2k(&pixels_b, 16, 16, 3, 8, false, &options).expect("encode fixture B");
3303        let inputs = [codestream_a.as_slice(), codestream_b.as_slice()];
3304        let mut batch_session = crate::CudaSession::default();
3305
3306        let batch = crate::J2kDecoder::decode_batch_to_device_with_session_and_profile(
3307            &inputs,
3308            PixelFormat::Rgb8,
3309            &mut batch_session,
3310        );
3311        let (surfaces, report) = match batch {
3312            Ok(result) => result,
3313            Err(crate::Error::CudaUnavailable | crate::Error::CudaRuntime { .. })
3314                if !cuda_runtime_required() =>
3315            {
3316                return;
3317            }
3318            Err(error) => panic!("batch CUDA decode failed: {error}"),
3319        };
3320
3321        assert_eq!(surfaces.len(), 2);
3322        assert_eq!(report.detail.ht_dispatch_count, 1);
3323        assert_eq!(report.detail.dequant_dispatch_count, 0);
3324        assert_eq!(report.detail.store_dispatch_count, 1);
3325        let batch_pixels_tight =
3326            crate::Surface::download_batch_tight(&surfaces).expect("download tight CUDA batch");
3327        assert_eq!(batch_pixels_tight.len(), surfaces.len() * 16 * 16 * 3);
3328        for (index, codestream) in inputs.iter().enumerate() {
3329            let mut single_session = crate::CudaSession::default();
3330            let mut decoder = crate::J2kDecoder::new(codestream).expect("single decoder");
3331            let single = decoder
3332                .decode_to_device_with_session(PixelFormat::Rgb8, &mut single_session)
3333                .expect("single CUDA decode");
3334            let mut single_pixels = vec![0u8; 16 * 16 * 3];
3335            let mut batch_pixels = vec![0u8; 16 * 16 * 3];
3336            single
3337                .download_into(&mut single_pixels, 16 * 3)
3338                .expect("download single decode");
3339            surfaces[index]
3340                .download_into(&mut batch_pixels, 16 * 3)
3341                .expect("download batch decode");
3342            assert_eq!(batch_pixels, single_pixels);
3343            assert_eq!(
3344                &batch_pixels_tight[index * 16 * 16 * 3..(index + 1) * 16 * 16 * 3],
3345                single_pixels.as_slice()
3346            );
3347        }
3348    }
3349
3350    #[test]
3351    fn cuda_batch_decode_mixed_idwt_shapes_avoids_fused_batch_store_without_idwt_batch() {
3352        let codestream_a = rgb8_htj2k_fixture(32, 32, 1, 7);
3353        let codestream_b = rgb8_htj2k_fixture(32, 32, 2, 19);
3354        let inputs = [codestream_a.as_slice(), codestream_b.as_slice()];
3355        let mut batch_session = crate::CudaSession::default();
3356
3357        let result = crate::J2kDecoder::decode_batch_to_device_with_session(
3358            &inputs,
3359            PixelFormat::Rgb8,
3360            &mut batch_session,
3361        );
3362        let surfaces = match result {
3363            Ok(surfaces) => surfaces,
3364            Err(crate::Error::CudaUnavailable | crate::Error::CudaRuntime { .. })
3365                if !cuda_runtime_required() =>
3366            {
3367                return;
3368            }
3369            Err(crate::Error::UnsupportedCudaRequest { .. }) => return,
3370            Err(error) => panic!("mixed-shape batch CUDA decode failed: {error}"),
3371        };
3372
3373        assert_eq!(surfaces.len(), inputs.len());
3374        for (index, codestream) in inputs.iter().enumerate() {
3375            let mut single_session = crate::CudaSession::default();
3376            let mut decoder = crate::J2kDecoder::new(codestream).expect("single decoder");
3377            let single = decoder
3378                .decode_to_device_with_session(PixelFormat::Rgb8, &mut single_session)
3379                .expect("single CUDA decode");
3380            let mut single_pixels = vec![0u8; 32 * 32 * 3];
3381            let mut batch_pixels = vec![0u8; 32 * 32 * 3];
3382            single
3383                .download_into(&mut single_pixels, 32 * 3)
3384                .expect("download single decode");
3385            surfaces[index]
3386                .download_into(&mut batch_pixels, 32 * 3)
3387                .expect("download mixed-shape batch decode");
3388            assert_eq!(batch_pixels, single_pixels);
3389        }
3390    }
3391
3392    #[test]
3393    fn decode_stage_timings_report_status_download_detail() {
3394        let mut report = crate::CudaHtj2kProfileReport::default();
3395        let timings = CudaDecodeStageTimings {
3396            h2d: 17,
3397            status_d2h: 5,
3398            ..CudaDecodeStageTimings::default()
3399        };
3400
3401        timings.add_to_report(&mut report);
3402
3403        assert_eq!(report.h2d_us, 17);
3404        assert_eq!(report.detail.status_d2h_us, 5);
3405    }
3406
3407    fn cuda_runtime_required() -> bool {
3408        std::env::var_os("J2K_REQUIRE_CUDA_RUNTIME").is_some()
3409    }
3410
3411    fn rgb8_htj2k_fixture(width: u32, height: u32, levels: u8, seed: u16) -> Vec<u8> {
3412        let mut pixels = Vec::with_capacity(width as usize * height as usize * 3);
3413        for idx in 0..width * height {
3414            let seed = u32::from(seed);
3415            pixels.push(u8::try_from((idx * seed + idx / 3) & 0xff).expect("red"));
3416            pixels.push(u8::try_from((idx * (seed + 11) + 7) & 0xff).expect("green"));
3417            pixels.push(u8::try_from((idx * (seed + 23) + 19) & 0xff).expect("blue"));
3418        }
3419        let options = EncodeOptions {
3420            reversible: true,
3421            num_decomposition_levels: levels,
3422            ..EncodeOptions::default()
3423        };
3424        encode_htj2k(&pixels, width, height, 3, 8, false, &options)
3425            .expect("encode RGB HTJ2K fixture")
3426    }
3427
3428    #[test]
3429    fn color_plan_flattens_one_shared_payload_for_component_decode() {
3430        let pixels: Vec<u8> = (0u16..4 * 4 * 3)
3431            .map(|idx| u8::try_from((idx * 13 + idx / 3) & 0xff).expect("masked byte"))
3432            .collect();
3433        let options = EncodeOptions {
3434            reversible: true,
3435            num_decomposition_levels: 1,
3436            ..EncodeOptions::default()
3437        };
3438        let codestream =
3439            encode_htj2k(&pixels, 4, 4, 3, 8, false, &options).expect("encode HTJ2K RGB fixture");
3440        let mut decoder = crate::J2kDecoder::new(&codestream).expect("decoder");
3441
3442        let color = decoder
3443            .build_cuda_htj2k_color_plans_with_profile(PixelFormat::Rgb8)
3444            .expect("CUDA color plans");
3445
3446        assert_eq!(color.components.len(), 3);
3447        assert!(!color.payload.is_empty());
3448        assert_eq!(color.report.payload_bytes, color.payload.len());
3449        for component in &color.components {
3450            assert!(component.payload().is_empty());
3451            for block in component.code_blocks() {
3452                let start = usize::try_from(block.payload_offset).expect("payload offset");
3453                let end = start + block.payload_len as usize;
3454                assert!(end <= color.payload.len());
3455            }
3456        }
3457    }
3458
3459    #[test]
3460    fn byte_color_plan_builder_matches_decoder_color_plan() {
3461        let pixels: Vec<u8> = (0u16..8 * 8 * 3)
3462            .map(|idx| u8::try_from((idx * 19 + idx / 5) & 0xff).expect("masked byte"))
3463            .collect();
3464        let options = EncodeOptions {
3465            reversible: true,
3466            num_decomposition_levels: 1,
3467            ..EncodeOptions::default()
3468        };
3469        let codestream =
3470            encode_htj2k(&pixels, 8, 8, 3, 8, false, &options).expect("encode HTJ2K RGB fixture");
3471        let mut decoder = crate::J2kDecoder::new(&codestream).expect("decoder");
3472        let decoder_plan = decoder
3473            .build_cuda_htj2k_color_plans_with_profile(PixelFormat::Rgb8)
3474            .expect("decoder CUDA color plans");
3475        let mut native_context = NativeDecoderContext::default();
3476        let byte_plan = build_cuda_htj2k_color_plans_from_bytes_with_profile(
3477            &codestream,
3478            PixelFormat::Rgb8,
3479            &mut native_context,
3480        )
3481        .expect("byte CUDA color plans");
3482
3483        assert_eq!(byte_plan.dimensions, decoder_plan.dimensions);
3484        assert_eq!(byte_plan.mct_dimensions, decoder_plan.mct_dimensions);
3485        assert_eq!(byte_plan.bit_depths, decoder_plan.bit_depths);
3486        assert_eq!(byte_plan.mct, decoder_plan.mct);
3487        assert_eq!(byte_plan.components.len(), decoder_plan.components.len());
3488        assert_eq!(byte_plan.payload.len(), decoder_plan.payload.len());
3489        assert_eq!(
3490            byte_plan
3491                .components
3492                .iter()
3493                .map(|component| component.code_blocks().len())
3494                .collect::<Vec<_>>(),
3495            decoder_plan
3496                .components
3497                .iter()
3498                .map(|component| component.code_blocks().len())
3499                .collect::<Vec<_>>()
3500        );
3501    }
3502
3503    #[test]
3504    fn multi_image_color_components_can_share_one_idwt_batch() {
3505        let pixels: Vec<u8> = (0u16..16 * 16 * 3)
3506            .map(|idx| u8::try_from((idx * 17 + idx / 7) & 0xff).expect("masked byte"))
3507            .collect();
3508        let options = EncodeOptions {
3509            reversible: true,
3510            num_decomposition_levels: 1,
3511            ..EncodeOptions::default()
3512        };
3513        let codestream =
3514            encode_htj2k(&pixels, 16, 16, 3, 8, false, &options).expect("encode HTJ2K RGB fixture");
3515        let mut first = crate::J2kDecoder::new(&codestream).expect("first decoder");
3516        let mut second = crate::J2kDecoder::new(&codestream).expect("second decoder");
3517        let first = first
3518            .build_cuda_htj2k_color_plans_with_profile(PixelFormat::Rgb8)
3519            .expect("first CUDA color plans");
3520        let second = second
3521            .build_cuda_htj2k_color_plans_with_profile(PixelFormat::Rgb8)
3522            .expect("second CUDA color plans");
3523        let components = first
3524            .components
3525            .iter()
3526            .chain(second.components.iter())
3527            .collect::<Vec<_>>();
3528
3529        assert_eq!(components.len(), 6);
3530        assert!(can_batch_color_idwt(&components));
3531    }
3532
3533    #[test]
3534    fn batched_color_idwt_defers_completion_to_store_sync() {
3535        let source = include_str!("decoder.rs");
3536
3537        assert!(
3538            !source.contains(
3539                "if !collect_stage_timings {\n        context.synchronize().map_err(cuda_error)?;\n    }"
3540            ),
3541            "batched color IDWT should keep queued resources live and let the following store synchronize"
3542        );
3543    }
3544}
3545
3546impl ImageCodec for J2kDecoder<'_> {
3547    type Error = Error;
3548    type Warning = Infallible;
3549    type Pool = CpuJ2kScratchPool;
3550}
3551
3552impl<'a> ImageDecode<'a> for J2kDecoder<'a> {
3553    type View = J2kView<'a>;
3554
3555    fn inspect(input: &'a [u8]) -> Result<j2k_core::Info, Self::Error> {
3556        Ok(CpuDecoder::inspect(input)?)
3557    }
3558
3559    fn parse(input: &'a [u8]) -> Result<Self::View, Self::Error> {
3560        Ok(J2kView::parse(input)?)
3561    }
3562
3563    fn from_view(view: Self::View) -> Result<Self, Self::Error> {
3564        Ok(Self {
3565            inner: CpuDecoder::from_view(view)?,
3566            pool: CpuJ2kScratchPool::new(),
3567        })
3568    }
3569
3570    fn decode_into(
3571        &mut self,
3572        out: &mut [u8],
3573        stride: usize,
3574        fmt: PixelFormat,
3575    ) -> Result<DecodeOutcome<Self::Warning>, Self::Error> {
3576        Ok(self.inner.decode_into(out, stride, fmt)?)
3577    }
3578
3579    fn decode_into_with_scratch(
3580        &mut self,
3581        pool: &mut Self::Pool,
3582        out: &mut [u8],
3583        stride: usize,
3584        fmt: PixelFormat,
3585    ) -> Result<DecodeOutcome<Self::Warning>, Self::Error> {
3586        Ok(self
3587            .inner
3588            .decode_into_with_scratch(pool, out, stride, fmt)?)
3589    }
3590
3591    fn decode_region_into(
3592        &mut self,
3593        pool: &mut Self::Pool,
3594        out: &mut [u8],
3595        stride: usize,
3596        fmt: PixelFormat,
3597        roi: Rect,
3598    ) -> Result<DecodeOutcome<Self::Warning>, Self::Error> {
3599        Ok(self.inner.decode_region_into(pool, out, stride, fmt, roi)?)
3600    }
3601
3602    fn decode_scaled_into(
3603        &mut self,
3604        pool: &mut Self::Pool,
3605        out: &mut [u8],
3606        stride: usize,
3607        fmt: PixelFormat,
3608        scale: Downscale,
3609    ) -> Result<DecodeOutcome<Self::Warning>, Self::Error> {
3610        Ok(self
3611            .inner
3612            .decode_scaled_into(pool, out, stride, fmt, scale)?)
3613    }
3614
3615    fn decode_region_scaled_into(
3616        &mut self,
3617        pool: &mut Self::Pool,
3618        out: &mut [u8],
3619        stride: usize,
3620        fmt: PixelFormat,
3621        roi: Rect,
3622        scale: Downscale,
3623    ) -> Result<DecodeOutcome<Self::Warning>, Self::Error> {
3624        Ok(self
3625            .inner
3626            .decode_region_scaled_into(pool, out, stride, fmt, roi, scale)?)
3627    }
3628}
3629
3630impl<'a> ImageDecodeDevice<'a> for J2kDecoder<'a> {
3631    type DeviceSurface = Surface;
3632}
3633
3634impl<'a> ImageDecodeSubmit<'a> for J2kDecoder<'a> {
3635    type Session = CudaSession;
3636    type DeviceSurface = Surface;
3637    type SubmittedSurface = ReadySubmission<Surface, Error>;
3638
3639    fn submit_to_device(
3640        &mut self,
3641        session: &mut Self::Session,
3642        fmt: PixelFormat,
3643        backend: BackendRequest,
3644    ) -> Result<Self::SubmittedSurface, Self::Error> {
3645        validate_surface_request(backend)?;
3646        Ok(submit_ready_device(session, |session| {
3647            self.decode_to_surface_impl(session, fmt, backend)
3648        }))
3649    }
3650
3651    fn submit_region_to_device(
3652        &mut self,
3653        session: &mut Self::Session,
3654        fmt: PixelFormat,
3655        roi: Rect,
3656        backend: BackendRequest,
3657    ) -> Result<Self::SubmittedSurface, Self::Error> {
3658        validate_surface_request(backend)?;
3659        Ok(submit_ready_device(session, |session| {
3660            self.decode_region_to_surface_impl(session, fmt, roi, backend)
3661        }))
3662    }
3663
3664    fn submit_scaled_to_device(
3665        &mut self,
3666        session: &mut Self::Session,
3667        fmt: PixelFormat,
3668        scale: Downscale,
3669        backend: BackendRequest,
3670    ) -> Result<Self::SubmittedSurface, Self::Error> {
3671        validate_surface_request(backend)?;
3672        Ok(submit_ready_device(session, |session| {
3673            self.decode_scaled_to_surface_impl(session, fmt, scale, backend)
3674        }))
3675    }
3676
3677    fn submit_region_scaled_to_device(
3678        &mut self,
3679        session: &mut Self::Session,
3680        fmt: PixelFormat,
3681        roi: Rect,
3682        scale: Downscale,
3683        backend: BackendRequest,
3684    ) -> Result<Self::SubmittedSurface, Self::Error> {
3685        validate_surface_request(backend)?;
3686        Ok(submit_ready_device(session, |session| {
3687            self.decode_region_scaled_to_surface_impl(session, fmt, roi, scale, backend)
3688        }))
3689    }
3690}
3691
3692#[cfg(test)]
3693mod dispatch_tests {
3694    use super::{
3695        format_cuda_idwt_batch_host_trace_row, htj2k_batched_dequant_dispatches,
3696        split_htj2k_subband_decode_dispatches, CudaIdwtBatchHostTraceRow,
3697    };
3698
3699    #[test]
3700    fn htj2k_decode_dispatch_split_separates_ht_and_dequant_counts() {
3701        assert_eq!(split_htj2k_subband_decode_dispatches(0), (0, 0));
3702        assert_eq!(split_htj2k_subband_decode_dispatches(1), (1, 0));
3703        assert_eq!(split_htj2k_subband_decode_dispatches(2), (1, 1));
3704        assert_eq!(split_htj2k_subband_decode_dispatches(3), (2, 1));
3705    }
3706
3707    #[test]
3708    fn htj2k_batched_dequant_dispatch_count_is_one_for_any_non_empty_batch() {
3709        assert_eq!(htj2k_batched_dequant_dispatches(0), 0);
3710        assert_eq!(htj2k_batched_dequant_dispatches(1), 1);
3711        assert_eq!(htj2k_batched_dequant_dispatches(48), 1);
3712    }
3713
3714    #[test]
3715    fn cuda_idwt_batch_host_trace_row_reports_host_split() {
3716        let row = CudaIdwtBatchHostTraceRow {
3717            component_count: 327,
3718            step_count: 5,
3719            output_alloc_us: 11,
3720            target_build_us: 22,
3721            enqueue_us: 33,
3722            output_take_count: 1635,
3723            output_pool_reuse_count: 1600,
3724            output_pool_alloc_count: 35,
3725            output_pool_scanned_count: 2400,
3726            output_pool_max_free_count: 1700,
3727            output_requested_bytes: 28,
3728        };
3729
3730        assert_eq!(
3731            format_cuda_idwt_batch_host_trace_row(row),
3732            "j2k_profile codec=j2k op=cuda_idwt_batch_host path=decode component_count=327 step_count=5 output_alloc_us=11 target_build_us=22 enqueue_us=33 output_take_count=1635 output_pool_reuse_count=1600 output_pool_alloc_count=35 output_pool_scanned_count=2400 output_pool_max_free_count=1700 output_requested_bytes=28"
3733        );
3734    }
3735}