Skip to main content

j2k_cuda/
profile.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use core::fmt::Write as _;
4use std::cell::RefCell;
5use std::sync::OnceLock;
6use std::time::Instant;
7
8use j2k_core::BackendKind;
9use j2k_profile::{profile_stage_mode_from_env, ProfileStageMode};
10
11use crate::SurfaceResidency;
12
13const PROFILE_ENV_VAR: &str = "J2K_PROFILE_STAGES";
14const CUDA_TRACE_ENV_VAR: &str = "J2K_CUDA_TRACE";
15
16thread_local! {
17    static PROFILE_SUMMARY: RefCell<j2k_profile::ProfileSummary> =
18        RefCell::new(j2k_profile::ProfileSummary::default().emit_on_drop());
19}
20
21/// Detailed route-overhead timings for strict CUDA HTJ2K decode.
22#[derive(Clone, Debug, Default, PartialEq, Eq)]
23pub struct CudaHtj2kDecodeProfileDetail {
24    /// End-to-end profiled decode wall time.
25    pub wall_total_us: u128,
26    /// Sum of the reported decode stage timings.
27    pub stage_sum_us: u128,
28    /// CUDA table/resource upload time.
29    pub table_upload_us: u128,
30    /// CUDA compressed payload/resource upload time.
31    ///
32    /// This includes mixed resource upload calls that contain compressed
33    /// payload bytes plus decode metadata. Metadata-only job upload is not
34    /// split out until the CUDA runtime exposes separate timings.
35    pub payload_upload_us: u128,
36    /// CUDA decode job upload time, reserved as zero until split runtime timings exist.
37    pub job_upload_us: u128,
38    /// CUDA status download time, reserved as zero until split runtime timings exist.
39    pub status_d2h_us: u128,
40    /// CUDA output download time, reserved as zero until split runtime timings exist.
41    pub output_d2h_us: u128,
42    /// HT cleanup/refinement CUDA dispatch count.
43    pub ht_dispatch_count: usize,
44    /// Dequantization CUDA dispatch count.
45    pub dequant_dispatch_count: usize,
46    /// Inverse DWT CUDA dispatch count.
47    pub idwt_dispatch_count: usize,
48    /// Inverse MCT CUDA dispatch count.
49    pub mct_dispatch_count: usize,
50    /// Store/format conversion CUDA dispatch count.
51    pub store_dispatch_count: usize,
52}
53
54/// Structured stage timings for a strict CUDA HTJ2K operation.
55#[derive(Clone, Debug, Default, PartialEq, Eq)]
56pub struct CudaHtj2kProfileReport {
57    /// CPU marker/box parse time.
58    pub parse_us: u128,
59    /// Native direct-plan construction time.
60    pub plan_us: u128,
61    /// Flat CUDA plan construction time.
62    pub flatten_us: u128,
63    /// Host-to-device upload time for payload and metadata.
64    pub h2d_us: u128,
65    /// HT cleanup kernel time.
66    pub ht_cleanup_us: u128,
67    /// HT refinement kernel time.
68    pub ht_refine_us: u128,
69    /// Dequantization kernel time.
70    pub dequant_us: u128,
71    /// Inverse DWT kernel time.
72    pub idwt_us: u128,
73    /// Inverse MCT kernel time.
74    pub mct_us: u128,
75    /// Store/format conversion kernel time.
76    pub store_us: u128,
77    /// Sum of measured decode stages.
78    ///
79    /// End-to-end wall time is reported in `detail.wall_total_us`.
80    pub total_us: u128,
81    /// Number of HTJ2K code blocks in the flat plan.
82    pub block_count: usize,
83    /// Number of compressed payload bytes uploaded to CUDA.
84    pub payload_bytes: usize,
85    /// Number of CUDA kernel dispatches.
86    pub dispatch_count: usize,
87    /// Surface residency represented by this profile.
88    pub residency: SurfaceResidency,
89    /// Detailed route-overhead profile for RCA.
90    pub detail: CudaHtj2kDecodeProfileDetail,
91}
92
93impl CudaHtj2kProfileReport {
94    /// Emit the report using `J2K_PROFILE_STAGES`, when enabled.
95    pub fn emit(&self, path: &str) {
96        emit_htj2k_profile_row(path, self);
97        export_trace_if_requested(path, self);
98    }
99}
100
101/// Structured stage timings for a strict CUDA HTJ2K encode operation.
102#[derive(Clone, Debug, PartialEq, Eq)]
103pub struct CudaHtj2kEncodeProfileReport {
104    /// Pixel deinterleave and level-shift CUDA stage time.
105    pub deinterleave_us: u128,
106    /// Forward MCT CUDA stage time.
107    pub mct_us: u128,
108    /// Forward DWT CUDA stage time.
109    pub dwt_us: u128,
110    /// Quantization CUDA stage time.
111    pub quantize_us: u128,
112    /// HTJ2K cleanup code-block encode CUDA stage time.
113    pub ht_encode_us: u128,
114    /// HTJ2K packetization CUDA stage time.
115    pub packetize_us: u128,
116    /// Total wall time for the measured encode call.
117    pub total_us: u128,
118    /// Input pixel byte count.
119    pub input_bytes: usize,
120    /// Output codestream byte count.
121    pub codestream_bytes: usize,
122    /// Number of HTJ2K code blocks encoded.
123    pub block_count: usize,
124    /// Number of CUDA kernel dispatches.
125    pub dispatch_count: usize,
126    /// Backend that satisfied the encode request.
127    pub backend: BackendKind,
128}
129
130impl Default for CudaHtj2kEncodeProfileReport {
131    fn default() -> Self {
132        Self {
133            deinterleave_us: 0,
134            mct_us: 0,
135            dwt_us: 0,
136            quantize_us: 0,
137            ht_encode_us: 0,
138            packetize_us: 0,
139            total_us: 0,
140            input_bytes: 0,
141            codestream_bytes: 0,
142            block_count: 0,
143            dispatch_count: 0,
144            backend: BackendKind::Cpu,
145        }
146    }
147}
148
149impl CudaHtj2kEncodeProfileReport {
150    /// Emit the report using `J2K_PROFILE_STAGES`, when enabled.
151    pub fn emit(&self, path: &str) {
152        emit_htj2k_encode_profile_row(path, self);
153        export_encode_trace_if_requested(path, self);
154    }
155}
156
157pub(crate) type ProfileInstant = Instant;
158
159fn profile_stage_mode() -> ProfileStageMode {
160    static MODE: OnceLock<ProfileStageMode> = OnceLock::new();
161    *MODE.get_or_init(|| profile_stage_mode_from_env(PROFILE_ENV_VAR))
162}
163
164pub(crate) fn profile_stages_enabled() -> bool {
165    profile_stage_mode() != ProfileStageMode::Disabled
166}
167
168pub(crate) fn profile_now(enabled: bool) -> Option<ProfileInstant> {
169    enabled.then(Instant::now)
170}
171
172pub(crate) fn elapsed_us(start: Option<ProfileInstant>) -> u128 {
173    start.map_or(0, |start| start.elapsed().as_micros())
174}
175
176#[cfg_attr(not(feature = "cuda-runtime"), allow(dead_code))]
177pub(crate) fn add_payload_resource_upload_us(
178    report: &mut CudaHtj2kProfileReport,
179    elapsed_us: u128,
180) {
181    report.h2d_us = report.h2d_us.saturating_add(elapsed_us);
182    report.detail.payload_upload_us = report.detail.payload_upload_us.saturating_add(elapsed_us);
183}
184
185#[cfg_attr(not(feature = "cuda-runtime"), allow(dead_code))]
186pub(crate) fn finalize_decode_total_us(report: &mut CudaHtj2kProfileReport) {
187    report.total_us = [
188        report.parse_us,
189        report.plan_us,
190        report.flatten_us,
191        report.h2d_us,
192        report.ht_cleanup_us,
193        report.ht_refine_us,
194        report.dequant_us,
195        report.idwt_us,
196        report.mct_us,
197        report.store_us,
198    ]
199    .into_iter()
200    .fold(0u128, u128::saturating_add);
201    report.detail.stage_sum_us = report.total_us;
202}
203
204pub(crate) fn emit_htj2k_profile_row(path: &str, report: &CudaHtj2kProfileReport) {
205    let parse_us = report.parse_us.to_string();
206    let plan_us = report.plan_us.to_string();
207    let flatten_us = report.flatten_us.to_string();
208    let h2d_us = report.h2d_us.to_string();
209    let ht_cleanup_us = report.ht_cleanup_us.to_string();
210    let ht_refine_us = report.ht_refine_us.to_string();
211    let dequant_us = report.dequant_us.to_string();
212    let idwt_us = report.idwt_us.to_string();
213    let mct_us = report.mct_us.to_string();
214    let store_us = report.store_us.to_string();
215    let total_us = report.total_us.to_string();
216    let block_count = report.block_count.to_string();
217    let payload_bytes = report.payload_bytes.to_string();
218    let dispatch_count = report.dispatch_count.to_string();
219    let residency = format!("{:?}", report.residency);
220    let wall_total_us = report.detail.wall_total_us.to_string();
221    let stage_sum_us = report.detail.stage_sum_us.to_string();
222    let table_upload_us = report.detail.table_upload_us.to_string();
223    let payload_upload_us = report.detail.payload_upload_us.to_string();
224    let job_upload_us = report.detail.job_upload_us.to_string();
225    let status_d2h_us = report.detail.status_d2h_us.to_string();
226    let output_d2h_us = report.detail.output_d2h_us.to_string();
227    let ht_dispatch_count = report.detail.ht_dispatch_count.to_string();
228    let dequant_dispatch_count = report.detail.dequant_dispatch_count.to_string();
229    let idwt_dispatch_count = report.detail.idwt_dispatch_count.to_string();
230    let mct_dispatch_count = report.detail.mct_dispatch_count.to_string();
231    let store_dispatch_count = report.detail.store_dispatch_count.to_string();
232
233    j2k_profile::emit_profile_row(
234        profile_stage_mode(),
235        &PROFILE_SUMMARY,
236        "j2k",
237        "cuda_htj2k",
238        path,
239        &[
240            ("parse_us", parse_us.as_str()),
241            ("plan_us", plan_us.as_str()),
242            ("flatten_us", flatten_us.as_str()),
243            ("h2d_us", h2d_us.as_str()),
244            ("ht_cleanup_us", ht_cleanup_us.as_str()),
245            ("ht_refine_us", ht_refine_us.as_str()),
246            ("dequant_us", dequant_us.as_str()),
247            ("idwt_us", idwt_us.as_str()),
248            ("mct_us", mct_us.as_str()),
249            ("store_us", store_us.as_str()),
250            ("total_us", total_us.as_str()),
251            ("block_count", block_count.as_str()),
252            ("payload_bytes", payload_bytes.as_str()),
253            ("dispatch_count", dispatch_count.as_str()),
254            ("residency", residency.as_str()),
255            ("wall_total_us", wall_total_us.as_str()),
256            ("stage_sum_us", stage_sum_us.as_str()),
257            ("table_upload_us", table_upload_us.as_str()),
258            ("payload_upload_us", payload_upload_us.as_str()),
259            ("job_upload_us", job_upload_us.as_str()),
260            ("status_d2h_us", status_d2h_us.as_str()),
261            ("output_d2h_us", output_d2h_us.as_str()),
262            ("ht_dispatch_count", ht_dispatch_count.as_str()),
263            ("dequant_dispatch_count", dequant_dispatch_count.as_str()),
264            ("idwt_dispatch_count", idwt_dispatch_count.as_str()),
265            ("mct_dispatch_count", mct_dispatch_count.as_str()),
266            ("store_dispatch_count", store_dispatch_count.as_str()),
267        ],
268    );
269}
270
271pub(crate) fn emit_htj2k_encode_profile_row(path: &str, report: &CudaHtj2kEncodeProfileReport) {
272    let deinterleave_us = report.deinterleave_us.to_string();
273    let mct_us = report.mct_us.to_string();
274    let dwt_us = report.dwt_us.to_string();
275    let quantize_us = report.quantize_us.to_string();
276    let ht_encode_us = report.ht_encode_us.to_string();
277    let packetize_us = report.packetize_us.to_string();
278    let total_us = report.total_us.to_string();
279    let input_bytes = report.input_bytes.to_string();
280    let codestream_bytes = report.codestream_bytes.to_string();
281    let block_count = report.block_count.to_string();
282    let dispatch_count = report.dispatch_count.to_string();
283    let backend = format!("{:?}", report.backend);
284
285    j2k_profile::emit_profile_row(
286        profile_stage_mode(),
287        &PROFILE_SUMMARY,
288        "j2k",
289        "cuda_htj2k_encode",
290        path,
291        &[
292            ("deinterleave_us", deinterleave_us.as_str()),
293            ("mct_us", mct_us.as_str()),
294            ("dwt_us", dwt_us.as_str()),
295            ("quantize_us", quantize_us.as_str()),
296            ("ht_encode_us", ht_encode_us.as_str()),
297            ("packetize_us", packetize_us.as_str()),
298            ("total_us", total_us.as_str()),
299            ("input_bytes", input_bytes.as_str()),
300            ("codestream_bytes", codestream_bytes.as_str()),
301            ("block_count", block_count.as_str()),
302            ("dispatch_count", dispatch_count.as_str()),
303            ("backend", backend.as_str()),
304        ],
305    );
306}
307
308fn export_trace_if_requested(path: &str, report: &CudaHtj2kProfileReport) {
309    let Some(trace_path) = std::env::var_os(CUDA_TRACE_ENV_VAR) else {
310        return;
311    };
312    let trace = chrome_trace_json(path, report);
313    if let Err(error) = std::fs::write(&trace_path, trace) {
314        std::eprintln!("j2k_profile codec=j2k op=cuda_htj2k_trace path=cuda error={error}");
315    }
316}
317
318fn chrome_trace_json(path: &str, report: &CudaHtj2kProfileReport) -> String {
319    let stages = [
320        ("parse", report.parse_us),
321        ("plan", report.plan_us),
322        ("flatten", report.flatten_us),
323        ("h2d", report.h2d_us),
324        ("ht_cleanup", report.ht_cleanup_us),
325        ("ht_refine", report.ht_refine_us),
326        ("dequant", report.dequant_us),
327        ("idwt", report.idwt_us),
328        ("mct", report.mct_us),
329        ("store", report.store_us),
330    ];
331    let mut trace = String::from("{\"traceEvents\":[");
332    let mut ts = 0u128;
333    for (index, (name, dur)) in stages.iter().enumerate() {
334        if index != 0 {
335            trace.push(',');
336        }
337        let event_ts = if *name == "ht_refine" {
338            ts.saturating_sub(report.ht_cleanup_us)
339        } else {
340            ts
341        };
342        write!(
343            trace,
344            "{{\"name\":\"{name}\",\"cat\":\"{path}\",\"ph\":\"X\",\"pid\":1,\"tid\":1,\"ts\":{event_ts},\"dur\":{dur}}}"
345        )
346        .expect("writing trace JSON to String failed");
347        if *name != "ht_refine" {
348            ts = ts.saturating_add(*dur);
349        }
350    }
351    trace.push_str("]}");
352    trace
353}
354
355fn export_encode_trace_if_requested(path: &str, report: &CudaHtj2kEncodeProfileReport) {
356    let Some(trace_path) = std::env::var_os(CUDA_TRACE_ENV_VAR) else {
357        return;
358    };
359    let trace = chrome_encode_trace_json(path, report);
360    if let Err(error) = std::fs::write(&trace_path, trace) {
361        std::eprintln!("j2k_profile codec=j2k op=cuda_htj2k_encode_trace path=cuda error={error}");
362    }
363}
364
365fn chrome_encode_trace_json(path: &str, report: &CudaHtj2kEncodeProfileReport) -> String {
366    let stages = [
367        ("deinterleave", report.deinterleave_us),
368        ("mct", report.mct_us),
369        ("dwt", report.dwt_us),
370        ("quantize", report.quantize_us),
371        ("ht_encode", report.ht_encode_us),
372        ("packetize", report.packetize_us),
373    ];
374    let mut trace = String::from("{\"traceEvents\":[");
375    let mut ts = 0u128;
376    for (index, (name, dur)) in stages.iter().enumerate() {
377        if index != 0 {
378            trace.push(',');
379        }
380        write!(
381            trace,
382            "{{\"name\":\"{name}\",\"cat\":\"{path}\",\"ph\":\"X\",\"pid\":1,\"tid\":1,\"ts\":{ts},\"dur\":{dur}}}"
383        )
384        .expect("writing trace JSON to String failed");
385        ts = ts.saturating_add(*dur);
386    }
387    trace.push_str("]}");
388    trace
389}
390
391#[cfg(test)]
392mod tests {
393    use super::{
394        add_payload_resource_upload_us, chrome_encode_trace_json, chrome_trace_json,
395        finalize_decode_total_us, CudaHtj2kDecodeProfileDetail, CudaHtj2kEncodeProfileReport,
396        CudaHtj2kProfileReport,
397    };
398    use j2k_core::BackendKind;
399
400    use crate::SurfaceResidency;
401
402    #[test]
403    fn finalize_decode_total_us_includes_cpu_and_cuda_stages() {
404        let mut report = CudaHtj2kProfileReport {
405            parse_us: 1,
406            plan_us: 2,
407            flatten_us: 3,
408            h2d_us: 4,
409            ht_cleanup_us: 5,
410            ht_refine_us: 6,
411            dequant_us: 7,
412            idwt_us: 8,
413            mct_us: 9,
414            store_us: 10,
415            total_us: 3,
416            block_count: 1,
417            payload_bytes: 2,
418            dispatch_count: 3,
419            residency: SurfaceResidency::CudaResidentDecode,
420            detail: CudaHtj2kDecodeProfileDetail::default(),
421        };
422
423        finalize_decode_total_us(&mut report);
424
425        assert_eq!(report.total_us, 55);
426        assert_eq!(report.detail.stage_sum_us, 55);
427    }
428
429    #[test]
430    fn detailed_decode_profile_separates_wall_and_stage_sum() {
431        let mut report = CudaHtj2kProfileReport {
432            parse_us: 1,
433            plan_us: 2,
434            flatten_us: 3,
435            h2d_us: 4,
436            ht_cleanup_us: 5,
437            ht_refine_us: 5,
438            dequant_us: 6,
439            idwt_us: 7,
440            mct_us: 8,
441            store_us: 9,
442            total_us: 0,
443            block_count: 10,
444            payload_bytes: 11,
445            dispatch_count: 12,
446            residency: SurfaceResidency::CudaResidentDecode,
447            detail: CudaHtj2kDecodeProfileDetail::default(),
448        };
449        report.detail.wall_total_us = 100;
450        report.detail.table_upload_us = 13;
451        report.detail.payload_upload_us = 17;
452        report.detail.ht_dispatch_count = 2;
453        finalize_decode_total_us(&mut report);
454
455        assert_eq!(report.detail.wall_total_us, 100);
456        assert_eq!(report.detail.stage_sum_us, report.total_us);
457        assert_eq!(report.detail.ht_dispatch_count, 2);
458    }
459
460    #[test]
461    fn payload_resource_upload_detail_does_not_claim_job_status_split() {
462        let mut report = CudaHtj2kProfileReport::default();
463
464        add_payload_resource_upload_us(&mut report, 23);
465
466        assert_eq!(report.h2d_us, 23);
467        assert_eq!(report.detail.payload_upload_us, 23);
468        assert_eq!(report.detail.job_upload_us, 0);
469        assert_eq!(report.detail.status_d2h_us, 0);
470        assert_eq!(report.detail.output_d2h_us, 0);
471    }
472
473    #[test]
474    fn decode_trace_json_contains_ordered_stage_spans() {
475        let report = CudaHtj2kProfileReport {
476            parse_us: 1,
477            plan_us: 2,
478            flatten_us: 3,
479            h2d_us: 4,
480            ht_cleanup_us: 5,
481            ht_refine_us: 6,
482            dequant_us: 7,
483            idwt_us: 8,
484            mct_us: 9,
485            store_us: 10,
486            total_us: 55,
487            block_count: 1,
488            payload_bytes: 2,
489            dispatch_count: 3,
490            residency: SurfaceResidency::CudaResidentDecode,
491            detail: CudaHtj2kDecodeProfileDetail::default(),
492        };
493
494        let trace = chrome_trace_json("decode", &report);
495
496        assert!(trace.starts_with("{\"traceEvents\":["));
497        assert!(trace.contains("\"name\":\"parse\",\"cat\":\"decode\",\"ph\":\"X\""));
498        assert!(trace.contains("\"name\":\"ht_cleanup\",\"cat\":\"decode\",\"ph\":\"X\""));
499        assert!(trace.contains("\"name\":\"store\",\"cat\":\"decode\",\"ph\":\"X\""));
500        assert!(trace.contains("\"ts\":0,\"dur\":1"));
501        assert!(trace.contains("\"ts\":39,\"dur\":10"));
502        assert!(trace.ends_with("]}"));
503    }
504
505    #[test]
506    fn decode_trace_json_does_not_advance_time_for_fused_refinement() {
507        let report = CudaHtj2kProfileReport {
508            parse_us: 1,
509            plan_us: 2,
510            flatten_us: 3,
511            h2d_us: 4,
512            ht_cleanup_us: 5,
513            ht_refine_us: 5,
514            dequant_us: 6,
515            idwt_us: 7,
516            mct_us: 8,
517            store_us: 9,
518            total_us: 45,
519            block_count: 1,
520            payload_bytes: 2,
521            dispatch_count: 3,
522            residency: SurfaceResidency::CudaResidentDecode,
523            detail: CudaHtj2kDecodeProfileDetail::default(),
524        };
525
526        let trace = chrome_trace_json("decode", &report);
527
528        assert!(trace.contains("\"name\":\"ht_refine\",\"cat\":\"decode\",\"ph\":\"X\""));
529        assert!(trace.contains("\"name\":\"ht_refine\",\"cat\":\"decode\",\"ph\":\"X\",\"pid\":1,\"tid\":1,\"ts\":10,\"dur\":5"));
530        assert!(trace.contains("\"name\":\"dequant\",\"cat\":\"decode\",\"ph\":\"X\",\"pid\":1,\"tid\":1,\"ts\":15,\"dur\":6"));
531        assert!(trace.contains("\"name\":\"store\",\"cat\":\"decode\",\"ph\":\"X\",\"pid\":1,\"tid\":1,\"ts\":36,\"dur\":9"));
532    }
533
534    #[test]
535    fn encode_trace_json_contains_ordered_stage_spans() {
536        let report = CudaHtj2kEncodeProfileReport {
537            deinterleave_us: 11,
538            mct_us: 12,
539            dwt_us: 13,
540            quantize_us: 14,
541            ht_encode_us: 15,
542            packetize_us: 16,
543            total_us: 81,
544            input_bytes: 100,
545            codestream_bytes: 50,
546            block_count: 4,
547            dispatch_count: 6,
548            backend: BackendKind::Cuda,
549        };
550
551        let trace = chrome_encode_trace_json("encode", &report);
552
553        assert!(trace.starts_with("{\"traceEvents\":["));
554        assert!(trace.contains("\"name\":\"deinterleave\",\"cat\":\"encode\",\"ph\":\"X\""));
555        assert!(trace.contains("\"name\":\"ht_encode\",\"cat\":\"encode\",\"ph\":\"X\""));
556        assert!(trace.contains("\"name\":\"packetize\",\"cat\":\"encode\",\"ph\":\"X\""));
557        assert!(trace.contains("\"ts\":0,\"dur\":11"));
558        assert!(trace.contains("\"ts\":65,\"dur\":16"));
559        assert!(trace.ends_with("]}"));
560    }
561}