Skip to main content

perfetto_cuda_injection/
lib.rs

1// Copyright (C) 2026 David Reveman.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod callbacks;
16pub mod cupti_profiler;
17pub mod cupti_profiler_sys;
18pub mod metrics;
19pub mod state;
20
21use callbacks::{buffer_completed, buffer_requested, profiler_callback_handler};
22use metrics::parse_metrics;
23use perfetto_gpu_compute_injection::config::Config;
24use perfetto_gpu_compute_injection::injection_log;
25use perfetto_gpu_compute_injection::tracing::{
26    get_counter_config, get_counters_data_source, get_renderstages_data_source,
27    get_track_event_data_source, register_backend, trace_time_ns, GpuBackend,
28};
29use perfetto_sdk::{
30    data_source::{StopGuard, TraceContext},
31    producer::{Backends, Producer, ProducerInitArgsBuilder},
32    protos::{
33        common::builtin_clock::BuiltinClock,
34        trace::{
35            interned_data::interned_data::InternedData,
36            trace_packet::{TracePacket, TracePacketSequenceFlags},
37            track_event::track_event::EventName,
38        },
39    },
40    track_event::TrackEvent,
41};
42use perfetto_sdk_protos_gpu::protos::trace::gpu::gpu_interned_data::InternedDataExt as _;
43use perfetto_sdk_protos_gpu::protos::{
44    common::gpu_counter_descriptor::{
45        GpuCounterDescriptor, GpuCounterDescriptorGpuCounterGroup,
46        GpuCounterDescriptorGpuCounterGroupSpec, GpuCounterDescriptorGpuCounterSpec,
47    },
48    trace::{
49        gpu::{
50            gpu_counter_event::{
51                GpuCounterEvent, GpuCounterEventGpuCounter, InternedGpuCounterDescriptor,
52            },
53            gpu_render_stage_event::{
54                GpuRenderStageEvent, GpuRenderStageEventComputeKernelLaunch,
55                GpuRenderStageEventDim3, GpuRenderStageEventExtraComputeArg,
56                InternedComputeArgName, InternedComputeKernel, InternedGpuRenderStageSpecification,
57                InternedGpuRenderStageSpecificationRenderStageCategory, InternedGraphicsContext,
58                InternedGraphicsContextApi,
59            },
60        },
61        interned_data::interned_data::prelude::*,
62        trace_packet::prelude::*,
63    },
64};
65use state::{
66    ConsumerStartOffsets, GpuActivity, GLOBAL_STATE, HW_QUEUE_IID_OFFSET, KERNEL_STAGE_IID,
67    MEMCPY_STAGE_IID, MEMSET_STAGE_IID,
68};
69use std::{
70    collections::HashSet,
71    panic, ptr,
72    sync::atomic::{AtomicU8, Ordering},
73};
74
75use cupti_profiler as profiler;
76use cupti_profiler::bindings::*;
77
78// ---------------------------------------------------------------------------
79// CuptiBackend implementation
80// ---------------------------------------------------------------------------
81
82/// 3-state teardown guard: 0=not started, 1=in progress, 2=done.
83static CUPTI_TEARDOWN_STATE: AtomicU8 = AtomicU8::new(0);
84
85struct CuptiBackend;
86
87/// Collected counter event data for emission outside of GLOBAL_STATE lock.
88struct CollectedCounterEvent {
89    timestamp_start: u64,
90    timestamp_end: u64,
91    gpu_id: i32,
92    metrics: Vec<(String, f64)>,
93    /// Bitmask of instance IDs that want counters for this kernel.
94    profiled_instances: u8,
95}
96
97impl GpuBackend for CuptiBackend {
98    fn default_data_source_suffix(&self) -> &'static str {
99        "nv"
100    }
101
102    fn on_first_consumer_start(&self) {
103        CUPTI_TEARDOWN_STATE.store(0, Ordering::SeqCst);
104        let _ = profiler::activity_enable(CUpti_ActivityKind_CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL);
105        let _ = profiler::activity_enable(CUpti_ActivityKind_CUPTI_ACTIVITY_KIND_MEMCPY);
106        let _ = profiler::activity_enable(CUpti_ActivityKind_CUPTI_ACTIVITY_KIND_MEMSET);
107        if let Ok(state) = GLOBAL_STATE.lock() {
108            let subscriber = state.subscriber_handle;
109            if !subscriber.is_null() {
110                let _ = unsafe {
111                    profiler::enable_callback(
112                        1,
113                        subscriber,
114                        CUpti_CallbackDomain_CUPTI_CB_DOMAIN_RUNTIME_API,
115                        CUpti_runtime_api_trace_cbid_enum_CUPTI_RUNTIME_TRACE_CBID_cudaDeviceReset_v3020,
116                    )
117                };
118            }
119        }
120    }
121
122    fn on_first_counters_start(&self) {
123        // Compute the union of counter_names across all active instances.
124        // The profiler collects all requested metrics; each instance only
125        // sees the subset it asked for during emission.
126        let mut union_names: Vec<String> = Vec::new();
127        let mut seen: HashSet<String> = HashSet::new();
128        for id in 0..8u32 {
129            if let Some(cfg) = get_counter_config(id) {
130                for name in &cfg.counter_names {
131                    if seen.insert(name.clone()) {
132                        union_names.push(name.clone());
133                    }
134                }
135            }
136        }
137        if !union_names.is_empty() {
138            if let Ok(mut state) = GLOBAL_STATE.lock() {
139                injection_log!(
140                    "using {} counter names from trace config (union of all instances)",
141                    union_names.len()
142                );
143                state.config.metrics = union_names;
144            }
145        }
146    }
147
148    fn on_last_counters_stop(&self) {}
149
150    fn on_renderstages_start_no_counters(&self) {}
151
152    fn register_counters_consumer(&self, inst_id: u32) {
153        if let Ok(mut state) = GLOBAL_STATE.lock() {
154            let offsets = ConsumerStartOffsets::snapshot(&state.context_data);
155            state.counters_consumers.insert(inst_id, offsets);
156        }
157    }
158
159    fn register_renderstages_consumer(&self, inst_id: u32) {
160        if let Ok(mut state) = GLOBAL_STATE.lock() {
161            let offsets = ConsumerStartOffsets::snapshot(&state.context_data);
162            state.renderstages_consumers.insert(inst_id, offsets);
163        }
164    }
165
166    fn flush_activity_buffers(&self) {
167        // cuptiActivityFlushAll deadlocks when called concurrently from
168        // multiple threads.  Use try_lock so that only one thread performs
169        // the actual CUPTI flush; the other skips it (the first call already
170        // delivers all pending records via the buffer_completed callback).
171        static FLUSH_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
172        if let Ok(_guard) = FLUSH_LOCK.try_lock() {
173            let _ = profiler::activity_flush_all(0);
174        }
175    }
176
177    fn finalize_range_profiler(&self) {
178        if let Ok(mut state) = GLOBAL_STATE.lock() {
179            for (_, data) in state.context_data.iter_mut() {
180                data.finalize_profiler(false);
181                if let Some(last_launch) = data
182                    .kernel_launches
183                    .iter_mut()
184                    .rev()
185                    .find(|l| l.profiled_instances != 0)
186                {
187                    if last_launch.end == 0 {
188                        last_launch.end = trace_time_ns();
189                    }
190                }
191            }
192        }
193    }
194
195    fn run_teardown(&self) {
196        if CUPTI_TEARDOWN_STATE
197            .compare_exchange(0, 1, Ordering::SeqCst, Ordering::SeqCst)
198            .is_ok()
199        {
200            let _ = panic::catch_unwind(|| {
201                let _ = profiler::get_last_error();
202
203                if let Ok(state) = GLOBAL_STATE.lock() {
204                    let subscriber = state.subscriber_handle;
205                    if !subscriber.is_null() {
206                        let _ = unsafe {
207                            profiler::enable_callback(
208                                0,
209                                subscriber,
210                                CUpti_CallbackDomain_CUPTI_CB_DOMAIN_RUNTIME_API,
211                                CUpti_runtime_api_trace_cbid_enum_CUPTI_RUNTIME_TRACE_CBID_cudaDeviceReset_v3020,
212                            )
213                        };
214                    }
215                }
216
217                let _ = profiler::activity_disable(
218                    CUpti_ActivityKind_CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL,
219                );
220                let _ = profiler::activity_disable(CUpti_ActivityKind_CUPTI_ACTIVITY_KIND_MEMCPY);
221                let _ = profiler::activity_disable(CUpti_ActivityKind_CUPTI_ACTIVITY_KIND_MEMSET);
222                let _ = profiler::activity_flush_all(
223                    CUpti_ActivityFlag_CUPTI_ACTIVITY_FLAG_FLUSH_FORCED,
224                );
225
226                self.finalize_range_profiler();
227            });
228            CUPTI_TEARDOWN_STATE.store(2, Ordering::SeqCst);
229        } else {
230            while CUPTI_TEARDOWN_STATE.load(Ordering::SeqCst) != 2 {
231                std::thread::sleep(std::time::Duration::from_millis(1));
232            }
233        }
234    }
235
236    fn emit_counter_events_for_instance(&self, inst_id: u32, stop_guard: Option<StopGuard>) {
237        let _ = panic::catch_unwind(|| {
238            // Phase 1: Collect data under GLOBAL_STATE lock, then release.
239            let collected_events = {
240                let mut state = match GLOBAL_STATE.lock() {
241                    Ok(s) => s,
242                    Err(_) => return,
243                };
244                let start_offsets = if stop_guard.is_some() {
245                    match state.counters_consumers.remove(&inst_id) {
246                        Some(o) => o,
247                        None => return,
248                    }
249                } else {
250                    match state.counters_consumers.get(&inst_id).cloned() {
251                        Some(o) => o,
252                        None => return,
253                    }
254                };
255                let mut events: Vec<CollectedCounterEvent> = Vec::new();
256                for (ctx_id, data) in state.context_data.iter() {
257                    let range_start = start_offsets.range_info.get(ctx_id).copied().unwrap_or(0);
258                    let launch_start = start_offsets
259                        .kernel_launches
260                        .get(ctx_id)
261                        .copied()
262                        .unwrap_or(0);
263                    let activity_start = start_offsets
264                        .kernel_activities
265                        .get(ctx_id)
266                        .copied()
267                        .unwrap_or(0);
268                    // `range_info` is dense (one entry per *profiled* kernel)
269                    // while `kernel_launches` and `kernel_activities` are sparse
270                    // (one entry per *every* cuLaunchKernel). Walk launches and
271                    // advance the range_info iterator only on profiled kernels
272                    // so each profiled kernel gets paired with its own metrics
273                    // — zipping them by index would slide range_info forward
274                    // by exactly `skip` positions and chop the last `skip`
275                    // results off the emitted events.
276                    let mut range_iter = data.range_info[range_start..].iter();
277                    for (launch, activity) in data.kernel_launches[launch_start..]
278                        .iter()
279                        .zip(data.kernel_activities[activity_start..].iter())
280                    {
281                        if launch.end == 0 {
282                            continue;
283                        }
284                        if launch.profiled_instances == 0 {
285                            continue;
286                        }
287                        let Some(range) = range_iter.next() else {
288                            break;
289                        };
290                        events.push(CollectedCounterEvent {
291                            timestamp_start: launch.start,
292                            timestamp_end: launch.end,
293                            gpu_id: activity.gpu_id() as i32,
294                            metrics: range
295                                .metric_and_values
296                                .iter()
297                                .map(|m| (m.metric_name.clone(), m.value))
298                                .collect(),
299                            profiled_instances: launch.profiled_instances,
300                        });
301                    }
302                }
303                events
304                // state (GLOBAL_STATE lock) dropped here
305            };
306
307            // Per-instance filtering using the profiled_instances bitmask:
308            // 1. Skip kernels where this instance's bit is not set.
309            // 2. Filter metrics by counter_names (only emit the counters this
310            //    instance requested).
311            let inst_bit = 1u8 << inst_id;
312            let collected_events: Vec<_> = if let Some(cfg) = get_counter_config(inst_id) {
313                let has_counter_filter = !cfg.counter_names.is_empty();
314                let requested: HashSet<&str> =
315                    cfg.counter_names.iter().map(|s| s.as_str()).collect();
316                collected_events
317                    .into_iter()
318                    .filter(|e| e.profiled_instances & inst_bit != 0)
319                    .map(|mut e| {
320                        if has_counter_filter {
321                            e.metrics
322                                .retain(|(name, _)| requested.contains(name.as_str()));
323                        }
324                        e
325                    })
326                    .filter(|e| !e.metrics.is_empty())
327                    .collect()
328            } else {
329                collected_events
330                    .into_iter()
331                    .filter(|e| e.profiled_instances & inst_bit != 0)
332                    .collect()
333            };
334
335            // Phase 2: Emit collected events without holding GLOBAL_STATE.
336            // This prevents deadlock with buffer_completed callback.
337            // Collect the set of GPU IDs present in this batch.
338            let gpu_ids: HashSet<i32> = collected_events.iter().map(|e| e.gpu_id).collect();
339            let mut stop_guard_opt = stop_guard;
340            get_counters_data_source().trace(|ctx: &mut TraceContext| {
341                if ctx.instance_index() != inst_id {
342                    return;
343                }
344                for event in &collected_events {
345                    ctx.with_incremental_state(|ctx: &mut TraceContext, inc_state| {
346                        let was_cleared = std::mem::replace(&mut inc_state.was_cleared, false);
347                        if was_cleared {
348                            emit_interned_counter_descriptors(ctx, &collected_events, &gpu_ids);
349                        }
350                        // Emit start sample (zero values).
351                        let desc_iid = event.gpu_id as u64 + 1;
352                        ctx.add_packet(|packet: &mut TracePacket| {
353                            packet
354                                .set_timestamp(event.timestamp_start)
355                                .set_timestamp_clock_id(BuiltinClock::BuiltinClockBoottime.into())
356                                .set_gpu_counter_event(|ce: &mut GpuCounterEvent| {
357                                    ce.set_counter_descriptor_iid(desc_iid);
358                                    for (i, _) in event.metrics.iter().enumerate() {
359                                        ce.set_counters(
360                                            |counter: &mut GpuCounterEventGpuCounter| {
361                                                counter.set_counter_id(i as u32).set_int_value(0);
362                                            },
363                                        );
364                                    }
365                                });
366                        });
367                        // Emit end sample (actual values).
368                        ctx.add_packet(|packet: &mut TracePacket| {
369                            packet
370                                .set_timestamp(event.timestamp_end)
371                                .set_timestamp_clock_id(BuiltinClock::BuiltinClockBoottime.into())
372                                .set_gpu_counter_event(|ce: &mut GpuCounterEvent| {
373                                    ce.set_counter_descriptor_iid(desc_iid);
374                                    for (i, (_, value)) in event.metrics.iter().enumerate() {
375                                        ce.set_counters(
376                                            |counter: &mut GpuCounterEventGpuCounter| {
377                                                counter
378                                                    .set_counter_id(i as u32)
379                                                    .set_double_value(*value);
380                                            },
381                                        );
382                                    }
383                                });
384                        });
385                    });
386                }
387                if let Some(sg) = stop_guard_opt.take() {
388                    let mut sg = Some(Some(sg));
389                    ctx.flush(move || drop(sg.take()));
390                }
391            });
392            drop(stop_guard_opt);
393        });
394    }
395
396    fn emit_renderstage_events_for_instance(&self, inst_id: u32, stop_guard: Option<StopGuard>) {
397        let _ = panic::catch_unwind(|| {
398            let (process_id, process_name) =
399                perfetto_gpu_compute_injection::config::get_process_info();
400
401            // PendingEvent holds all data needed to emit a render stage event.
402            struct PendingEvent {
403                start: u64,
404                end: u64,
405                name: String,
406                name_iid: u64,
407                channel_id: u32,
408                channel_type: u32,
409                activity_device_id: u32,
410                activity_context_id: u32,
411                activity_stream_id: u32,
412                stage_iid: u64,
413                correlation_id: u32,
414                // Structured compute kernel fields (kernel events only).
415                kernel_iid: Option<u64>,
416                kernel_mangled_name: Option<String>,
417                kernel_demangled_name: Option<String>,
418                kernel_arch: Option<String>,
419                kernel_registers_per_thread: Option<u64>,
420                kernel_shared_mem_static: Option<u64>,
421                kernel_func_cache_config: Option<String>,
422                kernel_shared_mem_config_size: Option<u64>,
423                launch_grid: Option<(u32, u32, u32)>,
424                launch_block: Option<(u32, u32, u32)>,
425                launch_args: Vec<(u64, KernelArgValue)>,
426            }
427
428            impl state::GpuActivity for PendingEvent {
429                fn start(&self) -> u64 {
430                    self.start
431                }
432                fn end(&self) -> u64 {
433                    self.end
434                }
435                fn device_id(&self) -> u32 {
436                    self.activity_device_id
437                }
438                fn context_id(&self) -> u32 {
439                    self.activity_context_id
440                }
441                fn stream_id(&self) -> u32 {
442                    self.activity_stream_id
443                }
444                fn channel_id(&self) -> u32 {
445                    self.channel_id
446                }
447                fn channel_type(&self) -> u32 {
448                    self.channel_type
449                }
450                fn stage_iid(&self) -> u64 {
451                    self.stage_iid
452                }
453                fn correlation_id(&self) -> u32 {
454                    self.correlation_id
455                }
456            }
457
458            // Phase 1: Collect all event data under GLOBAL_STATE lock, then release.
459            let (all_events, channels, contexts, updated_prev_ends, is_flush) = {
460                let mut state = match GLOBAL_STATE.lock() {
461                    Ok(s) => s,
462                    Err(_) => return,
463                };
464                let start_offsets = if stop_guard.is_some() {
465                    match state.renderstages_consumers.remove(&inst_id) {
466                        Some(o) => o,
467                        None => return,
468                    }
469                } else {
470                    match state.renderstages_consumers.get(&inst_id).cloned() {
471                        Some(o) => o,
472                        None => return,
473                    }
474                };
475                let is_flush = stop_guard.is_none();
476
477                // Build channel and context sets.
478                let mut channels: std::collections::HashSet<(u32, u32)> =
479                    std::collections::HashSet::new();
480                let mut contexts: std::collections::HashSet<u32> = std::collections::HashSet::new();
481                for (ctx_id, data) in state.context_data.iter() {
482                    let ka_start = start_offsets
483                        .kernel_activities
484                        .get(ctx_id)
485                        .copied()
486                        .unwrap_or(0);
487                    let mc_start = start_offsets
488                        .memcpy_activities
489                        .get(ctx_id)
490                        .copied()
491                        .unwrap_or(0);
492                    let ms_start = start_offsets
493                        .memset_activities
494                        .get(ctx_id)
495                        .copied()
496                        .unwrap_or(0);
497                    for activity in data.kernel_activities[ka_start..].iter() {
498                        channels.insert((activity.channel_id, activity.channel_type));
499                        contexts.insert(activity.context_id);
500                    }
501                    for activity in data.memcpy_activities[mc_start..].iter() {
502                        channels.insert((activity.channel_id, activity.channel_type));
503                        contexts.insert(activity.context_id);
504                    }
505                    for activity in data.memset_activities[ms_start..].iter() {
506                        channels.insert((activity.channel_id, activity.channel_type));
507                        contexts.insert(activity.context_id);
508                    }
509                }
510
511                // Build per-channel event lists.
512                let mut channel_events: std::collections::HashMap<(u32, u32), Vec<PendingEvent>> =
513                    std::collections::HashMap::new();
514
515                for (ctx_id, data) in state.context_data.iter() {
516                    let ka_start = start_offsets
517                        .kernel_activities
518                        .get(ctx_id)
519                        .copied()
520                        .unwrap_or(0);
521                    let launch_map: std::collections::HashMap<u32, &state::KernelLaunch> = data
522                        .kernel_launches
523                        .iter()
524                        .map(|l| (l.correlation_id, l))
525                        .collect();
526                    for activity in data.kernel_activities[ka_start..].iter() {
527                        let launch = launch_map.get(&activity.correlation_id);
528                        let (raw_start, raw_end) =
529                            if launch.is_some_and(|l| l.profiled_instances != 0) {
530                                let l = launch.unwrap();
531                                (l.start, l.end)
532                            } else {
533                                (activity.start, activity.end)
534                            };
535                        if raw_start == 0 || raw_end < raw_start {
536                            continue;
537                        }
538                        let demangled = perfetto_gpu_compute_injection::kernel::demangle_name(
539                            &activity.kernel_name,
540                        );
541                        let grid_size =
542                            activity.grid_size.0 * activity.grid_size.1 * activity.grid_size.2;
543                        let block_size =
544                            activity.block_size.0 * activity.block_size.1 * activity.block_size.2;
545                        let thread_count = grid_size * block_size;
546                        let cache_mode = launch.map_or(0, |l| l.cache_mode);
547                        let max_active_blocks = launch.map_or(0, |l| l.max_active_blocks_per_sm);
548                        let waves_per_multiprocessor = if data.num_sms > 0 && max_active_blocks > 0
549                        {
550                            grid_size as f64 / (data.num_sms * max_active_blocks) as f64
551                        } else {
552                            0.0
553                        };
554                        let regs_per_thread = activity.registers_per_thread as i32;
555                        let smem_per_block =
556                            activity.dynamic_shared_memory + activity.static_shared_memory;
557                        let warp_size = data.warp_size;
558                        let max_threads_sm = data.max_threads_per_sm;
559                        let max_blocks_sm = data.max_blocks_per_sm;
560                        let regs_per_sm = data.max_regs_per_sm;
561                        let smem_per_sm = data.max_smem_per_sm;
562                        let major = data.compute_capability.0;
563                        let minor = data.compute_capability.1;
564                        let warps_per_block = if warp_size > 0 {
565                            block_size / warp_size
566                        } else {
567                            0
568                        };
569                        let max_active_warps = max_active_blocks * warps_per_block;
570                        let regs_per_block = regs_per_thread * block_size;
571                        let max_warps_sm = if warp_size > 0 {
572                            max_threads_sm / warp_size
573                        } else {
574                            0
575                        };
576                        let max_active_warps_pct = if max_warps_sm > 0 {
577                            100.0 * max_active_warps as f64 / max_warps_sm as f64
578                        } else {
579                            0.0
580                        };
581                        let occupancy_limit_shared_mem = if smem_per_block != 0 {
582                            smem_per_sm / smem_per_block
583                        } else {
584                            16
585                        };
586                        let occupancy_limit_warps = if warps_per_block > 0 {
587                            max_warps_sm / warps_per_block
588                        } else {
589                            0
590                        };
591                        let occupancy_limit_registers = if regs_per_block != 0 {
592                            regs_per_sm / regs_per_block
593                        } else {
594                            16
595                        };
596
597                        // Build structured compute kernel fields.
598                        let arch = format!("CC_{}{}", major, minor);
599                        #[allow(nonstandard_style)]
600                        let func_cache_config_str = match cache_mode as u32 {
601                            CUfunc_cache_enum_CU_FUNC_CACHE_PREFER_NONE => "CachePreferNone",
602                            CUfunc_cache_enum_CU_FUNC_CACHE_PREFER_SHARED => "CachePreferShared",
603                            CUfunc_cache_enum_CU_FUNC_CACHE_PREFER_L1 => "CachePreferL1",
604                            CUfunc_cache_enum_CU_FUNC_CACHE_PREFER_EQUAL => "CachePreferEqual",
605                            _ => "n/a",
606                        };
607
608                        let launch_args: Vec<(u64, KernelArgValue)> = vec![
609                            (
610                                arg_iid("device"),
611                                KernelArgValue::Uint(activity.device_id as u64),
612                            ),
613                            (
614                                arg_iid("stream"),
615                                KernelArgValue::Uint(activity.stream_id as u64),
616                            ),
617                            (
618                                arg_iid("workgroup_size"),
619                                KernelArgValue::Uint(block_size as u64),
620                            ),
621                            (arg_iid("grid_size"), KernelArgValue::Uint(grid_size as u64)),
622                            (
623                                arg_iid("thread_count"),
624                                KernelArgValue::Uint(thread_count as u64),
625                            ),
626                            (
627                                arg_iid("shared_mem_dynamic"),
628                                KernelArgValue::Uint(activity.dynamic_shared_memory as u64),
629                            ),
630                            (
631                                arg_iid("waves_per_multiprocessor"),
632                                KernelArgValue::Double(waves_per_multiprocessor),
633                            ),
634                            (
635                                arg_iid("occupancy_limit_blocks"),
636                                KernelArgValue::Uint(max_blocks_sm as u64),
637                            ),
638                            (
639                                arg_iid("occupancy_limit_registers"),
640                                KernelArgValue::Uint(occupancy_limit_registers as u64),
641                            ),
642                            (
643                                arg_iid("occupancy_limit_shared_mem"),
644                                KernelArgValue::Uint(occupancy_limit_shared_mem as u64),
645                            ),
646                            (
647                                arg_iid("occupancy_limit_warps"),
648                                KernelArgValue::Uint(occupancy_limit_warps as u64),
649                            ),
650                            (
651                                arg_iid("sm__maximum_warps_per_active_cycle_pct"),
652                                KernelArgValue::Double(max_active_warps_pct),
653                            ),
654                            (
655                                arg_iid("sm__maximum_warps_avg_per_active_cycle"),
656                                KernelArgValue::Uint(max_active_warps as u64),
657                            ),
658                            (
659                                arg_iid("shared_mem_driver"),
660                                KernelArgValue::Uint(smem_per_block as u64),
661                            ),
662                        ];
663
664                        let simplified_name =
665                            perfetto_gpu_compute_injection::kernel::simplify_name(&demangled);
666                        channel_events
667                            .entry((activity.channel_id, activity.channel_type))
668                            .or_default()
669                            .push(PendingEvent {
670                                start: raw_start,
671                                end: raw_end,
672                                name: simplified_name.to_string(),
673                                name_iid: kernel_name_iid(simplified_name),
674                                channel_id: activity.channel_id,
675                                channel_type: activity.channel_type,
676                                activity_device_id: activity.device_id,
677                                activity_context_id: activity.context_id,
678                                activity_stream_id: activity.stream_id,
679                                stage_iid: state::KERNEL_STAGE_IID,
680                                correlation_id: activity.correlation_id,
681                                kernel_iid: Some(kernel_name_iid(&activity.kernel_name)),
682                                kernel_mangled_name: Some(activity.kernel_name.clone()),
683                                kernel_demangled_name: Some(demangled.clone()),
684                                kernel_arch: Some(arch),
685                                kernel_registers_per_thread: Some(
686                                    activity.registers_per_thread as u64,
687                                ),
688                                kernel_shared_mem_static: Some(
689                                    activity.static_shared_memory as u64,
690                                ),
691                                kernel_func_cache_config: Some(func_cache_config_str.to_string()),
692                                kernel_shared_mem_config_size: Some(49152),
693                                launch_grid: Some((
694                                    activity.grid_size.0 as u32,
695                                    activity.grid_size.1 as u32,
696                                    activity.grid_size.2 as u32,
697                                )),
698                                launch_block: Some((
699                                    activity.block_size.0 as u32,
700                                    activity.block_size.1 as u32,
701                                    activity.block_size.2 as u32,
702                                )),
703                                launch_args,
704                            });
705                    }
706                    let mc_start = start_offsets
707                        .memcpy_activities
708                        .get(ctx_id)
709                        .copied()
710                        .unwrap_or(0);
711                    for activity in data.memcpy_activities[mc_start..].iter() {
712                        if activity.start == 0 || activity.end < activity.start {
713                            injection_log!(
714                                "WARNING: memcpy activity with invalid timestamps \
715                                 (start={}, end={}), skipping: {} bytes",
716                                activity.start,
717                                activity.end,
718                                activity.bytes
719                            );
720                            continue;
721                        }
722                        let memcpy_name = format!("Memcpy {}", activity.direction_string());
723                        channel_events
724                            .entry((activity.channel_id, activity.channel_type))
725                            .or_default()
726                            .push(PendingEvent {
727                                start: activity.start,
728                                end: activity.end,
729                                name: memcpy_name.clone(),
730                                name_iid: kernel_name_iid(&memcpy_name),
731                                channel_id: activity.channel_id,
732                                channel_type: activity.channel_type,
733                                activity_device_id: activity.device_id,
734                                activity_context_id: activity.context_id,
735                                activity_stream_id: activity.stream_id,
736                                stage_iid: state::MEMCPY_STAGE_IID,
737                                correlation_id: activity.correlation_id,
738                                kernel_iid: None,
739                                kernel_mangled_name: None,
740                                kernel_demangled_name: None,
741                                kernel_arch: None,
742                                kernel_registers_per_thread: None,
743                                kernel_shared_mem_static: None,
744                                kernel_func_cache_config: None,
745                                kernel_shared_mem_config_size: None,
746                                launch_grid: None,
747                                launch_block: None,
748                                launch_args: vec![
749                                    (
750                                        arg_iid("device"),
751                                        KernelArgValue::Uint(activity.device_id as u64),
752                                    ),
753                                    (
754                                        arg_iid("stream"),
755                                        KernelArgValue::Uint(activity.stream_id as u64),
756                                    ),
757                                ],
758                            });
759                    }
760                    let ms_start = start_offsets
761                        .memset_activities
762                        .get(ctx_id)
763                        .copied()
764                        .unwrap_or(0);
765                    for activity in data.memset_activities[ms_start..].iter() {
766                        if activity.start == 0 || activity.end < activity.start {
767                            injection_log!(
768                                "WARNING: memset activity with invalid timestamps \
769                                 (start={}, end={}), skipping: {} bytes",
770                                activity.start,
771                                activity.end,
772                                activity.bytes
773                            );
774                            continue;
775                        }
776                        channel_events
777                            .entry((activity.channel_id, activity.channel_type))
778                            .or_default()
779                            .push(PendingEvent {
780                                start: activity.start,
781                                end: activity.end,
782                                name: "Memset".to_string(),
783                                name_iid: kernel_name_iid("Memset"),
784                                channel_id: activity.channel_id,
785                                channel_type: activity.channel_type,
786                                activity_device_id: activity.device_id,
787                                activity_context_id: activity.context_id,
788                                activity_stream_id: activity.stream_id,
789                                stage_iid: state::MEMSET_STAGE_IID,
790                                correlation_id: activity.correlation_id,
791                                kernel_iid: None,
792                                kernel_mangled_name: None,
793                                kernel_demangled_name: None,
794                                kernel_arch: None,
795                                kernel_registers_per_thread: None,
796                                kernel_shared_mem_static: None,
797                                kernel_func_cache_config: None,
798                                kernel_shared_mem_config_size: None,
799                                launch_grid: None,
800                                launch_block: None,
801                                launch_args: vec![
802                                    (
803                                        arg_iid("device"),
804                                        KernelArgValue::Uint(activity.device_id as u64),
805                                    ),
806                                    (
807                                        arg_iid("stream"),
808                                        KernelArgValue::Uint(activity.stream_id as u64),
809                                    ),
810                                ],
811                            });
812                    }
813                }
814
815                // Sort each channel's events by start time, then clamp to
816                // prevent overlaps caused by mixed timestamp sources.
817                let mut updated_prev_ends: std::collections::HashMap<(u32, u32), u64> =
818                    std::collections::HashMap::new();
819                for (channel_key, events) in channel_events.iter_mut() {
820                    events.sort_by_key(|e| e.start);
821                    let mut prev_end: u64 = start_offsets
822                        .channel_prev_end
823                        .get(channel_key)
824                        .copied()
825                        .unwrap_or(0);
826                    for event in events.iter_mut() {
827                        if event.start < prev_end {
828                            event.start = prev_end;
829                        }
830                        if event.end <= event.start {
831                            event.end = event.start + 1;
832                        }
833                        prev_end = event.end;
834                    }
835                    updated_prev_ends.insert(*channel_key, prev_end);
836                }
837
838                // Merge all channel events into a single list sorted by timestamp.
839                let mut all_events: Vec<PendingEvent> = channel_events
840                    .into_values()
841                    .flat_map(|v| v.into_iter())
842                    .collect();
843                all_events.sort_by_key(|e| e.start);
844
845                // Advance consumer offsets NOW (while holding the lock)
846                // to exactly what we consumed. This prevents the race
847                // where advance_and_drain sets offsets to a length that
848                // includes activities added by buffer_completed between
849                // emit and drain.
850                if is_flush {
851                    let lengths: Vec<(u32, usize, usize, usize, usize)> = state
852                        .context_data
853                        .iter()
854                        .map(|(&ctx_id, data)| {
855                            (
856                                ctx_id,
857                                data.kernel_launches.len(),
858                                data.kernel_activities.len(),
859                                data.memcpy_activities.len(),
860                                data.memset_activities.len(),
861                            )
862                        })
863                        .collect();
864                    if let Some(offsets) = state.renderstages_consumers.get_mut(&inst_id) {
865                        for (ctx_id, kl, ka, mc, ms) in &lengths {
866                            offsets.kernel_launches.insert(*ctx_id, *kl);
867                            offsets.kernel_activities.insert(*ctx_id, *ka);
868                            offsets.memcpy_activities.insert(*ctx_id, *mc);
869                            offsets.memset_activities.insert(*ctx_id, *ms);
870                        }
871                    }
872                }
873
874                (all_events, channels, contexts, updated_prev_ends, is_flush)
875                // state (GLOBAL_STATE lock) dropped here
876            };
877
878            // Phase 2: Emit collected events without holding GLOBAL_STATE.
879            // This prevents deadlock with buffer_completed callback.
880
881            // Collect unique event names and kernels for interning.
882            let mut unique_event_names: Vec<(u64, String)> = Vec::new();
883            {
884                let mut seen_name_iids: HashSet<u64> = HashSet::new();
885                for event in &all_events {
886                    if seen_name_iids.insert(event.name_iid) {
887                        unique_event_names.push((event.name_iid, event.name.clone()));
888                    }
889                }
890            }
891            let mut unique_kernels: Vec<UniqueKernel> = Vec::new();
892            {
893                let mut seen_kernel_iids: HashSet<u64> = HashSet::new();
894                for event in &all_events {
895                    if let Some(kiid) = event.kernel_iid {
896                        if seen_kernel_iids.insert(kiid) {
897                            unique_kernels.push(UniqueKernel {
898                                iid: kiid,
899                                mangled_name: event.kernel_mangled_name.clone().unwrap_or_default(),
900                                demangled_name: event
901                                    .kernel_demangled_name
902                                    .clone()
903                                    .unwrap_or_default(),
904                                arch: event.kernel_arch.clone().unwrap_or_default(),
905                                registers_per_thread: event
906                                    .kernel_registers_per_thread
907                                    .unwrap_or(0),
908                                shared_mem_static: event.kernel_shared_mem_static.unwrap_or(0),
909                                func_cache_config: event
910                                    .kernel_func_cache_config
911                                    .clone()
912                                    .unwrap_or_default(),
913                                shared_mem_config_size: event
914                                    .kernel_shared_mem_config_size
915                                    .unwrap_or(0),
916                                process_name: process_name.clone(),
917                                process_id: process_id as u64,
918                            });
919                        }
920                    }
921                }
922            }
923
924            let mut stop_guard_opt = stop_guard;
925            get_renderstages_data_source().trace(|ctx: &mut TraceContext| {
926                if ctx.instance_index() != inst_id {
927                    return;
928                }
929                for event in &all_events {
930                    let timestamp = event.start;
931                    let duration_ns = event.end.saturating_sub(event.start);
932                    ctx.with_incremental_state(|ctx: &mut TraceContext, inc_state| {
933                        let was_cleared = std::mem::replace(&mut inc_state.was_cleared, false);
934                        let rs_ctx = RenderStageContext {
935                            channels: &channels,
936                            contexts: &contexts,
937                            process_id,
938                            unique_kernels: &unique_kernels,
939                            unique_event_names: &unique_event_names,
940                        };
941                        emit_render_stage_event(
942                            ctx,
943                            event,
944                            timestamp,
945                            duration_ns,
946                            was_cleared,
947                            &rs_ctx,
948                            event.name_iid,
949                            event.kernel_iid,
950                            event.launch_grid,
951                            event.launch_block,
952                            &event.launch_args,
953                        );
954                    });
955                }
956                if let Some(sg) = stop_guard_opt.take() {
957                    let mut sg = Some(Some(sg));
958                    ctx.flush(move || drop(sg.take()));
959                }
960            });
961            drop(stop_guard_opt);
962
963            // Phase 3: Re-acquire lock to update channel_prev_end for flush path.
964            if is_flush {
965                if let Ok(mut state) = GLOBAL_STATE.lock() {
966                    if let Some(offsets) = state.renderstages_consumers.get_mut(&inst_id) {
967                        offsets.channel_prev_end = updated_prev_ends;
968                    }
969                }
970            }
971        });
972    }
973
974    fn flush_renderstage_events(&self) {
975        self.flush_activity_buffers();
976        let inst_ids: Vec<u32> = GLOBAL_STATE
977            .lock()
978            .map(|s| s.renderstages_consumers.keys().copied().collect())
979            .unwrap_or_default();
980        for inst_id in inst_ids {
981            self.emit_renderstage_events_for_instance(inst_id, None);
982        }
983        if let Ok(mut state) = GLOBAL_STATE.lock() {
984            state.advance_and_drain_renderstage_events();
985        }
986    }
987
988    fn flush_counter_events(&self) {
989        self.flush_activity_buffers();
990        let inst_ids: Vec<u32> = GLOBAL_STATE
991            .lock()
992            .map(|s| s.counters_consumers.keys().copied().collect())
993            .unwrap_or_default();
994        for inst_id in inst_ids {
995            self.emit_counter_events_for_instance(inst_id, None);
996        }
997        if let Ok(mut state) = GLOBAL_STATE.lock() {
998            state.advance_and_drain_counter_events();
999        }
1000    }
1001}
1002
1003// ---------------------------------------------------------------------------
1004// Compute kernel structured proto constants and helpers
1005// ---------------------------------------------------------------------------
1006
1007const COMPUTE_ARG_NAMES: &[(u64, &str)] = &[
1008    (1, "workgroup_size"),
1009    (2, "grid_size"),
1010    (3, "thread_count"),
1011    (4, "shared_mem_dynamic"),
1012    (5, "waves_per_multiprocessor"),
1013    (6, "occupancy_limit_blocks"),
1014    (7, "occupancy_limit_registers"),
1015    (8, "occupancy_limit_shared_mem"),
1016    (9, "occupancy_limit_warps"),
1017    (10, "sm__maximum_warps_per_active_cycle_pct"),
1018    (11, "sm__maximum_warps_avg_per_active_cycle"),
1019    (12, "process_name"),
1020    (13, "process_id"),
1021    (14, "registers_per_thread"),
1022    (15, "shared_mem_static"),
1023    (16, "func_cache_config"),
1024    (17, "device"),
1025    (18, "stream"),
1026    (19, "shared_mem_config_size"),
1027    (20, "shared_mem_driver"),
1028];
1029
1030fn arg_iid(name: &str) -> u64 {
1031    COMPUTE_ARG_NAMES
1032        .iter()
1033        .find(|(_, n)| *n == name)
1034        .unwrap_or_else(|| panic!("unknown compute arg name: {name}"))
1035        .0
1036}
1037
1038/// Simple hash of kernel name to produce a stable IID for InternedComputeKernel.
1039fn kernel_name_iid(name: &str) -> u64 {
1040    let mut h: u64 = 5381;
1041    for b in name.bytes() {
1042        h = h.wrapping_mul(33).wrapping_add(b as u64);
1043    }
1044    // Avoid 0 (reserved) by ensuring non-zero
1045    if h == 0 {
1046        1
1047    } else {
1048        h
1049    }
1050}
1051
1052enum KernelArgValue {
1053    Uint(u64),
1054    Double(f64),
1055    #[allow(dead_code)]
1056    Str(String),
1057}
1058
1059fn set_kernel_arg_uint(kernel: &mut InternedComputeKernel, name: &str, value: u64) {
1060    kernel.set_args(|arg: &mut GpuRenderStageEventExtraComputeArg| {
1061        arg.set_name_iid(arg_iid(name));
1062        arg.set_uint_value(value);
1063    });
1064}
1065
1066fn set_kernel_arg_string(kernel: &mut InternedComputeKernel, name: &str, value: &str) {
1067    kernel.set_args(|arg: &mut GpuRenderStageEventExtraComputeArg| {
1068        arg.set_name_iid(arg_iid(name));
1069        arg.set_string_value(value);
1070    });
1071}
1072
1073// ---------------------------------------------------------------------------
1074// Render stage helpers
1075// ---------------------------------------------------------------------------
1076
1077/// Information about a unique kernel for interning.
1078struct UniqueKernel {
1079    iid: u64,
1080    mangled_name: String,
1081    demangled_name: String,
1082    arch: String,
1083    registers_per_thread: u64,
1084    shared_mem_static: u64,
1085    func_cache_config: String,
1086    shared_mem_config_size: u64,
1087    process_name: String,
1088    process_id: u64,
1089}
1090
1091fn emit_interned_specifications(
1092    packet: &mut TracePacket,
1093    channels: &std::collections::HashSet<(u32, u32)>,
1094    contexts: &std::collections::HashSet<u32>,
1095    process_id: i32,
1096    unique_kernels: &[UniqueKernel],
1097    unique_event_names: &[(u64, String)],
1098) {
1099    packet.set_sequence_flags(TracePacketSequenceFlags::SeqIncrementalStateCleared as u32);
1100    packet.set_interned_data(|interned: &mut InternedData| {
1101        for context_id in contexts {
1102            interned.set_graphics_contexts(|ctx: &mut InternedGraphicsContext| {
1103                ctx.set_iid(*context_id as u64);
1104                ctx.set_pid(process_id);
1105                ctx.set_api(InternedGraphicsContextApi::Cuda);
1106            });
1107        }
1108        for (channel_id, channel_type) in channels {
1109            let queue_iid = *channel_id as u64 + HW_QUEUE_IID_OFFSET;
1110            let queue_category = match *channel_type {
1111                1 => InternedGpuRenderStageSpecificationRenderStageCategory::Compute,
1112                _ => InternedGpuRenderStageSpecificationRenderStageCategory::Other,
1113            };
1114            interned.set_gpu_specifications(|spec: &mut InternedGpuRenderStageSpecification| {
1115                spec.set_iid(queue_iid);
1116                spec.set_name(format!("Channel #{}", channel_id + 1));
1117                spec.set_category(queue_category);
1118            });
1119        }
1120        interned.set_gpu_specifications(|spec: &mut InternedGpuRenderStageSpecification| {
1121            spec.set_iid(KERNEL_STAGE_IID);
1122            spec.set_name("Kernel");
1123            spec.set_category(InternedGpuRenderStageSpecificationRenderStageCategory::Compute);
1124        });
1125        interned.set_gpu_specifications(|spec: &mut InternedGpuRenderStageSpecification| {
1126            spec.set_iid(MEMCPY_STAGE_IID);
1127            spec.set_name("MemoryTransfer");
1128            spec.set_category(InternedGpuRenderStageSpecificationRenderStageCategory::Other);
1129        });
1130        interned.set_gpu_specifications(|spec: &mut InternedGpuRenderStageSpecification| {
1131            spec.set_iid(MEMSET_STAGE_IID);
1132            spec.set_name("MemorySet");
1133            spec.set_category(InternedGpuRenderStageSpecificationRenderStageCategory::Other);
1134        });
1135        for &(iid, name) in COMPUTE_ARG_NAMES {
1136            interned.set_compute_arg_names(|an: &mut InternedComputeArgName| {
1137                an.set_iid(iid);
1138                an.set_name(name);
1139            });
1140        }
1141        for (iid, name) in unique_event_names {
1142            interned.set_event_names(|en: &mut EventName| {
1143                en.set_iid(*iid);
1144                en.set_name(name);
1145            });
1146        }
1147        for kernel in unique_kernels {
1148            interned.set_compute_kernels(|ck: &mut InternedComputeKernel| {
1149                ck.set_iid(kernel.iid);
1150                ck.set_name(&kernel.mangled_name);
1151                ck.set_demangled_name(&kernel.demangled_name);
1152                ck.set_arch(&kernel.arch);
1153                set_kernel_arg_uint(ck, "registers_per_thread", kernel.registers_per_thread);
1154                set_kernel_arg_uint(ck, "shared_mem_static", kernel.shared_mem_static);
1155                set_kernel_arg_string(ck, "func_cache_config", &kernel.func_cache_config);
1156                set_kernel_arg_uint(ck, "shared_mem_config_size", kernel.shared_mem_config_size);
1157                set_kernel_arg_string(ck, "process_name", &kernel.process_name);
1158                set_kernel_arg_uint(ck, "process_id", kernel.process_id);
1159            });
1160        }
1161    });
1162}
1163
1164/// Hardware block prefix groups for organizing instrumented counters.
1165/// Each entry is (metric_name_prefix, display_name).
1166const COUNTER_GROUPS: &[(&str, &str)] = &[
1167    ("dram__", "DRAM"),
1168    ("gpc__", "GPC"),
1169    ("sm__", "SM"),
1170    ("gpu__", "GPU"),
1171    ("l1tex__", "L1TEX"),
1172    ("lts__", "LTS"),
1173];
1174
1175/// Emit interned counter descriptors for all GPUs in the batch.
1176///
1177/// Each GPU gets one `InternedGpuCounterDescriptor` with iid = gpu_id + 1,
1178/// containing all counter specs with simple 0-based counter_ids. The gpu_id
1179/// on the interned descriptor handles per-GPU track separation.
1180///
1181/// Counters are organized into hardware block groups based on their metric
1182/// name prefix (e.g. `sm__` → "SM", `dram__` → "DRAM").
1183fn emit_interned_counter_descriptors(
1184    ctx: &mut TraceContext,
1185    collected_events: &[CollectedCounterEvent],
1186    gpu_ids: &HashSet<i32>,
1187) {
1188    ctx.add_packet(|packet: &mut TracePacket| {
1189        packet.set_sequence_flags(TracePacketSequenceFlags::SeqIncrementalStateCleared as u32);
1190        packet.set_interned_data(|interned: &mut InternedData| {
1191            for &gpu_id in gpu_ids {
1192                // Find the first event for this GPU to get metric names.
1193                let Some(sample) = collected_events.iter().find(|e| e.gpu_id == gpu_id) else {
1194                    continue;
1195                };
1196                interned.set_gpu_counter_descriptors(|desc: &mut InternedGpuCounterDescriptor| {
1197                    desc.set_iid(gpu_id as u64 + 1);
1198                    desc.set_gpu_id(gpu_id);
1199                    desc.set_counter_descriptor(|cd: &mut GpuCounterDescriptor| {
1200                        for (i, (metric_name, _)) in sample.metrics.iter().enumerate() {
1201                            cd.set_specs(|spec: &mut GpuCounterDescriptorGpuCounterSpec| {
1202                                spec.set_counter_id(i as u32);
1203                                spec.set_name(metric_name);
1204                                spec.set_groups(GpuCounterDescriptorGpuCounterGroup::Compute);
1205                            });
1206                        }
1207                        // Group counters by hardware block prefix.
1208                        for (group_id, &(prefix, group_name)) in COUNTER_GROUPS.iter().enumerate() {
1209                            let member_ids: Vec<u32> = sample
1210                                .metrics
1211                                .iter()
1212                                .enumerate()
1213                                .filter(|(_, (name, _))| name.starts_with(prefix))
1214                                .map(|(i, _)| i as u32)
1215                                .collect();
1216                            if !member_ids.is_empty() {
1217                                cd.set_counter_groups(
1218                                    |g: &mut GpuCounterDescriptorGpuCounterGroupSpec| {
1219                                        g.set_group_id(group_id as u32);
1220                                        g.set_name(group_name);
1221                                        for &id in &member_ids {
1222                                            g.set_counter_ids(id);
1223                                        }
1224                                    },
1225                                );
1226                            }
1227                        }
1228                        cd.set_supports_instrumented_sampling(true);
1229                        cd.set_supports_counter_names(true);
1230                        cd.set_supports_counter_name_globs(true);
1231                    });
1232                });
1233            }
1234        });
1235    });
1236}
1237
1238struct RenderStageContext<'a> {
1239    channels: &'a std::collections::HashSet<(u32, u32)>,
1240    contexts: &'a std::collections::HashSet<u32>,
1241    process_id: i32,
1242    unique_kernels: &'a [UniqueKernel],
1243    unique_event_names: &'a [(u64, String)],
1244}
1245
1246#[allow(clippy::too_many_arguments)]
1247fn emit_render_stage_event<T: GpuActivity>(
1248    ctx: &mut TraceContext,
1249    activity: &T,
1250    timestamp: u64,
1251    duration_ns: u64,
1252    emit_interned: bool,
1253    rs_ctx: &RenderStageContext,
1254    name_iid: u64,
1255    kernel_iid: Option<u64>,
1256    launch_grid: Option<(u32, u32, u32)>,
1257    launch_block: Option<(u32, u32, u32)>,
1258    launch_args: &[(u64, KernelArgValue)],
1259) {
1260    let hw_queue_iid = activity.channel_id() as u64 + HW_QUEUE_IID_OFFSET;
1261    let context_iid = activity.context_id() as u64;
1262    let gpu_id = activity.gpu_id() as i32;
1263    let stage_iid = activity.stage_iid();
1264
1265    ctx.add_packet(|packet: &mut TracePacket| {
1266        packet
1267            .set_timestamp(timestamp)
1268            .set_timestamp_clock_id(BuiltinClock::BuiltinClockBoottime.into())
1269            .set_gpu_render_stage_event(|event: &mut GpuRenderStageEvent| {
1270                event
1271                    .set_event_id(activity.correlation_id() as u64)
1272                    .set_duration(duration_ns)
1273                    .set_gpu_id(gpu_id)
1274                    .set_hw_queue_iid(hw_queue_iid)
1275                    .set_stage_iid(stage_iid)
1276                    .set_context(context_iid)
1277                    .set_name_iid(name_iid);
1278                if let Some(kiid) = kernel_iid {
1279                    event.set_kernel_iid(kiid);
1280                }
1281                if kernel_iid.is_some()
1282                    || launch_grid.is_some()
1283                    || launch_block.is_some()
1284                    || !launch_args.is_empty()
1285                {
1286                    event.set_launch(|launch: &mut GpuRenderStageEventComputeKernelLaunch| {
1287                        if let Some((gx, gy, gz)) = launch_grid {
1288                            launch.set_grid_size(|d: &mut GpuRenderStageEventDim3| {
1289                                d.set_x(gx);
1290                                d.set_y(gy);
1291                                d.set_z(gz);
1292                            });
1293                        }
1294                        if let Some((bx, by, bz)) = launch_block {
1295                            launch.set_workgroup_size(|d: &mut GpuRenderStageEventDim3| {
1296                                d.set_x(bx);
1297                                d.set_y(by);
1298                                d.set_z(bz);
1299                            });
1300                        }
1301                        for (name_iid, value) in launch_args {
1302                            launch.set_args(|arg: &mut GpuRenderStageEventExtraComputeArg| {
1303                                arg.set_name_iid(*name_iid);
1304                                match value {
1305                                    KernelArgValue::Uint(v) => {
1306                                        arg.set_uint_value(*v);
1307                                    }
1308                                    KernelArgValue::Double(v) => {
1309                                        arg.set_double_value(*v);
1310                                    }
1311                                    KernelArgValue::Str(v) => {
1312                                        arg.set_string_value(v);
1313                                    }
1314                                }
1315                            });
1316                        }
1317                    });
1318                }
1319            });
1320        if emit_interned {
1321            emit_interned_specifications(
1322                packet,
1323                rs_ctx.channels,
1324                rs_ctx.contexts,
1325                rs_ctx.process_id,
1326                rs_ctx.unique_kernels,
1327                rs_ctx.unique_event_names,
1328            );
1329        }
1330    });
1331}
1332
1333// ---------------------------------------------------------------------------
1334// API call track event emission
1335// ---------------------------------------------------------------------------
1336
1337// ---------------------------------------------------------------------------
1338// Atexit fallback
1339// ---------------------------------------------------------------------------
1340
1341extern "C" fn end_execution() {
1342    let _ = panic::catch_unwind(|| {
1343        let nvidia = CuptiBackend;
1344        nvidia.run_teardown();
1345        let counter_ids: Vec<u32> = GLOBAL_STATE
1346            .lock()
1347            .map(|s| s.counters_consumers.keys().copied().collect())
1348            .unwrap_or_default();
1349        let renderstage_ids: Vec<u32> = GLOBAL_STATE
1350            .lock()
1351            .map(|s| s.renderstages_consumers.keys().copied().collect())
1352            .unwrap_or_default();
1353        for inst_id in counter_ids {
1354            nvidia.emit_counter_events_for_instance(inst_id, None);
1355        }
1356        for inst_id in renderstage_ids {
1357            nvidia.emit_renderstage_events_for_instance(inst_id, None);
1358        }
1359    });
1360}
1361
1362unsafe extern "C" fn activity_timestamp_callback() -> u64 {
1363    perfetto_gpu_compute_injection::tracing::trace_time_ns()
1364}
1365
1366fn register_profiler_callbacks() -> Result<CUpti_SubscriberHandle, CUptiResult> {
1367    unsafe { profiler::activity_register_timestamp_callback(Some(activity_timestamp_callback)) }?;
1368
1369    let subscriber =
1370        unsafe { profiler::subscribe(Some(profiler_callback_handler), ptr::null_mut()) }?;
1371
1372    // Enable entire RUNTIME and DRIVER API domains so we can emit
1373    // track events directly from callbacks for all API calls.
1374    unsafe {
1375        profiler::enable_domain(
1376            1,
1377            subscriber,
1378            CUpti_CallbackDomain_CUPTI_CB_DOMAIN_RUNTIME_API,
1379        )
1380    }?;
1381    unsafe {
1382        profiler::enable_domain(
1383            1,
1384            subscriber,
1385            CUpti_CallbackDomain_CUPTI_CB_DOMAIN_DRIVER_API,
1386        )
1387    }?;
1388    unsafe {
1389        profiler::enable_callback(
1390            1,
1391            subscriber,
1392            CUpti_CallbackDomain_CUPTI_CB_DOMAIN_RESOURCE,
1393            CUpti_CallbackIdResource_CUPTI_CBID_RESOURCE_CONTEXT_CREATED,
1394        )
1395    }?;
1396    unsafe {
1397        profiler::enable_callback(
1398            1,
1399            subscriber,
1400            CUpti_CallbackDomain_CUPTI_CB_DOMAIN_RESOURCE,
1401            CUpti_CallbackIdResource_CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING,
1402        )
1403    }?;
1404    unsafe { profiler::enable_domain(1, subscriber, CUpti_CallbackDomain_CUPTI_CB_DOMAIN_STATE) }?;
1405    unsafe { profiler::enable_domain(1, subscriber, CUpti_CallbackDomain_CUPTI_CB_DOMAIN_NVTX) }?;
1406    unsafe {
1407        profiler::activity_register_callbacks(Some(buffer_requested), Some(buffer_completed))
1408    }?;
1409
1410    profiler::activity_enable(CUpti_ActivityKind_CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL)?;
1411    profiler::activity_enable(CUpti_ActivityKind_CUPTI_ACTIVITY_KIND_MEMCPY)?;
1412    profiler::activity_enable(CUpti_ActivityKind_CUPTI_ACTIVITY_KIND_MEMSET)?;
1413
1414    unsafe { libc::atexit(end_execution) };
1415
1416    Ok(subscriber)
1417}
1418
1419// ---------------------------------------------------------------------------
1420// Public C entry point
1421// ---------------------------------------------------------------------------
1422
1423/// NVIDIA injection entry point (called via CUDA_INJECTION64_PATH mechanism).
1424#[no_mangle]
1425pub extern "C" fn InitializeInjection() -> i32 {
1426    let result = panic::catch_unwind(|| {
1427        let mut config = Config::from_env();
1428
1429        // Log CUPTI version for debugging compatibility issues
1430        let mut cupti_version: u32 = 0;
1431        unsafe { profiler::bindings::cuptiGetVersion(&mut cupti_version) };
1432        injection_log!("CUPTI version: {}", cupti_version);
1433
1434        // Use system thread IDs (gettid on Linux) so that activity record
1435        // threadId values match Perfetto's thread track UUIDs.
1436        let _ = unsafe {
1437            cuptiSetThreadIdType(CUpti_ActivityThreadIdType_CUPTI_ACTIVITY_THREAD_ID_TYPE_SYSTEM)
1438        };
1439
1440        let metrics_str = std::env::var("INJECTION_METRICS").unwrap_or_default();
1441        config.metrics = parse_metrics(&metrics_str);
1442
1443        register_backend(CuptiBackend);
1444
1445        let producer_args = ProducerInitArgsBuilder::new().backends(Backends::SYSTEM);
1446        Producer::init(producer_args.build());
1447        TrackEvent::init();
1448        let _ = get_counters_data_source();
1449        let _ = get_renderstages_data_source();
1450        let _ = get_track_event_data_source();
1451
1452        if let Ok(mut state) = GLOBAL_STATE.lock() {
1453            if !state.injection_initialized {
1454                state.injection_initialized = true;
1455                state.config = config;
1456
1457                match register_profiler_callbacks() {
1458                    Ok(subscriber) => {
1459                        state.subscriber_handle = subscriber;
1460                    }
1461                    Err(e) => {
1462                        injection_log!("Failed to register callbacks: {:?}", e);
1463                        return 0;
1464                    }
1465                }
1466            }
1467        }
1468        1
1469    });
1470    result.unwrap_or(0)
1471}