#![allow(dead_code)]
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub struct NodeSpec {
pub id: String,
pub node_type: String,
pub params: HashMap<String, String>,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
}
impl NodeSpec {
#[must_use]
pub fn new(id: impl Into<String>, node_type: impl Into<String>) -> Self {
Self {
id: id.into(),
node_type: node_type.into(),
params: HashMap::new(),
inputs: vec![],
outputs: vec![],
}
}
#[must_use]
pub fn with_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.params.insert(key.into(), value.into());
self
}
#[must_use]
pub fn with_inputs(mut self, inputs: Vec<String>) -> Self {
self.inputs = inputs;
self
}
#[must_use]
pub fn with_outputs(mut self, outputs: Vec<String>) -> Self {
self.outputs = outputs;
self
}
}
pub trait OptimizationPass: Send + Sync {
fn name(&self) -> &str;
fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize;
}
pub struct ConstantFoldingPass;
impl ConstantFoldingPass {
#[must_use]
pub fn new() -> Self {
Self
}
fn is_identity(node: &NodeSpec) -> bool {
match node.node_type.as_str() {
"Scale" => {
let factor = node
.params
.get("factor")
.map(String::as_str)
.unwrap_or("1.0");
factor == "1.0" || factor == "1"
}
_ => false,
}
}
}
impl Default for ConstantFoldingPass {
fn default() -> Self {
Self::new()
}
}
impl OptimizationPass for ConstantFoldingPass {
fn name(&self) -> &str {
"ConstantFolding"
}
fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize {
let before = nodes.len();
let identity_ids: Vec<String> = nodes
.iter()
.filter(|n| Self::is_identity(n))
.map(|n| n.id.clone())
.collect();
if identity_ids.is_empty() {
return 0;
}
let identity_set: std::collections::HashSet<&str> =
identity_ids.iter().map(String::as_str).collect();
for node in nodes.iter_mut() {
node.inputs = node
.inputs
.iter()
.flat_map(|inp| {
if identity_set.contains(inp.as_str()) {
vec![] } else {
vec![inp.clone()]
}
})
.collect();
}
nodes.retain(|n| !identity_set.contains(n.id.as_str()));
before - nodes.len()
}
}
pub struct DeadNodeEliminationPass;
impl DeadNodeEliminationPass {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Default for DeadNodeEliminationPass {
fn default() -> Self {
Self::new()
}
}
impl OptimizationPass for DeadNodeEliminationPass {
fn name(&self) -> &str {
"DeadNodeElimination"
}
fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize {
if nodes.len() <= 1 {
return 0; }
let before = nodes.len();
let referenced: std::collections::HashSet<String> = nodes
.iter()
.flat_map(|n| n.inputs.iter().chain(n.outputs.iter()).cloned())
.collect();
nodes.retain(|n| !n.outputs.is_empty() || referenced.contains(&n.id));
before - nodes.len()
}
}
pub struct NodeFusionPass;
impl NodeFusionPass {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Default for NodeFusionPass {
fn default() -> Self {
Self::new()
}
}
impl OptimizationPass for NodeFusionPass {
fn name(&self) -> &str {
"NodeFusion"
}
fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize {
let mut fusions = 0usize;
let mut i = 0;
while i + 1 < nodes.len() {
let (a, b) = (&nodes[i], &nodes[i + 1]);
let is_fusable = a.node_type == "Brightness"
&& b.node_type == "Contrast"
&& b.inputs.contains(&a.id);
if is_fusable {
let mut params = a.params.clone();
for (k, v) in &b.params {
params.insert(k.clone(), v.clone());
}
let fused = NodeSpec {
id: format!("{}_{}", a.id, b.id),
node_type: "BrightnessContrast".to_string(),
params,
inputs: a.inputs.clone(),
outputs: b.outputs.clone(),
};
let a_id = a.id.clone();
let b_id = b.id.clone();
let fused_id = fused.id.clone();
nodes.remove(i + 1);
nodes[i] = fused;
for node in nodes.iter_mut() {
for inp in &mut node.inputs {
if *inp == a_id || *inp == b_id {
*inp = fused_id.clone();
}
}
for out in &mut node.outputs {
if *out == a_id || *out == b_id {
*out = fused_id.clone();
}
}
}
fusions += 1;
} else {
i += 1;
}
}
fusions
}
}
#[derive(Debug, Clone)]
pub struct OptimizationReport {
pub passes_applied: Vec<String>,
pub nodes_before: usize,
pub nodes_after: usize,
pub optimizations: usize,
}
#[derive(Default)]
pub struct GraphOptimizer {
passes: Vec<Box<dyn OptimizationPass>>,
}
impl GraphOptimizer {
#[must_use]
pub fn new() -> Self {
Self { passes: vec![] }
}
pub fn add_pass(&mut self, pass: Box<dyn OptimizationPass>) {
self.passes.push(pass);
}
#[must_use]
pub fn run(&self, mut nodes: Vec<NodeSpec>) -> (Vec<NodeSpec>, OptimizationReport) {
let nodes_before = nodes.len();
let mut passes_applied = Vec::new();
let mut total_optimizations = 0;
for pass in &self.passes {
let count = pass.optimize(&mut nodes);
passes_applied.push(pass.name().to_string());
total_optimizations += count;
}
let report = OptimizationReport {
passes_applied,
nodes_before,
nodes_after: nodes.len(),
optimizations: total_optimizations,
};
(nodes, report)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn scale_node(id: &str, factor: &str) -> NodeSpec {
NodeSpec::new(id, "Scale").with_param("factor", factor)
}
fn brightness_node(id: &str) -> NodeSpec {
NodeSpec::new(id, "Brightness").with_param("value", "1.2")
}
fn contrast_node(id: &str, input: &str) -> NodeSpec {
NodeSpec::new(id, "Contrast")
.with_param("value", "1.1")
.with_inputs(vec![input.to_string()])
}
#[test]
fn test_constant_folding_removes_scale_one() {
let pass = ConstantFoldingPass::new();
let mut nodes = vec![scale_node("s1", "1.0")];
let removed = pass.optimize(&mut nodes);
assert_eq!(removed, 1);
assert!(nodes.is_empty());
}
#[test]
fn test_constant_folding_keeps_scale_two() {
let pass = ConstantFoldingPass::new();
let mut nodes = vec![scale_node("s1", "2.0")];
let removed = pass.optimize(&mut nodes);
assert_eq!(removed, 0);
assert_eq!(nodes.len(), 1);
}
#[test]
fn test_constant_folding_integer_one() {
let pass = ConstantFoldingPass::new();
let mut nodes = vec![scale_node("s1", "1")];
let removed = pass.optimize(&mut nodes);
assert_eq!(removed, 1);
}
#[test]
fn test_constant_folding_mixed_nodes() {
let pass = ConstantFoldingPass::new();
let mut nodes = vec![scale_node("s1", "1.0"), scale_node("s2", "0.5")];
let removed = pass.optimize(&mut nodes);
assert_eq!(removed, 1);
assert_eq!(nodes.len(), 1);
assert_eq!(nodes[0].id, "s2");
}
#[test]
fn test_dead_node_elimination_no_outputs() {
let pass = DeadNodeEliminationPass::new();
let mut nodes = vec![NodeSpec::new("a", "Filter"), NodeSpec::new("b", "Filter")];
let removed = pass.optimize(&mut nodes);
assert_eq!(removed, 2);
assert!(nodes.is_empty());
}
#[test]
fn test_dead_node_elimination_referenced_node_kept() {
let pass = DeadNodeEliminationPass::new();
let mut nodes = vec![
NodeSpec::new("a", "Source").with_outputs(vec!["b".to_string()]),
NodeSpec::new("b", "Sink").with_inputs(vec!["a".to_string()]),
];
let removed = pass.optimize(&mut nodes);
assert_eq!(removed, 0);
}
#[test]
fn test_dead_node_elimination_single_node_preserved() {
let pass = DeadNodeEliminationPass::new();
let mut nodes = vec![NodeSpec::new("a", "Source")];
let removed = pass.optimize(&mut nodes);
assert_eq!(removed, 0); assert_eq!(nodes.len(), 1);
}
#[test]
fn test_node_fusion_brightness_contrast() {
let pass = NodeFusionPass::new();
let mut nodes = vec![
brightness_node("b1").with_outputs(vec!["c1".to_string()]),
contrast_node("c1", "b1"),
];
let fusions = pass.optimize(&mut nodes);
assert_eq!(fusions, 1);
assert_eq!(nodes.len(), 1);
assert_eq!(nodes[0].node_type, "BrightnessContrast");
}
#[test]
fn test_node_fusion_no_match() {
let pass = NodeFusionPass::new();
let mut nodes = vec![NodeSpec::new("a", "Scale"), NodeSpec::new("b", "Gamma")];
let fusions = pass.optimize(&mut nodes);
assert_eq!(fusions, 0);
assert_eq!(nodes.len(), 2);
}
#[test]
fn test_node_fusion_fused_node_has_merged_params() {
let pass = NodeFusionPass::new();
let mut nodes = vec![
brightness_node("b1").with_outputs(vec!["c1".to_string()]),
contrast_node("c1", "b1"),
];
pass.optimize(&mut nodes);
assert!(nodes[0].params.contains_key("value"));
}
#[test]
fn test_optimizer_empty_graph() {
let mut opt = GraphOptimizer::new();
opt.add_pass(Box::new(ConstantFoldingPass::new()));
let (nodes, report) = opt.run(vec![]);
assert!(nodes.is_empty());
assert_eq!(report.nodes_before, 0);
assert_eq!(report.nodes_after, 0);
}
#[test]
fn test_optimizer_report_fields() {
let mut opt = GraphOptimizer::new();
opt.add_pass(Box::new(ConstantFoldingPass::new()));
let nodes = vec![scale_node("s1", "1.0"), scale_node("s2", "2.0")];
let (_, report) = opt.run(nodes);
assert_eq!(report.nodes_before, 2);
assert_eq!(report.nodes_after, 1);
assert_eq!(report.optimizations, 1);
assert_eq!(report.passes_applied, vec!["ConstantFolding"]);
}
#[test]
fn test_optimizer_multiple_passes() {
let mut opt = GraphOptimizer::new();
opt.add_pass(Box::new(ConstantFoldingPass::new()));
opt.add_pass(Box::new(DeadNodeEliminationPass::new()));
let nodes = vec![scale_node("s1", "1.0"), NodeSpec::new("orphan", "Filter")];
let (_, report) = opt.run(nodes);
assert_eq!(report.passes_applied.len(), 2);
}
#[test]
fn test_optimizer_no_passes() {
let opt = GraphOptimizer::new();
let nodes = vec![NodeSpec::new("a", "Filter")];
let (result, report) = opt.run(nodes);
assert_eq!(result.len(), 1);
assert_eq!(report.optimizations, 0);
assert!(report.passes_applied.is_empty());
}
}