Skip to main content

perfetto_hip_injection/
lib.rs

1// Copyright (C) 2026 David Reveman.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod callbacks;
16mod metrics;
17pub mod rocprofiler_sys;
18mod state;
19
20use perfetto_gpu_compute_injection::injection_log;
21use perfetto_gpu_compute_injection::tracing::{
22    get_counters_data_source, get_next_event_id, get_renderstages_data_source, register_backend,
23    GpuBackend, GOT_FIRST_RENDERSTAGES,
24};
25use perfetto_sdk::{
26    data_source::{StopGuard, TraceContext},
27    protos::{
28        common::builtin_clock::BuiltinClock,
29        trace::{
30            interned_data::interned_data::InternedData,
31            trace_packet::{TracePacket, TracePacketSequenceFlags},
32        },
33    },
34    track_event::TrackEvent,
35    track_event_categories,
36};
37use perfetto_sdk_protos_gpu::protos::{
38    common::gpu_counter_descriptor::{
39        GpuCounterDescriptor, GpuCounterDescriptorGpuCounterGroup,
40        GpuCounterDescriptorGpuCounterSpec,
41    },
42    trace::{
43        gpu::{
44            gpu_counter_event::{GpuCounterEvent, GpuCounterEventGpuCounter},
45            gpu_render_stage_event::{
46                GpuRenderStageEvent, GpuRenderStageEventExtraData,
47                InternedGpuRenderStageSpecification,
48                InternedGpuRenderStageSpecificationRenderStageCategory, InternedGraphicsContext,
49                InternedGraphicsContextApi,
50            },
51        },
52        interned_data::interned_data::prelude::*,
53        trace_packet::prelude::*,
54    },
55};
56use rocprofiler_sys::*;
57use state::{ConsumerStartOffsets, CounterConsumerStartOffsets, GLOBAL_STATE};
58use std::{collections::HashSet, panic, sync::atomic::Ordering};
59
60// ---------------------------------------------------------------------------
61// Track event categories for HIP API call tracing
62// ---------------------------------------------------------------------------
63
64track_event_categories! {
65    pub mod hip_te_ns {
66        ( "hip", "HIP Runtime API calls", [ "api" ] ),
67    }
68}
69use hip_te_ns as perfetto_te_ns;
70
71// IID for the HIP Compute stage specification.
72const AMD_KERNEL_STAGE_IID: u64 = 1;
73const AMD_MEMCPY_STAGE_IID: u64 = 2;
74const AMD_MEMSET_STAGE_IID: u64 = 3;
75// Queue IID base offset to avoid collision with stage IIDs.
76const AMD_HW_QUEUE_IID_OFFSET: u64 = 1000;
77
78// ---------------------------------------------------------------------------
79// RocprofilerBackend implementation
80// ---------------------------------------------------------------------------
81
82struct RocprofilerBackend;
83
84impl GpuBackend for RocprofilerBackend {
85    fn default_data_source_suffix(&self) -> &'static str {
86        "amd"
87    }
88
89    fn on_first_consumer_start(&self) {
90        // Start the rocprofiler tracing context so buffer records flow.
91        let context_handle = GLOBAL_STATE.lock().ok().and_then(|s| s.tracing_context);
92        if let Some(handle) = context_handle {
93            let ctx = rocprofiler_context_id_t { handle };
94            let status = unsafe { rocprofiler_start_context(ctx) };
95            if status != ROCPROFILER_STATUS_SUCCESS {
96                injection_log!("rocprofiler_start_context failed: {}", status);
97            }
98        }
99    }
100
101    fn on_renderstages_start_no_counters(&self) {
102        // No-op for RocProfiler: context is already started by on_first_consumer_start.
103    }
104
105    fn register_renderstages_consumer(&self, inst_id: u32) {
106        if let Ok(mut state) = GLOBAL_STATE.lock() {
107            let offsets = ConsumerStartOffsets::snapshot(&state);
108            state.renderstages_consumers.insert(inst_id, offsets);
109        }
110    }
111
112    fn run_teardown(&self) {
113        // For AMD: stop the tracing context and flush the buffer.
114        let (context_handle, buffer_handle) = {
115            let state = match GLOBAL_STATE.lock() {
116                Ok(s) => s,
117                Err(_) => return,
118            };
119            (state.tracing_context, state.tracing_buffer)
120        };
121
122        if let Some(handle) = context_handle {
123            let ctx = rocprofiler_context_id_t { handle };
124            let _ = unsafe { rocprofiler_stop_context(ctx) };
125        }
126
127        if let Some(handle) = buffer_handle {
128            let buf = rocprofiler_buffer_id_t { handle };
129            let _ = unsafe { rocprofiler_flush_buffer(buf) };
130        }
131    }
132
133    fn flush_activity_buffers(&self) {
134        let buffer_handle = GLOBAL_STATE.lock().ok().and_then(|s| s.tracing_buffer);
135        if let Some(handle) = buffer_handle {
136            let buf = rocprofiler_buffer_id_t { handle };
137            let _ = unsafe { rocprofiler_flush_buffer(buf) };
138        }
139    }
140
141    fn emit_renderstage_events_for_instance(&self, inst_id: u32, stop_guard: Option<StopGuard>) {
142        let _ = panic::catch_unwind(|| {
143            let (process_id, process_name) =
144                perfetto_gpu_compute_injection::config::get_process_info();
145
146            // Phase 1: Collect all event data under GLOBAL_STATE lock, then release.
147            struct PendingRenderStageEvent {
148                start_ns: u64,
149                end_ns: u64,
150                gpu_id: i32,
151                hw_queue_iid: u64,
152                stage_iid: u64,
153                name: String,
154                extra_fields: Vec<(String, String)>,
155            }
156
157            let (events, queues) = {
158                let mut state = match GLOBAL_STATE.lock() {
159                    Ok(s) => s,
160                    Err(_) => return,
161                };
162                let start_offsets = if stop_guard.is_some() {
163                    match state.renderstages_consumers.remove(&inst_id) {
164                        Some(o) => o,
165                        None => return,
166                    }
167                } else {
168                    match state.renderstages_consumers.get(&inst_id).cloned() {
169                        Some(o) => o,
170                        None => return,
171                    }
172                };
173                let kd_start = start_offsets.kernel_dispatches;
174                let mc_start = start_offsets.memcopies;
175                let ms_start = start_offsets.memsets;
176
177                // Collect unique queue handles for interned specs.
178                let mut queues: std::collections::HashSet<u64> = std::collections::HashSet::new();
179                for kd in state.kernel_dispatches[kd_start..].iter() {
180                    queues.insert(kd.queue_handle);
181                }
182
183                let mut events: Vec<PendingRenderStageEvent> = Vec::new();
184
185                // Kernel dispatch events.
186                for kd in state.kernel_dispatches[kd_start..].iter() {
187                    let demangled =
188                        perfetto_gpu_compute_injection::kernel::demangle_name(&kd.kernel_name);
189                    let grid_size = kd.grid.0 * kd.grid.1 * kd.grid.2;
190                    let workgroup_size = kd.workgroup.0 * kd.workgroup.1 * kd.workgroup.2;
191                    let thread_count = grid_size * workgroup_size;
192                    let total_waves = if kd.wave_front_size > 0 {
193                        (grid_size * workgroup_size).div_ceil(kd.wave_front_size)
194                    } else {
195                        0
196                    };
197                    let waves_per_cu = if kd.cu_count > 0 {
198                        total_waves as f64 / kd.cu_count as f64
199                    } else {
200                        0.0
201                    };
202                    let hw_queue_iid = (kd.queue_handle & 0xFFFF) + AMD_HW_QUEUE_IID_OFFSET;
203                    let extra_fields: Vec<(String, String)> = vec![
204                        ("kernel_name".to_string(), kd.kernel_name.clone()),
205                        ("kernel_demangled_name".to_string(), demangled.clone()),
206                        ("kernel_type".to_string(), "Compute".to_string()),
207                        ("process_id".to_string(), process_id.to_string()),
208                        ("process_name".to_string(), process_name.clone()),
209                        ("device_id".to_string(), kd.device_index.to_string()),
210                        ("arch".to_string(), kd.arch.clone()),
211                        ("queue_id".to_string(), kd.queue_handle.to_string()),
212                        ("launch__grid_size".to_string(), grid_size.to_string()),
213                        ("launch__grid_size_x".to_string(), kd.grid.0.to_string()),
214                        ("launch__grid_size_y".to_string(), kd.grid.1.to_string()),
215                        ("launch__grid_size_z".to_string(), kd.grid.2.to_string()),
216                        ("launch__block_size".to_string(), workgroup_size.to_string()),
217                        (
218                            "launch__block_size_x".to_string(),
219                            kd.workgroup.0.to_string(),
220                        ),
221                        (
222                            "launch__block_size_y".to_string(),
223                            kd.workgroup.1.to_string(),
224                        ),
225                        (
226                            "launch__block_size_z".to_string(),
227                            kd.workgroup.2.to_string(),
228                        ),
229                        ("launch__thread_count".to_string(), thread_count.to_string()),
230                        (
231                            "launch__waves_per_multiprocessor".to_string(),
232                            format!("{:.2}", waves_per_cu),
233                        ),
234                    ];
235                    events.push(PendingRenderStageEvent {
236                        start_ns: kd.start_ns,
237                        end_ns: kd.end_ns,
238                        gpu_id: kd.device_index,
239                        hw_queue_iid,
240                        stage_iid: AMD_KERNEL_STAGE_IID,
241                        name: perfetto_gpu_compute_injection::kernel::simplify_name(&demangled)
242                            .to_string(),
243                        extra_fields,
244                    });
245                }
246
247                // Memory copy events.
248                for mc in state.memcopies[mc_start..].iter() {
249                    let memcpy_name = match mc.direction {
250                        1 => "Memcpy HtoH",
251                        2 => "Memcpy HtoD",
252                        3 => "Memcpy DtoH",
253                        4 => "Memcpy DtoD",
254                        _ => "Memcpy",
255                    };
256                    let extra_fields: Vec<(String, String)> = vec![
257                        ("process_id".to_string(), process_id.to_string()),
258                        ("process_name".to_string(), process_name.clone()),
259                        ("device_id".to_string(), mc.device_index.to_string()),
260                        ("size_bytes".to_string(), mc.bytes.to_string()),
261                        ("direction".to_string(), mc.direction.to_string()),
262                    ];
263                    events.push(PendingRenderStageEvent {
264                        start_ns: mc.start_ns,
265                        end_ns: mc.end_ns,
266                        gpu_id: mc.device_index,
267                        hw_queue_iid: AMD_HW_QUEUE_IID_OFFSET,
268                        stage_iid: AMD_MEMCPY_STAGE_IID,
269                        name: memcpy_name.to_string(),
270                        extra_fields,
271                    });
272                }
273
274                // Memory set events.
275                for ms in state.memsets[ms_start..].iter() {
276                    let extra_fields: Vec<(String, String)> = vec![
277                        ("process_id".to_string(), process_id.to_string()),
278                        ("process_name".to_string(), process_name.clone()),
279                        ("device_id".to_string(), ms.device_index.to_string()),
280                    ];
281                    events.push(PendingRenderStageEvent {
282                        start_ns: ms.start_ns,
283                        end_ns: ms.end_ns,
284                        gpu_id: ms.device_index,
285                        hw_queue_iid: AMD_HW_QUEUE_IID_OFFSET,
286                        stage_iid: AMD_MEMSET_STAGE_IID,
287                        name: "Memset".to_string(),
288                        extra_fields,
289                    });
290                }
291
292                let emitted = events.len();
293                injection_log!(
294                    "emitted {} AMD render stage events (instance {})",
295                    emitted,
296                    inst_id
297                );
298
299                (events, queues)
300                // state (GLOBAL_STATE lock) dropped here
301            };
302
303            // Phase 2: Emit collected events without holding GLOBAL_STATE.
304            // This prevents deadlock with buffer_callback which also needs GLOBAL_STATE.
305            let mut stop_guard_opt = stop_guard;
306            get_renderstages_data_source().trace(|ctx: &mut TraceContext| {
307                if ctx.instance_index() != inst_id {
308                    return;
309                }
310
311                ctx.with_incremental_state(|ctx: &mut TraceContext, inc_state| {
312                    let was_cleared =
313                        std::mem::replace(&mut inc_state.was_cleared, false);
314                    let got_first =
315                        GOT_FIRST_RENDERSTAGES.fetch_or(1 << inst_id, Ordering::SeqCst);
316                    let emit_interned =
317                        was_cleared || got_first & (1 << inst_id) == 0;
318
319                    if emit_interned {
320                        ctx.add_packet(|packet: &mut TracePacket| {
321                            packet.set_sequence_flags(
322                                TracePacketSequenceFlags::SeqIncrementalStateCleared as u32,
323                            );
324                            packet.set_interned_data(|interned: &mut InternedData| {
325                                interned.set_graphics_contexts(
326                                    |gctx: &mut InternedGraphicsContext| {
327                                        gctx.set_iid(1);
328                                        gctx.set_pid(process_id);
329                                        gctx.set_api(InternedGraphicsContextApi::Hip);
330                                    },
331                                );
332                                for &queue_handle in &queues {
333                                    let iid =
334                                        (queue_handle & 0xFFFF) + AMD_HW_QUEUE_IID_OFFSET;
335                                    interned.set_gpu_specifications(
336                                        |spec: &mut InternedGpuRenderStageSpecification| {
337                                            spec.set_iid(iid);
338                                            spec.set_name(format!(
339                                                "Queue ({})",
340                                                queue_handle
341                                            ));
342                                            spec.set_category(
343                                                InternedGpuRenderStageSpecificationRenderStageCategory::Compute,
344                                            );
345                                        },
346                                    );
347                                }
348                                interned.set_gpu_specifications(
349                                    |spec: &mut InternedGpuRenderStageSpecification| {
350                                        spec.set_iid(AMD_KERNEL_STAGE_IID);
351                                        spec.set_name("Kernel");
352                                        spec.set_description("HIP Kernel");
353                                        spec.set_category(
354                                            InternedGpuRenderStageSpecificationRenderStageCategory::Compute,
355                                        );
356                                    },
357                                );
358                                interned.set_gpu_specifications(
359                                    |spec: &mut InternedGpuRenderStageSpecification| {
360                                        spec.set_iid(AMD_MEMCPY_STAGE_IID);
361                                        spec.set_name("MemoryTransfer");
362                                        spec.set_description("HIP Memory Transfer");
363                                        spec.set_category(
364                                            InternedGpuRenderStageSpecificationRenderStageCategory::Other,
365                                        );
366                                    },
367                                );
368                                interned.set_gpu_specifications(
369                                    |spec: &mut InternedGpuRenderStageSpecification| {
370                                        spec.set_iid(AMD_MEMSET_STAGE_IID);
371                                        spec.set_name("MemorySet");
372                                        spec.set_description("HIP Memory Set");
373                                        spec.set_category(
374                                            InternedGpuRenderStageSpecificationRenderStageCategory::Other,
375                                        );
376                                    },
377                                );
378                            });
379                        });
380                    }
381
382                    for event in &events {
383                        let duration_ns = event.end_ns.saturating_sub(event.start_ns);
384                        ctx.add_packet(|packet: &mut TracePacket| {
385                            packet
386                                .set_timestamp(event.start_ns)
387                                .set_timestamp_clock_id(
388                                    BuiltinClock::BuiltinClockBoottime.into(),
389                                )
390                                .set_gpu_render_stage_event(
391                                    |re: &mut GpuRenderStageEvent| {
392                                        re.set_event_id(get_next_event_id())
393                                            .set_duration(duration_ns)
394                                            .set_gpu_id(event.gpu_id)
395                                            .set_hw_queue_iid(event.hw_queue_iid)
396                                            .set_stage_iid(event.stage_iid)
397                                            .set_context(1)
398                                            .set_name(&event.name);
399                                        for (name, value) in &event.extra_fields {
400                                            re.set_extra_data(
401                                                |ed: &mut GpuRenderStageEventExtraData| {
402                                                    ed.set_name(name);
403                                                    ed.set_value(value);
404                                                },
405                                            );
406                                        }
407                                    },
408                                );
409                        });
410                    }
411                });
412
413                let mut sg = Some(stop_guard_opt.take());
414                ctx.flush(move || drop(sg.take()));
415            });
416            drop(stop_guard_opt);
417        });
418    }
419
420    fn on_first_counters_start(&self) {
421        // Set up per-agent counter configs and configure the dispatch counting
422        // service. This is deferred to first consumer start so we don't
423        // enumerate counters when only gpu.renderstages is enabled.
424        use callbacks::{dispatch_counting_callback, record_counting_callback};
425
426        let requested_metrics: Vec<String> = GLOBAL_STATE
427            .lock()
428            .ok()
429            .map(|s| s.config.metrics.clone())
430            .unwrap_or_default();
431        let agent_handles: Vec<u64> = GLOBAL_STATE
432            .lock()
433            .ok()
434            .map(|s| s.agents.keys().copied().collect())
435            .unwrap_or_default();
436
437        for &agent_handle in &agent_handles {
438            let agent_id = rocprofiler_agent_id_t {
439                handle: agent_handle,
440            };
441
442            // Enumerate available counters for this agent.
443            let mut available_counter_ids: Vec<rocprofiler_counter_id_t> = Vec::new();
444            unsafe extern "C" fn counters_cb(
445                _agent_id: rocprofiler_agent_id_t,
446                counters: *mut rocprofiler_counter_id_t,
447                num_counters: usize,
448                user_data: *mut std::os::raw::c_void,
449            ) -> rocprofiler_status_t {
450                if !counters.is_null() && num_counters > 0 {
451                    let out = &mut *(user_data as *mut Vec<rocprofiler_counter_id_t>);
452                    let slice = std::slice::from_raw_parts(counters, num_counters);
453                    out.extend_from_slice(slice);
454                }
455                ROCPROFILER_STATUS_SUCCESS
456            }
457            let status = unsafe {
458                rocprofiler_iterate_agent_supported_counters(
459                    agent_id,
460                    Some(counters_cb),
461                    &mut available_counter_ids as *mut Vec<rocprofiler_counter_id_t>
462                        as *mut std::os::raw::c_void,
463                )
464            };
465            if status != ROCPROFILER_STATUS_SUCCESS {
466                injection_log!(
467                    "agent {:#x}: rocprofiler_iterate_agent_supported_counters failed: {}",
468                    agent_handle,
469                    status
470                );
471            }
472
473            // Build name → counter_id map from available counters.
474            let mut name_to_id: std::collections::HashMap<String, rocprofiler_counter_id_t> =
475                std::collections::HashMap::new();
476            for &cid in &available_counter_ids {
477                let mut info = std::mem::MaybeUninit::<rocprofiler_counter_info_v0_t>::zeroed();
478                let status = unsafe {
479                    rocprofiler_query_counter_info(
480                        cid,
481                        ROCPROFILER_COUNTER_INFO_VERSION_0,
482                        info.as_mut_ptr() as *mut std::os::raw::c_void,
483                    )
484                };
485                if status != ROCPROFILER_STATUS_SUCCESS {
486                    continue;
487                }
488                let info = unsafe { info.assume_init() };
489                if info.name.is_null() {
490                    continue;
491                }
492                let name = unsafe { std::ffi::CStr::from_ptr(info.name) }
493                    .to_string_lossy()
494                    .into_owned();
495                name_to_id.insert(name, cid);
496            }
497
498            injection_log!(
499                "agent {:#x}: {} counters available",
500                agent_handle,
501                name_to_id.len()
502            );
503
504            // Match requested metrics against available counters.
505            let mut matched_ids: Vec<rocprofiler_counter_id_t> = Vec::new();
506            let mut matched_names: Vec<String> = Vec::new();
507            for metric in &requested_metrics {
508                if let Some(&cid) = name_to_id.get(metric) {
509                    matched_ids.push(cid);
510                    matched_names.push(metric.clone());
511                } else {
512                    injection_log!(
513                        "agent {:#x}: requested counter '{}' not available",
514                        agent_handle,
515                        metric
516                    );
517                }
518            }
519
520            if matched_ids.is_empty() {
521                injection_log!(
522                    "agent {:#x}: no matching counters found, skipping counter config",
523                    agent_handle
524                );
525                continue;
526            }
527
528            injection_log!(
529                "agent {:#x}: configuring {} counters: {:?}",
530                agent_handle,
531                matched_names.len(),
532                matched_names
533            );
534
535            // Create counter config for this agent.
536            let mut config_id = rocprofiler_counter_config_id_t { handle: 0 };
537            let status = unsafe {
538                rocprofiler_create_counter_config(
539                    agent_id,
540                    matched_ids.as_mut_ptr(),
541                    matched_ids.len(),
542                    &mut config_id,
543                )
544            };
545            if status != ROCPROFILER_STATUS_SUCCESS {
546                injection_log!(
547                    "agent {:#x}: rocprofiler_create_counter_config failed: {}",
548                    agent_handle,
549                    status
550                );
551                continue;
552            }
553
554            if let Ok(mut state) = GLOBAL_STATE.lock() {
555                state.counter_configs.insert(agent_handle, config_id.handle);
556                // Build counter_id_to_index mapping (only needed once, same for all agents).
557                if state.counter_names.is_empty() {
558                    state.counter_names = matched_names;
559                    for (idx, cid) in matched_ids.iter().enumerate() {
560                        state.counter_id_to_index.insert(cid.handle, idx);
561                    }
562                }
563            }
564        }
565
566        // Configure callback dispatch counting service on the tracing context.
567        // The context may already be started (by on_first_consumer_start), so we
568        // must stop it first — rocprofiler requires service configuration before
569        // context start.
570        let has_counter_configs = GLOBAL_STATE
571            .lock()
572            .ok()
573            .map(|s| !s.counter_configs.is_empty())
574            .unwrap_or(false);
575        if has_counter_configs {
576            let context_handle = GLOBAL_STATE.lock().ok().and_then(|s| s.tracing_context);
577            if let Some(handle) = context_handle {
578                let tracing_ctx = rocprofiler_context_id_t { handle };
579                // Stop context so we can add the counting service.
580                let _ = unsafe { rocprofiler_stop_context(tracing_ctx) };
581                let status = unsafe {
582                    rocprofiler_configure_callback_dispatch_counting_service(
583                        tracing_ctx,
584                        Some(dispatch_counting_callback),
585                        std::ptr::null_mut(),
586                        Some(record_counting_callback),
587                        std::ptr::null_mut(),
588                    )
589                };
590                if status != ROCPROFILER_STATUS_SUCCESS {
591                    injection_log!(
592                        "rocprofiler_configure_callback_dispatch_counting_service failed: {}",
593                        status
594                    );
595                }
596                // Restart context with the counting service now configured.
597                let status = unsafe { rocprofiler_start_context(tracing_ctx) };
598                if status != ROCPROFILER_STATUS_SUCCESS {
599                    injection_log!(
600                        "rocprofiler_start_context after counter config failed: {}",
601                        status
602                    );
603                }
604            }
605        }
606    }
607
608    fn register_counters_consumer(&self, inst_id: u32) {
609        if let Ok(mut state) = GLOBAL_STATE.lock() {
610            let offsets = CounterConsumerStartOffsets {
611                counter_results: state.counter_results.len(),
612            };
613            state.counters_consumers.insert(inst_id, offsets);
614        }
615    }
616
617    fn emit_counter_events_for_instance(&self, inst_id: u32, stop_guard: Option<StopGuard>) {
618        let _ = panic::catch_unwind(|| {
619            // Collected counter event data.
620            struct CollectedCounterEvent {
621                start_ns: u64,
622                end_ns: u64,
623                device_index: i32,
624                values: Vec<f64>,
625            }
626
627            // Phase 1: Collect data under GLOBAL_STATE lock, then release.
628            let (collected_events, counter_names) = {
629                let mut state = match GLOBAL_STATE.lock() {
630                    Ok(s) => s,
631                    Err(_) => return,
632                };
633                let start_offsets = if stop_guard.is_some() {
634                    match state.counters_consumers.remove(&inst_id) {
635                        Some(o) => o,
636                        None => return,
637                    }
638                } else {
639                    match state.counters_consumers.get(&inst_id).cloned() {
640                        Some(o) => o,
641                        None => return,
642                    }
643                };
644                let counter_names = state.counter_names.clone();
645                if counter_names.is_empty() {
646                    return;
647                }
648                let cr_start = start_offsets.counter_results;
649                let events: Vec<CollectedCounterEvent> = state.counter_results[cr_start..]
650                    .iter()
651                    .map(|r| CollectedCounterEvent {
652                        start_ns: r.start_ns,
653                        end_ns: r.end_ns,
654                        device_index: r.device_index,
655                        values: r.values.clone(),
656                    })
657                    .collect();
658                let emitted = events.len();
659                injection_log!("emitted {} counter events (instance {})", emitted, inst_id);
660                (events, counter_names)
661                // state (GLOBAL_STATE lock) dropped here
662            };
663
664            // Phase 2: Emit collected events without holding GLOBAL_STATE.
665            // This prevents deadlock with buffer_callback which also needs GLOBAL_STATE.
666            //
667            // Counter ID offset and multi-GPU workaround (mode 1 / legacy inline
668            // descriptors):
669            //
670            // Counter IDs use a large offset (4096) to avoid collision with the
671            // gpu-probes crate's counters (Temperature, Power, Utilization, etc.)
672            // which use small counter_id values starting at 1.
673            //
674            // For multi-GPU support we assign globally unique counter_ids per GPU:
675            //   counter_id = COUNTER_ID_OFFSET + gpu_id * num_metrics + metric_index
676            //
677            // This is needed because the trace processor's mode 1 (inline
678            // counter_descriptor) keys its gpu_counter_state_ map by counter_id
679            // alone — there is no gpu_id in the key. If two GPUs share the same
680            // counter_id, the second GPU's descriptor overwrites the first and all
681            // samples land on one GPU's track.
682            //
683            // Each GPU's first GpuCounterEvent includes an inline
684            // counter_descriptor so the trace processor creates separate tracks
685            // per GPU (the descriptor's gpu_id associates tracks with the right
686            // GPU in the hierarchy).
687            //
688            // TODO: Switch to mode 2 (interned descriptors via
689            // counter_descriptor_iid and InternedGpuCounterDescriptor) once the
690            // gpu_counter_descriptors field is available in perfetto-sdk-protos-gpu's
691            // InternedData extension. At that point, counter_ids can be simplified
692            // and the per-GPU spacing logic can be removed.
693            const COUNTER_ID_OFFSET: u32 = 4096;
694            let num_metrics = counter_names.len() as u32;
695            let mut gpus_needing_descriptors: HashSet<i32> =
696                collected_events.iter().map(|e| e.device_index).collect();
697            let mut stop_guard_opt = stop_guard;
698            get_counters_data_source().trace(|ctx: &mut TraceContext| {
699                if ctx.instance_index() != inst_id {
700                    return;
701                }
702                for result in &collected_events {
703                    let gpu_id = result.device_index;
704                    let emit_desc = gpus_needing_descriptors.remove(&gpu_id);
705                    if emit_desc {
706                        ctx.add_packet(|packet: &mut TracePacket| {
707                            packet
708                                .set_timestamp(result.start_ns)
709                                .set_timestamp_clock_id(
710                                    BuiltinClock::BuiltinClockBoottime.into(),
711                                )
712                                .set_gpu_counter_event(|event: &mut GpuCounterEvent| {
713                                    event.set_gpu_id(gpu_id).set_counter_descriptor(
714                                        |desc: &mut GpuCounterDescriptor| {
715                                            for (i, name) in counter_names.iter().enumerate() {
716                                                desc.set_specs(|spec: &mut GpuCounterDescriptorGpuCounterSpec| {
717                                                    spec.set_counter_id(
718                                                        COUNTER_ID_OFFSET + gpu_id as u32 * num_metrics + i as u32,
719                                                    );
720                                                    spec.set_name(name);
721                                                    spec.set_groups(GpuCounterDescriptorGpuCounterGroup::Compute);
722                                                });
723                                            }
724                                        },
725                                    );
726                                });
727                        });
728                    }
729                    ctx.add_packet(|packet: &mut TracePacket| {
730                        packet
731                            .set_timestamp(result.start_ns)
732                            .set_timestamp_clock_id(BuiltinClock::BuiltinClockBoottime.into())
733                            .set_gpu_counter_event(|event: &mut GpuCounterEvent| {
734                                event.set_gpu_id(gpu_id);
735                                for i in 0..counter_names.len() {
736                                    event.set_counters(
737                                        |counter: &mut GpuCounterEventGpuCounter| {
738                                            counter
739                                                .set_counter_id(
740                                                    COUNTER_ID_OFFSET + gpu_id as u32 * num_metrics + i as u32,
741                                                )
742                                                .set_int_value(0);
743                                        },
744                                    );
745                                }
746                            });
747                    });
748                    ctx.add_packet(|packet: &mut TracePacket| {
749                        packet
750                            .set_timestamp(result.end_ns)
751                            .set_timestamp_clock_id(BuiltinClock::BuiltinClockBoottime.into())
752                            .set_gpu_counter_event(|event: &mut GpuCounterEvent| {
753                                event.set_gpu_id(gpu_id);
754                                for (i, &value) in result.values.iter().enumerate() {
755                                    event.set_counters(
756                                        |counter: &mut GpuCounterEventGpuCounter| {
757                                            counter
758                                                .set_counter_id(
759                                                    COUNTER_ID_OFFSET + gpu_id as u32 * num_metrics + i as u32,
760                                                )
761                                                .set_double_value(value);
762                                        },
763                                    );
764                                }
765                            });
766                    });
767                }
768                let mut sg = Some(stop_guard_opt.take());
769                ctx.flush(move || drop(sg.take()));
770            });
771            drop(stop_guard_opt);
772        });
773    }
774
775    fn flush_renderstage_events(&self) {
776        // Force rocprofiler to deliver buffered records so that
777        // kernel_dispatches / memcopies / memsets vectors are up to date.
778        self.flush_activity_buffers();
779        let inst_ids: Vec<u32> = GLOBAL_STATE
780            .lock()
781            .map(|s| s.renderstages_consumers.keys().copied().collect())
782            .unwrap_or_default();
783        for inst_id in inst_ids {
784            self.emit_renderstage_events_for_instance(inst_id, None);
785        }
786        if let Ok(mut state) = GLOBAL_STATE.lock() {
787            state.advance_and_drain_renderstage_events();
788        }
789    }
790
791    fn flush_counter_events(&self) {
792        // Force rocprofiler to deliver buffered records so that
793        // kernel_dispatches are up to date for counter emission.
794        self.flush_activity_buffers();
795        let inst_ids: Vec<u32> = GLOBAL_STATE
796            .lock()
797            .map(|s| s.counters_consumers.keys().copied().collect())
798            .unwrap_or_default();
799        for inst_id in inst_ids {
800            self.emit_counter_events_for_instance(inst_id, None);
801        }
802        if let Ok(mut state) = GLOBAL_STATE.lock() {
803            state.advance_and_drain_counter_events();
804        }
805    }
806}
807
808// ---------------------------------------------------------------------------
809// HIP API call track event emission
810// ---------------------------------------------------------------------------
811
812// ---------------------------------------------------------------------------
813// Atexit fallback
814// ---------------------------------------------------------------------------
815
816extern "C" fn end_execution() {
817    let _ = panic::catch_unwind(|| {
818        let amd = RocprofilerBackend;
819        amd.run_teardown();
820        let (renderstage_ids, counter_ids): (Vec<u32>, Vec<u32>) = GLOBAL_STATE
821            .lock()
822            .map(|s| {
823                (
824                    s.renderstages_consumers.keys().copied().collect(),
825                    s.counters_consumers.keys().copied().collect(),
826                )
827            })
828            .unwrap_or_default();
829        for inst_id in renderstage_ids {
830            amd.emit_renderstage_events_for_instance(inst_id, None);
831        }
832        for inst_id in counter_ids {
833            amd.emit_counter_events_for_instance(inst_id, None);
834        }
835    });
836}
837
838// ---------------------------------------------------------------------------
839// AMD rocprofiler initialization helpers
840// ---------------------------------------------------------------------------
841
842/// Initialize AMD: populate agent map, create contexts and buffers.
843/// Called from `tool_initialize` during `rocprofiler_configure`.
844fn initialize_rocprofiler() -> rocprofiler_status_t {
845    use callbacks::{agents_callback, buffer_callback, code_object_callback};
846
847    // Enumerate GPU agents.
848    unsafe {
849        rocprofiler_query_available_agents(
850            ROCPROFILER_AGENT_INFO_VERSION_0,
851            Some(agents_callback),
852            std::mem::size_of::<rocprofiler_agent_v0_t>(),
853            std::ptr::null_mut(),
854        )
855    };
856
857    // Create utility context (always-on) for code object callbacks.
858    let mut utility_ctx = rocprofiler_context_id_t { handle: 0 };
859    let status = unsafe { rocprofiler_create_context(&mut utility_ctx) };
860    if status != ROCPROFILER_STATUS_SUCCESS {
861        return status;
862    }
863
864    // Configure code object kernel symbol callback on utility context.
865    let status = unsafe {
866        rocprofiler_configure_callback_tracing_service(
867            utility_ctx,
868            ROCPROFILER_CALLBACK_TRACING_CODE_OBJECT,
869            std::ptr::null_mut(),
870            0,
871            Some(code_object_callback),
872            std::ptr::null_mut(),
873        )
874    };
875    if status != ROCPROFILER_STATUS_SUCCESS {
876        return status;
877    }
878
879    // Start utility context immediately (it always runs).
880    let status = unsafe { rocprofiler_start_context(utility_ctx) };
881    if status != ROCPROFILER_STATUS_SUCCESS {
882        return status;
883    }
884
885    // Create tracing context (started/stopped when Perfetto consumers connect).
886    let mut tracing_ctx = rocprofiler_context_id_t { handle: 0 };
887    let status = unsafe { rocprofiler_create_context(&mut tracing_ctx) };
888    if status != ROCPROFILER_STATUS_SUCCESS {
889        return status;
890    }
891
892    // Create buffer for kernel dispatch and memory copy records.
893    let mut buffer_id = rocprofiler_buffer_id_t { handle: 0 };
894    let status = unsafe {
895        rocprofiler_create_buffer(
896            tracing_ctx,
897            4 * 1024 * 1024, // 4 MiB buffer
898            0,               // flush on every record
899            ROCPROFILER_BUFFER_POLICY_LOSSLESS,
900            Some(buffer_callback),
901            std::ptr::null_mut(),
902            &mut buffer_id,
903        )
904    };
905    if status != ROCPROFILER_STATUS_SUCCESS {
906        return status;
907    }
908
909    // Configure kernel dispatch buffer tracing.
910    let status = unsafe {
911        rocprofiler_configure_buffer_tracing_service(
912            tracing_ctx,
913            ROCPROFILER_BUFFER_TRACING_KERNEL_DISPATCH,
914            std::ptr::null_mut(),
915            0,
916            buffer_id,
917        )
918    };
919    if status != ROCPROFILER_STATUS_SUCCESS {
920        return status;
921    }
922
923    // Configure memory copy buffer tracing.
924    let status = unsafe {
925        rocprofiler_configure_buffer_tracing_service(
926            tracing_ctx,
927            ROCPROFILER_BUFFER_TRACING_MEMORY_COPY,
928            std::ptr::null_mut(),
929            0,
930            buffer_id,
931        )
932    };
933    if status != ROCPROFILER_STATUS_SUCCESS {
934        return status;
935    }
936
937    // Configure HIP runtime API buffer tracing. Records are always collected;
938    // emission is gated on the "hip" track event category being enabled.
939    let status = unsafe {
940        rocprofiler_configure_buffer_tracing_service(
941            tracing_ctx,
942            ROCPROFILER_BUFFER_TRACING_HIP_RUNTIME_API,
943            std::ptr::null_mut(),
944            0,
945            buffer_id,
946        )
947    };
948    if status != ROCPROFILER_STATUS_SUCCESS {
949        injection_log!(
950            "rocprofiler_configure_buffer_tracing_service (HIP_RUNTIME_API) failed: {}",
951            status
952        );
953        // Non-fatal: API call tracing is optional.
954    }
955
956    // Store context and buffer handles in global state.
957    if let Ok(mut state) = GLOBAL_STATE.lock() {
958        state.utility_context = Some(utility_ctx.handle);
959        state.tracing_context = Some(tracing_ctx.handle);
960        state.tracing_buffer = Some(buffer_id.handle);
961    }
962
963    ROCPROFILER_STATUS_SUCCESS
964}
965
966// ---------------------------------------------------------------------------
967// Public C entry point
968// ---------------------------------------------------------------------------
969
970/// AMD registration entry point (called by ROCm runtime at library load).
971///
972/// Returns a static `rocprofiler_tool_configure_result_t` with `initialize`
973/// and `finalize` function pointers.
974#[no_mangle]
975#[allow(clippy::not_unsafe_ptr_arg_deref)]
976pub extern "C" fn rocprofiler_configure(
977    _version: u32,
978    _runtime_version: *const std::os::raw::c_char,
979    _priority: u32,
980    client_id: *mut rocprofiler_client_id_t,
981) -> *mut rocprofiler_tool_configure_result_t {
982    let _ = panic::catch_unwind(|| {
983        if !client_id.is_null() {
984            unsafe {
985                (*client_id).name = c"perfetto-hip-injection".as_ptr();
986            }
987        }
988    });
989
990    unsafe extern "C" fn tool_initialize(
991        _finalize_func: rocprofiler_client_finalize_t,
992        _tool_data: *mut std::os::raw::c_void,
993    ) -> i32 {
994        let result = panic::catch_unwind(|| {
995            use perfetto_gpu_compute_injection::config::Config;
996            use perfetto_sdk::producer::{Backends, Producer, ProducerInitArgsBuilder};
997
998            let config = Config::from_env();
999
1000            // Parse metrics from environment.
1001            let metrics_str = std::env::var("INJECTION_METRICS").unwrap_or_default();
1002            let metrics = metrics::parse_metrics(&metrics_str);
1003
1004            if let Ok(mut state) = crate::state::GLOBAL_STATE.lock() {
1005                if !state.initialized {
1006                    state.initialized = true;
1007                    state.config = config;
1008                    state.config.metrics = metrics;
1009                }
1010            }
1011
1012            // Set up rocprofiler contexts, buffers, and services before
1013            // registering Perfetto data sources. Data source on_start
1014            // callbacks need the tracing context to be available.
1015            let status = initialize_rocprofiler();
1016            if status != ROCPROFILER_STATUS_SUCCESS {
1017                injection_log!("rocprofiler initialization failed: {}", status);
1018                return -1;
1019            }
1020
1021            register_backend(RocprofilerBackend);
1022
1023            // Initialize Perfetto producer and register data sources.
1024            let producer_args = ProducerInitArgsBuilder::new().backends(Backends::SYSTEM);
1025            Producer::init(producer_args.build());
1026            let _ = get_renderstages_data_source();
1027            let _ = get_counters_data_source();
1028
1029            // Initialize track event categories for HIP API call tracing.
1030            // HIP API buffer tracing is configured in initialize_rocprofiler();
1031            // the category callback only controls whether events are emitted.
1032            TrackEvent::init();
1033            let _ = perfetto_te_ns::register();
1034
1035            unsafe { libc::atexit(end_execution) };
1036
1037            injection_log!("AMD rocprofiler tool initialized");
1038            0
1039        });
1040        result.unwrap_or(-1)
1041    }
1042
1043    unsafe extern "C" fn tool_finalize(_tool_data: *mut std::os::raw::c_void) {
1044        let _ = panic::catch_unwind(|| {
1045            injection_log!("AMD rocprofiler tool finalizing");
1046        });
1047    }
1048
1049    static mut CONFIGURE_RESULT: rocprofiler_tool_configure_result_t =
1050        rocprofiler_tool_configure_result_t {
1051            size: std::mem::size_of::<rocprofiler_tool_configure_result_t>(),
1052            initialize: Some(tool_initialize),
1053            finalize: Some(tool_finalize),
1054            tool_data: std::ptr::null_mut(),
1055        };
1056
1057    std::ptr::addr_of_mut!(CONFIGURE_RESULT)
1058}