1use core::slice;
53use std::{
54 ffi::{c_void, CStr},
55 mem::{forget, transmute},
56 os::raw::c_char,
57 ptr::{null, null_mut},
58 sync::Arc,
59 time::Duration,
60};
61
62use crate::{
63 error::{Error, CSTR_CONVERT_ERROR_PLUG},
64 from_char_array,
65 message::Shape,
66 sys, to_cstring, Buffer, MemoryType,
67};
68
69bitflags::bitflags! {
70 struct Level: u32 {
77 const DISABLED = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_DISABLED;
79 #[deprecated]
81 const MIN = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_MIN;
82 #[deprecated]
84 const MAX = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_MAX;
85 const TIMESTAMPS = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_TIMESTAMPS;
87 const TENSORS = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_TENSORS;
89 }
90}
91
92impl Level {
93 #[allow(dead_code)]
94 fn as_str(self) -> &'static str {
96 unsafe {
97 let ptr = sys::TRITONSERVER_InferenceTraceLevelString(self.bits());
98 assert!(!ptr.is_null());
99 CStr::from_ptr(ptr)
100 .to_str()
101 .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
102 }
103 }
104}
105
106#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
108#[repr(u32)]
109pub enum Activity {
110 RequestStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_REQUEST_START,
111 QueueStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_QUEUE_START,
112 ComputeStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_START,
113 ComputeInputEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_INPUT_END,
114 ComputeOutputStart =
115 sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_OUTPUT_START,
116 ComputeEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_END,
117 RequestEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_REQUEST_END,
118 TensorQueueInput = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT,
119 TensorBackendInput =
120 sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_BACKEND_INPUT,
121 TensorBackendOutput =
122 sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_BACKEND_OUTPUT,
123 CustomActivity = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_CUSTOM_ACTIVITY,
124}
125
126pub trait TraceHandler: Send + Sync + 'static {
128 fn trace_activity(&self, trace: &Trace, event: Activity, event_time: Duration);
138}
139
140impl TraceHandler for () {
141 fn trace_activity(&self, _trace: &Trace, _event: Activity, _event_time: Duration) {}
142}
143
144pub trait TensorTraceHandler: Send + Sync + 'static {
146 fn trace_tensor_activity(
155 &self,
156 trace: &Trace,
157 event: Activity,
158 tensor_data: &Buffer,
159 tensor_shape: Shape,
160 );
161}
162
163impl TensorTraceHandler for () {
164 fn trace_tensor_activity(
165 &self,
166 _trace: &Trace,
167 _event: Activity,
168 _tensor_data: &Buffer,
169 _tensor_shape: Shape,
170 ) {
171 }
172}
173
174pub const NOOP: Option<()> = None;
176
177struct TraceCallbackItems<H: TraceHandler, T: TensorTraceHandler> {
178 activity_handler: Option<H>,
179 tensor_activity_handler: Option<T>,
180}
181
182trait DynamicTypeHelper: Send + Sync {}
187impl<H: TraceHandler, T: TensorTraceHandler> DynamicTypeHelper for TraceCallbackItems<H, T> {}
188
189pub struct Trace {
193 pub(crate) ptr: TraceInner,
194 handlers_copy: Arc<dyn DynamicTypeHelper>,
196}
197
198pub(crate) struct TraceInner(pub(crate) *mut sys::TRITONSERVER_InferenceTrace);
199unsafe impl Send for TraceInner {}
200unsafe impl Sync for TraceInner {}
201
202impl PartialEq for Trace {
203 fn eq(&self, other: &Self) -> bool {
204 let left = match self.id() {
205 Ok(l) => l,
206 Err(err) => {
207 log::warn!("Error getting ID for two Traces comparison: {err}");
208 return false;
209 }
210 };
211 let right = match other.id() {
212 Ok(r) => r,
213 Err(err) => {
214 log::warn!("Error getting ID for two Traces comparison: {err}");
215 return false;
216 }
217 };
218 left == right
219 }
220}
221impl Eq for Trace {}
222
223impl Trace {
224 pub fn new_with_handle<H: TraceHandler, T: TensorTraceHandler>(
240 parent_id: u64,
241 activity_handler: Option<H>,
242 tensor_activity_handler: Option<T>,
243 ) -> Result<Self, Error> {
244 let enable_activity = activity_handler.is_some();
245 let enable_tensor_activity = tensor_activity_handler.is_some();
246
247 let level = match (enable_activity, enable_tensor_activity) {
248 (true, true) => Level::TENSORS | Level::TIMESTAMPS,
249 (true, false) => Level::TIMESTAMPS,
250 (false, true) => Level::TENSORS,
251 (false, false) => Level::DISABLED,
252 };
253
254 let mut ptr = null_mut::<sys::TRITONSERVER_InferenceTrace>();
255 let handlers = Arc::new(TraceCallbackItems {
256 activity_handler,
257 tensor_activity_handler,
258 });
259 let raw_handlers = Arc::into_raw(handlers.clone()) as *mut c_void;
260
261 triton_call!(sys::TRITONSERVER_InferenceTraceTensorNew(
262 &mut ptr as *mut _,
263 level.bits(),
264 parent_id,
265 enable_activity.then_some(activity_wraper::<H, T>),
266 enable_tensor_activity.then_some(tensor_activity_wrapper::<H, T>),
267 Some(delete::<H, T>),
268 raw_handlers,
269 ))?;
270
271 assert!(!ptr.is_null());
272 let trace = Trace {
273 ptr: TraceInner(ptr),
274 handlers_copy: handlers,
275 };
276 Ok(trace)
277 }
278
279 pub fn report_activity<N: AsRef<str>>(
284 &self,
285 timestamp: Duration,
286 activity_name: N,
287 ) -> Result<(), Error> {
288 let name = to_cstring(activity_name)?;
289 triton_call!(sys::TRITONSERVER_InferenceTraceReportActivity(
290 self.ptr.0,
291 timestamp.as_nanos() as _,
292 name.as_ptr()
293 ))
294 }
295
296 pub fn id(&self) -> Result<u64, Error> {
299 let mut id: u64 = 0;
300 triton_call!(
301 sys::TRITONSERVER_InferenceTraceId(self.ptr.0, &mut id as *mut _),
302 id
303 )
304 }
305
306 pub fn parent_id(&self) -> Result<u64, Error> {
310 let mut id: u64 = 0;
311 triton_call!(
312 sys::TRITONSERVER_InferenceTraceParentId(self.ptr.0, &mut id as *mut _),
313 id
314 )
315 }
316
317 pub fn model_name(&self) -> Result<String, Error> {
319 let mut name = null::<c_char>();
320 triton_call!(
321 sys::TRITONSERVER_InferenceTraceModelName(self.ptr.0, &mut name as *mut _),
322 from_char_array(name)
323 )
324 }
325
326 pub fn model_version(&self) -> Result<i64, Error> {
328 let mut version: i64 = 0;
329 triton_call!(
330 sys::TRITONSERVER_InferenceTraceModelVersion(self.ptr.0, &mut version as *mut _),
331 version
332 )
333 }
334
335 pub fn request_id(&self) -> Result<String, Error> {
338 let mut request_id = null::<c_char>();
339
340 triton_call!(
341 sys::TRITONSERVER_InferenceTraceRequestId(self.ptr.0, &mut request_id as *mut _),
342 from_char_array(request_id)
343 )
344 }
345
346 pub fn spawn_child(&self) -> Result<Trace, Error> {
351 let mut trace = null_mut();
352 triton_call!(
353 sys::TRITONSERVER_InferenceTraceSpawnChildTrace(self.ptr.0, &mut trace),
354 Trace {
355 ptr: TraceInner(trace),
356 handlers_copy: self.handlers_copy.clone(),
357 }
358 )
359 }
360
361 pub fn set_context(&mut self, context: String) -> Result<&mut Self, Error> {
363 let context = to_cstring(context)?;
364 triton_call!(
365 sys::TRITONSERVER_InferenceTraceSetContext(self.ptr.0, context.as_ptr()),
366 self
367 )
368 }
369
370 pub fn context(&self) -> Result<String, Error> {
372 let mut context = null::<c_char>();
373 triton_call!(
374 sys::TRITONSERVER_InferenceTraceContext(self.ptr.0, &mut context as *mut _),
375 from_char_array(context)
376 )
377 }
378}
379
380impl Drop for TraceInner {
381 fn drop(&mut self) {
382 if !self.0.is_null() {
383 unsafe {
384 sys::TRITONSERVER_InferenceTraceDelete(self.0);
385 }
386 }
387 }
388}
389
390unsafe extern "C" fn delete<H: TraceHandler, T: TensorTraceHandler>(
391 this: *mut sys::TRITONSERVER_InferenceTrace,
392 userp: *mut c_void,
393) {
394 if !userp.is_null() && !this.is_null() {
395 sys::TRITONSERVER_InferenceTraceDelete(this);
396 Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
397 }
398}
399
400unsafe extern "C" fn activity_wraper<H: TraceHandler, T: TensorTraceHandler>(
401 trace: *mut sys::TRITONSERVER_InferenceTrace,
402 activity: sys::TRITONSERVER_InferenceTraceActivity,
403 timestamp_ns: u64,
404 userp: *mut ::std::os::raw::c_void,
405) {
406 if !userp.is_null() {
407 let handle = Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
408 let foo_trace = Trace {
409 ptr: TraceInner(trace),
410 handlers_copy: handle.clone(),
411 };
412 let activity: Activity = transmute(activity);
413
414 let timestamp = Duration::from_nanos(timestamp_ns);
415
416 if let Some(activity_handle) = handle.activity_handler.as_ref() {
417 activity_handle.trace_activity(&foo_trace, activity, timestamp)
418 };
419
420 forget(handle);
422 forget(foo_trace.ptr);
423 }
424}
425
426unsafe extern "C" fn tensor_activity_wrapper<H: TraceHandler, T: TensorTraceHandler>(
427 trace: *mut sys::TRITONSERVER_InferenceTrace,
428 activity: sys::TRITONSERVER_InferenceTraceActivity,
429 name: *const ::std::os::raw::c_char,
430 datatype: sys::TRITONSERVER_DataType,
431 base: *const ::std::os::raw::c_void,
432 byte_size: usize,
433 shape: *const i64,
434 dim_count: u64,
435 memory_type: sys::TRITONSERVER_MemoryType,
436 _memory_type_id: i64,
437 userp: *mut ::std::os::raw::c_void,
438) {
439 if !userp.is_null() {
440 let handle = Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
441
442 let foo_trace = Trace {
443 ptr: TraceInner(trace),
444 handlers_copy: handle.clone(),
445 };
446 let activity: Activity = transmute(activity);
447
448 let data_type = unsafe { transmute::<u32, crate::memory::DataType>(datatype) };
449 let memory_type: MemoryType = unsafe { transmute(memory_type) };
450
451 let tensor_shape = Shape {
452 name: from_char_array(name),
453 datatype: data_type,
454 dims: slice::from_raw_parts(shape, dim_count as _).to_vec(),
455 };
456
457 let tensor_data = Buffer {
458 ptr: base as *mut _,
459 len: byte_size,
460 data_type,
461 memory_type,
462 owned: false,
463 };
464
465 if let Some(tensor_activity_handler) = handle.tensor_activity_handler.as_ref() {
466 tensor_activity_handler.trace_tensor_activity(
467 &foo_trace,
468 activity,
469 &tensor_data,
470 tensor_shape,
471 )
472 };
473
474 forget(handle);
475 forget(foo_trace.ptr);
476 }
478}