Skip to main content

j2k_cuda/
session.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#[cfg(feature = "cuda-runtime")]
4use j2k_cuda_runtime::{
5    CudaBufferPool, CudaContext, CudaHtj2kDecodeTableResources, CudaHtj2kDecodeTables,
6};
7#[cfg(feature = "cuda-runtime")]
8use j2k_native::{ht_uvlc_table0, ht_uvlc_table1, ht_vlc_table0, ht_vlc_table1};
9#[cfg(all(test, feature = "cuda-runtime"))]
10use std::sync::atomic::{AtomicUsize, Ordering};
11
12#[cfg(feature = "cuda-runtime")]
13use crate::runtime::cuda_error;
14#[cfg(feature = "cuda-runtime")]
15use crate::Error;
16
17#[cfg(all(test, feature = "cuda-runtime"))]
18static HTJ2K_DECODE_TABLE_UPLOADS: AtomicUsize = AtomicUsize::new(0);
19
20/// Mutable CUDA adapter session reused across submissions.
21#[derive(Clone, Default)]
22pub struct CudaSession {
23    submissions: u64,
24    #[cfg(feature = "cuda-runtime")]
25    context: Option<CudaContext>,
26    #[cfg(feature = "cuda-runtime")]
27    htj2k_decode_tables: Option<CudaHtj2kDecodeTableResources>,
28    #[cfg(feature = "cuda-runtime")]
29    decode_buffer_pool: Option<CudaBufferPool>,
30    #[cfg(feature = "cuda-runtime")]
31    decode_batch_buffer_pool: Option<CudaBufferPool>,
32}
33
34impl CudaSession {
35    /// Number of submissions recorded by this session.
36    pub fn submissions(&self) -> u64 {
37        self.submissions
38    }
39
40    #[cfg(feature = "cuda-runtime")]
41    /// True when a CUDA runtime context has been initialized.
42    pub fn is_runtime_initialized(&self) -> bool {
43        self.context.is_some()
44    }
45
46    #[cfg(feature = "cuda-runtime")]
47    pub(crate) fn cuda_context(&mut self) -> Result<CudaContext, Error> {
48        if self.context.is_none() {
49            self.context = Some(CudaContext::system_default().map_err(cuda_error)?);
50        }
51        self.context.clone().ok_or(Error::CudaUnavailable)
52    }
53
54    #[cfg(feature = "cuda-runtime")]
55    pub(crate) fn htj2k_decode_table_resources(
56        &mut self,
57    ) -> Result<CudaHtj2kDecodeTableResources, Error> {
58        if let Some(tables) = &self.htj2k_decode_tables {
59            return Ok(tables.clone());
60        }
61
62        let context = self.cuda_context()?;
63        let tables = CudaHtj2kDecodeTables {
64            vlc_table0: ht_vlc_table0(),
65            vlc_table1: ht_vlc_table1(),
66            uvlc_table0: ht_uvlc_table0(),
67            uvlc_table1: ht_uvlc_table1(),
68        };
69        let resources = context
70            .upload_htj2k_decode_table_resources(tables)
71            .map_err(cuda_error)?;
72        #[cfg(test)]
73        HTJ2K_DECODE_TABLE_UPLOADS.fetch_add(1, Ordering::Relaxed);
74        self.htj2k_decode_tables = Some(resources.clone());
75        Ok(resources)
76    }
77
78    #[cfg(feature = "cuda-runtime")]
79    pub(crate) fn decode_buffer_pool(&mut self) -> Result<CudaBufferPool, Error> {
80        if let Some(pool) = &self.decode_buffer_pool {
81            return Ok(pool.clone());
82        }
83        let context = self.cuda_context()?;
84        let pool = context.buffer_pool();
85        self.decode_buffer_pool = Some(pool.clone());
86        Ok(pool)
87    }
88
89    #[cfg(feature = "cuda-runtime")]
90    pub(crate) fn decode_batch_buffer_pool(&mut self) -> Result<CudaBufferPool, Error> {
91        if let Some(pool) = &self.decode_batch_buffer_pool {
92            return Ok(pool.clone());
93        }
94        let context = self.cuda_context()?;
95        let pool = context.best_fit_buffer_pool();
96        self.decode_batch_buffer_pool = Some(pool.clone());
97        Ok(pool)
98    }
99}
100
101impl j2k_core::DeviceSubmitSession for CudaSession {
102    fn record_submit(&mut self) {
103        self.submissions = self.submissions.saturating_add(1);
104    }
105}
106
107impl j2k_core::AcceleratorSession for CudaSession {
108    fn backend_kind(&self) -> j2k_core::BackendKind {
109        j2k_core::BackendKind::Cuda
110    }
111
112    fn execution_stats(&self) -> j2k_core::ExecutionStats {
113        j2k_core::ExecutionStats {
114            submissions: self.submissions,
115            ..j2k_core::ExecutionStats::default()
116        }
117    }
118}
119
120#[cfg(all(test, feature = "cuda-runtime"))]
121pub(crate) fn reset_htj2k_decode_table_uploads_for_test() {
122    HTJ2K_DECODE_TABLE_UPLOADS.store(0, Ordering::Relaxed);
123}
124
125#[cfg(all(test, feature = "cuda-runtime"))]
126pub(crate) fn htj2k_decode_table_uploads_for_test() -> usize {
127    HTJ2K_DECODE_TABLE_UPLOADS.load(Ordering::Relaxed)
128}
129
130impl std::fmt::Debug for CudaSession {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        let mut debug = f.debug_struct("CudaSession");
133        debug.field("submissions", &self.submissions);
134        #[cfg(feature = "cuda-runtime")]
135        debug.field("runtime_initialized", &self.is_runtime_initialized());
136        #[cfg(feature = "cuda-runtime")]
137        debug.field(
138            "htj2k_decode_tables_cached",
139            &self.htj2k_decode_tables.is_some(),
140        );
141        #[cfg(feature = "cuda-runtime")]
142        debug.field(
143            "decode_buffer_pool_cached",
144            &self.decode_buffer_pool.is_some(),
145        );
146        #[cfg(feature = "cuda-runtime")]
147        debug.field(
148            "decode_batch_buffer_pool_cached",
149            &self.decode_batch_buffer_pool.is_some(),
150        );
151        debug.finish_non_exhaustive()
152    }
153}
154
155#[cfg(all(test, feature = "cuda-runtime"))]
156mod tests {
157    use super::CudaSession;
158    use crate::Error;
159
160    fn cuda_required() -> bool {
161        std::env::var_os("J2K_REQUIRE_CUDA_RUNTIME").is_some()
162    }
163
164    #[test]
165    fn htj2k_decode_tables_are_uploaded_once_per_session() {
166        crate::session::reset_htj2k_decode_table_uploads_for_test();
167        let mut session = CudaSession::default();
168
169        let first = session.htj2k_decode_table_resources();
170        if matches!(
171            first,
172            Err(Error::CudaUnavailable | Error::CudaRuntime { .. })
173        ) && !cuda_required()
174        {
175            return;
176        }
177        first.expect("first HTJ2K decode table upload");
178        session
179            .htj2k_decode_table_resources()
180            .expect("cached HTJ2K decode tables");
181
182        assert_eq!(crate::session::htj2k_decode_table_uploads_for_test(), 1);
183    }
184
185    #[test]
186    fn cuda_session_reuses_one_decode_buffer_pool_when_required() {
187        let mut session = CudaSession::default();
188
189        let first = session.decode_buffer_pool();
190        if matches!(
191            first,
192            Err(Error::CudaUnavailable | Error::CudaRuntime { .. })
193        ) && !cuda_required()
194        {
195            return;
196        }
197        let first = first.expect("first decode buffer pool");
198        let second = session
199            .decode_buffer_pool()
200            .expect("cached decode buffer pool");
201        {
202            let buffer = first.take(16).expect("pooled decode buffer");
203            assert_eq!(buffer.byte_len(), 16);
204        }
205
206        assert!(second.cached_count().expect("shared pool cached count") >= 1);
207    }
208}