kn_graph/optimizer/
mod.rs

1use crate::graph::Graph;
2use crate::optimizer::core::Optimizer;
3
4mod affine;
5mod core;
6pub mod recurse;
7
8/// Settings for the optimizer.
9///
10/// Use `Default::default()` to get reasonable defaults.
11#[derive(Debug, Copy, Clone)]
12pub struct OptimizerSettings {
13    /// If `false`, don't do any optimization at all.
14    pub optimize: bool,
15    /// If `true`, convert a bias operation followed by a convolution _through_ the convolution,
16    /// even in cases where this requires switching to a non-spatially-broadcasted bias constant.
17    pub force_bias_through_conv: bool,
18    /// If `true`, try fusing the right sequence of operations into a single LayerNorm operation.
19    pub fuse_layernorm: bool,
20    /// If `true`, convert a division by a constant into multiplication by the inverse consent.
21    pub div_to_mul: bool,
22}
23
24/// Optimize the given graph according to the given settings. Returns a new, fully independent graph.
25pub fn optimize_graph(graph: &Graph, settings: OptimizerSettings) -> Graph {
26    if !settings.optimize {
27        return graph.clone();
28    }
29
30    let mut optimizer = Optimizer::new(settings, graph);
31
32    // ensure all inputs are copied over in the same order
33    for &old_input in graph.inputs() {
34        let (shape, dtype) = graph.shape_dtype(old_input);
35        let new_input = optimizer.new_graph.input(shape.clone(), dtype);
36        optimizer.insert_mapping(old_input, new_input);
37    }
38
39    // register all outputs, again in the same order as before
40    for &old_output in graph.outputs() {
41        let new_output = optimizer.visit_completely(old_output);
42        optimizer.new_graph.output(new_output);
43    }
44
45    optimizer.new_graph
46}
47
48#[allow(clippy::derivable_impls)]
49impl Default for OptimizerSettings {
50    fn default() -> Self {
51        OptimizerSettings {
52            optimize: true,
53            force_bias_through_conv: false,
54            fuse_layernorm: true,
55            div_to_mul: true,
56        }
57    }
58}