use std::cell::RefCell;
use std::collections::HashMap;
use crate::tensor::{DType, Tensor};
#[derive(Debug, Clone)]
pub struct TracedNode {
pub op_type: String,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub attributes: Vec<TracedAttribute>,
}
#[derive(Debug, Clone)]
pub enum TracedAttribute {
Int(String, i64),
Float(String, f32),
Ints(String, Vec<i64>),
Floats(String, Vec<f32>),
String(String, String),
}
#[derive(Debug, Clone)]
pub struct ValueInfo {
pub shape: Vec<usize>,
pub dtype: DType,
}
pub struct TracedGraph {
pub nodes: Vec<TracedNode>,
pub inputs: Vec<(String, ValueInfo)>,
pub outputs: Vec<(String, ValueInfo)>,
pub initializers: Vec<(String, Tensor)>,
}
struct Tracer {
nodes: Vec<TracedNode>,
value_names: HashMap<usize, String>,
value_info: HashMap<String, ValueInfo>,
initializers: Vec<(String, Tensor)>,
next_id: usize,
suppress_depth: usize,
}
impl Tracer {
fn new() -> Self {
Self {
nodes: Vec::new(),
value_names: HashMap::new(),
value_info: HashMap::new(),
initializers: Vec::new(),
next_id: 0,
suppress_depth: 0,
}
}
fn fresh_name(&mut self, prefix: &str) -> String {
let name = format!("{}_{}", prefix, self.next_id);
self.next_id += 1;
name
}
}
thread_local! {
static TRACER: RefCell<Option<Tracer>> = RefCell::new(None);
}
pub fn is_tracing() -> bool {
TRACER.with(|t| t.borrow().is_some())
}
pub fn is_suppressed() -> bool {
TRACER.with(|t| {
t.borrow()
.as_ref()
.map_or(false, |tr| tr.suppress_depth > 0)
})
}
pub fn with_tracer<F, R>(f: F) -> R
where
F: FnOnce(&mut TracerHandle<'_>) -> R,
{
TRACER.with(|t| {
let mut borrow = t.borrow_mut();
let tracer = borrow.as_mut().expect("ONNX tracer not active");
let mut handle = TracerHandle { inner: tracer };
f(&mut handle)
})
}
pub struct TracerHandle<'a> {
inner: &'a mut Tracer,
}
impl TracerHandle<'_> {
pub fn name_of(&mut self, tensor: &Tensor) -> String {
let key = tensor.storage.ptr_id();
if let Some(name) = self.inner.value_names.get(&key) {
return name.clone();
}
let name = self.inner.fresh_name("val");
self.inner.value_names.insert(key, name.clone());
self.inner.value_info.insert(
name.clone(),
ValueInfo {
shape: tensor.shape().to_vec(),
dtype: tensor.dtype(),
},
);
name
}
pub fn set_name(&mut self, tensor: &Tensor, name: String) {
let key = tensor.storage.ptr_id();
self.inner.value_names.insert(key, name.clone());
self.inner.value_info.insert(
name,
ValueInfo {
shape: tensor.shape().to_vec(),
dtype: tensor.dtype(),
},
);
}
pub fn record_node(&mut self, node: TracedNode) {
if self.inner.suppress_depth == 0 {
self.inner.nodes.push(node);
}
}
pub fn add_initializer(&mut self, name: &str, tensor: &Tensor) {
if !self.inner.initializers.iter().any(|(n, _)| n == name) {
self.inner.initializers.push((name.to_string(), tensor.clone()));
}
}
pub fn enter_fusion(&mut self) {
self.inner.suppress_depth += 1;
}
pub fn leave_fusion(&mut self) {
self.inner.suppress_depth = self.inner.suppress_depth.saturating_sub(1);
}
pub fn fresh_name(&mut self, prefix: &str) -> String {
self.inner.fresh_name(prefix)
}
pub fn register_value(&mut self, name: &str, info: ValueInfo) {
self.inner.value_info.insert(name.to_string(), info);
}
}
pub fn trace<F>(
state_dict: &HashMap<String, Tensor>,
input_specs: &[(&str, Vec<usize>, DType)],
forward_fn: F,
) -> TracedGraph
where
F: FnOnce(&[Tensor]) -> Tensor,
{
let mut inputs = Vec::with_capacity(input_specs.len());
let mut input_infos = Vec::new();
let mut tracer = Tracer::new();
for (name, tensor) in state_dict {
let key = tensor.storage.ptr_id();
tracer.value_names.insert(key, name.clone());
tracer.value_info.insert(
name.clone(),
ValueInfo {
shape: tensor.shape().to_vec(),
dtype: tensor.dtype(),
},
);
tracer.initializers.push((name.clone(), tensor.clone()));
}
for (name, shape, dtype) in input_specs {
let numel: usize = shape.iter().product();
let tensor = Tensor::new(vec![0.0f32; numel], shape.clone());
let key = tensor.storage.ptr_id();
tracer.value_names.insert(key, name.to_string());
tracer.value_info.insert(
name.to_string(),
ValueInfo {
shape: shape.clone(),
dtype: *dtype,
},
);
input_infos.push((name.to_string(), ValueInfo { shape: shape.clone(), dtype: *dtype }));
inputs.push(tensor);
}
TRACER.with(|t| *t.borrow_mut() = Some(tracer));
let output = forward_fn(&inputs);
let tracer = TRACER.with(|t| t.borrow_mut().take()).expect("tracer lost");
let output_name = tracer
.value_names
.get(&output.storage.ptr_id())
.cloned()
.unwrap_or_else(|| "output_0".to_string());
let output_info = ValueInfo {
shape: output.shape().to_vec(),
dtype: output.dtype(),
};
TracedGraph {
nodes: tracer.nodes,
inputs: input_infos,
outputs: vec![(output_name, output_info)],
initializers: tracer.initializers,
}
}
pub fn record_unary(input: &Tensor, output: &Tensor, op_type: &str) {
if !is_tracing() || is_suppressed() {
return;
}
with_tracer(|t| {
let in_name = t.name_of(input);
let out_name = t.name_of(output);
t.record_node(TracedNode {
op_type: op_type.to_string(),
inputs: vec![in_name],
outputs: vec![out_name],
attributes: vec![],
});
});
}
pub fn record_binary(lhs: &Tensor, rhs: &Tensor, output: &Tensor, op_type: &str) {
if !is_tracing() || is_suppressed() {
return;
}
with_tracer(|t| {
let l = t.name_of(lhs);
let r = t.name_of(rhs);
let o = t.name_of(output);
t.record_node(TracedNode {
op_type: op_type.to_string(),
inputs: vec![l, r],
outputs: vec![o],
attributes: vec![],
});
});
}
pub fn trace_linear(
input: &Tensor,
weight: &Tensor,
bias: Option<&Tensor>,
output: &Tensor,
weight_name: &str,
bias_name: Option<&str>,
) {
if !is_tracing() {
return;
}
with_tracer(|t| {
let in_name = t.name_of(input);
let out_name = t.name_of(output);
t.add_initializer(weight_name, weight);
let mut inputs = vec![in_name, weight_name.to_string()];
if let (Some(b), Some(bn)) = (bias, bias_name) {
t.add_initializer(bn, b);
inputs.push(bn.to_string());
}
t.record_node(TracedNode {
op_type: "Gemm".to_string(),
inputs,
outputs: vec![out_name],
attributes: vec![
TracedAttribute::Float("alpha".to_string(), 1.0),
TracedAttribute::Float("beta".to_string(), 1.0),
TracedAttribute::Int("transB".to_string(), 0), ],
});
});
}
pub fn trace_conv2d(
input: &Tensor,
output: &Tensor,
weight_name: &str,
weight: &Tensor,
bias_name: Option<&str>,
bias: Option<&Tensor>,
kernel_size: usize,
stride: usize,
padding: usize,
in_channels: usize,
out_channels: usize,
) {
if !is_tracing() {
return;
}
with_tracer(|t| {
let in_name = t.name_of(input);
let out_name = t.name_of(output);
t.add_initializer(weight_name, weight);
t.register_value(weight_name, ValueInfo {
shape: vec![out_channels, in_channels, kernel_size, kernel_size],
dtype: weight.dtype(),
});
let mut inputs = vec![in_name, weight_name.to_string()];
if let (Some(b), Some(bn)) = (bias, bias_name) {
t.add_initializer(bn, b);
inputs.push(bn.to_string());
}
t.record_node(TracedNode {
op_type: "Conv".to_string(),
inputs,
outputs: vec![out_name],
attributes: vec![
TracedAttribute::Ints("kernel_shape".to_string(), vec![kernel_size as i64, kernel_size as i64]),
TracedAttribute::Ints("strides".to_string(), vec![stride as i64, stride as i64]),
TracedAttribute::Ints("pads".to_string(), vec![padding as i64, padding as i64, padding as i64, padding as i64]),
TracedAttribute::Ints("dilations".to_string(), vec![1, 1]),
],
});
});
}