use crate::fx::types::{Edge, Node};
use crate::graph_analysis::GraphDiff;
use crate::graph_analysis::GraphDifference;
use crate::performance::{GraphCompression, ParallelTraversal};
use crate::{FxGraph, TorshResult};
use petgraph::graph::NodeIndex;
use petgraph::visit::EdgeRef;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
impl FxGraph {
pub fn nodes_of_type(&self, node_type: fn(&Node) -> bool) -> Vec<(NodeIndex, &Node)> {
self.nodes().filter(|(_, node)| node_type(node)).collect()
}
pub fn input_nodes(&self) -> Vec<(NodeIndex, &Node)> {
self.nodes_of_type(|node| matches!(node, Node::Input(_)))
}
pub fn output_nodes(&self) -> Vec<(NodeIndex, &Node)> {
self.nodes_of_type(|node| matches!(node, Node::Output))
}
pub fn call_nodes(&self) -> Vec<(NodeIndex, &Node)> {
self.nodes_of_type(|node| matches!(node, Node::Call(_, _)))
}
pub fn conditional_nodes(&self) -> Vec<(NodeIndex, &Node)> {
self.nodes_of_type(|node| matches!(node, Node::Conditional { .. }))
}
pub fn loop_nodes(&self) -> Vec<(NodeIndex, &Node)> {
self.nodes_of_type(|node| matches!(node, Node::Loop { .. }))
}
pub fn get_operation_names(&self) -> Vec<String> {
let mut op_names = Vec::new();
for (_, node) in self.call_nodes() {
if let Node::Call(op_name, _) = node {
if !op_names.contains(op_name) {
op_names.push(op_name.clone());
}
}
}
op_names.sort();
op_names
}
pub fn contains_operation(&self, op_name: &str) -> bool {
self.call_nodes().iter().any(|(_, node)| {
if let Node::Call(name, _) = node {
name == op_name
} else {
false
}
})
}
pub fn nodes_by_operation(&self, op_name: &str) -> Vec<(NodeIndex, &Node)> {
self.call_nodes()
.into_iter()
.filter(|(_, node)| {
if let Node::Call(name, _) = node {
name == op_name
} else {
false
}
})
.collect()
}
pub fn edges(&self) -> Vec<(NodeIndex, NodeIndex, &Edge)> {
use petgraph::visit::EdgeRef;
self.graph
.edge_references()
.map(|edge_ref| (edge_ref.source(), edge_ref.target(), edge_ref.weight()))
.collect()
}
pub fn node_fanout(&self, node_idx: NodeIndex) -> usize {
self.graph
.edges_directed(node_idx, petgraph::Direction::Outgoing)
.count()
}
pub fn node_fanin(&self, node_idx: NodeIndex) -> usize {
self.graph
.edges_directed(node_idx, petgraph::Direction::Incoming)
.count()
}
pub fn find_high_fanout_nodes(&self, threshold: usize) -> Vec<(NodeIndex, usize)> {
let mut high_fanout = Vec::new();
for (idx, _) in self.nodes() {
let fanout = self.node_fanout(idx);
if fanout >= threshold {
high_fanout.push((idx, fanout));
}
}
high_fanout.sort_by(|a, b| b.1.cmp(&a.1));
high_fanout
}
pub fn compress(&self) -> TorshResult<FxGraph> {
GraphCompression::deduplicate_operations(self)
}
pub fn optimize_nodes(&self) -> TorshResult<FxGraph> {
GraphCompression::remove_redundant_nodes(self)
}
pub fn diff(&self, other: &FxGraph) -> GraphDifference {
GraphDiff::diff(self, other)
}
pub fn parallel_traversal(self: Arc<Self>) -> ParallelTraversal {
ParallelTraversal::new(self)
}
pub fn optimize(&self) -> TorshResult<FxGraph> {
let mut optimized = self.compress()?;
let orphaned_nodes = optimized.find_orphaned_nodes();
for &orphaned_idx in &orphaned_nodes {
optimized.graph.remove_node(orphaned_idx);
}
let dead_end_nodes = optimized.find_dead_end_nodes();
for &dead_end_idx in &dead_end_nodes {
optimized.graph.remove_node(dead_end_idx);
}
optimized
.inputs
.retain(|&idx| optimized.graph.node_weight(idx).is_some());
optimized
.outputs
.retain(|&idx| optimized.graph.node_weight(idx).is_some());
Ok(optimized)
}
pub fn subgraph(&self, node_indices: &[NodeIndex]) -> TorshResult<FxGraph> {
let mut subgraph = FxGraph::new();
let mut node_mapping = std::collections::HashMap::new();
let node_set: std::collections::HashSet<_> = node_indices.iter().collect();
for &idx in node_indices {
if let Some(node) = self.get_node(idx) {
let new_idx = subgraph.add_node(node.clone());
node_mapping.insert(idx, new_idx);
if self.inputs.contains(&idx) {
subgraph.add_input(new_idx);
}
if self.outputs.contains(&idx) {
subgraph.add_output(new_idx);
}
}
}
for (src_idx, target_idx, edge) in self.edges() {
if node_set.contains(&src_idx) && node_set.contains(&target_idx) {
if let (Some(&new_src), Some(&new_target)) =
(node_mapping.get(&src_idx), node_mapping.get(&target_idx))
{
subgraph.add_edge(new_src, new_target, edge.clone());
}
}
}
Ok(subgraph)
}
pub fn merge(&mut self, other: &FxGraph) -> TorshResult<HashMap<NodeIndex, NodeIndex>> {
let mut node_mapping = HashMap::new();
for (idx, node) in other.nodes() {
let new_idx = self.add_node(node.clone());
node_mapping.insert(idx, new_idx);
}
for (src_idx, target_idx, edge) in other.edges() {
if let (Some(&new_src), Some(&new_target)) =
(node_mapping.get(&src_idx), node_mapping.get(&target_idx))
{
self.add_edge(new_src, new_target, edge.clone());
}
}
for &input_idx in &other.inputs {
if let Some(&new_idx) = node_mapping.get(&input_idx) {
self.add_input(new_idx);
}
}
for &output_idx in &other.outputs {
if let Some(&new_idx) = node_mapping.get(&output_idx) {
self.add_output(new_idx);
}
}
Ok(node_mapping)
}
pub fn find_all_paths(&self) -> Vec<Vec<NodeIndex>> {
let mut all_paths = Vec::new();
for &input_idx in &self.inputs {
for &output_idx in &self.outputs {
if let Some(path) = self.find_path(input_idx, output_idx) {
all_paths.push(path);
}
}
}
all_paths
}
pub fn find_path(&self, start: NodeIndex, end: NodeIndex) -> Option<Vec<NodeIndex>> {
let mut visited = HashSet::new();
let mut path = Vec::new();
if self.dfs_find_path(start, end, &mut visited, &mut path) {
Some(path)
} else {
None
}
}
fn dfs_find_path(
&self,
current: NodeIndex,
target: NodeIndex,
visited: &mut HashSet<NodeIndex>,
path: &mut Vec<NodeIndex>,
) -> bool {
if current == target {
path.push(current);
return true;
}
if visited.contains(¤t) {
return false;
}
visited.insert(current);
path.push(current);
for edge_ref in self
.graph
.edges_directed(current, petgraph::Direction::Outgoing)
{
if self.dfs_find_path(edge_ref.target(), target, visited, path) {
return true;
}
}
path.pop();
false
}
}