kn_graph/optimizer/
mod.rs1use crate::graph::Graph;
2use crate::optimizer::core::Optimizer;
3
4mod affine;
5mod core;
6pub mod recurse;
7
8#[derive(Debug, Copy, Clone)]
12pub struct OptimizerSettings {
13 pub optimize: bool,
15 pub force_bias_through_conv: bool,
18 pub fuse_layernorm: bool,
20 pub div_to_mul: bool,
22}
23
24pub fn optimize_graph(graph: &Graph, settings: OptimizerSettings) -> Graph {
26 if !settings.optimize {
27 return graph.clone();
28 }
29
30 let mut optimizer = Optimizer::new(settings, graph);
31
32 for &old_input in graph.inputs() {
34 let (shape, dtype) = graph.shape_dtype(old_input);
35 let new_input = optimizer.new_graph.input(shape.clone(), dtype);
36 optimizer.insert_mapping(old_input, new_input);
37 }
38
39 for &old_output in graph.outputs() {
41 let new_output = optimizer.visit_completely(old_output);
42 optimizer.new_graph.output(new_output);
43 }
44
45 optimizer.new_graph
46}
47
48#[allow(clippy::derivable_impls)]
49impl Default for OptimizerSettings {
50 fn default() -> Self {
51 OptimizerSettings {
52 optimize: true,
53 force_bias_through_conv: false,
54 fuse_layernorm: true,
55 div_to_mul: true,
56 }
57 }
58}