use crate::codegen::CodeGenerator;
use crate::fx::types::{Edge, Node};
use crate::onnx_export::{export_to_onnx, OnnxExporter, OnnxModel};
use crate::{FxGraph, TorshResult};
use petgraph::graph::{Graph, NodeIndex};
#[derive(Debug, Clone)]
pub struct FxGraphCore {
pub graph: Graph<Node, Edge>,
pub inputs: Vec<NodeIndex>,
pub outputs: Vec<NodeIndex>,
}
impl FxGraph {
pub fn new() -> Self {
Self {
graph: Graph::new(),
inputs: Vec::new(),
outputs: Vec::new(),
}
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn edge_count(&self) -> usize {
self.graph.edge_count()
}
pub fn inputs(&self) -> &[NodeIndex] {
&self.inputs
}
pub fn outputs(&self) -> &[NodeIndex] {
&self.outputs
}
pub fn get_node(&self, idx: NodeIndex) -> Option<&Node> {
self.graph.node_weight(idx)
}
pub fn add_node(&mut self, node: Node) -> NodeIndex {
self.graph.add_node(node)
}
pub fn add_edge(&mut self, source: NodeIndex, target: NodeIndex, edge: Edge) {
self.graph.add_edge(source, target, edge);
}
pub fn add_input(&mut self, input: NodeIndex) {
self.inputs.push(input);
}
pub fn add_output(&mut self, output: NodeIndex) {
self.outputs.push(output);
}
pub fn nodes(&self) -> impl Iterator<Item = (NodeIndex, &Node)> {
self.graph
.node_indices()
.map(move |idx| (idx, &self.graph[idx]))
}
pub fn print(&self) {
println!("FX Graph:");
println!(" Nodes: {}", self.node_count());
println!(" Edges: {}", self.edge_count());
println!(" Inputs: {:?}", self.inputs);
println!(" Outputs: {:?}", self.outputs);
for (idx, node) in self.nodes() {
println!(" Node {:?}: {:?}", idx, node);
}
}
pub fn generate_code(&self, target: &str) -> TorshResult<String> {
let generator = CodeGenerator::new();
generator.generate_code(self, target)
}
pub fn to_python(&self) -> TorshResult<String> {
self.generate_code("python")
}
pub fn to_cpp(&self) -> TorshResult<String> {
self.generate_code("cpp")
}
pub fn to_onnx(&self) -> TorshResult<OnnxModel> {
export_to_onnx(self, None)
}
pub fn to_onnx_named(&self, model_name: String) -> TorshResult<OnnxModel> {
export_to_onnx(self, Some(model_name))
}
pub fn to_onnx_json(&self) -> TorshResult<String> {
let exporter = OnnxExporter::new();
exporter.export_to_json(self)
}
}
impl Default for FxGraph {
fn default() -> Self {
Self::new()
}
}