use crate::errors::GraphResult;
use crate::graph::traits::GraphQuery;
use crate::graph::Graph;
use crate::transformer::optimization::constraints::{
validate_assembly, AssemblyReport, ConstraintReport, TopologyConstraint, TopologyDefect,
TopologyValidator,
};
use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum EditOperation {
AddNode {
node_id: usize,
operator_type: OperatorType,
},
RemoveNode {
node_id: usize,
operator_type: OperatorType,
},
AddEdge {
from: usize,
to: usize,
weight_name: String,
},
RemoveEdge {
from: usize,
to: usize,
},
ModifyNode {
node_id: usize,
old_type: OperatorType,
new_type: OperatorType,
},
ReplaceModule {
path: String,
old_module: Vec<usize>,
new_module: Vec<usize>,
},
}
#[derive(Debug, Clone)]
pub struct HistoryEntry {
pub description: String,
pub timestamp: u128,
pub operations: Vec<EditOperation>,
pub reverted: bool,
}
#[derive(Debug, Clone)]
pub struct SubGraph {
pub nodes: Vec<(usize, OperatorType)>,
pub edges: Vec<(usize, usize, String)>,
pub inputs: Vec<usize>,
pub outputs: Vec<usize>,
}
impl SubGraph {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
inputs: Vec::new(),
outputs: Vec::new(),
}
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
}
impl Default for SubGraph {
fn default() -> Self {
Self::new()
}
}
pub struct CadStyleEditor<'a> {
graph: &'a mut Graph<OperatorType, WeightTensor>,
validator: TopologyValidator,
history: Vec<HistoryEntry>,
module_cache: HashMap<String, SubGraph>,
auto_save: bool,
}
impl<'a> CadStyleEditor<'a> {
pub fn new(graph: &'a mut Graph<OperatorType, WeightTensor>) -> Self {
Self {
graph,
validator: TopologyValidator::new(),
history: Vec::new(),
module_cache: HashMap::new(),
auto_save: true,
}
}
pub fn with_defaults(graph: &'a mut Graph<OperatorType, WeightTensor>) -> Self {
let mut editor = Self::new(graph);
editor.validator = TopologyValidator::with_default_constraints();
editor
}
pub fn set_auto_save(&mut self, enabled: bool) {
self.auto_save = enabled;
}
pub fn history(&self) -> &[HistoryEntry] {
&self.history
}
pub fn history_len(&self) -> usize {
self.history.len()
}
pub fn detect_defects(&self) -> GraphResult<Vec<TopologyDefect>> {
self.validator.detect_defects(self.graph)
}
pub fn add_constraint(&mut self, constraint: TopologyConstraint) -> GraphResult<()> {
self.validator.add_constraint(constraint);
Ok(())
}
pub fn solve_constraints(&mut self) -> GraphResult<ConstraintReport> {
use crate::graph::traits::GraphOps;
let mut operations = Vec::new();
let defects = self.detect_defects()?;
for defect in &defects {
match defect.defect_type {
crate::transformer::optimization::constraints::DefectType::IsolatedNode => {
self.fix_isolated_node(defect.location, &mut operations)?;
}
crate::transformer::optimization::constraints::DefectType::DisconnectedComponent => {
self.fix_disconnected_component(defect.location, &mut operations)?;
}
_ => {
}
}
}
for operation in &operations {
match operation {
EditOperation::AddEdge { from, to, weight_name } => {
let from_node = self.graph.nodes()
.find(|n| n.index().index() == *from)
.map(|n| n.index());
let to_node = self.graph.nodes()
.find(|n| n.index().index() == *to)
.map(|n| n.index());
if let (Some(from_idx), Some(to_idx)) = (from_node, to_node) {
let weight = WeightTensor::new(weight_name.clone(), vec![1.0], vec![1]);
let _ = self.graph.add_edge(from_idx, to_idx, weight);
}
}
EditOperation::RemoveEdge { from: _, to: _ } => {
}
EditOperation::AddNode { node_id: _, operator_type: _ } => {
}
EditOperation::RemoveNode { node_id: _, operator_type: _ } => {
}
EditOperation::ModifyNode { node_id: _, old_type: _, new_type: _ } => {
}
EditOperation::ReplaceModule { path: _, old_module: _, new_module: _ } => {
}
}
}
let report = self.validator.validate(self.graph)?;
if self.auto_save && !operations.is_empty() {
self.save_to_history("solve_constraints".to_string(), operations);
}
Ok(report)
}
pub fn extract_module(&mut self, path: &str) -> GraphResult<SubGraph> {
let mut subgraph = SubGraph::new();
for node_ref in self.graph.nodes() {
let node_id = node_ref.index().index();
let node_data = node_ref.data();
if format!("{:?}", node_data).contains(path) {
subgraph.nodes.push((node_id, node_data.clone()));
subgraph.outputs.push(node_id);
if subgraph.inputs.is_empty() {
subgraph.inputs.push(node_id);
}
}
}
self.module_cache.insert(path.to_string(), subgraph.clone());
Ok(subgraph)
}
pub fn replace_module(
&mut self,
path: &str,
new_module: SubGraph,
) -> GraphResult<()> {
use crate::graph::traits::GraphOps;
let mut operations = Vec::new();
let old_module = self.extract_module(path)?;
let old_node_ids: Vec<usize> = old_module.nodes.iter().map(|(id, _)| *id).collect();
let mut edges_to_remove = Vec::new();
for edge_ref in self.graph.edges() {
let src = edge_ref.source().index();
let dst = edge_ref.target().index();
if old_node_ids.contains(&src) || old_node_ids.contains(&dst) {
edges_to_remove.push((src, dst));
}
}
for (src, dst) in &edges_to_remove {
operations.push(EditOperation::RemoveEdge {
from: *src,
to: *dst,
});
}
for (node_id, operator_type) in &old_module.nodes {
operations.push(EditOperation::RemoveNode {
node_id: *node_id,
operator_type: operator_type.clone(),
});
}
let mut new_node_mapping: HashMap<usize, usize> = HashMap::new();
for (old_node_id, operator_type) in &new_module.nodes {
let new_idx = self.graph.add_node(operator_type.clone())?;
new_node_mapping.insert(*old_node_id, new_idx.index());
operations.push(EditOperation::AddNode {
node_id: new_idx.index(),
operator_type: operator_type.clone(),
});
}
for (from, to, weight_name) in &new_module.edges {
if let (Some(&new_from), Some(&new_to)) = (
new_node_mapping.get(from),
new_node_mapping.get(to),
) {
let _weight = WeightTensor::new(
weight_name.clone(),
vec![1.0],
vec![1],
);
operations.push(EditOperation::AddEdge {
from: new_from,
to: new_to,
weight_name: weight_name.clone(),
});
}
}
if self.auto_save {
operations.push(EditOperation::ReplaceModule {
path: path.to_string(),
old_module: old_module.nodes.iter().map(|(id, _)| *id).collect(),
new_module: new_module.nodes.iter().map(|(id, _)| *id).collect(),
});
self.save_to_history(format!("replace_module: {}", path), operations);
}
Ok(())
}
pub fn validate_assembly(&self) -> GraphResult<AssemblyReport> {
validate_assembly(self.graph)
}
pub fn rollback(&mut self, index: usize) -> GraphResult<bool> {
if index >= self.history.len() {
return Ok(false);
}
for entry in self.history.iter_mut().skip(index) {
entry.reverted = true;
}
Ok(true)
}
pub fn undo(&mut self) -> GraphResult<bool> {
if self.history.is_empty() {
return Ok(false);
}
let last_index = self.history.len() - 1;
self.rollback(last_index)
}
pub fn module_cache(&self) -> &HashMap<String, SubGraph> {
&self.module_cache
}
pub fn validator(&self) -> &TopologyValidator {
&self.validator
}
pub fn validator_mut(&mut self) -> &mut TopologyValidator {
&mut self.validator
}
#[cfg(feature = "tensor")]
pub fn optimize_with_gradients(
&mut self,
loss_fn: &dyn Fn(&crate::tensor::differentiable::DifferentiableGraph<Vec<f64>>) -> f64,
steps: usize,
_learning_rate: f64,
) -> GraphResult<OptimizationReport> {
use crate::tensor::differentiable::{DifferentiableGraph, GradientConfig};
use crate::graph::traits::GraphBase;
use std::collections::HashMap;
let num_nodes = self.graph.node_count();
let mut diff_graph = DifferentiableGraph::with_config(
num_nodes,
GradientConfig::default()
.with_sparsity(0.001)
.with_smoothness(0.0001),
);
for edge_ref in self.graph.edges() {
let src = edge_ref.source().index();
let dst = edge_ref.target().index();
diff_graph.add_learnable_edge(src, dst, 0.9);
}
let initial_loss = loss_fn(&diff_graph);
let mut final_loss = initial_loss;
let mut losses = vec![initial_loss];
let initial_edge_count = diff_graph.num_edges();
for step in 0..steps {
let loss = loss_fn(&diff_graph);
final_loss = loss;
losses.push(loss);
let mut gradients = HashMap::new();
let edges: Vec<(usize, usize, f64)> = diff_graph.get_learnable_edges()
.iter()
.map(|e| (e.src, e.dst, e.probability))
.collect();
for (src, dst, _prob) in edges {
let eps = 1e-5;
let _current_prob = diff_graph.get_edge_probability(src, dst)
.unwrap_or(0.5);
let grad = (loss_fn(&diff_graph) - loss) / eps;
gradients.insert((src, dst), grad);
}
diff_graph.update_structure(&gradients);
diff_graph.anneal_temperature();
if step % 10 == 0 {
eprintln!("Step {}: loss={:.6}, temp={:.4}", step, loss, diff_graph.temperature());
}
}
diff_graph.discretize();
let pruned_edges = diff_graph.get_learnable_edges()
.iter()
.filter(|e| !e.exists)
.count();
for edge_ref in self.graph.edges() {
let src = edge_ref.source().index();
let dst = edge_ref.target().index();
let should_exist = diff_graph.get_edge_exists(src, dst)
.unwrap_or(true);
if !should_exist {
}
}
Ok(OptimizationReport {
initial_loss,
final_loss,
losses,
steps,
pruned_edges,
total_edges: initial_edge_count,
})
}
fn save_to_history(&mut self, description: String, operations: Vec<EditOperation>) {
let entry = HistoryEntry {
description,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis(),
operations,
reverted: false,
};
self.history.push(entry);
}
fn fix_isolated_node(
&mut self,
node_id: usize,
operations: &mut Vec<EditOperation>,
) -> GraphResult<()> {
let other_nodes: Vec<usize> = self.graph.nodes()
.map(|n| n.index().index())
.filter(|&id| id != node_id)
.collect();
if other_nodes.is_empty() {
return Ok(());
}
let nearest_node = other_nodes
.iter()
.min_by_key(|&&id| (id as i64 - node_id as i64).abs())
.copied()
.unwrap_or(other_nodes[0]);
operations.push(EditOperation::AddEdge {
from: node_id,
to: nearest_node,
weight_name: format!("fix_isolated_{}_to_{}", node_id, nearest_node),
});
operations.push(EditOperation::AddEdge {
from: nearest_node,
to: node_id,
weight_name: format!("fix_isolated_{}_to_{}", nearest_node, node_id),
});
Ok(())
}
fn fix_disconnected_component(
&mut self,
component_start: usize,
operations: &mut Vec<EditOperation>,
) -> GraphResult<()> {
use crate::algorithms::community::connected_components;
use crate::node::NodeIndex;
let components = connected_components(self.graph);
if components.len() <= 1 {
return Ok(());
}
let start_node_idx = NodeIndex::new(component_start, 0);
let _component_containing_start = components.iter()
.position(|comp| comp.contains(&start_node_idx))
.unwrap_or(0);
let main_component = &components[0];
let target_node_idx = main_component.first()
.map(|n| n.index())
.unwrap_or(0);
operations.push(EditOperation::AddEdge {
from: component_start,
to: target_node_idx,
weight_name: format!("fix_disconnected_{}_to_{}", component_start, target_node_idx),
});
operations.push(EditOperation::AddEdge {
from: target_node_idx,
to: component_start,
weight_name: format!("fix_disconnected_{}_to_{}", target_node_idx, component_start),
});
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct OptimizationReport {
pub initial_loss: f64,
pub final_loss: f64,
pub losses: Vec<f64>,
pub steps: usize,
pub pruned_edges: usize,
pub total_edges: usize,
}
impl OptimizationReport {
pub fn pruning_ratio(&self) -> f64 {
if self.total_edges > 0 {
self.pruned_edges as f64 / self.total_edges as f64
} else {
0.0
}
}
pub fn loss_reduction(&self) -> f64 {
self.initial_loss - self.final_loss
}
}
pub fn build_subgraph(
graph: &Graph<OperatorType, WeightTensor>,
path_pattern: &str,
) -> GraphResult<SubGraph> {
let mut subgraph = SubGraph::new();
for node_ref in graph.nodes() {
let node_id = node_ref.index().index();
let node_data = node_ref.data();
if format!("{:?}", node_data).contains(path_pattern) {
subgraph.nodes.push((node_id, node_data.clone()));
subgraph.inputs.push(node_id);
subgraph.outputs.push(node_id);
}
}
Ok(subgraph)
}
pub fn subgraph_equivalent(a: &SubGraph, b: &SubGraph) -> bool {
if a.node_count() != b.node_count() {
return false;
}
if a.edge_count() != b.edge_count() {
return false;
}
let a_types: Vec<_> = a.nodes.iter().map(|(_, t)| format!("{:?}", t)).collect();
let b_types: Vec<_> = b.nodes.iter().map(|(_, t)| format!("{:?}", t)).collect();
a_types == b_types
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::traits::GraphOps;
#[test]
fn test_subgraph_creation() {
let subgraph = SubGraph::new();
assert_eq!(subgraph.node_count(), 0);
assert_eq!(subgraph.edge_count(), 0);
}
#[test]
fn test_editor_creation() {
let mut graph = Graph::<OperatorType, WeightTensor>::directed();
let editor = CadStyleEditor::new(&mut graph);
assert_eq!(editor.history_len(), 0);
}
#[test]
fn test_defect_detection() {
let mut graph = Graph::<OperatorType, WeightTensor>::directed();
let _node = graph
.add_node(OperatorType::Linear {
in_features: 512,
out_features: 512,
})
.unwrap();
let editor = CadStyleEditor::new(&mut graph);
let defects = editor.detect_defects().unwrap();
assert!(!defects.is_empty());
}
#[test]
fn test_module_extraction() {
let mut graph = Graph::<OperatorType, WeightTensor>::directed();
let _node = graph
.add_node(OperatorType::Attention {
num_heads: 8,
hidden_dim: 512,
})
.unwrap();
let mut editor = CadStyleEditor::new(&mut graph);
let subgraph = editor.extract_module("attention").unwrap();
assert_eq!(subgraph.node_count(), 0); assert!(editor.module_cache().contains_key("attention"));
}
#[test]
fn test_subgraph_equivalent() {
let mut a = SubGraph::new();
a.nodes.push((0, OperatorType::Linear {
in_features: 512,
out_features: 512,
}));
let mut b = SubGraph::new();
b.nodes.push((0, OperatorType::Linear {
in_features: 512,
out_features: 512,
}));
assert!(subgraph_equivalent(&a, &b));
let mut c = SubGraph::new();
c.nodes.push((0, OperatorType::Attention {
num_heads: 8,
hidden_dim: 512,
}));
assert!(!subgraph_equivalent(&a, &c));
}
}