use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
use crate::dtype::DType;
use crate::runtime_config::RuntimeConfig;
static OP_TRACER: OnceLock<Arc<Mutex<OpTracerInternal>>> = OnceLock::new();
pub type TraceId = u64;
#[derive(Debug, Clone)]
pub struct TraceConfig {
pub enabled: bool,
pub max_traces: usize,
pub capture_values: bool,
pub capture_outputs: bool,
pub capture_stack_trace: bool,
pub operation_filters: Vec<String>,
pub max_depth: usize,
pub break_on_error: bool,
}
impl Default for TraceConfig {
fn default() -> Self {
Self {
enabled: false,
max_traces: 10_000,
capture_values: false,
capture_outputs: false,
capture_stack_trace: false,
operation_filters: Vec::new(),
max_depth: 0, break_on_error: true,
}
}
}
#[derive(Debug, Clone)]
pub struct TensorMetadata {
pub name: String,
pub shape: Vec<usize>,
pub dtype: Option<DType>,
pub numel: usize,
pub size_bytes: usize,
pub is_contiguous: bool,
pub values: Option<Vec<f64>>,
}
impl TensorMetadata {
pub fn new(name: impl Into<String>, shape: Vec<usize>) -> Self {
let numel = shape.iter().product();
Self {
name: name.into(),
shape,
dtype: None,
numel,
size_bytes: 0,
is_contiguous: true,
values: None,
}
}
pub fn with_dtype(mut self, dtype: DType) -> Self {
self.size_bytes = self.numel * dtype.size();
self.dtype = Some(dtype);
self
}
pub fn with_contiguous(mut self, is_contiguous: bool) -> Self {
self.is_contiguous = is_contiguous;
self
}
pub fn with_values(mut self, values: Vec<f64>) -> Self {
self.values = Some(values);
self
}
}
#[derive(Debug, Clone)]
pub struct OperationTrace {
pub id: TraceId,
pub parent_id: Option<TraceId>,
pub operation: String,
pub category: Option<String>,
pub inputs: Vec<TensorMetadata>,
pub outputs: Vec<TensorMetadata>,
pub start_time: Instant,
pub duration: Option<Duration>,
pub depth: usize,
pub stack_trace: Option<String>,
pub metadata: HashMap<String, String>,
pub had_error: bool,
pub error_message: Option<String>,
}
impl OperationTrace {
fn new(id: TraceId, parent_id: Option<TraceId>, operation: String, depth: usize) -> Self {
Self {
id,
parent_id,
operation,
category: None,
inputs: Vec::new(),
outputs: Vec::new(),
start_time: Instant::now(),
duration: None,
depth,
stack_trace: None,
metadata: HashMap::new(),
had_error: false,
error_message: None,
}
}
pub fn set_category(&mut self, category: impl Into<String>) {
self.category = Some(category.into());
}
pub fn add_input(&mut self, input: TensorMetadata) {
self.inputs.push(input);
}
pub fn add_output(&mut self, output: TensorMetadata) {
self.outputs.push(output);
}
pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.insert(key.into(), value.into());
}
pub fn complete(&mut self) {
self.duration = Some(self.start_time.elapsed());
}
pub fn mark_error(&mut self, error: impl Into<String>) {
self.had_error = true;
self.error_message = Some(error.into());
self.complete();
}
fn matches_filter(&self, filter: &str) -> bool {
self.operation.contains(filter)
|| self.category.as_ref().map_or(false, |c| c.contains(filter))
}
}
pub struct TraceBuilder {
trace_id: TraceId,
}
impl TraceBuilder {
fn new(trace_id: TraceId) -> Self {
Self { trace_id }
}
pub fn record_input(&self, name: impl Into<String>, shape: Vec<usize>) {
let metadata = TensorMetadata::new(name, shape);
if let Some(tracer) = OP_TRACER.get() {
if let Ok(mut tracer) = tracer.lock() {
if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
trace.add_input(metadata);
}
}
}
}
pub fn record_input_with_dtype(
&self,
name: impl Into<String>,
shape: Vec<usize>,
dtype: DType,
) {
let metadata = TensorMetadata::new(name, shape).with_dtype(dtype);
if let Some(tracer) = OP_TRACER.get() {
if let Ok(mut tracer) = tracer.lock() {
if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
trace.add_input(metadata);
}
}
}
}
pub fn record_output(&self, name: impl Into<String>, shape: Vec<usize>) {
let metadata = TensorMetadata::new(name, shape);
if let Some(tracer) = OP_TRACER.get() {
if let Ok(mut tracer) = tracer.lock() {
if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
trace.add_output(metadata);
}
}
}
}
pub fn record_output_with_dtype(
&self,
name: impl Into<String>,
shape: Vec<usize>,
dtype: DType,
) {
let metadata = TensorMetadata::new(name, shape).with_dtype(dtype);
if let Some(tracer) = OP_TRACER.get() {
if let Ok(mut tracer) = tracer.lock() {
if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
trace.add_output(metadata);
}
}
}
}
pub fn add_metadata(&self, key: impl Into<String>, value: impl Into<String>) {
if let Some(tracer) = OP_TRACER.get() {
if let Ok(mut tracer) = tracer.lock() {
if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
trace.add_metadata(key, value);
}
}
}
}
pub fn set_category(&self, category: impl Into<String>) {
if let Some(tracer) = OP_TRACER.get() {
if let Ok(mut tracer) = tracer.lock() {
if let Some(trace) = tracer.traces.get_mut(&self.trace_id) {
trace.set_category(category);
}
}
}
}
}
struct OpTracerInternal {
config: TraceConfig,
traces: HashMap<TraceId, OperationTrace>,
trace_order: VecDeque<TraceId>,
next_id: TraceId,
current_depth: usize,
depth_stack: Vec<TraceId>,
breakpoints: HashMap<String, bool>, }
impl OpTracerInternal {
fn new() -> Self {
Self {
config: TraceConfig::default(),
traces: HashMap::new(),
trace_order: VecDeque::new(),
next_id: 1,
current_depth: 0,
depth_stack: Vec::new(),
breakpoints: HashMap::new(),
}
}
fn should_trace(&self, operation: &str) -> bool {
if !self.config.enabled {
return false;
}
if self.config.max_depth > 0 && self.current_depth >= self.config.max_depth {
return false;
}
if !self.config.operation_filters.is_empty() {
return self
.config
.operation_filters
.iter()
.any(|f| operation.contains(f));
}
true
}
fn start_trace(&mut self, operation: String) -> Option<TraceId> {
if !self.should_trace(&operation) {
return None;
}
let trace_id = self.next_id;
self.next_id += 1;
let parent_id = self.depth_stack.last().copied();
let trace = OperationTrace::new(trace_id, parent_id, operation, self.current_depth);
self.traces.insert(trace_id, trace);
self.trace_order.push_back(trace_id);
self.depth_stack.push(trace_id);
self.current_depth += 1;
while self.trace_order.len() > self.config.max_traces {
if let Some(old_id) = self.trace_order.pop_front() {
self.traces.remove(&old_id);
}
}
Some(trace_id)
}
fn complete_trace(&mut self, trace_id: TraceId) {
if let Some(trace) = self.traces.get_mut(&trace_id) {
trace.complete();
}
if self.depth_stack.last() == Some(&trace_id) {
self.depth_stack.pop();
if self.current_depth > 0 {
self.current_depth -= 1;
}
}
}
fn mark_error(&mut self, trace_id: TraceId, error: String) {
if let Some(trace) = self.traces.get_mut(&trace_id) {
trace.mark_error(error);
}
}
}
pub struct OpTracer {
inner: Arc<Mutex<OpTracerInternal>>,
}
impl OpTracer {
pub fn global() -> Self {
let inner = OP_TRACER
.get_or_init(|| Arc::new(Mutex::new(OpTracerInternal::new())))
.clone();
Self { inner }
}
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(OpTracerInternal::new())),
}
}
pub fn set_enabled(&self, enabled: bool) {
if let Ok(mut tracer) = self.inner.lock() {
tracer.config.enabled = enabled;
}
}
pub fn is_enabled(&self) -> bool {
self.inner.lock().map_or(false, |t| t.config.enabled)
}
pub fn set_config(&self, config: TraceConfig) {
if let Ok(mut tracer) = self.inner.lock() {
tracer.config = config;
}
}
pub fn get_config(&self) -> TraceConfig {
self.inner
.lock()
.map_or(TraceConfig::default(), |t| t.config.clone())
}
pub fn add_filter(&self, pattern: impl Into<String>) {
if let Ok(mut tracer) = self.inner.lock() {
tracer.config.operation_filters.push(pattern.into());
}
}
pub fn clear_filters(&self) {
if let Ok(mut tracer) = self.inner.lock() {
tracer.config.operation_filters.clear();
}
}
pub fn set_breakpoint(&self, operation: impl Into<String>) {
if let Ok(mut tracer) = self.inner.lock() {
tracer.breakpoints.insert(operation.into(), true);
}
}
pub fn remove_breakpoint(&self, operation: &str) {
if let Ok(mut tracer) = self.inner.lock() {
tracer.breakpoints.remove(operation);
}
}
pub fn has_breakpoint(&self, operation: &str) -> bool {
self.inner.lock().map_or(false, |t| {
t.breakpoints.get(operation).copied().unwrap_or(false)
})
}
pub fn get_trace(&self, trace_id: TraceId) -> Option<OperationTrace> {
self.inner.lock().ok()?.traces.get(&trace_id).cloned()
}
pub fn get_all_traces(&self) -> Vec<OperationTrace> {
self.inner.lock().map_or(Vec::new(), |t| {
t.trace_order
.iter()
.filter_map(|id| t.traces.get(id).cloned())
.collect()
})
}
pub fn get_filtered_traces(&self, filter: &str) -> Vec<OperationTrace> {
self.inner.lock().map_or(Vec::new(), |t| {
t.trace_order
.iter()
.filter_map(|id| t.traces.get(id))
.filter(|trace| trace.matches_filter(filter))
.cloned()
.collect()
})
}
pub fn clear_traces(&self) {
if let Ok(mut tracer) = self.inner.lock() {
tracer.traces.clear();
tracer.trace_order.clear();
}
}
pub fn get_statistics(&self) -> TraceStatistics {
let tracer = match self.inner.lock() {
Ok(t) => t,
Err(_) => return TraceStatistics::default(),
};
let total_traces = tracer.traces.len();
let total_errors = tracer.traces.values().filter(|t| t.had_error).count();
let total_duration: Duration = tracer.traces.values().filter_map(|t| t.duration).sum();
let operations_by_type: HashMap<String, usize> =
tracer
.traces
.values()
.fold(HashMap::new(), |mut acc, trace| {
*acc.entry(trace.operation.clone()).or_insert(0) += 1;
acc
});
TraceStatistics {
total_traces,
total_errors,
total_duration,
operations_by_type,
}
}
}
impl Default for OpTracer {
fn default() -> Self {
Self::global()
}
}
#[derive(Debug, Clone, Default)]
pub struct TraceStatistics {
pub total_traces: usize,
pub total_errors: usize,
pub total_duration: Duration,
pub operations_by_type: HashMap<String, usize>,
}
pub fn trace_operation<F>(operation: impl Into<String>, f: F) -> Option<TraceId>
where
F: FnOnce(&TraceBuilder),
{
let operation = operation.into();
let tracer = OpTracer::global();
let runtime_config = RuntimeConfig::global();
if !runtime_config.should_collect_metrics(&operation) {
return None;
}
let trace_id = {
let mut inner = tracer.inner.lock().ok()?;
inner.start_trace(operation.clone())?
};
let builder = TraceBuilder::new(trace_id);
f(&builder);
{
let mut inner = tracer.inner.lock().ok()?;
inner.complete_trace(trace_id);
}
Some(trace_id)
}
pub fn trace_operation_result<F, T, E>(operation: impl Into<String>, f: F) -> Result<T, E>
where
F: FnOnce(&TraceBuilder) -> Result<T, E>,
E: std::fmt::Display,
{
let operation = operation.into();
let tracer = OpTracer::global();
let trace_id = {
let mut inner = tracer.inner.lock().ok().ok_or_else(|| {
panic!("Failed to acquire tracer lock")
})?;
inner.start_trace(operation.clone())
};
let builder = trace_id.map(TraceBuilder::new);
let result = match builder.as_ref() {
Some(b) => f(b),
None => f(&TraceBuilder::new(0)), };
if let Some(tid) = trace_id {
let mut inner = tracer
.inner
.lock()
.ok()
.ok_or_else(|| panic!("Failed to acquire tracer lock"))?;
match &result {
Ok(_) => inner.complete_trace(tid),
Err(e) => inner.mark_error(tid, e.to_string()),
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tracer_enable_disable() {
let tracer = OpTracer::new();
assert!(!tracer.is_enabled());
tracer.set_enabled(true);
assert!(tracer.is_enabled());
tracer.set_enabled(false);
assert!(!tracer.is_enabled());
}
#[test]
fn test_trace_operation() {
let tracer = OpTracer::new();
tracer.set_enabled(true);
let trace_id = {
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner
.start_trace("test_op".to_string())
.expect("start_trace should succeed")
};
assert!(tracer.get_trace(trace_id).is_some());
{
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner.complete_trace(trace_id);
}
let trace = tracer.get_trace(trace_id).expect("trace should exist");
assert_eq!(trace.operation, "test_op");
assert!(trace.duration.is_some());
}
#[test]
fn test_trace_with_inputs_outputs() {
let tracer = OpTracer::new();
tracer.set_enabled(true);
let trace_id = {
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner
.start_trace("matmul".to_string())
.expect("start_trace should succeed")
};
{
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
if let Some(trace) = inner.traces.get_mut(&trace_id) {
trace.add_input(TensorMetadata::new("lhs", vec![10, 20]).with_dtype(DType::F32));
trace.add_input(TensorMetadata::new("rhs", vec![20, 30]).with_dtype(DType::F32));
trace
.add_output(TensorMetadata::new("result", vec![10, 30]).with_dtype(DType::F32));
}
}
{
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner.complete_trace(trace_id);
}
let trace = tracer.get_trace(trace_id).expect("trace should exist");
assert_eq!(trace.inputs.len(), 2);
assert_eq!(trace.outputs.len(), 1);
assert_eq!(trace.inputs[0].shape, vec![10, 20]);
assert_eq!(trace.outputs[0].shape, vec![10, 30]);
}
#[test]
fn test_trace_filtering() {
let tracer = OpTracer::new();
tracer.set_enabled(true);
tracer.add_filter("matmul");
let trace_id1 = {
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner.start_trace("matmul".to_string())
};
assert!(trace_id1.is_some());
let trace_id2 = {
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner.start_trace("add".to_string())
};
assert!(trace_id2.is_none());
}
#[test]
fn test_trace_hierarchy() {
let tracer = OpTracer::new();
tracer.set_enabled(true);
let parent_id = {
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner
.start_trace("parent_op".to_string())
.expect("start_trace should succeed")
};
let child_id = {
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner
.start_trace("child_op".to_string())
.expect("start_trace should succeed")
};
{
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner.complete_trace(child_id);
inner.complete_trace(parent_id);
}
let parent_trace = tracer
.get_trace(parent_id)
.expect("parent trace should exist");
let child_trace = tracer
.get_trace(child_id)
.expect("child trace should exist");
assert_eq!(parent_trace.depth, 0);
assert_eq!(child_trace.depth, 1);
assert_eq!(child_trace.parent_id, Some(parent_id));
}
#[test]
fn test_breakpoints() {
let tracer = OpTracer::new();
tracer.set_breakpoint("critical_op");
assert!(tracer.has_breakpoint("critical_op"));
tracer.remove_breakpoint("critical_op");
assert!(!tracer.has_breakpoint("critical_op"));
}
#[test]
fn test_trace_statistics() {
let tracer = OpTracer::new();
tracer.set_enabled(true);
for i in 0..5 {
let trace_id = {
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner
.start_trace(format!("op_{}", i))
.expect("start_trace should succeed")
};
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner.complete_trace(trace_id);
}
let stats = tracer.get_statistics();
assert_eq!(stats.total_traces, 5);
assert_eq!(stats.total_errors, 0);
}
#[test]
fn test_error_tracing() {
let tracer = OpTracer::new();
tracer.set_enabled(true);
let trace_id = {
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner
.start_trace("failing_op".to_string())
.expect("start_trace should succeed")
};
{
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner.mark_error(trace_id, "Test error".to_string());
}
let trace = tracer.get_trace(trace_id).expect("trace should exist");
assert!(trace.had_error);
assert_eq!(trace.error_message, Some("Test error".to_string()));
let stats = tracer.get_statistics();
assert_eq!(stats.total_errors, 1);
}
#[test]
fn test_max_traces_limit() {
let tracer = OpTracer::new();
let mut config = TraceConfig::default();
config.enabled = true;
config.max_traces = 5;
tracer.set_config(config);
for i in 0..10 {
let trace_id = {
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner
.start_trace(format!("op_{}", i))
.expect("start_trace should succeed")
};
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner.complete_trace(trace_id);
}
let all_traces = tracer.get_all_traces();
assert_eq!(all_traces.len(), 5); }
#[test]
fn test_clear_traces() {
let tracer = OpTracer::new();
tracer.set_enabled(true);
let trace_id = {
let mut inner = tracer.inner.lock().expect("lock should not be poisoned");
inner
.start_trace("test_op".to_string())
.expect("start_trace should succeed")
};
assert!(tracer.get_trace(trace_id).is_some());
tracer.clear_traces();
assert!(tracer.get_trace(trace_id).is_none());
}
}