use crate::interpreter::{interpret_with_inputs, GraphInterpreter};
use crate::{FxGraph, TorshResult};
use std::collections::HashMap;
use torsh_core::device::DeviceType;
use torsh_nn::{Module, ModuleBase, Parameter};
use torsh_tensor::Tensor;
pub struct GraphModule {
base: ModuleBase,
graph: FxGraph,
interpreter: GraphInterpreter,
input_names: Vec<String>,
output_names: Vec<String>,
}
impl GraphModule {
pub fn new(graph: FxGraph) -> Self {
let mut input_names = Vec::new();
let mut output_names = Vec::new();
for &input_idx in graph.inputs() {
if let Some(crate::Node::Input(name)) = graph.get_node(input_idx) {
input_names.push(name.clone());
}
}
for (i, _) in graph.outputs().iter().enumerate() {
output_names.push(format!("output_{}", i));
}
Self {
base: ModuleBase::new(),
graph,
interpreter: GraphInterpreter::new(DeviceType::Cpu),
input_names,
output_names,
}
}
pub fn with_device(graph: FxGraph, device: DeviceType) -> Self {
let mut module = Self::new(graph);
module.interpreter = GraphInterpreter::new(device);
module
}
pub fn graph(&self) -> &FxGraph {
&self.graph
}
pub fn graph_mut(&mut self) -> &mut FxGraph {
&mut self.graph
}
pub fn input_names(&self) -> &[String] {
&self.input_names
}
pub fn output_names(&self) -> &[String] {
&self.output_names
}
pub fn execute(&mut self, inputs: HashMap<String, Tensor>) -> TorshResult<Vec<Tensor>> {
self.interpreter.run(&self.graph, inputs)
}
pub fn execute_positional(&mut self, inputs: Vec<Tensor>) -> TorshResult<Vec<Tensor>> {
if inputs.len() != self.input_names.len() {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Expected {} inputs, got {}",
self.input_names.len(),
inputs.len()
)));
}
let named_inputs: HashMap<String, Tensor> = self
.input_names
.iter()
.zip(inputs.into_iter())
.map(|(name, tensor)| (name.clone(), tensor))
.collect();
self.execute(named_inputs)
}
pub fn optimize(&mut self) -> TorshResult<()> {
use crate::passes::PassManager;
use crate::subgraph_rewriter::apply_fusion_optimizations;
let pass_manager = PassManager::default_optimization_passes();
pass_manager.run(&mut self.graph)?;
apply_fusion_optimizations(&mut self.graph)?;
Ok(())
}
pub fn print_graph(&self) {
self.graph.print();
}
pub fn graph_stats(&self) -> GraphStats {
let mut op_counts = HashMap::new();
let mut param_count = 0;
let mut total_params = 0;
for (_, node) in self.graph.nodes() {
match node {
crate::Node::Call(op_name, _) => {
*op_counts.entry(op_name.clone()).or_insert(0) += 1;
}
_ => {}
}
}
for param in self.base.parameters.values() {
param_count += 1;
total_params += param.tensor().read().shape().numel();
}
GraphStats {
node_count: self.graph.node_count(),
edge_count: self.graph.edge_count(),
input_count: self.graph.inputs().len(),
output_count: self.graph.outputs().len(),
operation_counts: op_counts,
parameter_count: param_count,
total_parameters: total_params,
}
}
pub fn clone_graph(&self) -> FxGraph {
FxGraph::new()
}
pub fn export_graph(&self) -> TorshResult<String> {
let mut output = String::new();
output.push_str(&format!(
"Graph with {} nodes, {} edges\n",
self.graph.node_count(),
self.graph.edge_count()
));
output.push_str("Inputs:\n");
for name in &self.input_names {
output.push_str(&format!(" - {}\n", name));
}
output.push_str("Operations:\n");
for (idx, node) in self.graph.nodes() {
match node {
crate::Node::Call(op_name, args) => {
output.push_str(&format!(" {:?}: {} (args: {:?})\n", idx, op_name, args));
}
crate::Node::Input(name) => {
output.push_str(&format!(" {:?}: Input({})\n", idx, name));
}
crate::Node::Output => {
output.push_str(&format!(" {:?}: Output\n", idx));
}
crate::Node::Conditional { condition, then_branch, else_branch } => {
output.push_str(&format!(" {:?}: Conditional(condition: {}, then: {:?}, else: {:?})\n", idx, condition, then_branch, else_branch));
}
crate::Node::Loop { condition, body, loop_vars } => {
output.push_str(&format!(" {:?}: Loop(condition: {}, body: {:?}, vars: {:?})\n", idx, condition, body, loop_vars));
}
crate::Node::Merge { inputs } => {
output.push_str(&format!(" {:?}: Merge(inputs: {:?})\n", idx, inputs));
}
crate::Node::GetAttr { target, attr } => {
output.push_str(&format!(" {:?}: GetAttr(target: {}, attr: {})\n", idx, target, attr));
}
}
}
output.push_str("Outputs:\n");
for name in &self.output_names {
output.push_str(&format!(" - {}\n", name));
}
Ok(output)
}
}
#[derive(Debug)]
pub struct GraphStats {
pub node_count: usize,
pub edge_count: usize,
pub input_count: usize,
pub output_count: usize,
pub operation_counts: HashMap<String, usize>,
pub parameter_count: usize,
pub total_parameters: usize,
}
impl Module for GraphModule {
fn forward(&self, input: &Tensor) -> TorshResult<Tensor> {
if self.input_names.len() != 1 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"forward() expects exactly one input, use execute() for multiple inputs"
.to_string(),
));
}
let mut inputs = HashMap::new();
inputs.insert(self.input_names[0].clone(), input.clone());
let outputs = interpret_with_inputs(&self.graph, inputs)?;
outputs.into_iter().next().ok_or_else(|| {
torsh_core::error::TorshError::InvalidArgument("Graph produced no outputs".to_string())
})
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn to_device(&mut self, device: DeviceType) -> TorshResult<()> {
self.interpreter = GraphInterpreter::new(device);
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for GraphModule {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GraphModule")
.field("inputs", &self.input_names)
.field("outputs", &self.output_names)
.field("node_count", &self.graph.node_count())
.field("edge_count", &self.graph.edge_count())
.finish()
}
}
pub struct GraphModuleBuilder {
graph: Option<FxGraph>,
device: DeviceType,
optimize: bool,
}
impl GraphModuleBuilder {
pub fn new() -> Self {
Self {
graph: None,
device: DeviceType::Cpu,
optimize: false,
}
}
pub fn graph(mut self, graph: FxGraph) -> Self {
self.graph = Some(graph);
self
}
pub fn device(mut self, device: DeviceType) -> Self {
self.device = device;
self
}
pub fn optimize(mut self, optimize: bool) -> Self {
self.optimize = optimize;
self
}
pub fn build(self) -> TorshResult<GraphModule> {
let graph = self.graph.ok_or_else(|| {
torsh_core::error::TorshError::InvalidArgument("Graph not provided".to_string())
})?;
let mut module = GraphModule::with_device(graph, self.device);
if self.optimize {
module.optimize()?;
}
Ok(module)
}
}
impl Default for GraphModuleBuilder {
fn default() -> Self {
Self::new()
}
}
pub fn create_graph_module(graph: FxGraph) -> GraphModule {
GraphModule::new(graph)
}
pub fn create_optimized_graph_module(graph: FxGraph) -> TorshResult<GraphModule> {
GraphModuleBuilder::new()
.graph(graph)
.optimize(true)
.build()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tracer::ModuleTracer;
use torsh_tensor::creation::ones;
#[test]
fn test_graph_module_creation() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let module = GraphModule::new(graph);
assert_eq!(module.input_names().len(), 1);
assert_eq!(module.output_names().len(), 1);
assert_eq!(module.input_names()[0], "x");
}
#[test]
fn test_graph_module_forward() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let module = GraphModule::new(graph);
let input = ones(&[2, 3]);
match module.forward(&input) {
Ok(output) => {
assert_eq!(output.shape().dims(), &[2, 3]);
}
Err(_) => {
}
}
}
#[test]
fn test_graph_module_execute() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let mut module = GraphModule::new(graph);
let mut inputs = HashMap::new();
inputs.insert("x".to_string(), ones(&[2, 3]));
match module.execute(inputs) {
Ok(outputs) => {
assert_eq!(outputs.len(), 1);
}
Err(_) => {
}
}
let positional_inputs = vec![ones(&[2, 3])];
match module.execute_positional(positional_inputs) {
Ok(outputs) => {
assert_eq!(outputs.len(), 1);
}
Err(_) => {
}
}
}
#[test]
fn test_graph_stats() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_call("sigmoid", vec!["node_0".to_string()]);
tracer.add_output("node_1");
let graph = tracer.finalize();
let module = GraphModule::new(graph);
let stats = module.graph_stats();
assert_eq!(stats.input_count, 1);
assert_eq!(stats.output_count, 1);
assert!(stats.operation_counts.contains_key("relu"));
assert!(stats.operation_counts.contains_key("sigmoid"));
}
#[test]
fn test_graph_module_builder() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let module = GraphModuleBuilder::new()
.graph(graph)
.device(DeviceType::Cpu)
.optimize(false)
.build()
.unwrap();
assert_eq!(module.input_names().len(), 1);
}
#[test]
fn test_graph_export() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let module = GraphModule::new(graph);
let exported = module.export_graph().unwrap();
assert!(exported.contains("Graph with"));
assert!(exported.contains("Inputs:"));
assert!(exported.contains("Operations:"));
assert!(exported.contains("Outputs:"));
}
}