torsh-fx
Graph transformation and optimization framework for ToRSh, providing TorchFX-compatible functionality.
Overview
TorshFX is a toolkit for capturing, analyzing, and transforming PyTorch-style programs. It provides:
- Graph Capture: Convert eager mode code to graph representation
- Graph Transformation: Modify and optimize computational graphs
- Symbolic Tracing: Trace through Python code to build graphs
- Graph Optimization: Apply passes for performance improvements
- Code Generation: Convert graphs back to executable code
Usage
Basic Symbolic Tracing
use *;
use *;
// Define a model
// Trace the model
let model = new;
let tracer = new;
let graph_module = tracer.trace?;
// Print the graph
println!;
Graph Transformation
use *;
// Apply optimization passes
let optimized = graph_module
.transform?
.transform?
.transform?;
// Custom transformation
;
let transformed = graph_module.transform?;
Subgraph Matching and Rewriting
use *;
// Define a pattern to match
let pattern = pattern! ;
// Define replacement
let replacement = ;
// Apply rewriter
let rewriter = new;
let optimized = rewriter.rewrite?;
Quantization with FX
use *;
// Prepare model for quantization
let prepared = prepare_fx?;
// Run calibration
for batch in calibration_data
// Convert to quantized model
let quantized = convert_fx?;
Graph Analysis
use *;
// Analyze graph properties
let analyzer = new;
// Get operation count
let op_count = analyzer.count_operations;
println!;
println!;
// Analyze shapes
let shape_prop = new;
let shapes = shape_prop.propagate?;
// Find bottlenecks
let profiler = new;
let profile = profiler.profile?;
let bottlenecks = profile.find_bottlenecks;
Custom Graph Passes
use ;
// Define custom pass
// Use pass manager
let pass_manager = new
.add_pass
.add_pass
.add_pass;
let optimized = pass_manager.run?;
Interpreter Mode
use *;
// Create custom interpreter
// Run with custom interpreter
let interpreter = new;
let output = interpreter.run?;
Serialization
// Save graph module
graph_module.save?;
// Load graph module
let loaded = load?;
// Export to ONNX-like format
let exported = graph_module.export?;
Graph IR
The FX intermediate representation (IR) consists of:
- Nodes: Individual operations (placeholder, call_function, call_method, call_module, output)
- Graph: DAG of nodes representing computation
- GraphModule: Combination of graph and module state
Integration with JIT
// Convert FX graph to JIT
let jit_module = compile_fx?;
// Optimize with JIT
let optimized = jit_module.optimize?;
License
Licensed under the Apache License, Version 2.0. See LICENSE for details.