use super::error::PruningError;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DependencyType {
Sequential,
Skip,
Concat,
Add,
Mul,
Attention,
}
#[derive(Debug, Clone)]
pub struct GraphNode {
pub id: String,
pub name: String,
pub node_type: NodeType,
pub output_dim: usize,
pub input_dim: usize,
pub prunable: bool,
}
impl GraphNode {
pub fn new(id: impl Into<String>, name: impl Into<String>, node_type: NodeType) -> Self {
Self {
id: id.into(),
name: name.into(),
node_type,
output_dim: 0,
input_dim: 0,
prunable: true,
}
}
#[must_use]
pub fn with_dims(mut self, input_dim: usize, output_dim: usize) -> Self {
self.input_dim = input_dim;
self.output_dim = output_dim;
self
}
#[must_use]
pub fn with_prunable(mut self, prunable: bool) -> Self {
self.prunable = prunable;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NodeType {
Linear,
Conv,
LayerNorm,
BatchNorm,
Embedding,
Attention,
MLP,
Activation,
Pooling,
Input,
Output,
Other,
}
#[derive(Debug, Clone)]
pub struct GraphEdge {
pub from: String,
pub to: String,
pub dep_type: DependencyType,
pub dim_index: usize,
}
impl GraphEdge {
pub fn new(from: impl Into<String>, to: impl Into<String>, dep_type: DependencyType) -> Self {
Self {
from: from.into(),
to: to.into(),
dep_type,
dim_index: 0,
}
}
#[must_use]
pub fn with_dim_index(mut self, dim_index: usize) -> Self {
self.dim_index = dim_index;
self
}
}
#[derive(Debug, Clone)]
pub struct DependencyGraph {
nodes: HashMap<String, GraphNode>,
edges_out: HashMap<String, Vec<GraphEdge>>,
edges_in: HashMap<String, Vec<GraphEdge>>,
}
impl DependencyGraph {
#[must_use]
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges_out: HashMap::new(),
edges_in: HashMap::new(),
}
}
pub fn add_node(&mut self, node: GraphNode) {
let id = node.id.clone();
self.nodes.insert(id.clone(), node);
self.edges_out.entry(id.clone()).or_default();
self.edges_in.entry(id).or_default();
}
pub fn add_edge(&mut self, edge: GraphEdge) -> Result<(), PruningError> {
if !self.nodes.contains_key(&edge.from) {
return Err(PruningError::InvalidPattern {
message: format!("Source node '{}' not found", edge.from),
});
}
if !self.nodes.contains_key(&edge.to) {
return Err(PruningError::InvalidPattern {
message: format!("Target node '{}' not found", edge.to),
});
}
self.edges_out
.entry(edge.from.clone())
.or_default()
.push(edge.clone());
self.edges_in.entry(edge.to.clone()).or_default().push(edge);
Ok(())
}
#[must_use]
pub fn get_node(&self, id: &str) -> Option<&GraphNode> {
self.nodes.get(id)
}
pub fn get_node_mut(&mut self, id: &str) -> Option<&mut GraphNode> {
self.nodes.get_mut(id)
}
pub fn nodes(&self) -> impl Iterator<Item = &GraphNode> {
self.nodes.values()
}
pub fn edges_from(&self, id: &str) -> &[GraphEdge] {
self.edges_out.get(id).map_or(&[], Vec::as_slice)
}
pub fn edges_to(&self, id: &str) -> &[GraphEdge] {
self.edges_in.get(id).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn num_edges(&self) -> usize {
self.edges_out.values().map(Vec::len).sum()
}
#[must_use]
pub fn downstream_dependents(&self, node_id: &str) -> HashSet<String> {
let mut dependents = HashSet::new();
let mut visited = HashSet::new();
let mut queue = vec![node_id.to_string()];
while let Some(current) = queue.pop() {
if visited.contains(¤t) {
continue;
}
visited.insert(current.clone());
for edge in self.edges_from(¤t) {
dependents.insert(edge.to.clone());
queue.push(edge.to.clone());
}
}
dependents
}
#[must_use]
pub fn upstream_dependents(&self, node_id: &str) -> HashSet<String> {
let mut dependents = HashSet::new();
let mut visited = HashSet::new();
let mut queue = vec![node_id.to_string()];
while let Some(current) = queue.pop() {
if visited.contains(¤t) {
continue;
}
visited.insert(current.clone());
for edge in self.edges_to(¤t) {
dependents.insert(edge.from.clone());
queue.push(edge.from.clone());
}
}
dependents
}
#[must_use]
pub fn prunable_nodes(&self) -> Vec<&GraphNode> {
self.nodes.values().filter(|n| n.prunable).collect()
}
pub fn validate(&self) -> Result<(), PruningError> {
for edges in self.edges_out.values() {
for edge in edges {
if !self.nodes.contains_key(&edge.to) {
return Err(PruningError::InvalidPattern {
message: format!("Edge references unknown node: {}", edge.to),
});
}
}
}
for edges in self.edges_out.values() {
for edge in edges {
if edge.dep_type == DependencyType::Sequential {
let from_node = self.nodes.get(&edge.from);
let to_node = self.nodes.get(&edge.to);
if let (Some(from), Some(to)) = (from_node, to_node) {
if from.output_dim != 0
&& to.input_dim != 0
&& from.output_dim != to.input_dim
{
return Err(PruningError::ShapeMismatch {
expected: vec![from.output_dim],
got: vec![to.input_dim],
});
}
}
}
}
}
Ok(())
}
#[must_use]
pub fn linear_chain(layer_dims: &[(usize, usize)], names: &[&str]) -> Self {
let mut graph = Self::new();
for (i, ((in_dim, out_dim), name)) in layer_dims.iter().zip(names.iter()).enumerate() {
let node = GraphNode::new(format!("layer_{i}"), *name, NodeType::Linear)
.with_dims(*in_dim, *out_dim);
graph.add_node(node);
}
for i in 1..names.len() {
let edge = GraphEdge::new(
format!("layer_{}", i - 1),
format!("layer_{i}"),
DependencyType::Sequential,
);
graph.add_edge(edge).ok();
}
graph
}
}
impl Default for DependencyGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PruningPlan {
pub channel_removals: HashMap<String, Vec<usize>>,
pub layer_removals: Vec<String>,
validated: bool,
}
include!("removed.rs");
include!("graph_tests.rs");