use super::core::*;
use crate::error::TensorError;
impl Graph {
pub fn remove_node(&mut self, node_id: NodeId) -> Result<(), TensorError> {
let node = self
.nodes
.get(&node_id)
.ok_or_else(|| TensorError::invalid_argument(format!("Node {} not found", node_id)))?;
let mut edges_to_remove = Vec::new();
edges_to_remove.extend(node.inputs.iter());
edges_to_remove.extend(node.outputs.iter());
self.name_to_node.remove(&node.name);
for &edge_id in &edges_to_remove {
if let Some(edge) = self.edges.remove(&edge_id) {
if edge.from_node != node_id {
if let Some(from_node) = self.nodes.get_mut(&edge.from_node) {
from_node.outputs.retain(|&id| id != edge_id);
}
}
if edge.to_node != node_id {
if let Some(to_node) = self.nodes.get_mut(&edge.to_node) {
to_node.inputs.retain(|&id| id != edge_id);
}
}
}
}
self.nodes.remove(&node_id);
self.topological_order = None; self.version += 1;
Ok(())
}
pub fn remove_edge(&mut self, edge_id: EdgeId) -> Result<(), TensorError> {
let edge = self
.edges
.remove(&edge_id)
.ok_or_else(|| TensorError::invalid_argument(format!("Edge {} not found", edge_id)))?;
if let Some(from_node) = self.nodes.get_mut(&edge.from_node) {
from_node.outputs.retain(|&id| id != edge_id);
}
if let Some(to_node) = self.nodes.get_mut(&edge.to_node) {
to_node.inputs.retain(|&id| id != edge_id);
}
self.topological_order = None; self.version += 1;
Ok(())
}
pub fn replace_with_constant(
&mut self,
node_id: NodeId,
constant_value: crate::tensor::Tensor<f32>,
) -> Result<(), TensorError> {
let node = self
.nodes
.get_mut(&node_id)
.ok_or_else(|| TensorError::invalid_argument(format!("Node {} not found", node_id)))?;
node.op_type = NodeType::Constant;
node.attributes
.insert("value".to_string(), AttributeValue::Tensor(constant_value));
let input_edges: Vec<EdgeId> = node.inputs.clone();
node.inputs.clear();
for edge_id in input_edges {
self.remove_edge(edge_id)?;
}
self.version += 1;
Ok(())
}
pub fn redirect_node_outputs(
&mut self,
from_node: NodeId,
to_node: NodeId,
) -> Result<usize, TensorError> {
if !self.nodes.contains_key(&from_node) {
return Err(TensorError::invalid_argument(format!(
"Source node {} not found",
from_node
)));
}
if !self.nodes.contains_key(&to_node) {
return Err(TensorError::invalid_argument(format!(
"Target node {} not found",
to_node
)));
}
let output_edges: Vec<EdgeId> = self
.nodes
.get(&from_node)
.expect("Source node must exist after validation")
.outputs
.clone();
let mut redirected_count = 0;
for edge_id in output_edges {
if let Some(edge) = self.edges.get_mut(&edge_id) {
edge.from_node = to_node;
redirected_count += 1;
self.nodes
.get_mut(&to_node)
.expect("Target node must exist after validation")
.outputs
.push(edge_id);
}
}
self.nodes
.get_mut(&from_node)
.expect("Source node must exist after validation")
.outputs
.clear();
if redirected_count > 0 {
self.topological_order = None; self.version += 1;
}
Ok(redirected_count)
}
}