1use std::collections::HashMap;
33use std::ffi::{c_char, c_void, CStr};
34use std::sync::{Arc, Mutex, PoisonError};
35use std::time::Instant;
36
37#[derive(Debug, Clone)]
43pub struct OpEvent {
44 pub op_name: String,
46 pub op_idx: i64,
48 pub subgraph_idx: i64,
50 pub duration_us: u64,
52}
53
54#[repr(C)]
65struct TfLiteTelemetryProfilerStruct {
66 data: *mut c_void,
67
68 report_telemetry_event: Option<
69 unsafe extern "C" fn(
70 profiler: *mut TfLiteTelemetryProfilerStruct,
71 event_name: *const c_char,
72 status: u64,
73 ),
74 >,
75
76 report_telemetry_op_event: Option<
77 unsafe extern "C" fn(
78 profiler: *mut TfLiteTelemetryProfilerStruct,
79 event_name: *const c_char,
80 op_idx: i64,
81 subgraph_idx: i64,
82 status: u64,
83 ),
84 >,
85
86 report_settings: Option<
87 unsafe extern "C" fn(
88 profiler: *mut TfLiteTelemetryProfilerStruct,
89 setting_name: *const c_char,
90 settings: *const c_void,
91 ),
92 >,
93
94 report_begin_op_invoke_event: Option<
95 unsafe extern "C" fn(
96 profiler: *mut TfLiteTelemetryProfilerStruct,
97 op_name: *const c_char,
98 op_idx: i64,
99 subgraph_idx: i64,
100 ) -> u32,
101 >,
102
103 report_end_op_invoke_event: Option<
104 unsafe extern "C" fn(profiler: *mut TfLiteTelemetryProfilerStruct, event_handle: u32),
105 >,
106
107 report_op_invoke_event: Option<
108 unsafe extern "C" fn(
109 profiler: *mut TfLiteTelemetryProfilerStruct,
110 op_name: *const c_char,
111 elapsed_time: u64,
112 op_idx: i64,
113 subgraph_idx: i64,
114 ),
115 >,
116}
117
118unsafe extern "C" fn report_telemetry_event_noop(
124 _profiler: *mut TfLiteTelemetryProfilerStruct,
125 _event_name: *const c_char,
126 _status: u64,
127) {
128}
129
130unsafe extern "C" fn report_telemetry_op_event_noop(
132 _profiler: *mut TfLiteTelemetryProfilerStruct,
133 _event_name: *const c_char,
134 _op_idx: i64,
135 _subgraph_idx: i64,
136 _status: u64,
137) {
138}
139
140unsafe extern "C" fn report_settings_noop(
142 _profiler: *mut TfLiteTelemetryProfilerStruct,
143 _setting_name: *const c_char,
144 _settings: *const c_void,
145) {
146}
147
148unsafe fn inner_from_profiler<'a>(
157 profiler: *mut TfLiteTelemetryProfilerStruct,
158) -> &'a Arc<Mutex<ProfilerInner>> {
159 unsafe { &*((*profiler).data.cast::<Arc<Mutex<ProfilerInner>>>()) }
162}
163
164unsafe extern "C" fn report_begin_op_invoke(
173 profiler: *mut TfLiteTelemetryProfilerStruct,
174 op_name: *const c_char,
175 op_idx: i64,
176 subgraph_idx: i64,
177) -> u32 {
178 let inner = unsafe { inner_from_profiler(profiler) };
180 let mut guard = inner.lock().unwrap_or_else(PoisonError::into_inner);
181 let handle = guard.next_handle;
182 guard.next_handle = guard.next_handle.wrapping_add(1);
183 let name = unsafe { CStr::from_ptr(op_name) }
185 .to_string_lossy()
186 .into_owned();
187 guard
188 .pending
189 .insert(handle, (name, op_idx, subgraph_idx, Instant::now()));
190 handle
191}
192
193unsafe extern "C" fn report_end_op_invoke(
201 profiler: *mut TfLiteTelemetryProfilerStruct,
202 event_handle: u32,
203) {
204 let inner = unsafe { inner_from_profiler(profiler) };
206 let mut guard = inner.lock().unwrap_or_else(PoisonError::into_inner);
207 if let Some((op_name, op_idx, subgraph_idx, start)) = guard.pending.remove(&event_handle) {
208 #[allow(clippy::cast_possible_truncation)]
209 let duration_us = start.elapsed().as_micros() as u64;
210 guard.events.push(OpEvent {
211 op_name,
212 op_idx,
213 subgraph_idx,
214 duration_us,
215 });
216 }
217}
218
219unsafe extern "C" fn report_op_invoke_event(
228 profiler: *mut TfLiteTelemetryProfilerStruct,
229 op_name: *const c_char,
230 elapsed_time: u64,
231 op_idx: i64,
232 subgraph_idx: i64,
233) {
234 let inner = unsafe { inner_from_profiler(profiler) };
236 let mut guard = inner.lock().unwrap_or_else(PoisonError::into_inner);
237 let name = unsafe { CStr::from_ptr(op_name) }
239 .to_string_lossy()
240 .into_owned();
241 guard.events.push(OpEvent {
242 op_name: name,
243 op_idx,
244 subgraph_idx,
245 duration_us: elapsed_time,
246 });
247}
248
249struct ProfilerInner {
255 events: Vec<OpEvent>,
257 pending: HashMap<u32, (String, i64, i64, Instant)>,
259 next_handle: u32,
261}
262
263pub struct Profiler {
300 inner: Arc<Mutex<ProfilerInner>>,
302 c_struct: Box<TfLiteTelemetryProfilerStruct>,
305 data_ptr: *mut Arc<Mutex<ProfilerInner>>,
308}
309
310unsafe impl Send for Profiler {}
315unsafe impl Sync for Profiler {}
317
318impl Profiler {
319 #[must_use]
321 pub fn new() -> Self {
322 let inner = Arc::new(Mutex::new(ProfilerInner {
323 events: Vec::new(),
324 pending: HashMap::new(),
325 next_handle: 0,
326 }));
327
328 let data_box = Box::new(inner.clone());
331 let data_ptr = Box::into_raw(data_box);
332
333 let c_struct = Box::new(TfLiteTelemetryProfilerStruct {
334 data: data_ptr.cast::<c_void>(),
335 report_telemetry_event: Some(report_telemetry_event_noop),
336 report_telemetry_op_event: Some(report_telemetry_op_event_noop),
337 report_settings: Some(report_settings_noop),
338 report_begin_op_invoke_event: Some(report_begin_op_invoke),
339 report_end_op_invoke_event: Some(report_end_op_invoke),
340 report_op_invoke_event: Some(report_op_invoke_event),
341 });
342
343 Self {
344 inner,
345 c_struct,
346 data_ptr,
347 }
348 }
349
350 #[must_use]
352 pub fn events(&self) -> Vec<OpEvent> {
353 let guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
354 guard.events.clone()
355 }
356
357 #[must_use]
359 pub fn drain_events(&self) -> Vec<OpEvent> {
360 let mut guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
361 std::mem::take(&mut guard.events)
362 }
363
364 pub fn clear(&self) {
366 let mut guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
367 guard.events.clear();
368 guard.pending.clear();
369 guard.next_handle = 0;
370 }
371
372 #[must_use]
374 pub fn event_count(&self) -> usize {
375 let guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
376 guard.events.len()
377 }
378
379 pub(crate) fn as_ptr(&self) -> *mut c_void {
385 (self.c_struct.as_ref() as *const TfLiteTelemetryProfilerStruct)
386 .cast_mut()
387 .cast()
388 }
389}
390
391impl Default for Profiler {
392 fn default() -> Self {
393 Self::new()
394 }
395}
396
397impl Drop for Profiler {
398 fn drop(&mut self) {
399 unsafe {
403 drop(Box::from_raw(self.data_ptr));
404 }
405 }
406}
407
408#[allow(clippy::missing_fields_in_debug)]
409impl std::fmt::Debug for Profiler {
410 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411 let guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
412 f.debug_struct("Profiler")
413 .field("events", &guard.events.len())
414 .field("pending", &guard.pending.len())
415 .finish_non_exhaustive()
416 }
417}
418
419#[cfg(test)]
424mod tests {
425 use super::*;
426
427 fn c_struct_ptr(profiler: &Profiler) -> *mut TfLiteTelemetryProfilerStruct {
430 profiler.as_ptr().cast()
431 }
432
433 #[test]
434 fn new_profiler_has_no_events() {
435 let profiler = Profiler::new();
436 assert!(profiler.events().is_empty());
437 assert_eq!(profiler.event_count(), 0);
438 }
439
440 #[test]
441 fn default_matches_new() {
442 let profiler = Profiler::default();
443 assert!(profiler.events().is_empty());
444 }
445
446 #[test]
447 fn clear_resets_state() {
448 let profiler = Profiler::new();
449 {
451 let mut guard = profiler.inner.lock().unwrap();
452 guard.events.push(OpEvent {
453 op_name: "TEST_OP".to_string(),
454 op_idx: 0,
455 subgraph_idx: 0,
456 duration_us: 100,
457 });
458 }
459 assert_eq!(profiler.event_count(), 1);
460 profiler.clear();
461 assert_eq!(profiler.event_count(), 0);
462 }
463
464 #[test]
465 fn drain_events_empties_list() {
466 let profiler = Profiler::new();
467 {
468 let mut guard = profiler.inner.lock().unwrap();
469 guard.events.push(OpEvent {
470 op_name: "OP_A".to_string(),
471 op_idx: 1,
472 subgraph_idx: 0,
473 duration_us: 50,
474 });
475 guard.events.push(OpEvent {
476 op_name: "OP_B".to_string(),
477 op_idx: 2,
478 subgraph_idx: 0,
479 duration_us: 75,
480 });
481 }
482 let drained = profiler.drain_events();
483 assert_eq!(drained.len(), 2);
484 assert!(profiler.events().is_empty());
485 }
486
487 #[test]
488 fn events_returns_snapshot() {
489 let profiler = Profiler::new();
490 {
491 let mut guard = profiler.inner.lock().unwrap();
492 guard.events.push(OpEvent {
493 op_name: "CONV2D".to_string(),
494 op_idx: 0,
495 subgraph_idx: 0,
496 duration_us: 200,
497 });
498 }
499 let events = profiler.events();
500 assert_eq!(events.len(), 1);
501 assert_eq!(events[0].op_name, "CONV2D");
502 assert_eq!(events[0].duration_us, 200);
503 assert_eq!(profiler.event_count(), 1);
505 }
506
507 #[test]
508 fn debug_format() {
509 let profiler = Profiler::new();
510 let debug = format!("{profiler:?}");
511 assert!(debug.contains("Profiler"));
512 assert!(debug.contains("events"));
513 }
514
515 #[test]
516 fn op_event_debug_clone() {
517 let event = OpEvent {
518 op_name: "SOFTMAX".to_string(),
519 op_idx: 3,
520 subgraph_idx: 0,
521 duration_us: 42,
522 };
523 let cloned = event.clone();
524 assert_eq!(cloned.op_name, "SOFTMAX");
525 assert_eq!(cloned.op_idx, 3);
526 assert_eq!(cloned.duration_us, 42);
527 let debug = format!("{event:?}");
528 assert!(debug.contains("SOFTMAX"));
529 }
530
531 #[test]
532 fn profiler_is_send_and_sync() {
533 fn assert_send_sync<T: Send + Sync>() {}
534 assert_send_sync::<Profiler>();
535 }
536
537 #[test]
538 fn c_struct_pointer_is_stable() {
539 let profiler = Profiler::new();
540 let ptr1 = profiler.as_ptr();
541 let ptr2 = profiler.as_ptr();
542 assert_eq!(ptr1, ptr2, "C struct pointer must be stable (boxed)");
543 }
544
545 #[test]
546 fn begin_end_callback_round_trip() {
547 let profiler = Profiler::new();
548 let c_ptr = c_struct_ptr(&profiler);
549
550 let op_name = CStr::from_bytes_with_nul(b"TEST_OP\0").unwrap();
551
552 let handle = unsafe {
555 ((*c_ptr).report_begin_op_invoke_event.unwrap())(c_ptr, op_name.as_ptr(), 5, 0)
556 };
557 std::thread::sleep(std::time::Duration::from_micros(10));
559 unsafe {
560 ((*c_ptr).report_end_op_invoke_event.unwrap())(c_ptr, handle);
561 }
562
563 let events = profiler.events();
564 assert_eq!(events.len(), 1);
565 assert_eq!(events[0].op_name, "TEST_OP");
566 assert_eq!(events[0].op_idx, 5);
567 assert_eq!(events[0].subgraph_idx, 0);
568 assert!(events[0].duration_us > 0);
570 }
571
572 #[test]
573 fn self_reported_op_invoke_callback() {
574 let profiler = Profiler::new();
575 let c_ptr = c_struct_ptr(&profiler);
576
577 let op_name = CStr::from_bytes_with_nul(b"DELEGATE_OP\0").unwrap();
578
579 unsafe {
581 ((*c_ptr).report_op_invoke_event.unwrap())(c_ptr, op_name.as_ptr(), 1234, 2, 1);
582 }
583
584 let events = profiler.events();
585 assert_eq!(events.len(), 1);
586 assert_eq!(events[0].op_name, "DELEGATE_OP");
587 assert_eq!(events[0].duration_us, 1234);
588 assert_eq!(events[0].op_idx, 2);
589 assert_eq!(events[0].subgraph_idx, 1);
590 }
591
592 #[test]
593 fn end_with_unknown_handle_is_ignored() {
594 let profiler = Profiler::new();
595 let c_ptr = c_struct_ptr(&profiler);
596
597 unsafe {
599 ((*c_ptr).report_end_op_invoke_event.unwrap())(c_ptr, 999);
600 }
601
602 assert!(profiler.events().is_empty());
603 }
604}