use tract_core::ops::binary::TypedBinOp;
use tract_core::ops::element_wise::ElementWiseOp;
use tract_core::ops::math::{Mul, Pow, Tanh};
use tract_nnef::internal::*;
use crate::rule_ensure;
use super::{
find_succ_add_with, find_succ_add_with_const, find_succ_mul_with_const,
matches_single_input_const, next_node,
};
pub fn register(registry: &mut Registry) {
registry.register_dumper(ser_gelu_approx);
registry.register_primitive(
"tract_transformers_gelu_approx",
&[TypeName::Scalar.tensor().named("input"), TypeName::Logical.named("fast_impl")],
&[("output", TypeName::Scalar.tensor())],
de_gelu_approx,
);
}
fn de_gelu_approx(
builder: &mut ModelBuilder,
invocation: &ResolvedInvocation,
) -> TractResult<Value> {
let input = invocation.named_arg_as(builder, "input")?;
let fast_impl = invocation.named_arg_as(builder, "fast_impl")?;
builder.wire(GeluApproximate { fast_impl }, &[input])
}
fn ser_gelu_approx(
ast: &mut IntoAst,
node: &TypedNode,
op: &GeluApproximate,
) -> TractResult<Option<Arc<RValue>>> {
let input = ast.mapping[&node.inputs[0]].clone();
Ok(Some(invocation(
"tract_transformers_gelu_approx",
&[input],
&[("fast_impl", logical(op.fast_impl))],
)))
}
#[derive(Default, Clone, Debug, Hash)]
pub struct GeluApproximate {
pub fast_impl: bool,
}
impl Op for GeluApproximate {
fn name(&self) -> StaticName {
if self.fast_impl {
"GeluApproximateFast".to_string().into()
} else {
"GeluApproximate".to_string().into()
}
}
op_as_typed_op!();
}
impl EvalOp for GeluApproximate {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs);
let dt = input.datum_type();
let a_f32 = input.cast_to_dt(DatumType::F32)?;
let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
let pow = if self.fast_impl { 2 } else { 3 };
let gelu_approx_f32_data = a_f32
.as_slice::<f32>()?
.iter()
.map(|x| 0.5 * x * (1.0 + f32::tanh(sqrt_2_over_pi * (x + 0.044715 * x.powi(pow)))))
.collect::<Vec<_>>();
let gelu_approx_f32 = Tensor::from_shape(input.shape(), &gelu_approx_f32_data)?;
Ok(tvec![gelu_approx_f32.cast_to_dt(dt)?.into_owned().into_tvalue()])
}
}
impl TypedOp for GeluApproximate {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let dt = inputs[0].datum_type;
let fact = dt.fact(inputs[0].shape.clone());
Ok(tvec!(fact))
}
as_op!();
}
pub fn gelu_approx_rule(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
node_name: &str,
op: &TypedBinOp,
) -> TractResult<Option<TypedModelPatch>> {
rule_ensure!(op.0.is::<Pow>());
let pow_node = node;
let in_fact = model.node_input_facts(pow_node.id)?[0];
let dt = in_fact.datum_type;
rule_ensure!(matches!(dt, DatumType::F32 | DatumType::F16));
let mut patch = TypedModelPatch::default();
let gelu_approx_input = patch.taps(model, &pow_node.inputs)?;
rule_ensure!(
matches_single_input_const(model, pow_node, 3.0)
|| matches_single_input_const(model, pow_node, 2.0)
);
let fast_impl = matches_single_input_const(model, pow_node, 2.0);
let Some(mul_coef_a) = find_succ_mul_with_const(model, pow_node, 0.044715) else {
return Ok(None);
};
let Some(x_plus_mul_coef_a) = find_succ_add_with(model, mul_coef_a, &pow_node.inputs[0]) else {
return Ok(None);
};
let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
let Some(mul_sqrt_2_over_pi) =
find_succ_mul_with_const(model, x_plus_mul_coef_a, sqrt_2_over_pi)
else {
return Ok(None);
};
let Some(tanh_succ) = next_node(model, mul_sqrt_2_over_pi) else { return Ok(None) };
let Some(tanh_succ_op) = tanh_succ.op_as::<ElementWiseOp>() else { return Ok(None) };
rule_ensure!(tanh_succ_op.0.is::<Tanh>());
let Some(tanh_plus_1) = find_succ_add_with_const(model, tanh_succ, 1.0) else {
return Ok(None);
};
let Some(mul_succ) = next_node(model, tanh_plus_1) else { return Ok(None) };
let Some(mul_succ_op) = mul_succ.op_as::<TypedBinOp>() else { return Ok(None) };
rule_ensure!(mul_succ_op.0.is::<Mul>());
let last_node_id = if mul_succ.inputs.contains(&pow_node.inputs[0]) {
let Some(last_mul_with_0_5) = find_succ_mul_with_const(model, mul_succ, 0.5) else {
return Ok(None);
};
last_mul_with_0_5.id
} else {
let Some(x_mul_0_5) = mul_succ
.inputs
.iter()
.filter_map(|i| {
let n = &model.nodes()[i.node];
let op = n.op_as::<TypedBinOp>()?;
op.0.is::<Mul>().then_some(n)
})
.next()
else {
return Ok(None);
};
rule_ensure!(matches_single_input_const(model, x_mul_0_5, 0.5));
rule_ensure!(x_mul_0_5.inputs.contains(&pow_node.inputs[0]));
mul_succ.id
};
let out = patch.wire_node(
format!("{node_name}.gelu_approx"),
GeluApproximate { fast_impl },
&[gelu_approx_input[0]],
)?;
patch.shunt_outside(model, last_node_id.into(), out[0])?;
Ok(Some(patch))
}