use crate::graph::Graph;
use crate::optimizer::core::Optimizer;
mod affine;
mod core;
pub mod recurse;
#[derive(Debug, Copy, Clone)]
pub struct OptimizerSettings {
pub optimize: bool,
pub force_bias_through_conv: bool,
pub fuse_layernorm: bool,
pub div_to_mul: bool,
}
pub fn optimize_graph(graph: &Graph, settings: OptimizerSettings) -> Graph {
if !settings.optimize {
return graph.clone();
}
let mut optimizer = Optimizer::new(settings, graph);
for &old_input in graph.inputs() {
let (shape, dtype) = graph.shape_dtype(old_input);
let new_input = optimizer.new_graph.input(shape.clone(), dtype);
optimizer.insert_mapping(old_input, new_input);
}
for &old_output in graph.outputs() {
let new_output = optimizer.visit_completely(old_output);
optimizer.new_graph.output(new_output);
}
optimizer.new_graph
}
#[allow(clippy::derivable_impls)]
impl Default for OptimizerSettings {
fn default() -> Self {
OptimizerSettings {
optimize: true,
force_bias_through_conv: false,
fuse_layernorm: true,
div_to_mul: true,
}
}
}