use std::collections::HashMap;
use indexmap::IndexMap;
use crate::dtype::{DScalar, DTensor, DType};
use crate::graph::{BinaryOp, Graph, Operation, ReduceOp, UnaryOp, Value};
use crate::optimizer::OptimizerSettings;
use crate::optimizer::recurse::heap_recurse;
#[derive(Debug)]
pub struct Optimizer<'a> {
settings: OptimizerSettings,
pub old_graph: &'a Graph,
pub new_graph: Graph,
map: HashMap<Value, Value>,
}
pub type VisitResult<T> = Result<T, Value>;
impl<'a> Optimizer<'a> {
pub fn new(settings: OptimizerSettings, old_graph: &'a Graph) -> Self {
Optimizer {
settings,
new_graph: Graph::new(),
old_graph,
map: HashMap::default(),
}
}
pub fn visit_completely(&mut self, old: Value) -> Value {
heap_recurse(old, |curr_old| self.visit_single_cached(curr_old))
}
fn visit(&mut self, old: Value) -> VisitResult<Value> {
if let Some(&new) = self.map.get(&old) {
Ok(new)
} else {
Err(old)
}
}
fn visit_single_cached(&mut self, old: Value) -> VisitResult<Value> {
if let Some(&new) = self.map.get(&old) {
return Ok(new);
}
let new = self.visit_single_new(old)?;
self.insert_mapping(old, new);
Ok(new)
}
pub fn insert_mapping(&mut self, old: Value, new: Value) {
assert_eq!(self.old_graph[old].shape, self.new_graph[new].shape);
let prev = self.map.insert(old, new);
assert!(prev.is_none());
}
fn visit_single_new(&mut self, old_value: Value) -> VisitResult<Value> {
if let Some(fused) = self.try_fuse(old_value)? {
self.new_graph
.set_debug_id(fused, self.old_graph[old_value].debug_id.clone());
return Ok(fused);
}
let old_info = &self.old_graph[old_value];
let shape = old_info.shape.clone();
let old_operation = &old_info.operation;
for old_input in old_operation.inputs() {
self.visit(old_input)?;
}
let new_operation = old_operation.clone_map_inputs(|old_input| self.visit(old_input).unwrap());
let new_value = self.new_graph.push(shape, old_info.dtype, new_operation);
self.new_graph.set_debug_id(new_value, old_info.debug_id.clone());
Ok(new_value)
}
fn try_fuse(&mut self, old_start: Value) -> VisitResult<Option<Value>> {
let (_, dtype) = self.old_graph.shape_dtype(old_start);
if dtype != DType::F32 {
return Ok(None);
}
if self.settings.fuse_layernorm {
if let Some(result) = self.try_fuse_layernorm(old_start)? {
return Ok(Some(result));
}
}
if let Some(result) = self.try_fuse_clamp(old_start)? {
return Ok(Some(result));
}
if let Some(result) = self.try_fuse_conv_affine(old_start)? {
return Ok(Some(result));
}
if self.settings.div_to_mul {
if let Some(result) = self.try_convert_div_to_mul(old_start)? {
return Ok(Some(result));
}
}
Ok(None)
}
fn try_fuse_layernorm(&mut self, value: Value) -> VisitResult<Option<Value>> {
let mut fused_values = IndexMap::<Value, usize>::new();
let mut op = |v| {
*fused_values.entry(v).or_insert(0) += 1;
&self.old_graph[v].operation
};
if let &Operation::Binary {
left: zeroed0,
right: std_broadcast,
op: BinaryOp::Div,
} = &self.old_graph[value].operation
{
if let &Operation::Binary {
left: input0,
right: mean_broadcast,
op: BinaryOp::Sub,
} = op(zeroed0)
{
if let &Operation::Broadcast { input: mean_view } = op(mean_broadcast) {
if let &Operation::View { input: mean } = op(mean_view) {
if let &Operation::Reduce {
input: input1,
axes: ref axes0,
op: ReduceOp::Mean,
} = op(mean)
{
if let &Operation::Broadcast { input: std } = op(std_broadcast) {
if let &Operation::Unary {
input: stable_var,
op: UnaryOp::Sqrt,
} = op(std)
{
if let &Operation::Binary {
left: var_view,
right: const_eps,
op: BinaryOp::Add,
} = op(stable_var)
{
if let &Operation::View { input: var } = op(var_view) {
if let &Operation::Reduce {
input: pow,
axes: ref axes1,
op: ReduceOp::Mean,
} = op(var)
{
if let &Operation::Binary {
left: zeroed1,
right: const_2,
op: BinaryOp::Pow,
} = op(pow)
{
op(zeroed1);
if input0 != input1 || zeroed0 != zeroed1 || axes0 != axes1 {
return Ok(None);
}
if fused_values.iter().any(|(&fused_value, &count)| {
!self.old_graph.is_hidden_with_uses(fused_value, count)
}) {
return Ok(None);
}
return self
.try_fuse_layernorm_inner(input0, axes0, const_2, const_eps);
}
}
}
}
}
}
}
}
}
}
}
Ok(None)
}
fn try_fuse_layernorm_inner(
&mut self,
old_input: Value,
axes: &[usize],
old_const_2: Value,
old_const_eps: Value,
) -> VisitResult<Option<Value>> {
if axes.len() != 1 {
return Ok(None);
}
let axis = axes[0];
let eps = match self.old_graph.as_single_const(old_const_eps) {
Some(DScalar::F32(eps)) => *eps,
_ => return Ok(None),
};
if !self.old_graph.is_const_filled_with(old_const_2, DScalar::f32(2.0)) {
return Ok(None);
}
let new_input = self.visit(old_input)?;
Ok(Some(self.new_graph.layernorm(new_input, axis, eps)))
}
fn try_fuse_clamp(&mut self, old_start: Value) -> VisitResult<Option<Value>> {
let mut total_min = f32::NEG_INFINITY;
let mut total_max = f32::INFINITY;
let old_input = self.follow_if(old_start, |_, _, operation| {
if let &Operation::Binary {
left: old_left,
right: old_right,
op: op @ (BinaryOp::Min | BinaryOp::Max),
} = operation
{
if let Some(DScalar::F32(value)) = self.old_graph.as_single_const(old_right) {
match op {
BinaryOp::Min => total_max = f32::min(total_max, *value),
BinaryOp::Max => total_min = f32::max(total_min, *value),
_ => unreachable!(),
}
return Ok(Some(old_left));
}
}
Ok(None)
})?;
if let Some(old_input) = old_input {
let new_input = self.visit(old_input)?;
let new_output = self.new_graph.clamp::<f32>(new_input, total_min, total_max);
Ok(Some(new_output))
} else {
Ok(None)
}
}
fn try_fuse_conv_affine(&mut self, old_start: Value) -> VisitResult<Option<Value>> {
let group = self.try_build_affine_group(old_start)?;
if let Some(group) = group {
let new_input = self.visit(group.old_input())?;
let new_start = group.apply_fused(self.settings, &mut self.new_graph, new_input);
Ok(Some(new_start))
} else {
Ok(None)
}
}
fn try_convert_div_to_mul(&mut self, old_start: Value) -> VisitResult<Option<Value>> {
if let &Operation::Binary {
left,
right,
op: BinaryOp::Div,
} = &self.old_graph[old_start].operation
{
if let Some(DTensor::F32(data)) = self.old_graph.as_const(right) {
let new_data = data.mapv(|x| 1.0 / x).into_shared();
let new_right = self.new_graph.constant_tensor(DTensor::F32(new_data));
let new_left = self.visit(left)?;
let result = self.new_graph.mul(new_left, new_right);
Ok(Some(result))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
pub fn follow_if(
&self,
start: Value,
mut next: impl FnMut(&Graph, Value, &Operation) -> VisitResult<Option<Value>>,
) -> VisitResult<Option<Value>> {
let mut curr = start;
loop {
if !self.old_graph.is_hidden_with_uses(curr, 1) {
break;
}
if let Some(next) = next(self.old_graph, curr, &self.old_graph[curr].operation)? {
curr = next;
} else {
break;
}
}
if curr == start {
Ok(None)
} else {
Ok(Some(curr))
}
}
}