use std::collections::HashMap;
use std::ffi::{c_char, c_void, CStr};
use std::sync::{Arc, Mutex, PoisonError};
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct OpEvent {
pub op_name: String,
pub op_idx: i64,
pub subgraph_idx: i64,
pub duration_us: u64,
}
#[repr(C)]
struct TfLiteTelemetryProfilerStruct {
data: *mut c_void,
report_telemetry_event: Option<
unsafe extern "C" fn(
profiler: *mut TfLiteTelemetryProfilerStruct,
event_name: *const c_char,
status: u64,
),
>,
report_telemetry_op_event: Option<
unsafe extern "C" fn(
profiler: *mut TfLiteTelemetryProfilerStruct,
event_name: *const c_char,
op_idx: i64,
subgraph_idx: i64,
status: u64,
),
>,
report_settings: Option<
unsafe extern "C" fn(
profiler: *mut TfLiteTelemetryProfilerStruct,
setting_name: *const c_char,
settings: *const c_void,
),
>,
report_begin_op_invoke_event: Option<
unsafe extern "C" fn(
profiler: *mut TfLiteTelemetryProfilerStruct,
op_name: *const c_char,
op_idx: i64,
subgraph_idx: i64,
) -> u32,
>,
report_end_op_invoke_event: Option<
unsafe extern "C" fn(profiler: *mut TfLiteTelemetryProfilerStruct, event_handle: u32),
>,
report_op_invoke_event: Option<
unsafe extern "C" fn(
profiler: *mut TfLiteTelemetryProfilerStruct,
op_name: *const c_char,
elapsed_time: u64,
op_idx: i64,
subgraph_idx: i64,
),
>,
}
unsafe extern "C" fn report_telemetry_event_noop(
_profiler: *mut TfLiteTelemetryProfilerStruct,
_event_name: *const c_char,
_status: u64,
) {
}
unsafe extern "C" fn report_telemetry_op_event_noop(
_profiler: *mut TfLiteTelemetryProfilerStruct,
_event_name: *const c_char,
_op_idx: i64,
_subgraph_idx: i64,
_status: u64,
) {
}
unsafe extern "C" fn report_settings_noop(
_profiler: *mut TfLiteTelemetryProfilerStruct,
_setting_name: *const c_char,
_settings: *const c_void,
) {
}
unsafe fn inner_from_profiler<'a>(
profiler: *mut TfLiteTelemetryProfilerStruct,
) -> &'a Arc<Mutex<ProfilerInner>> {
unsafe { &*((*profiler).data.cast::<Arc<Mutex<ProfilerInner>>>()) }
}
unsafe extern "C" fn report_begin_op_invoke(
profiler: *mut TfLiteTelemetryProfilerStruct,
op_name: *const c_char,
op_idx: i64,
subgraph_idx: i64,
) -> u32 {
let inner = unsafe { inner_from_profiler(profiler) };
let mut guard = inner.lock().unwrap_or_else(PoisonError::into_inner);
let handle = guard.next_handle;
guard.next_handle = guard.next_handle.wrapping_add(1);
let name = unsafe { CStr::from_ptr(op_name) }
.to_string_lossy()
.into_owned();
guard
.pending
.insert(handle, (name, op_idx, subgraph_idx, Instant::now()));
handle
}
unsafe extern "C" fn report_end_op_invoke(
profiler: *mut TfLiteTelemetryProfilerStruct,
event_handle: u32,
) {
let inner = unsafe { inner_from_profiler(profiler) };
let mut guard = inner.lock().unwrap_or_else(PoisonError::into_inner);
if let Some((op_name, op_idx, subgraph_idx, start)) = guard.pending.remove(&event_handle) {
#[allow(clippy::cast_possible_truncation)]
let duration_us = start.elapsed().as_micros() as u64;
guard.events.push(OpEvent {
op_name,
op_idx,
subgraph_idx,
duration_us,
});
}
}
unsafe extern "C" fn report_op_invoke_event(
profiler: *mut TfLiteTelemetryProfilerStruct,
op_name: *const c_char,
elapsed_time: u64,
op_idx: i64,
subgraph_idx: i64,
) {
let inner = unsafe { inner_from_profiler(profiler) };
let mut guard = inner.lock().unwrap_or_else(PoisonError::into_inner);
let name = unsafe { CStr::from_ptr(op_name) }
.to_string_lossy()
.into_owned();
guard.events.push(OpEvent {
op_name: name,
op_idx,
subgraph_idx,
duration_us: elapsed_time,
});
}
struct ProfilerInner {
events: Vec<OpEvent>,
pending: HashMap<u32, (String, i64, i64, Instant)>,
next_handle: u32,
}
pub struct Profiler {
inner: Arc<Mutex<ProfilerInner>>,
c_struct: Box<TfLiteTelemetryProfilerStruct>,
data_ptr: *mut Arc<Mutex<ProfilerInner>>,
}
unsafe impl Send for Profiler {}
unsafe impl Sync for Profiler {}
impl Profiler {
#[must_use]
pub fn new() -> Self {
let inner = Arc::new(Mutex::new(ProfilerInner {
events: Vec::new(),
pending: HashMap::new(),
next_handle: 0,
}));
let data_box = Box::new(inner.clone());
let data_ptr = Box::into_raw(data_box);
let c_struct = Box::new(TfLiteTelemetryProfilerStruct {
data: data_ptr.cast::<c_void>(),
report_telemetry_event: Some(report_telemetry_event_noop),
report_telemetry_op_event: Some(report_telemetry_op_event_noop),
report_settings: Some(report_settings_noop),
report_begin_op_invoke_event: Some(report_begin_op_invoke),
report_end_op_invoke_event: Some(report_end_op_invoke),
report_op_invoke_event: Some(report_op_invoke_event),
});
Self {
inner,
c_struct,
data_ptr,
}
}
#[must_use]
pub fn events(&self) -> Vec<OpEvent> {
let guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
guard.events.clone()
}
#[must_use]
pub fn drain_events(&self) -> Vec<OpEvent> {
let mut guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
std::mem::take(&mut guard.events)
}
pub fn clear(&self) {
let mut guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
guard.events.clear();
guard.pending.clear();
guard.next_handle = 0;
}
#[must_use]
pub fn event_count(&self) -> usize {
let guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
guard.events.len()
}
pub(crate) fn as_ptr(&self) -> *mut c_void {
(self.c_struct.as_ref() as *const TfLiteTelemetryProfilerStruct)
.cast_mut()
.cast()
}
}
impl Default for Profiler {
fn default() -> Self {
Self::new()
}
}
impl Drop for Profiler {
fn drop(&mut self) {
unsafe {
drop(Box::from_raw(self.data_ptr));
}
}
}
#[allow(clippy::missing_fields_in_debug)]
impl std::fmt::Debug for Profiler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
f.debug_struct("Profiler")
.field("events", &guard.events.len())
.field("pending", &guard.pending.len())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn c_struct_ptr(profiler: &Profiler) -> *mut TfLiteTelemetryProfilerStruct {
profiler.as_ptr().cast()
}
#[test]
fn new_profiler_has_no_events() {
let profiler = Profiler::new();
assert!(profiler.events().is_empty());
assert_eq!(profiler.event_count(), 0);
}
#[test]
fn default_matches_new() {
let profiler = Profiler::default();
assert!(profiler.events().is_empty());
}
#[test]
fn clear_resets_state() {
let profiler = Profiler::new();
{
let mut guard = profiler.inner.lock().unwrap();
guard.events.push(OpEvent {
op_name: "TEST_OP".to_string(),
op_idx: 0,
subgraph_idx: 0,
duration_us: 100,
});
}
assert_eq!(profiler.event_count(), 1);
profiler.clear();
assert_eq!(profiler.event_count(), 0);
}
#[test]
fn drain_events_empties_list() {
let profiler = Profiler::new();
{
let mut guard = profiler.inner.lock().unwrap();
guard.events.push(OpEvent {
op_name: "OP_A".to_string(),
op_idx: 1,
subgraph_idx: 0,
duration_us: 50,
});
guard.events.push(OpEvent {
op_name: "OP_B".to_string(),
op_idx: 2,
subgraph_idx: 0,
duration_us: 75,
});
}
let drained = profiler.drain_events();
assert_eq!(drained.len(), 2);
assert!(profiler.events().is_empty());
}
#[test]
fn events_returns_snapshot() {
let profiler = Profiler::new();
{
let mut guard = profiler.inner.lock().unwrap();
guard.events.push(OpEvent {
op_name: "CONV2D".to_string(),
op_idx: 0,
subgraph_idx: 0,
duration_us: 200,
});
}
let events = profiler.events();
assert_eq!(events.len(), 1);
assert_eq!(events[0].op_name, "CONV2D");
assert_eq!(events[0].duration_us, 200);
assert_eq!(profiler.event_count(), 1);
}
#[test]
fn debug_format() {
let profiler = Profiler::new();
let debug = format!("{profiler:?}");
assert!(debug.contains("Profiler"));
assert!(debug.contains("events"));
}
#[test]
fn op_event_debug_clone() {
let event = OpEvent {
op_name: "SOFTMAX".to_string(),
op_idx: 3,
subgraph_idx: 0,
duration_us: 42,
};
let cloned = event.clone();
assert_eq!(cloned.op_name, "SOFTMAX");
assert_eq!(cloned.op_idx, 3);
assert_eq!(cloned.duration_us, 42);
let debug = format!("{event:?}");
assert!(debug.contains("SOFTMAX"));
}
#[test]
fn profiler_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Profiler>();
}
#[test]
fn c_struct_pointer_is_stable() {
let profiler = Profiler::new();
let ptr1 = profiler.as_ptr();
let ptr2 = profiler.as_ptr();
assert_eq!(ptr1, ptr2, "C struct pointer must be stable (boxed)");
}
#[test]
fn begin_end_callback_round_trip() {
let profiler = Profiler::new();
let c_ptr = c_struct_ptr(&profiler);
let op_name = CStr::from_bytes_with_nul(b"TEST_OP\0").unwrap();
let handle = unsafe {
((*c_ptr).report_begin_op_invoke_event.unwrap())(c_ptr, op_name.as_ptr(), 5, 0)
};
std::thread::sleep(std::time::Duration::from_micros(10));
unsafe {
((*c_ptr).report_end_op_invoke_event.unwrap())(c_ptr, handle);
}
let events = profiler.events();
assert_eq!(events.len(), 1);
assert_eq!(events[0].op_name, "TEST_OP");
assert_eq!(events[0].op_idx, 5);
assert_eq!(events[0].subgraph_idx, 0);
assert!(events[0].duration_us > 0);
}
#[test]
fn self_reported_op_invoke_callback() {
let profiler = Profiler::new();
let c_ptr = c_struct_ptr(&profiler);
let op_name = CStr::from_bytes_with_nul(b"DELEGATE_OP\0").unwrap();
unsafe {
((*c_ptr).report_op_invoke_event.unwrap())(c_ptr, op_name.as_ptr(), 1234, 2, 1);
}
let events = profiler.events();
assert_eq!(events.len(), 1);
assert_eq!(events[0].op_name, "DELEGATE_OP");
assert_eq!(events[0].duration_us, 1234);
assert_eq!(events[0].op_idx, 2);
assert_eq!(events[0].subgraph_idx, 1);
}
#[test]
fn end_with_unknown_handle_is_ignored() {
let profiler = Profiler::new();
let c_ptr = c_struct_ptr(&profiler);
unsafe {
((*c_ptr).report_end_op_invoke_event.unwrap())(c_ptr, 999);
}
assert!(profiler.events().is_empty());
}
}