use crate::{Edge, FxGraph, Node, TorshResult};
use petgraph::graph::NodeIndex;
use std::collections::HashMap;
pub struct ModuleTracer {
graph: FxGraph,
node_map: HashMap<String, NodeIndex>,
next_node_id: usize,
}
impl ModuleTracer {
pub fn new() -> Self {
Self {
graph: FxGraph::new(),
node_map: HashMap::new(),
next_node_id: 0,
}
}
fn next_node_name(&mut self) -> String {
let node_id = self.next_node_id;
let name = format!("node_{node_id}");
self.next_node_id += 1;
name
}
pub fn add_input(&mut self, name: &str) -> NodeIndex {
let node_idx = self.graph.graph.add_node(Node::Input(name.to_string()));
self.graph.inputs.push(node_idx);
self.node_map.insert(name.to_string(), node_idx);
node_idx
}
pub fn add_call(&mut self, op_name: &str, args: Vec<String>) -> NodeIndex {
let node_name = self.next_node_name();
let node_idx = self
.graph
.graph
.add_node(Node::Call(op_name.to_string(), args.clone()));
for arg in &args {
if let Some(&arg_idx) = self.node_map.get(arg) {
self.graph
.graph
.add_edge(arg_idx, node_idx, Edge { name: arg.clone() });
}
}
self.node_map.insert(node_name.clone(), node_idx);
node_idx
}
pub fn add_output(&mut self, input_name: &str) -> NodeIndex {
let node_idx = self.graph.graph.add_node(Node::Output);
if let Some(&input_idx) = self.node_map.get(input_name) {
self.graph.graph.add_edge(
input_idx,
node_idx,
Edge {
name: input_name.to_string(),
},
);
}
self.graph.outputs.push(node_idx);
node_idx
}
pub fn add_conditional(
&mut self,
condition: &str,
then_branch: Vec<String>,
else_branch: Vec<String>,
) -> NodeIndex {
let node_name = self.next_node_name();
let node_idx = self.graph.graph.add_node(Node::Conditional {
condition: condition.to_string(),
then_branch: then_branch.clone(),
else_branch: else_branch.clone(),
});
if let Some(&cond_idx) = self.node_map.get(condition) {
self.graph.graph.add_edge(
cond_idx,
node_idx,
Edge {
name: condition.to_string(),
},
);
}
for input in then_branch.iter().chain(else_branch.iter()) {
if let Some(&input_idx) = self.node_map.get(input) {
self.graph.graph.add_edge(
input_idx,
node_idx,
Edge {
name: input.clone(),
},
);
}
}
self.node_map.insert(node_name, node_idx);
node_idx
}
pub fn add_loop(
&mut self,
condition: &str,
body: Vec<String>,
loop_vars: Vec<String>,
) -> NodeIndex {
let node_name = self.next_node_name();
let node_idx = self.graph.graph.add_node(Node::Loop {
condition: condition.to_string(),
body: body.clone(),
loop_vars: loop_vars.clone(),
});
if let Some(&cond_idx) = self.node_map.get(condition) {
self.graph.graph.add_edge(
cond_idx,
node_idx,
Edge {
name: condition.to_string(),
},
);
}
for input in body.iter().chain(loop_vars.iter()) {
if let Some(&input_idx) = self.node_map.get(input) {
self.graph.graph.add_edge(
input_idx,
node_idx,
Edge {
name: input.clone(),
},
);
}
}
self.node_map.insert(node_name, node_idx);
node_idx
}
pub fn add_merge(&mut self, inputs: Vec<String>) -> NodeIndex {
let node_name = self.next_node_name();
let node_idx = self.graph.graph.add_node(Node::Merge {
inputs: inputs.clone(),
});
for input in &inputs {
if let Some(&input_idx) = self.node_map.get(input) {
self.graph.graph.add_edge(
input_idx,
node_idx,
Edge {
name: input.clone(),
},
);
}
}
self.node_map.insert(node_name, node_idx);
node_idx
}
pub fn add_get_attr(&mut self, target: &str, attr: &str) -> NodeIndex {
let node_name = self.next_node_name();
let node_idx = self.graph.graph.add_node(Node::GetAttr {
target: target.to_string(),
attr: attr.to_string(),
});
if let Some(&target_idx) = self.node_map.get(target) {
self.graph.graph.add_edge(
target_idx,
node_idx,
Edge {
name: target.to_string(),
},
);
}
self.node_map.insert(node_name, node_idx);
node_idx
}
pub fn finalize(self) -> FxGraph {
self.graph
}
}
pub trait Module {
fn forward(&self, inputs: &[String]) -> TorshResult<Vec<String>>;
}
pub fn trace(_module: &dyn Module) -> TorshResult<FxGraph> {
let mut tracer = ModuleTracer::new();
let _input_node = tracer.add_input("input");
let _linear_node = tracer.add_call("linear", vec!["input".to_string()]);
let _output_node = tracer.add_output("node_0");
Ok(tracer.finalize())
}
pub struct SymbolicTensor {
pub name: String,
pub shape: Vec<usize>,
pub node: NodeIndex,
}
impl SymbolicTensor {
pub fn new(name: String, shape: Vec<usize>, node: NodeIndex) -> Self {
Self { name, shape, node }
}
}
pub trait TracingProxy {
fn trace_forward(
&self,
tracer: &mut ModuleTracer,
inputs: &[SymbolicTensor],
) -> TorshResult<Vec<SymbolicTensor>>;
}
impl<T: Module> TracingProxy for T {
fn trace_forward(
&self,
tracer: &mut ModuleTracer,
inputs: &[SymbolicTensor],
) -> TorshResult<Vec<SymbolicTensor>> {
let input_names: Vec<String> = inputs.iter().map(|t| t.name.clone()).collect();
let output_node = tracer.add_call("module_call", input_names);
let node_index = output_node.index();
let output_name = format!("node_{node_index}");
let output_shape = inputs.first().map(|t| t.shape.clone()).unwrap_or_default();
Ok(vec![SymbolicTensor::new(
output_name,
output_shape,
output_node,
)])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tracer_basic() {
let mut tracer = ModuleTracer::new();
let _input = tracer.add_input("x");
let _relu = tracer.add_call("relu", vec!["x".to_string()]);
let _output = tracer.add_output("node_0");
let graph = tracer.finalize();
assert_eq!(graph.inputs.len(), 1);
assert_eq!(graph.outputs.len(), 1);
}
#[test]
fn test_symbolic_tensor() {
let tensor = SymbolicTensor::new("test".to_string(), vec![2, 3], NodeIndex::new(0));
assert_eq!(tensor.name, "test");
assert_eq!(tensor.shape, vec![2, 3]);
}
#[test]
fn test_conditional_node() {
let mut tracer = ModuleTracer::new();
let _input = tracer.add_input("x");
let _condition = tracer.add_call("gt", vec!["x".to_string()]);
let _then_result = tracer.add_call("relu", vec!["x".to_string()]);
let _else_result = tracer.add_call("sigmoid", vec!["x".to_string()]);
let _conditional = tracer.add_conditional(
"node_0",
vec!["node_1".to_string()],
vec!["node_2".to_string()],
);
let _output = tracer.add_output("node_3");
let graph = tracer.finalize();
assert_eq!(graph.inputs.len(), 1);
assert_eq!(graph.outputs.len(), 1);
assert!(graph.node_count() >= 5); }
#[test]
fn test_loop_node() {
let mut tracer = ModuleTracer::new();
let _input = tracer.add_input("x");
let _condition = tracer.add_call("lt", vec!["x".to_string()]);
let _body_op = tracer.add_call("add", vec!["x".to_string()]);
let _loop_node =
tracer.add_loop("node_0", vec!["node_1".to_string()], vec!["x".to_string()]);
let _output = tracer.add_output("node_2");
let graph = tracer.finalize();
assert_eq!(graph.inputs.len(), 1);
assert_eq!(graph.outputs.len(), 1);
}
#[test]
fn test_merge_node() {
let mut tracer = ModuleTracer::new();
let _input1 = tracer.add_input("x");
let _input2 = tracer.add_input("y");
let _op1 = tracer.add_call("relu", vec!["x".to_string()]);
let _op2 = tracer.add_call("sigmoid", vec!["y".to_string()]);
let _merge = tracer.add_merge(vec!["node_0".to_string(), "node_1".to_string()]);
let _output = tracer.add_output("node_2");
let graph = tracer.finalize();
assert_eq!(graph.inputs.len(), 2);
assert_eq!(graph.outputs.len(), 1);
}
#[test]
fn test_get_attr_node() {
let mut tracer = ModuleTracer::new();
let _input = tracer.add_input("module");
let _attr = tracer.add_get_attr("module", "weight");
let _output = tracer.add_output("node_0");
let graph = tracer.finalize();
assert_eq!(graph.inputs.len(), 1);
assert_eq!(graph.outputs.len(), 1);
}
}