1#![allow(dead_code)]
7
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, PartialEq)]
17pub struct NodeSpec {
18 pub id: String,
20 pub node_type: String,
22 pub params: HashMap<String, String>,
24 pub inputs: Vec<String>,
26 pub outputs: Vec<String>,
28}
29
30impl NodeSpec {
31 #[must_use]
33 pub fn new(id: impl Into<String>, node_type: impl Into<String>) -> Self {
34 Self {
35 id: id.into(),
36 node_type: node_type.into(),
37 params: HashMap::new(),
38 inputs: vec![],
39 outputs: vec![],
40 }
41 }
42
43 #[must_use]
45 pub fn with_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
46 self.params.insert(key.into(), value.into());
47 self
48 }
49
50 #[must_use]
52 pub fn with_inputs(mut self, inputs: Vec<String>) -> Self {
53 self.inputs = inputs;
54 self
55 }
56
57 #[must_use]
59 pub fn with_outputs(mut self, outputs: Vec<String>) -> Self {
60 self.outputs = outputs;
61 self
62 }
63}
64
65pub trait OptimizationPass: Send + Sync {
71 fn name(&self) -> &str;
73
74 fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize;
76}
77
78pub struct ConstantFoldingPass;
87
88impl ConstantFoldingPass {
89 #[must_use]
91 pub fn new() -> Self {
92 Self
93 }
94
95 fn is_identity(node: &NodeSpec) -> bool {
96 match node.node_type.as_str() {
97 "Scale" => {
98 let factor = node
99 .params
100 .get("factor")
101 .map(String::as_str)
102 .unwrap_or("1.0");
103 factor == "1.0" || factor == "1"
104 }
105 _ => false,
106 }
107 }
108}
109
110impl Default for ConstantFoldingPass {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116impl OptimizationPass for ConstantFoldingPass {
117 fn name(&self) -> &str {
118 "ConstantFolding"
119 }
120
121 fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize {
122 let before = nodes.len();
123
124 let identity_ids: Vec<String> = nodes
126 .iter()
127 .filter(|n| Self::is_identity(n))
128 .map(|n| n.id.clone())
129 .collect();
130
131 if identity_ids.is_empty() {
132 return 0;
133 }
134
135 let identity_set: std::collections::HashSet<&str> =
136 identity_ids.iter().map(String::as_str).collect();
137
138 for node in nodes.iter_mut() {
141 node.inputs = node
142 .inputs
143 .iter()
144 .flat_map(|inp| {
145 if identity_set.contains(inp.as_str()) {
146 vec![] } else {
153 vec![inp.clone()]
154 }
155 })
156 .collect();
157 }
158
159 nodes.retain(|n| !identity_set.contains(n.id.as_str()));
161
162 before - nodes.len()
163 }
164}
165
166pub struct DeadNodeEliminationPass;
176
177impl DeadNodeEliminationPass {
178 #[must_use]
180 pub fn new() -> Self {
181 Self
182 }
183}
184
185impl Default for DeadNodeEliminationPass {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191impl OptimizationPass for DeadNodeEliminationPass {
192 fn name(&self) -> &str {
193 "DeadNodeElimination"
194 }
195
196 fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize {
197 if nodes.len() <= 1 {
198 return 0; }
200
201 let before = nodes.len();
202
203 let referenced: std::collections::HashSet<String> = nodes
206 .iter()
207 .flat_map(|n| n.inputs.iter().chain(n.outputs.iter()).cloned())
208 .collect();
209
210 nodes.retain(|n| !n.outputs.is_empty() || referenced.contains(&n.id));
212
213 before - nodes.len()
214 }
215}
216
217pub struct NodeFusionPass;
226
227impl NodeFusionPass {
228 #[must_use]
230 pub fn new() -> Self {
231 Self
232 }
233}
234
235impl Default for NodeFusionPass {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241impl OptimizationPass for NodeFusionPass {
242 fn name(&self) -> &str {
243 "NodeFusion"
244 }
245
246 fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize {
247 let mut fusions = 0usize;
248 let mut i = 0;
249
250 while i + 1 < nodes.len() {
251 let (a, b) = (&nodes[i], &nodes[i + 1]);
252
253 let is_fusable = a.node_type == "Brightness"
255 && b.node_type == "Contrast"
256 && b.inputs.contains(&a.id);
257
258 if is_fusable {
259 let mut params = a.params.clone();
261 for (k, v) in &b.params {
262 params.insert(k.clone(), v.clone());
263 }
264
265 let fused = NodeSpec {
266 id: format!("{}_{}", a.id, b.id),
267 node_type: "BrightnessContrast".to_string(),
268 params,
269 inputs: a.inputs.clone(),
270 outputs: b.outputs.clone(),
271 };
272
273 let a_id = a.id.clone();
274 let b_id = b.id.clone();
275 let fused_id = fused.id.clone();
276
277 nodes.remove(i + 1);
279 nodes[i] = fused;
280
281 for node in nodes.iter_mut() {
283 for inp in &mut node.inputs {
284 if *inp == a_id || *inp == b_id {
285 *inp = fused_id.clone();
286 }
287 }
288 for out in &mut node.outputs {
289 if *out == a_id || *out == b_id {
290 *out = fused_id.clone();
291 }
292 }
293 }
294
295 fusions += 1;
296 } else {
299 i += 1;
300 }
301 }
302
303 fusions
304 }
305}
306
307#[derive(Debug, Clone)]
313pub struct OptimizationReport {
314 pub passes_applied: Vec<String>,
316 pub nodes_before: usize,
318 pub nodes_after: usize,
320 pub optimizations: usize,
322}
323
324#[derive(Default)]
330pub struct GraphOptimizer {
331 passes: Vec<Box<dyn OptimizationPass>>,
332}
333
334impl GraphOptimizer {
335 #[must_use]
337 pub fn new() -> Self {
338 Self { passes: vec![] }
339 }
340
341 pub fn add_pass(&mut self, pass: Box<dyn OptimizationPass>) {
343 self.passes.push(pass);
344 }
345
346 #[must_use]
350 pub fn run(&self, mut nodes: Vec<NodeSpec>) -> (Vec<NodeSpec>, OptimizationReport) {
351 let nodes_before = nodes.len();
352 let mut passes_applied = Vec::new();
353 let mut total_optimizations = 0;
354
355 for pass in &self.passes {
356 let count = pass.optimize(&mut nodes);
357 passes_applied.push(pass.name().to_string());
358 total_optimizations += count;
359 }
360
361 let report = OptimizationReport {
362 passes_applied,
363 nodes_before,
364 nodes_after: nodes.len(),
365 optimizations: total_optimizations,
366 };
367
368 (nodes, report)
369 }
370}
371
372#[cfg(test)]
377mod tests {
378 use super::*;
379
380 fn scale_node(id: &str, factor: &str) -> NodeSpec {
381 NodeSpec::new(id, "Scale").with_param("factor", factor)
382 }
383
384 fn brightness_node(id: &str) -> NodeSpec {
385 NodeSpec::new(id, "Brightness").with_param("value", "1.2")
386 }
387
388 fn contrast_node(id: &str, input: &str) -> NodeSpec {
389 NodeSpec::new(id, "Contrast")
390 .with_param("value", "1.1")
391 .with_inputs(vec![input.to_string()])
392 }
393
394 #[test]
397 fn test_constant_folding_removes_scale_one() {
398 let pass = ConstantFoldingPass::new();
399 let mut nodes = vec![scale_node("s1", "1.0")];
400 let removed = pass.optimize(&mut nodes);
401 assert_eq!(removed, 1);
402 assert!(nodes.is_empty());
403 }
404
405 #[test]
406 fn test_constant_folding_keeps_scale_two() {
407 let pass = ConstantFoldingPass::new();
408 let mut nodes = vec![scale_node("s1", "2.0")];
409 let removed = pass.optimize(&mut nodes);
410 assert_eq!(removed, 0);
411 assert_eq!(nodes.len(), 1);
412 }
413
414 #[test]
415 fn test_constant_folding_integer_one() {
416 let pass = ConstantFoldingPass::new();
417 let mut nodes = vec![scale_node("s1", "1")];
418 let removed = pass.optimize(&mut nodes);
419 assert_eq!(removed, 1);
420 }
421
422 #[test]
423 fn test_constant_folding_mixed_nodes() {
424 let pass = ConstantFoldingPass::new();
425 let mut nodes = vec![scale_node("s1", "1.0"), scale_node("s2", "0.5")];
426 let removed = pass.optimize(&mut nodes);
427 assert_eq!(removed, 1);
428 assert_eq!(nodes.len(), 1);
429 assert_eq!(nodes[0].id, "s2");
430 }
431
432 #[test]
435 fn test_dead_node_elimination_no_outputs() {
436 let pass = DeadNodeEliminationPass::new();
437 let mut nodes = vec![NodeSpec::new("a", "Filter"), NodeSpec::new("b", "Filter")];
439 let removed = pass.optimize(&mut nodes);
441 assert_eq!(removed, 2);
442 assert!(nodes.is_empty());
443 }
444
445 #[test]
446 fn test_dead_node_elimination_referenced_node_kept() {
447 let pass = DeadNodeEliminationPass::new();
448 let mut nodes = vec![
449 NodeSpec::new("a", "Source").with_outputs(vec!["b".to_string()]),
450 NodeSpec::new("b", "Sink").with_inputs(vec!["a".to_string()]),
451 ];
452 let removed = pass.optimize(&mut nodes);
453 assert_eq!(removed, 0);
454 }
455
456 #[test]
457 fn test_dead_node_elimination_single_node_preserved() {
458 let pass = DeadNodeEliminationPass::new();
459 let mut nodes = vec![NodeSpec::new("a", "Source")];
460 let removed = pass.optimize(&mut nodes);
461 assert_eq!(removed, 0); assert_eq!(nodes.len(), 1);
463 }
464
465 #[test]
468 fn test_node_fusion_brightness_contrast() {
469 let pass = NodeFusionPass::new();
470 let mut nodes = vec![
471 brightness_node("b1").with_outputs(vec!["c1".to_string()]),
472 contrast_node("c1", "b1"),
473 ];
474 let fusions = pass.optimize(&mut nodes);
475 assert_eq!(fusions, 1);
476 assert_eq!(nodes.len(), 1);
477 assert_eq!(nodes[0].node_type, "BrightnessContrast");
478 }
479
480 #[test]
481 fn test_node_fusion_no_match() {
482 let pass = NodeFusionPass::new();
483 let mut nodes = vec![NodeSpec::new("a", "Scale"), NodeSpec::new("b", "Gamma")];
484 let fusions = pass.optimize(&mut nodes);
485 assert_eq!(fusions, 0);
486 assert_eq!(nodes.len(), 2);
487 }
488
489 #[test]
490 fn test_node_fusion_fused_node_has_merged_params() {
491 let pass = NodeFusionPass::new();
492 let mut nodes = vec![
493 brightness_node("b1").with_outputs(vec!["c1".to_string()]),
494 contrast_node("c1", "b1"),
495 ];
496 pass.optimize(&mut nodes);
497 assert!(nodes[0].params.contains_key("value"));
498 }
499
500 #[test]
503 fn test_optimizer_empty_graph() {
504 let mut opt = GraphOptimizer::new();
505 opt.add_pass(Box::new(ConstantFoldingPass::new()));
506 let (nodes, report) = opt.run(vec![]);
507 assert!(nodes.is_empty());
508 assert_eq!(report.nodes_before, 0);
509 assert_eq!(report.nodes_after, 0);
510 }
511
512 #[test]
513 fn test_optimizer_report_fields() {
514 let mut opt = GraphOptimizer::new();
515 opt.add_pass(Box::new(ConstantFoldingPass::new()));
516 let nodes = vec![scale_node("s1", "1.0"), scale_node("s2", "2.0")];
517 let (_, report) = opt.run(nodes);
518 assert_eq!(report.nodes_before, 2);
519 assert_eq!(report.nodes_after, 1);
520 assert_eq!(report.optimizations, 1);
521 assert_eq!(report.passes_applied, vec!["ConstantFolding"]);
522 }
523
524 #[test]
525 fn test_optimizer_multiple_passes() {
526 let mut opt = GraphOptimizer::new();
527 opt.add_pass(Box::new(ConstantFoldingPass::new()));
528 opt.add_pass(Box::new(DeadNodeEliminationPass::new()));
529 let nodes = vec![scale_node("s1", "1.0"), NodeSpec::new("orphan", "Filter")];
530 let (_, report) = opt.run(nodes);
531 assert_eq!(report.passes_applied.len(), 2);
532 }
533
534 #[test]
535 fn test_optimizer_no_passes() {
536 let opt = GraphOptimizer::new();
537 let nodes = vec![NodeSpec::new("a", "Filter")];
538 let (result, report) = opt.run(nodes);
539 assert_eq!(result.len(), 1);
540 assert_eq!(report.optimizations, 0);
541 assert!(report.passes_applied.is_empty());
542 }
543}