Skip to main content

j2k_cuda/
codec.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use core::convert::Infallible;
4
5use j2k::{
6    adapter::device_plan::{DeviceDecodePlan, DeviceDecodeRequest},
7    J2kCodec as CpuCodec, J2kContext as CpuJ2kContext, J2kDecoder as CpuDecoder,
8    J2kScratchPool as CpuJ2kScratchPool,
9};
10use j2k_core::{
11    submit_ready_device, BackendRequest, Downscale, ImageCodec, PixelFormat, ReadySubmission, Rect,
12    TileBatchDecode, TileBatchDecodeDevice, TileBatchDecodeManyDevice, TileBatchDecodeSubmit,
13};
14
15use crate::runtime::{validate_surface_request, wrap_surface};
16use crate::{CudaSession, Error, J2kDecoder, Surface};
17
18/// Marker type implementing tile-batch CUDA surface decode traits.
19#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
20pub struct Codec;
21
22impl ImageCodec for Codec {
23    type Error = Error;
24    type Warning = Infallible;
25    type Pool = CpuJ2kScratchPool;
26}
27
28impl Codec {
29    fn supports_cuda_batch_format(fmt: PixelFormat) -> bool {
30        matches!(
31            fmt,
32            PixelFormat::Rgb8 | PixelFormat::Rgba8 | PixelFormat::Rgb16 | PixelFormat::Rgba16
33        )
34    }
35
36    #[cfg(feature = "cuda-runtime")]
37    fn decode_tiles_to_cuda_batch(
38        inputs: &[&[u8]],
39        fmt: PixelFormat,
40        session: &mut CudaSession,
41    ) -> Result<Vec<Surface>, Error> {
42        J2kDecoder::decode_batch_to_device_with_session(inputs, fmt, session)
43    }
44
45    #[cfg(not(feature = "cuda-runtime"))]
46    fn decode_tiles_to_cuda_batch(
47        _inputs: &[&[u8]],
48        _fmt: PixelFormat,
49        _session: &mut CudaSession,
50    ) -> Result<Vec<Surface>, Error> {
51        Err(Error::CudaUnavailable)
52    }
53
54    fn decode_tile_to_surface_impl(
55        ctx: &mut j2k_core::DecoderContext<CpuJ2kContext>,
56        session: &mut CudaSession,
57        pool: &mut CpuJ2kScratchPool,
58        input: &[u8],
59        fmt: PixelFormat,
60        backend: BackendRequest,
61    ) -> Result<Surface, Error> {
62        validate_surface_request(backend)?;
63        if matches!(backend, BackendRequest::Cuda) {
64            let mut decoder = J2kDecoder::new(input)?;
65            return decoder.decode_to_device_with_session(fmt, session);
66        }
67        let dims = CpuDecoder::inspect(input)?.dimensions;
68        let stride = dims.0 as usize * fmt.bytes_per_pixel();
69        let mut out = vec![0u8; stride * dims.1 as usize];
70        CpuCodec::decode_tile(ctx, pool, input, &mut out, stride, fmt)?;
71        wrap_surface(out, dims, fmt, backend, session)
72    }
73
74    fn decode_tile_region_to_surface_impl(
75        ctx: &mut j2k_core::DecoderContext<CpuJ2kContext>,
76        session: &mut CudaSession,
77        pool: &mut CpuJ2kScratchPool,
78        input: &[u8],
79        fmt: PixelFormat,
80        roi: Rect,
81        backend: BackendRequest,
82    ) -> Result<Surface, Error> {
83        validate_surface_request(backend)?;
84        if matches!(backend, BackendRequest::Cuda) {
85            let mut decoder = J2kDecoder::new(input)?;
86            return decoder.decode_region_to_device_with_session(fmt, roi, session);
87        }
88        let dims = DeviceDecodePlan::for_image(
89            CpuDecoder::inspect(input)?.dimensions,
90            DeviceDecodeRequest::Region { roi },
91        )?
92        .output_dims();
93        let stride = dims.0 as usize * fmt.bytes_per_pixel();
94        let mut out = vec![0u8; stride * dims.1 as usize];
95        CpuCodec::decode_tile_region(ctx, pool, input, &mut out, stride, fmt, roi)?;
96        wrap_surface(out, dims, fmt, backend, session)
97    }
98
99    fn decode_tile_scaled_to_surface_impl(
100        ctx: &mut j2k_core::DecoderContext<CpuJ2kContext>,
101        session: &mut CudaSession,
102        pool: &mut CpuJ2kScratchPool,
103        input: &[u8],
104        fmt: PixelFormat,
105        scale: Downscale,
106        backend: BackendRequest,
107    ) -> Result<Surface, Error> {
108        validate_surface_request(backend)?;
109        if matches!(backend, BackendRequest::Cuda) {
110            let mut decoder = J2kDecoder::new(input)?;
111            return decoder.decode_scaled_to_device_with_session(fmt, scale, session);
112        }
113        let dims = DeviceDecodePlan::for_image(
114            CpuDecoder::inspect(input)?.dimensions,
115            DeviceDecodeRequest::Scaled { scale },
116        )?
117        .output_dims();
118        let stride = dims.0 as usize * fmt.bytes_per_pixel();
119        let mut out = vec![0u8; stride * dims.1 as usize];
120        CpuCodec::decode_tile_scaled(ctx, pool, input, &mut out, stride, fmt, scale)?;
121        wrap_surface(out, dims, fmt, backend, session)
122    }
123
124    #[allow(clippy::too_many_arguments)]
125    fn decode_tile_region_scaled_to_surface_impl(
126        ctx: &mut j2k_core::DecoderContext<CpuJ2kContext>,
127        session: &mut CudaSession,
128        pool: &mut CpuJ2kScratchPool,
129        input: &[u8],
130        fmt: PixelFormat,
131        roi: Rect,
132        scale: Downscale,
133        backend: BackendRequest,
134    ) -> Result<Surface, Error> {
135        validate_surface_request(backend)?;
136        if matches!(backend, BackendRequest::Cuda) {
137            let mut decoder = J2kDecoder::new(input)?;
138            return decoder.decode_region_scaled_to_device_with_session(fmt, roi, scale, session);
139        }
140        let dims = DeviceDecodePlan::for_image(
141            CpuDecoder::inspect(input)?.dimensions,
142            DeviceDecodeRequest::RegionScaled { roi, scale },
143        )?
144        .output_dims();
145        let stride = dims.0 as usize * fmt.bytes_per_pixel();
146        let mut out = vec![0u8; stride * dims.1 as usize];
147        CpuCodec::decode_tile_region_scaled(ctx, pool, input, &mut out, stride, fmt, roi, scale)?;
148        wrap_surface(out, dims, fmt, backend, session)
149    }
150}
151
152impl TileBatchDecodeSubmit for Codec {
153    type Context = CpuJ2kContext;
154    type Session = CudaSession;
155    type DeviceSurface = Surface;
156    type SubmittedSurface = ReadySubmission<Surface, Error>;
157
158    fn submit_tile_to_device(
159        ctx: &mut j2k_core::DecoderContext<Self::Context>,
160        session: &mut Self::Session,
161        pool: &mut Self::Pool,
162        input: &[u8],
163        fmt: PixelFormat,
164        backend: BackendRequest,
165    ) -> Result<Self::SubmittedSurface, Self::Error> {
166        validate_surface_request(backend)?;
167        Ok(submit_ready_device(session, |session| {
168            Self::decode_tile_to_surface_impl(ctx, session, pool, input, fmt, backend)
169        }))
170    }
171
172    fn submit_tile_region_to_device(
173        ctx: &mut j2k_core::DecoderContext<Self::Context>,
174        session: &mut Self::Session,
175        pool: &mut Self::Pool,
176        input: &[u8],
177        fmt: PixelFormat,
178        roi: Rect,
179        backend: BackendRequest,
180    ) -> Result<Self::SubmittedSurface, Self::Error> {
181        validate_surface_request(backend)?;
182        Ok(submit_ready_device(session, |session| {
183            Self::decode_tile_region_to_surface_impl(ctx, session, pool, input, fmt, roi, backend)
184        }))
185    }
186
187    fn submit_tile_scaled_to_device(
188        ctx: &mut j2k_core::DecoderContext<Self::Context>,
189        session: &mut Self::Session,
190        pool: &mut Self::Pool,
191        input: &[u8],
192        fmt: PixelFormat,
193        scale: Downscale,
194        backend: BackendRequest,
195    ) -> Result<Self::SubmittedSurface, Self::Error> {
196        validate_surface_request(backend)?;
197        Ok(submit_ready_device(session, |session| {
198            Self::decode_tile_scaled_to_surface_impl(ctx, session, pool, input, fmt, scale, backend)
199        }))
200    }
201
202    fn submit_tile_region_scaled_to_device(
203        ctx: &mut j2k_core::DecoderContext<Self::Context>,
204        session: &mut Self::Session,
205        pool: &mut Self::Pool,
206        input: &[u8],
207        fmt: PixelFormat,
208        roi: Rect,
209        scale: Downscale,
210        backend: BackendRequest,
211    ) -> Result<Self::SubmittedSurface, Self::Error> {
212        validate_surface_request(backend)?;
213        Ok(submit_ready_device(session, |session| {
214            Self::decode_tile_region_scaled_to_surface_impl(
215                ctx, session, pool, input, fmt, roi, scale, backend,
216            )
217        }))
218    }
219}
220
221impl TileBatchDecodeDevice for Codec {
222    type Context = CpuJ2kContext;
223    type DeviceSurface = Surface;
224}
225
226impl TileBatchDecodeManyDevice for Codec {
227    type Context = CpuJ2kContext;
228    type DeviceSurface = Surface;
229
230    fn decode_tiles_to_device(
231        ctx: &mut j2k_core::DecoderContext<Self::Context>,
232        pool: &mut Self::Pool,
233        inputs: &[&[u8]],
234        fmt: PixelFormat,
235        backend: BackendRequest,
236    ) -> Result<Vec<Self::DeviceSurface>, Self::Error> {
237        validate_surface_request(backend)?;
238        if inputs.is_empty() {
239            return Ok(Vec::new());
240        }
241
242        let mut session = CudaSession::default();
243        if matches!(backend, BackendRequest::Cuda) && Self::supports_cuda_batch_format(fmt) {
244            return Self::decode_tiles_to_cuda_batch(inputs, fmt, &mut session);
245        }
246
247        inputs
248            .iter()
249            .map(|input| {
250                Self::decode_tile_to_surface_impl(ctx, &mut session, pool, input, fmt, backend)
251            })
252            .collect()
253    }
254}
255
256#[cfg(all(test, feature = "cuda-runtime"))]
257mod tests {
258    use j2k_core::{BackendRequest, DecoderContext, PixelFormat, TileBatchDecodeManyDevice};
259    use j2k_test_support::{cuda_runtime_required, htj2k_rgb8_pattern_fixture};
260
261    use super::{Codec, CpuJ2kContext, CpuJ2kScratchPool};
262    use crate::decoder::{
263        testing_cuda_htj2k_batch_decode_calls, testing_reset_cuda_htj2k_batch_decode_calls,
264    };
265    use crate::{Error, SurfaceResidency};
266
267    #[test]
268    fn explicit_cuda_rgb_many_decode_uses_batch_api_once() {
269        testing_reset_cuda_htj2k_batch_decode_calls();
270        let fixture = rgb8_htj2k_fixture(32, 32);
271        let inputs = [fixture.as_slice(), fixture.as_slice()];
272        let mut ctx = DecoderContext::<CpuJ2kContext>::new();
273        let mut pool = CpuJ2kScratchPool::new();
274
275        let result = Codec::decode_tiles_to_device(
276            &mut ctx,
277            &mut pool,
278            &inputs,
279            PixelFormat::Rgb8,
280            BackendRequest::Cuda,
281        );
282
283        assert_eq!(testing_cuda_htj2k_batch_decode_calls(), 1);
284        match result {
285            Ok(surfaces) => {
286                assert_eq!(surfaces.len(), inputs.len());
287                for surface in surfaces {
288                    assert_eq!(surface.residency(), SurfaceResidency::CudaResidentDecode);
289                    assert_eq!(surface.as_host_bytes(), None);
290                }
291            }
292            Err(Error::CudaUnavailable) => {
293                assert!(!cuda_runtime_required());
294            }
295            Err(error) => panic!("unexpected strict CUDA RGB batch error: {error}"),
296        }
297    }
298
299    fn rgb8_htj2k_fixture(width: u32, height: u32) -> Vec<u8> {
300        htj2k_rgb8_pattern_fixture(width, height, 17)
301    }
302}