1#[cfg(feature = "cuda-runtime")]
4use j2k_cuda_runtime::{
5 CudaBufferPool, CudaContext, CudaHtj2kDecodeTableResources, CudaHtj2kDecodeTables,
6};
7#[cfg(feature = "cuda-runtime")]
8use j2k_native::{ht_uvlc_table0, ht_uvlc_table1, ht_vlc_table0, ht_vlc_table1};
9#[cfg(all(test, feature = "cuda-runtime"))]
10use std::sync::atomic::{AtomicUsize, Ordering};
11
12#[cfg(feature = "cuda-runtime")]
13use crate::runtime::cuda_error;
14#[cfg(feature = "cuda-runtime")]
15use crate::Error;
16
17#[cfg(all(test, feature = "cuda-runtime"))]
18static HTJ2K_DECODE_TABLE_UPLOADS: AtomicUsize = AtomicUsize::new(0);
19
20#[derive(Clone, Default)]
22pub struct CudaSession {
23 submissions: u64,
24 #[cfg(feature = "cuda-runtime")]
25 context: Option<CudaContext>,
26 #[cfg(feature = "cuda-runtime")]
27 htj2k_decode_tables: Option<CudaHtj2kDecodeTableResources>,
28 #[cfg(feature = "cuda-runtime")]
29 decode_buffer_pool: Option<CudaBufferPool>,
30 #[cfg(feature = "cuda-runtime")]
31 decode_batch_buffer_pool: Option<CudaBufferPool>,
32}
33
34impl CudaSession {
35 pub fn submissions(&self) -> u64 {
37 self.submissions
38 }
39
40 #[cfg(feature = "cuda-runtime")]
41 pub fn is_runtime_initialized(&self) -> bool {
43 self.context.is_some()
44 }
45
46 #[cfg(feature = "cuda-runtime")]
47 pub(crate) fn cuda_context(&mut self) -> Result<CudaContext, Error> {
48 if self.context.is_none() {
49 self.context = Some(CudaContext::system_default().map_err(cuda_error)?);
50 }
51 self.context.clone().ok_or(Error::CudaUnavailable)
52 }
53
54 #[cfg(feature = "cuda-runtime")]
55 pub(crate) fn htj2k_decode_table_resources(
56 &mut self,
57 ) -> Result<CudaHtj2kDecodeTableResources, Error> {
58 if let Some(tables) = &self.htj2k_decode_tables {
59 return Ok(tables.clone());
60 }
61
62 let context = self.cuda_context()?;
63 let tables = CudaHtj2kDecodeTables {
64 vlc_table0: ht_vlc_table0(),
65 vlc_table1: ht_vlc_table1(),
66 uvlc_table0: ht_uvlc_table0(),
67 uvlc_table1: ht_uvlc_table1(),
68 };
69 let resources = context
70 .upload_htj2k_decode_table_resources(tables)
71 .map_err(cuda_error)?;
72 #[cfg(test)]
73 HTJ2K_DECODE_TABLE_UPLOADS.fetch_add(1, Ordering::Relaxed);
74 self.htj2k_decode_tables = Some(resources.clone());
75 Ok(resources)
76 }
77
78 #[cfg(feature = "cuda-runtime")]
79 pub(crate) fn decode_buffer_pool(&mut self) -> Result<CudaBufferPool, Error> {
80 if let Some(pool) = &self.decode_buffer_pool {
81 return Ok(pool.clone());
82 }
83 let context = self.cuda_context()?;
84 let pool = context.buffer_pool();
85 self.decode_buffer_pool = Some(pool.clone());
86 Ok(pool)
87 }
88
89 #[cfg(feature = "cuda-runtime")]
90 pub(crate) fn decode_batch_buffer_pool(&mut self) -> Result<CudaBufferPool, Error> {
91 if let Some(pool) = &self.decode_batch_buffer_pool {
92 return Ok(pool.clone());
93 }
94 let context = self.cuda_context()?;
95 let pool = context.best_fit_buffer_pool();
96 self.decode_batch_buffer_pool = Some(pool.clone());
97 Ok(pool)
98 }
99}
100
101impl j2k_core::DeviceSubmitSession for CudaSession {
102 fn record_submit(&mut self) {
103 self.submissions = self.submissions.saturating_add(1);
104 }
105}
106
107impl j2k_core::AcceleratorSession for CudaSession {
108 fn backend_kind(&self) -> j2k_core::BackendKind {
109 j2k_core::BackendKind::Cuda
110 }
111
112 fn execution_stats(&self) -> j2k_core::ExecutionStats {
113 j2k_core::ExecutionStats {
114 submissions: self.submissions,
115 ..j2k_core::ExecutionStats::default()
116 }
117 }
118}
119
120#[cfg(all(test, feature = "cuda-runtime"))]
121pub(crate) fn reset_htj2k_decode_table_uploads_for_test() {
122 HTJ2K_DECODE_TABLE_UPLOADS.store(0, Ordering::Relaxed);
123}
124
125#[cfg(all(test, feature = "cuda-runtime"))]
126pub(crate) fn htj2k_decode_table_uploads_for_test() -> usize {
127 HTJ2K_DECODE_TABLE_UPLOADS.load(Ordering::Relaxed)
128}
129
130impl std::fmt::Debug for CudaSession {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 let mut debug = f.debug_struct("CudaSession");
133 debug.field("submissions", &self.submissions);
134 #[cfg(feature = "cuda-runtime")]
135 debug.field("runtime_initialized", &self.is_runtime_initialized());
136 #[cfg(feature = "cuda-runtime")]
137 debug.field(
138 "htj2k_decode_tables_cached",
139 &self.htj2k_decode_tables.is_some(),
140 );
141 #[cfg(feature = "cuda-runtime")]
142 debug.field(
143 "decode_buffer_pool_cached",
144 &self.decode_buffer_pool.is_some(),
145 );
146 #[cfg(feature = "cuda-runtime")]
147 debug.field(
148 "decode_batch_buffer_pool_cached",
149 &self.decode_batch_buffer_pool.is_some(),
150 );
151 debug.finish_non_exhaustive()
152 }
153}
154
155#[cfg(all(test, feature = "cuda-runtime"))]
156mod tests {
157 use super::CudaSession;
158 use crate::Error;
159
160 fn cuda_required() -> bool {
161 std::env::var_os("J2K_REQUIRE_CUDA_RUNTIME").is_some()
162 }
163
164 #[test]
165 fn htj2k_decode_tables_are_uploaded_once_per_session() {
166 crate::session::reset_htj2k_decode_table_uploads_for_test();
167 let mut session = CudaSession::default();
168
169 let first = session.htj2k_decode_table_resources();
170 if matches!(
171 first,
172 Err(Error::CudaUnavailable | Error::CudaRuntime { .. })
173 ) && !cuda_required()
174 {
175 return;
176 }
177 first.expect("first HTJ2K decode table upload");
178 session
179 .htj2k_decode_table_resources()
180 .expect("cached HTJ2K decode tables");
181
182 assert_eq!(crate::session::htj2k_decode_table_uploads_for_test(), 1);
183 }
184
185 #[test]
186 fn cuda_session_reuses_one_decode_buffer_pool_when_required() {
187 let mut session = CudaSession::default();
188
189 let first = session.decode_buffer_pool();
190 if matches!(
191 first,
192 Err(Error::CudaUnavailable | Error::CudaRuntime { .. })
193 ) && !cuda_required()
194 {
195 return;
196 }
197 let first = first.expect("first decode buffer pool");
198 let second = session
199 .decode_buffer_pool()
200 .expect("cached decode buffer pool");
201 {
202 let buffer = first.take(16).expect("pooled decode buffer");
203 assert_eq!(buffer.byte_len(), 16);
204 }
205
206 assert!(second.cached_count().expect("shared pool cached count") >= 1);
207 }
208}