Skip to main content

edgefirst_tflite/
profiler.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Op-level profiler using the `TFLite` telemetry profiler C API.
5//!
6//! The [`Profiler`] collects per-operator timing events during inference.
7//! Attach it to an [`InterpreterBuilder`](crate::InterpreterBuilder) before
8//! building, then read events after [`Interpreter::invoke`](crate::Interpreter::invoke).
9//!
10//! # Example
11//!
12//! ```no_run
13//! use edgefirst_tflite::{Library, Model, Interpreter, Profiler};
14//!
15//! let lib = Library::new()?;
16//! let model = Model::from_file(&lib, "model.tflite")?;
17//! let profiler = Profiler::new();
18//!
19//! let mut interp = Interpreter::builder(&lib)?
20//!     .profiler(&profiler)?
21//!     .build(&model)?;
22//!
23//! interp.invoke()?;
24//!
25//! for event in profiler.events() {
26//!     println!("{}: {}us (op={}, subgraph={})",
27//!         event.op_name, event.duration_us, event.op_idx, event.subgraph_idx);
28//! }
29//! # Ok::<(), edgefirst_tflite::Error>(())
30//! ```
31
32use std::collections::HashMap;
33use std::ffi::{c_char, c_void, CStr};
34use std::sync::{Arc, Mutex, PoisonError};
35use std::time::Instant;
36
37// ---------------------------------------------------------------------------
38// OpEvent
39// ---------------------------------------------------------------------------
40
41/// A recorded per-op timing event from a single inference invocation.
42#[derive(Debug, Clone)]
43pub struct OpEvent {
44    /// Operator name (e.g., `NeutronDelegate`, `SOFTMAX`, `Transpose`).
45    pub op_name: String,
46    /// Operator index in the subgraph.
47    pub op_idx: i64,
48    /// Subgraph index.
49    pub subgraph_idx: i64,
50    /// Duration in microseconds.
51    pub duration_us: u64,
52}
53
54// ---------------------------------------------------------------------------
55// TfLiteTelemetryProfilerStruct (C ABI)
56// ---------------------------------------------------------------------------
57
58/// C-compatible telemetry profiler struct matching
59/// `TfLiteTelemetryProfilerStruct` from
60/// `tensorflow/lite/profiling/telemetry/c/profiler.h`.
61///
62/// The `data` field points to user-owned state. Function pointers are
63/// called by the `TFLite` runtime during inference to report events.
64#[repr(C)]
65struct TfLiteTelemetryProfilerStruct {
66    data: *mut c_void,
67
68    report_telemetry_event: Option<
69        unsafe extern "C" fn(
70            profiler: *mut TfLiteTelemetryProfilerStruct,
71            event_name: *const c_char,
72            status: u64,
73        ),
74    >,
75
76    report_telemetry_op_event: Option<
77        unsafe extern "C" fn(
78            profiler: *mut TfLiteTelemetryProfilerStruct,
79            event_name: *const c_char,
80            op_idx: i64,
81            subgraph_idx: i64,
82            status: u64,
83        ),
84    >,
85
86    report_settings: Option<
87        unsafe extern "C" fn(
88            profiler: *mut TfLiteTelemetryProfilerStruct,
89            setting_name: *const c_char,
90            settings: *const c_void,
91        ),
92    >,
93
94    report_begin_op_invoke_event: Option<
95        unsafe extern "C" fn(
96            profiler: *mut TfLiteTelemetryProfilerStruct,
97            op_name: *const c_char,
98            op_idx: i64,
99            subgraph_idx: i64,
100        ) -> u32,
101    >,
102
103    report_end_op_invoke_event: Option<
104        unsafe extern "C" fn(profiler: *mut TfLiteTelemetryProfilerStruct, event_handle: u32),
105    >,
106
107    report_op_invoke_event: Option<
108        unsafe extern "C" fn(
109            profiler: *mut TfLiteTelemetryProfilerStruct,
110            op_name: *const c_char,
111            elapsed_time: u64,
112            op_idx: i64,
113            subgraph_idx: i64,
114        ),
115    >,
116}
117
118// ---------------------------------------------------------------------------
119// C callbacks
120// ---------------------------------------------------------------------------
121
122/// No-op callback for `ReportTelemetryEvent`.
123unsafe extern "C" fn report_telemetry_event_noop(
124    _profiler: *mut TfLiteTelemetryProfilerStruct,
125    _event_name: *const c_char,
126    _status: u64,
127) {
128}
129
130/// No-op callback for `ReportTelemetryOpEvent`.
131unsafe extern "C" fn report_telemetry_op_event_noop(
132    _profiler: *mut TfLiteTelemetryProfilerStruct,
133    _event_name: *const c_char,
134    _op_idx: i64,
135    _subgraph_idx: i64,
136    _status: u64,
137) {
138}
139
140/// No-op callback for `ReportSettings`.
141unsafe extern "C" fn report_settings_noop(
142    _profiler: *mut TfLiteTelemetryProfilerStruct,
143    _setting_name: *const c_char,
144    _settings: *const c_void,
145) {
146}
147
148/// Recover an `&Arc<Mutex<ProfilerInner>>` from the C struct's `data` field.
149///
150/// # Safety
151///
152/// The `data` field of `*profiler` must point to a live, heap-allocated
153/// `Arc<Mutex<ProfilerInner>>` (created via `Box::into_raw` in
154/// [`Profiler::new`]). The returned reference is only valid for the
155/// caller-chosen lifetime and must be kept local to the callback.
156unsafe fn inner_from_profiler<'a>(
157    profiler: *mut TfLiteTelemetryProfilerStruct,
158) -> &'a Arc<Mutex<ProfilerInner>> {
159    // SAFETY: Caller guarantees the pointer is valid and the pointee is alive
160    // for the duration of the returned borrow.
161    unsafe { &*((*profiler).data.cast::<Arc<Mutex<ProfilerInner>>>()) }
162}
163
164/// Called at the start of each op invocation. Records the start time and
165/// returns a handle that `TFLite` passes back to `report_end_op_invoke`.
166///
167/// # Safety
168///
169/// `profiler` must be a valid pointer to a `TfLiteTelemetryProfilerStruct`
170/// whose `data` field points to a live `Arc<Mutex<ProfilerInner>>`.
171/// `op_name` must be a valid, NUL-terminated C string.
172unsafe extern "C" fn report_begin_op_invoke(
173    profiler: *mut TfLiteTelemetryProfilerStruct,
174    op_name: *const c_char,
175    op_idx: i64,
176    subgraph_idx: i64,
177) -> u32 {
178    // SAFETY: Caller (TFLite runtime) upholds the data-pointer invariant.
179    let inner = unsafe { inner_from_profiler(profiler) };
180    let mut guard = inner.lock().unwrap_or_else(PoisonError::into_inner);
181    let handle = guard.next_handle;
182    guard.next_handle = guard.next_handle.wrapping_add(1);
183    // SAFETY: `op_name` is a valid C string provided by the TFLite runtime.
184    let name = unsafe { CStr::from_ptr(op_name) }
185        .to_string_lossy()
186        .into_owned();
187    guard
188        .pending
189        .insert(handle, (name, op_idx, subgraph_idx, Instant::now()));
190    handle
191}
192
193/// Called at the end of each op invocation. Computes elapsed time from the
194/// corresponding begin event and records a completed [`OpEvent`].
195///
196/// # Safety
197///
198/// `profiler` must be a valid pointer to a `TfLiteTelemetryProfilerStruct`
199/// whose `data` field points to a live `Arc<Mutex<ProfilerInner>>`.
200unsafe extern "C" fn report_end_op_invoke(
201    profiler: *mut TfLiteTelemetryProfilerStruct,
202    event_handle: u32,
203) {
204    // SAFETY: Caller (TFLite runtime) upholds the data-pointer invariant.
205    let inner = unsafe { inner_from_profiler(profiler) };
206    let mut guard = inner.lock().unwrap_or_else(PoisonError::into_inner);
207    if let Some((op_name, op_idx, subgraph_idx, start)) = guard.pending.remove(&event_handle) {
208        #[allow(clippy::cast_possible_truncation)]
209        let duration_us = start.elapsed().as_micros() as u64;
210        guard.events.push(OpEvent {
211            op_name,
212            op_idx,
213            subgraph_idx,
214            duration_us,
215        });
216    }
217}
218
219/// Called for ops that self-report their timing (`elapsed_time` in
220/// microseconds).
221///
222/// # Safety
223///
224/// `profiler` must be a valid pointer to a `TfLiteTelemetryProfilerStruct`
225/// whose `data` field points to a live `Arc<Mutex<ProfilerInner>>`.
226/// `op_name` must be a valid, NUL-terminated C string.
227unsafe extern "C" fn report_op_invoke_event(
228    profiler: *mut TfLiteTelemetryProfilerStruct,
229    op_name: *const c_char,
230    elapsed_time: u64,
231    op_idx: i64,
232    subgraph_idx: i64,
233) {
234    // SAFETY: Caller (TFLite runtime) upholds the data-pointer invariant.
235    let inner = unsafe { inner_from_profiler(profiler) };
236    let mut guard = inner.lock().unwrap_or_else(PoisonError::into_inner);
237    // SAFETY: `op_name` is a valid C string provided by the TFLite runtime.
238    let name = unsafe { CStr::from_ptr(op_name) }
239        .to_string_lossy()
240        .into_owned();
241    guard.events.push(OpEvent {
242        op_name: name,
243        op_idx,
244        subgraph_idx,
245        duration_us: elapsed_time,
246    });
247}
248
249// ---------------------------------------------------------------------------
250// ProfilerInner
251// ---------------------------------------------------------------------------
252
253/// Shared mutable state for the profiler, protected by a `Mutex`.
254struct ProfilerInner {
255    /// Completed op timing events.
256    events: Vec<OpEvent>,
257    /// In-flight events keyed by handle.
258    pending: HashMap<u32, (String, i64, i64, Instant)>,
259    /// Monotonically increasing handle counter.
260    next_handle: u32,
261}
262
263// ---------------------------------------------------------------------------
264// Profiler
265// ---------------------------------------------------------------------------
266
267/// Collects per-op timing events during `TFLite` inference.
268///
269/// Created via [`Profiler::new`], attached to an interpreter via
270/// [`InterpreterBuilder::profiler`](crate::InterpreterBuilder::profiler),
271/// then events are read after
272/// [`Interpreter::invoke`](crate::Interpreter::invoke).
273///
274/// The `Profiler` must outlive the [`Interpreter`](crate::Interpreter) it
275/// is attached to. This is guaranteed when the `Profiler` is declared
276/// before the `Interpreter` in the same scope, or when it is stored in a
277/// longer-lived struct.
278///
279/// # Example
280///
281/// ```no_run
282/// use edgefirst_tflite::{Library, Model, Interpreter, Profiler};
283///
284/// let lib = Library::new()?;
285/// let model = Model::from_file(&lib, "model.tflite")?;
286/// let profiler = Profiler::new();
287///
288/// let mut interp = Interpreter::builder(&lib)?
289///     .profiler(&profiler)?
290///     .build(&model)?;
291///
292/// interp.invoke()?;
293///
294/// for event in profiler.events() {
295///     println!("{}: {}us", event.op_name, event.duration_us);
296/// }
297/// # Ok::<(), edgefirst_tflite::Error>(())
298/// ```
299pub struct Profiler {
300    /// Shared state holding completed and in-flight events.
301    inner: Arc<Mutex<ProfilerInner>>,
302    /// Boxed C struct that `TFLite` holds a pointer to. Must not move after
303    /// the pointer is handed to `TfLiteInterpreterOptionsSetTelemetryProfiler`.
304    c_struct: Box<TfLiteTelemetryProfilerStruct>,
305    /// Raw pointer to a heap-allocated `Arc<Mutex<ProfilerInner>>` created
306    /// via `Box::into_raw`. Freed on drop.
307    data_ptr: *mut Arc<Mutex<ProfilerInner>>,
308}
309
310// SAFETY: The `c_struct` contains a `*mut c_void` data pointer to an
311// `Arc<Mutex<ProfilerInner>>`, which is itself `Send + Sync`. The C struct
312// is only mutated through the `Mutex`-protected inner state. The raw
313// `data_ptr` is never dereferenced outside the Mutex-guarded callbacks.
314unsafe impl Send for Profiler {}
315// SAFETY: All mutable access goes through the `Mutex` inside the `Arc`.
316unsafe impl Sync for Profiler {}
317
318impl Profiler {
319    /// Create a new profiler ready to be attached to an interpreter.
320    #[must_use]
321    pub fn new() -> Self {
322        let inner = Arc::new(Mutex::new(ProfilerInner {
323            events: Vec::new(),
324            pending: HashMap::new(),
325            next_handle: 0,
326        }));
327
328        // Heap-allocate a clone of the Arc so we have a stable pointer that
329        // the C callbacks can cast back to `&Arc<Mutex<ProfilerInner>>`.
330        let data_box = Box::new(inner.clone());
331        let data_ptr = Box::into_raw(data_box);
332
333        let c_struct = Box::new(TfLiteTelemetryProfilerStruct {
334            data: data_ptr.cast::<c_void>(),
335            report_telemetry_event: Some(report_telemetry_event_noop),
336            report_telemetry_op_event: Some(report_telemetry_op_event_noop),
337            report_settings: Some(report_settings_noop),
338            report_begin_op_invoke_event: Some(report_begin_op_invoke),
339            report_end_op_invoke_event: Some(report_end_op_invoke),
340            report_op_invoke_event: Some(report_op_invoke_event),
341        });
342
343        Self {
344            inner,
345            c_struct,
346            data_ptr,
347        }
348    }
349
350    /// Get a snapshot of all collected events since the last drain or clear.
351    #[must_use]
352    pub fn events(&self) -> Vec<OpEvent> {
353        let guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
354        guard.events.clone()
355    }
356
357    /// Drain and return all collected events, leaving the internal list empty.
358    #[must_use]
359    pub fn drain_events(&self) -> Vec<OpEvent> {
360        let mut guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
361        std::mem::take(&mut guard.events)
362    }
363
364    /// Clear all collected events without returning them.
365    pub fn clear(&self) {
366        let mut guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
367        guard.events.clear();
368        guard.pending.clear();
369        guard.next_handle = 0;
370    }
371
372    /// Returns the number of completed events collected so far.
373    #[must_use]
374    pub fn event_count(&self) -> usize {
375        let guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
376        guard.events.len()
377    }
378
379    /// Returns a raw mutable pointer to the C profiler struct for FFI.
380    ///
381    /// The returned pointer is valid for the lifetime of this `Profiler`.
382    /// It points to the `TfLiteTelemetryProfilerStruct` but is returned as
383    /// `*mut c_void` to avoid exposing the private C struct type.
384    pub(crate) fn as_ptr(&self) -> *mut c_void {
385        (self.c_struct.as_ref() as *const TfLiteTelemetryProfilerStruct)
386            .cast_mut()
387            .cast()
388    }
389}
390
391impl Default for Profiler {
392    fn default() -> Self {
393        Self::new()
394    }
395}
396
397impl Drop for Profiler {
398    fn drop(&mut self) {
399        // SAFETY: `data_ptr` was created via `Box::into_raw` in `new()`.
400        // We reconstruct the Box so the inner `Arc` is dropped properly,
401        // decrementing its refcount.
402        unsafe {
403            drop(Box::from_raw(self.data_ptr));
404        }
405    }
406}
407
408#[allow(clippy::missing_fields_in_debug)]
409impl std::fmt::Debug for Profiler {
410    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411        let guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner);
412        f.debug_struct("Profiler")
413            .field("events", &guard.events.len())
414            .field("pending", &guard.pending.len())
415            .finish_non_exhaustive()
416    }
417}
418
419// ---------------------------------------------------------------------------
420// Tests
421// ---------------------------------------------------------------------------
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    /// Test helper: return a typed pointer to the C struct for direct
428    /// callback invocation in tests.
429    fn c_struct_ptr(profiler: &Profiler) -> *mut TfLiteTelemetryProfilerStruct {
430        profiler.as_ptr().cast()
431    }
432
433    #[test]
434    fn new_profiler_has_no_events() {
435        let profiler = Profiler::new();
436        assert!(profiler.events().is_empty());
437        assert_eq!(profiler.event_count(), 0);
438    }
439
440    #[test]
441    fn default_matches_new() {
442        let profiler = Profiler::default();
443        assert!(profiler.events().is_empty());
444    }
445
446    #[test]
447    fn clear_resets_state() {
448        let profiler = Profiler::new();
449        // Manually push an event through the inner state.
450        {
451            let mut guard = profiler.inner.lock().unwrap();
452            guard.events.push(OpEvent {
453                op_name: "TEST_OP".to_string(),
454                op_idx: 0,
455                subgraph_idx: 0,
456                duration_us: 100,
457            });
458        }
459        assert_eq!(profiler.event_count(), 1);
460        profiler.clear();
461        assert_eq!(profiler.event_count(), 0);
462    }
463
464    #[test]
465    fn drain_events_empties_list() {
466        let profiler = Profiler::new();
467        {
468            let mut guard = profiler.inner.lock().unwrap();
469            guard.events.push(OpEvent {
470                op_name: "OP_A".to_string(),
471                op_idx: 1,
472                subgraph_idx: 0,
473                duration_us: 50,
474            });
475            guard.events.push(OpEvent {
476                op_name: "OP_B".to_string(),
477                op_idx: 2,
478                subgraph_idx: 0,
479                duration_us: 75,
480            });
481        }
482        let drained = profiler.drain_events();
483        assert_eq!(drained.len(), 2);
484        assert!(profiler.events().is_empty());
485    }
486
487    #[test]
488    fn events_returns_snapshot() {
489        let profiler = Profiler::new();
490        {
491            let mut guard = profiler.inner.lock().unwrap();
492            guard.events.push(OpEvent {
493                op_name: "CONV2D".to_string(),
494                op_idx: 0,
495                subgraph_idx: 0,
496                duration_us: 200,
497            });
498        }
499        let events = profiler.events();
500        assert_eq!(events.len(), 1);
501        assert_eq!(events[0].op_name, "CONV2D");
502        assert_eq!(events[0].duration_us, 200);
503        // Original events still present (snapshot, not drain).
504        assert_eq!(profiler.event_count(), 1);
505    }
506
507    #[test]
508    fn debug_format() {
509        let profiler = Profiler::new();
510        let debug = format!("{profiler:?}");
511        assert!(debug.contains("Profiler"));
512        assert!(debug.contains("events"));
513    }
514
515    #[test]
516    fn op_event_debug_clone() {
517        let event = OpEvent {
518            op_name: "SOFTMAX".to_string(),
519            op_idx: 3,
520            subgraph_idx: 0,
521            duration_us: 42,
522        };
523        let cloned = event.clone();
524        assert_eq!(cloned.op_name, "SOFTMAX");
525        assert_eq!(cloned.op_idx, 3);
526        assert_eq!(cloned.duration_us, 42);
527        let debug = format!("{event:?}");
528        assert!(debug.contains("SOFTMAX"));
529    }
530
531    #[test]
532    fn profiler_is_send_and_sync() {
533        fn assert_send_sync<T: Send + Sync>() {}
534        assert_send_sync::<Profiler>();
535    }
536
537    #[test]
538    fn c_struct_pointer_is_stable() {
539        let profiler = Profiler::new();
540        let ptr1 = profiler.as_ptr();
541        let ptr2 = profiler.as_ptr();
542        assert_eq!(ptr1, ptr2, "C struct pointer must be stable (boxed)");
543    }
544
545    #[test]
546    fn begin_end_callback_round_trip() {
547        let profiler = Profiler::new();
548        let c_ptr = c_struct_ptr(&profiler);
549
550        let op_name = CStr::from_bytes_with_nul(b"TEST_OP\0").unwrap();
551
552        // Simulate what TFLite does: call begin, then end.
553        // SAFETY: We own the profiler and the C struct is valid.
554        let handle = unsafe {
555            ((*c_ptr).report_begin_op_invoke_event.unwrap())(c_ptr, op_name.as_ptr(), 5, 0)
556        };
557        // Small delay to get a nonzero duration.
558        std::thread::sleep(std::time::Duration::from_micros(10));
559        unsafe {
560            ((*c_ptr).report_end_op_invoke_event.unwrap())(c_ptr, handle);
561        }
562
563        let events = profiler.events();
564        assert_eq!(events.len(), 1);
565        assert_eq!(events[0].op_name, "TEST_OP");
566        assert_eq!(events[0].op_idx, 5);
567        assert_eq!(events[0].subgraph_idx, 0);
568        // Duration should be at least a few microseconds.
569        assert!(events[0].duration_us > 0);
570    }
571
572    #[test]
573    fn self_reported_op_invoke_callback() {
574        let profiler = Profiler::new();
575        let c_ptr = c_struct_ptr(&profiler);
576
577        let op_name = CStr::from_bytes_with_nul(b"DELEGATE_OP\0").unwrap();
578
579        // SAFETY: We own the profiler and the C struct is valid.
580        unsafe {
581            ((*c_ptr).report_op_invoke_event.unwrap())(c_ptr, op_name.as_ptr(), 1234, 2, 1);
582        }
583
584        let events = profiler.events();
585        assert_eq!(events.len(), 1);
586        assert_eq!(events[0].op_name, "DELEGATE_OP");
587        assert_eq!(events[0].duration_us, 1234);
588        assert_eq!(events[0].op_idx, 2);
589        assert_eq!(events[0].subgraph_idx, 1);
590    }
591
592    #[test]
593    fn end_with_unknown_handle_is_ignored() {
594        let profiler = Profiler::new();
595        let c_ptr = c_struct_ptr(&profiler);
596
597        // SAFETY: We own the profiler and the C struct is valid.
598        unsafe {
599            ((*c_ptr).report_end_op_invoke_event.unwrap())(c_ptr, 999);
600        }
601
602        assert!(profiler.events().is_empty());
603    }
604}