1use core::fmt::Write as _;
4use std::cell::RefCell;
5use std::sync::OnceLock;
6use std::time::Instant;
7
8use j2k_core::BackendKind;
9use j2k_profile::{profile_stage_mode_from_env, ProfileStageMode};
10
11use crate::SurfaceResidency;
12
13const PROFILE_ENV_VAR: &str = "J2K_PROFILE_STAGES";
14const CUDA_TRACE_ENV_VAR: &str = "J2K_CUDA_TRACE";
15
16thread_local! {
17 static PROFILE_SUMMARY: RefCell<j2k_profile::ProfileSummary> =
18 RefCell::new(j2k_profile::ProfileSummary::default().emit_on_drop());
19}
20
21#[derive(Clone, Debug, Default, PartialEq, Eq)]
23pub struct CudaHtj2kDecodeProfileDetail {
24 pub wall_total_us: u128,
26 pub stage_sum_us: u128,
28 pub table_upload_us: u128,
30 pub payload_upload_us: u128,
36 pub job_upload_us: u128,
38 pub status_d2h_us: u128,
40 pub output_d2h_us: u128,
42 pub ht_dispatch_count: usize,
44 pub dequant_dispatch_count: usize,
46 pub idwt_dispatch_count: usize,
48 pub mct_dispatch_count: usize,
50 pub store_dispatch_count: usize,
52}
53
54#[derive(Clone, Debug, Default, PartialEq, Eq)]
56pub struct CudaHtj2kProfileReport {
57 pub parse_us: u128,
59 pub plan_us: u128,
61 pub flatten_us: u128,
63 pub h2d_us: u128,
65 pub ht_cleanup_us: u128,
67 pub ht_refine_us: u128,
69 pub dequant_us: u128,
71 pub idwt_us: u128,
73 pub mct_us: u128,
75 pub store_us: u128,
77 pub total_us: u128,
81 pub block_count: usize,
83 pub payload_bytes: usize,
85 pub dispatch_count: usize,
87 pub residency: SurfaceResidency,
89 pub detail: CudaHtj2kDecodeProfileDetail,
91}
92
93impl CudaHtj2kProfileReport {
94 pub fn emit(&self, path: &str) {
96 emit_htj2k_profile_row(path, self);
97 export_trace_if_requested(path, self);
98 }
99}
100
101#[derive(Clone, Debug, PartialEq, Eq)]
103pub struct CudaHtj2kEncodeProfileReport {
104 pub deinterleave_us: u128,
106 pub mct_us: u128,
108 pub dwt_us: u128,
110 pub quantize_us: u128,
112 pub ht_encode_us: u128,
114 pub packetize_us: u128,
116 pub total_us: u128,
118 pub input_bytes: usize,
120 pub codestream_bytes: usize,
122 pub block_count: usize,
124 pub dispatch_count: usize,
126 pub backend: BackendKind,
128}
129
130impl Default for CudaHtj2kEncodeProfileReport {
131 fn default() -> Self {
132 Self {
133 deinterleave_us: 0,
134 mct_us: 0,
135 dwt_us: 0,
136 quantize_us: 0,
137 ht_encode_us: 0,
138 packetize_us: 0,
139 total_us: 0,
140 input_bytes: 0,
141 codestream_bytes: 0,
142 block_count: 0,
143 dispatch_count: 0,
144 backend: BackendKind::Cpu,
145 }
146 }
147}
148
149impl CudaHtj2kEncodeProfileReport {
150 pub fn emit(&self, path: &str) {
152 emit_htj2k_encode_profile_row(path, self);
153 export_encode_trace_if_requested(path, self);
154 }
155}
156
157pub(crate) type ProfileInstant = Instant;
158
159fn profile_stage_mode() -> ProfileStageMode {
160 static MODE: OnceLock<ProfileStageMode> = OnceLock::new();
161 *MODE.get_or_init(|| profile_stage_mode_from_env(PROFILE_ENV_VAR))
162}
163
164pub(crate) fn profile_stages_enabled() -> bool {
165 profile_stage_mode() != ProfileStageMode::Disabled
166}
167
168pub(crate) fn profile_now(enabled: bool) -> Option<ProfileInstant> {
169 enabled.then(Instant::now)
170}
171
172pub(crate) fn elapsed_us(start: Option<ProfileInstant>) -> u128 {
173 start.map_or(0, |start| start.elapsed().as_micros())
174}
175
176#[cfg_attr(not(feature = "cuda-runtime"), allow(dead_code))]
177pub(crate) fn add_payload_resource_upload_us(
178 report: &mut CudaHtj2kProfileReport,
179 elapsed_us: u128,
180) {
181 report.h2d_us = report.h2d_us.saturating_add(elapsed_us);
182 report.detail.payload_upload_us = report.detail.payload_upload_us.saturating_add(elapsed_us);
183}
184
185#[cfg_attr(not(feature = "cuda-runtime"), allow(dead_code))]
186pub(crate) fn finalize_decode_total_us(report: &mut CudaHtj2kProfileReport) {
187 report.total_us = [
188 report.parse_us,
189 report.plan_us,
190 report.flatten_us,
191 report.h2d_us,
192 report.ht_cleanup_us,
193 report.ht_refine_us,
194 report.dequant_us,
195 report.idwt_us,
196 report.mct_us,
197 report.store_us,
198 ]
199 .into_iter()
200 .fold(0u128, u128::saturating_add);
201 report.detail.stage_sum_us = report.total_us;
202}
203
204pub(crate) fn emit_htj2k_profile_row(path: &str, report: &CudaHtj2kProfileReport) {
205 let parse_us = report.parse_us.to_string();
206 let plan_us = report.plan_us.to_string();
207 let flatten_us = report.flatten_us.to_string();
208 let h2d_us = report.h2d_us.to_string();
209 let ht_cleanup_us = report.ht_cleanup_us.to_string();
210 let ht_refine_us = report.ht_refine_us.to_string();
211 let dequant_us = report.dequant_us.to_string();
212 let idwt_us = report.idwt_us.to_string();
213 let mct_us = report.mct_us.to_string();
214 let store_us = report.store_us.to_string();
215 let total_us = report.total_us.to_string();
216 let block_count = report.block_count.to_string();
217 let payload_bytes = report.payload_bytes.to_string();
218 let dispatch_count = report.dispatch_count.to_string();
219 let residency = format!("{:?}", report.residency);
220 let wall_total_us = report.detail.wall_total_us.to_string();
221 let stage_sum_us = report.detail.stage_sum_us.to_string();
222 let table_upload_us = report.detail.table_upload_us.to_string();
223 let payload_upload_us = report.detail.payload_upload_us.to_string();
224 let job_upload_us = report.detail.job_upload_us.to_string();
225 let status_d2h_us = report.detail.status_d2h_us.to_string();
226 let output_d2h_us = report.detail.output_d2h_us.to_string();
227 let ht_dispatch_count = report.detail.ht_dispatch_count.to_string();
228 let dequant_dispatch_count = report.detail.dequant_dispatch_count.to_string();
229 let idwt_dispatch_count = report.detail.idwt_dispatch_count.to_string();
230 let mct_dispatch_count = report.detail.mct_dispatch_count.to_string();
231 let store_dispatch_count = report.detail.store_dispatch_count.to_string();
232
233 j2k_profile::emit_profile_row(
234 profile_stage_mode(),
235 &PROFILE_SUMMARY,
236 "j2k",
237 "cuda_htj2k",
238 path,
239 &[
240 ("parse_us", parse_us.as_str()),
241 ("plan_us", plan_us.as_str()),
242 ("flatten_us", flatten_us.as_str()),
243 ("h2d_us", h2d_us.as_str()),
244 ("ht_cleanup_us", ht_cleanup_us.as_str()),
245 ("ht_refine_us", ht_refine_us.as_str()),
246 ("dequant_us", dequant_us.as_str()),
247 ("idwt_us", idwt_us.as_str()),
248 ("mct_us", mct_us.as_str()),
249 ("store_us", store_us.as_str()),
250 ("total_us", total_us.as_str()),
251 ("block_count", block_count.as_str()),
252 ("payload_bytes", payload_bytes.as_str()),
253 ("dispatch_count", dispatch_count.as_str()),
254 ("residency", residency.as_str()),
255 ("wall_total_us", wall_total_us.as_str()),
256 ("stage_sum_us", stage_sum_us.as_str()),
257 ("table_upload_us", table_upload_us.as_str()),
258 ("payload_upload_us", payload_upload_us.as_str()),
259 ("job_upload_us", job_upload_us.as_str()),
260 ("status_d2h_us", status_d2h_us.as_str()),
261 ("output_d2h_us", output_d2h_us.as_str()),
262 ("ht_dispatch_count", ht_dispatch_count.as_str()),
263 ("dequant_dispatch_count", dequant_dispatch_count.as_str()),
264 ("idwt_dispatch_count", idwt_dispatch_count.as_str()),
265 ("mct_dispatch_count", mct_dispatch_count.as_str()),
266 ("store_dispatch_count", store_dispatch_count.as_str()),
267 ],
268 );
269}
270
271pub(crate) fn emit_htj2k_encode_profile_row(path: &str, report: &CudaHtj2kEncodeProfileReport) {
272 let deinterleave_us = report.deinterleave_us.to_string();
273 let mct_us = report.mct_us.to_string();
274 let dwt_us = report.dwt_us.to_string();
275 let quantize_us = report.quantize_us.to_string();
276 let ht_encode_us = report.ht_encode_us.to_string();
277 let packetize_us = report.packetize_us.to_string();
278 let total_us = report.total_us.to_string();
279 let input_bytes = report.input_bytes.to_string();
280 let codestream_bytes = report.codestream_bytes.to_string();
281 let block_count = report.block_count.to_string();
282 let dispatch_count = report.dispatch_count.to_string();
283 let backend = format!("{:?}", report.backend);
284
285 j2k_profile::emit_profile_row(
286 profile_stage_mode(),
287 &PROFILE_SUMMARY,
288 "j2k",
289 "cuda_htj2k_encode",
290 path,
291 &[
292 ("deinterleave_us", deinterleave_us.as_str()),
293 ("mct_us", mct_us.as_str()),
294 ("dwt_us", dwt_us.as_str()),
295 ("quantize_us", quantize_us.as_str()),
296 ("ht_encode_us", ht_encode_us.as_str()),
297 ("packetize_us", packetize_us.as_str()),
298 ("total_us", total_us.as_str()),
299 ("input_bytes", input_bytes.as_str()),
300 ("codestream_bytes", codestream_bytes.as_str()),
301 ("block_count", block_count.as_str()),
302 ("dispatch_count", dispatch_count.as_str()),
303 ("backend", backend.as_str()),
304 ],
305 );
306}
307
308fn export_trace_if_requested(path: &str, report: &CudaHtj2kProfileReport) {
309 let Some(trace_path) = std::env::var_os(CUDA_TRACE_ENV_VAR) else {
310 return;
311 };
312 let trace = chrome_trace_json(path, report);
313 if let Err(error) = std::fs::write(&trace_path, trace) {
314 std::eprintln!("j2k_profile codec=j2k op=cuda_htj2k_trace path=cuda error={error}");
315 }
316}
317
318fn chrome_trace_json(path: &str, report: &CudaHtj2kProfileReport) -> String {
319 let stages = [
320 ("parse", report.parse_us),
321 ("plan", report.plan_us),
322 ("flatten", report.flatten_us),
323 ("h2d", report.h2d_us),
324 ("ht_cleanup", report.ht_cleanup_us),
325 ("ht_refine", report.ht_refine_us),
326 ("dequant", report.dequant_us),
327 ("idwt", report.idwt_us),
328 ("mct", report.mct_us),
329 ("store", report.store_us),
330 ];
331 let mut trace = String::from("{\"traceEvents\":[");
332 let mut ts = 0u128;
333 for (index, (name, dur)) in stages.iter().enumerate() {
334 if index != 0 {
335 trace.push(',');
336 }
337 let event_ts = if *name == "ht_refine" {
338 ts.saturating_sub(report.ht_cleanup_us)
339 } else {
340 ts
341 };
342 write!(
343 trace,
344 "{{\"name\":\"{name}\",\"cat\":\"{path}\",\"ph\":\"X\",\"pid\":1,\"tid\":1,\"ts\":{event_ts},\"dur\":{dur}}}"
345 )
346 .expect("writing trace JSON to String failed");
347 if *name != "ht_refine" {
348 ts = ts.saturating_add(*dur);
349 }
350 }
351 trace.push_str("]}");
352 trace
353}
354
355fn export_encode_trace_if_requested(path: &str, report: &CudaHtj2kEncodeProfileReport) {
356 let Some(trace_path) = std::env::var_os(CUDA_TRACE_ENV_VAR) else {
357 return;
358 };
359 let trace = chrome_encode_trace_json(path, report);
360 if let Err(error) = std::fs::write(&trace_path, trace) {
361 std::eprintln!("j2k_profile codec=j2k op=cuda_htj2k_encode_trace path=cuda error={error}");
362 }
363}
364
365fn chrome_encode_trace_json(path: &str, report: &CudaHtj2kEncodeProfileReport) -> String {
366 let stages = [
367 ("deinterleave", report.deinterleave_us),
368 ("mct", report.mct_us),
369 ("dwt", report.dwt_us),
370 ("quantize", report.quantize_us),
371 ("ht_encode", report.ht_encode_us),
372 ("packetize", report.packetize_us),
373 ];
374 let mut trace = String::from("{\"traceEvents\":[");
375 let mut ts = 0u128;
376 for (index, (name, dur)) in stages.iter().enumerate() {
377 if index != 0 {
378 trace.push(',');
379 }
380 write!(
381 trace,
382 "{{\"name\":\"{name}\",\"cat\":\"{path}\",\"ph\":\"X\",\"pid\":1,\"tid\":1,\"ts\":{ts},\"dur\":{dur}}}"
383 )
384 .expect("writing trace JSON to String failed");
385 ts = ts.saturating_add(*dur);
386 }
387 trace.push_str("]}");
388 trace
389}
390
391#[cfg(test)]
392mod tests {
393 use super::{
394 add_payload_resource_upload_us, chrome_encode_trace_json, chrome_trace_json,
395 finalize_decode_total_us, CudaHtj2kDecodeProfileDetail, CudaHtj2kEncodeProfileReport,
396 CudaHtj2kProfileReport,
397 };
398 use j2k_core::BackendKind;
399
400 use crate::SurfaceResidency;
401
402 #[test]
403 fn finalize_decode_total_us_includes_cpu_and_cuda_stages() {
404 let mut report = CudaHtj2kProfileReport {
405 parse_us: 1,
406 plan_us: 2,
407 flatten_us: 3,
408 h2d_us: 4,
409 ht_cleanup_us: 5,
410 ht_refine_us: 6,
411 dequant_us: 7,
412 idwt_us: 8,
413 mct_us: 9,
414 store_us: 10,
415 total_us: 3,
416 block_count: 1,
417 payload_bytes: 2,
418 dispatch_count: 3,
419 residency: SurfaceResidency::CudaResidentDecode,
420 detail: CudaHtj2kDecodeProfileDetail::default(),
421 };
422
423 finalize_decode_total_us(&mut report);
424
425 assert_eq!(report.total_us, 55);
426 assert_eq!(report.detail.stage_sum_us, 55);
427 }
428
429 #[test]
430 fn detailed_decode_profile_separates_wall_and_stage_sum() {
431 let mut report = CudaHtj2kProfileReport {
432 parse_us: 1,
433 plan_us: 2,
434 flatten_us: 3,
435 h2d_us: 4,
436 ht_cleanup_us: 5,
437 ht_refine_us: 5,
438 dequant_us: 6,
439 idwt_us: 7,
440 mct_us: 8,
441 store_us: 9,
442 total_us: 0,
443 block_count: 10,
444 payload_bytes: 11,
445 dispatch_count: 12,
446 residency: SurfaceResidency::CudaResidentDecode,
447 detail: CudaHtj2kDecodeProfileDetail::default(),
448 };
449 report.detail.wall_total_us = 100;
450 report.detail.table_upload_us = 13;
451 report.detail.payload_upload_us = 17;
452 report.detail.ht_dispatch_count = 2;
453 finalize_decode_total_us(&mut report);
454
455 assert_eq!(report.detail.wall_total_us, 100);
456 assert_eq!(report.detail.stage_sum_us, report.total_us);
457 assert_eq!(report.detail.ht_dispatch_count, 2);
458 }
459
460 #[test]
461 fn payload_resource_upload_detail_does_not_claim_job_status_split() {
462 let mut report = CudaHtj2kProfileReport::default();
463
464 add_payload_resource_upload_us(&mut report, 23);
465
466 assert_eq!(report.h2d_us, 23);
467 assert_eq!(report.detail.payload_upload_us, 23);
468 assert_eq!(report.detail.job_upload_us, 0);
469 assert_eq!(report.detail.status_d2h_us, 0);
470 assert_eq!(report.detail.output_d2h_us, 0);
471 }
472
473 #[test]
474 fn decode_trace_json_contains_ordered_stage_spans() {
475 let report = CudaHtj2kProfileReport {
476 parse_us: 1,
477 plan_us: 2,
478 flatten_us: 3,
479 h2d_us: 4,
480 ht_cleanup_us: 5,
481 ht_refine_us: 6,
482 dequant_us: 7,
483 idwt_us: 8,
484 mct_us: 9,
485 store_us: 10,
486 total_us: 55,
487 block_count: 1,
488 payload_bytes: 2,
489 dispatch_count: 3,
490 residency: SurfaceResidency::CudaResidentDecode,
491 detail: CudaHtj2kDecodeProfileDetail::default(),
492 };
493
494 let trace = chrome_trace_json("decode", &report);
495
496 assert!(trace.starts_with("{\"traceEvents\":["));
497 assert!(trace.contains("\"name\":\"parse\",\"cat\":\"decode\",\"ph\":\"X\""));
498 assert!(trace.contains("\"name\":\"ht_cleanup\",\"cat\":\"decode\",\"ph\":\"X\""));
499 assert!(trace.contains("\"name\":\"store\",\"cat\":\"decode\",\"ph\":\"X\""));
500 assert!(trace.contains("\"ts\":0,\"dur\":1"));
501 assert!(trace.contains("\"ts\":39,\"dur\":10"));
502 assert!(trace.ends_with("]}"));
503 }
504
505 #[test]
506 fn decode_trace_json_does_not_advance_time_for_fused_refinement() {
507 let report = CudaHtj2kProfileReport {
508 parse_us: 1,
509 plan_us: 2,
510 flatten_us: 3,
511 h2d_us: 4,
512 ht_cleanup_us: 5,
513 ht_refine_us: 5,
514 dequant_us: 6,
515 idwt_us: 7,
516 mct_us: 8,
517 store_us: 9,
518 total_us: 45,
519 block_count: 1,
520 payload_bytes: 2,
521 dispatch_count: 3,
522 residency: SurfaceResidency::CudaResidentDecode,
523 detail: CudaHtj2kDecodeProfileDetail::default(),
524 };
525
526 let trace = chrome_trace_json("decode", &report);
527
528 assert!(trace.contains("\"name\":\"ht_refine\",\"cat\":\"decode\",\"ph\":\"X\""));
529 assert!(trace.contains("\"name\":\"ht_refine\",\"cat\":\"decode\",\"ph\":\"X\",\"pid\":1,\"tid\":1,\"ts\":10,\"dur\":5"));
530 assert!(trace.contains("\"name\":\"dequant\",\"cat\":\"decode\",\"ph\":\"X\",\"pid\":1,\"tid\":1,\"ts\":15,\"dur\":6"));
531 assert!(trace.contains("\"name\":\"store\",\"cat\":\"decode\",\"ph\":\"X\",\"pid\":1,\"tid\":1,\"ts\":36,\"dur\":9"));
532 }
533
534 #[test]
535 fn encode_trace_json_contains_ordered_stage_spans() {
536 let report = CudaHtj2kEncodeProfileReport {
537 deinterleave_us: 11,
538 mct_us: 12,
539 dwt_us: 13,
540 quantize_us: 14,
541 ht_encode_us: 15,
542 packetize_us: 16,
543 total_us: 81,
544 input_bytes: 100,
545 codestream_bytes: 50,
546 block_count: 4,
547 dispatch_count: 6,
548 backend: BackendKind::Cuda,
549 };
550
551 let trace = chrome_encode_trace_json("encode", &report);
552
553 assert!(trace.starts_with("{\"traceEvents\":["));
554 assert!(trace.contains("\"name\":\"deinterleave\",\"cat\":\"encode\",\"ph\":\"X\""));
555 assert!(trace.contains("\"name\":\"ht_encode\",\"cat\":\"encode\",\"ph\":\"X\""));
556 assert!(trace.contains("\"name\":\"packetize\",\"cat\":\"encode\",\"ph\":\"X\""));
557 assert!(trace.contains("\"ts\":0,\"dur\":11"));
558 assert!(trace.contains("\"ts\":65,\"dur\":16"));
559 assert!(trace.ends_with("]}"));
560 }
561}