1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
use crate::{backend::Backend, ElementConversion};
use core::f64::consts::SQRT_2;
/// Activation function operations.
///
/// This trait let backend implementations override activation functions for better performance.
pub trait ActivationOps<B: Backend> {
/// Applies the ReLU activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn relu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
let mask = B::lower_equal_elem(tensor.clone(), 0.elem());
B::mask_fill(tensor, mask, 0.elem())
}
/// Applies the ReLU activation function backward.
///
/// # Arguments
///
/// * `output` - The output tensor.
///
/// # Returns
///
/// The gradient.
fn relu_backward<const D: usize>(
output: B::TensorPrimitive<D>,
grad: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D> {
let mask = B::lower_equal_elem(output, 0.elem());
B::mask_fill(grad, mask, 0.elem())
}
/// Applies the Gelu activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn gelu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
let x = B::div_scalar(tensor.clone(), SQRT_2.elem());
let x = B::erf(x);
let x = B::add_scalar(x, 1i32.elem());
let x = B::mul(tensor, x);
B::div_scalar(x, 2i32.elem())
}
/// Applies the Gelu activation function backward.
///
/// # Arguments
///
/// * `x` - The tensor.
/// * `grad` - The gradient.
///
/// # Returns
///
/// The output tensor.
fn gelu_backward<const D: usize>(
x: B::TensorPrimitive<D>,
grad: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D> {
// Derivative of the approximate gelu implementation based on tanh.
let constant_1 = 0.0356774;
let constant_2 = 0.797885;
let constant_3 = 0.0535161;
let constant_4 = 0.398942;
let x3 = B::powf(x.clone(), 3.0);
let c1 = B::mul_scalar(x3.clone(), constant_1.elem());
let c2 = B::mul_scalar(x.clone(), constant_2.elem());
let c3 = B::mul_scalar(x3, constant_3.elem());
let c4 = B::mul_scalar(x, constant_4.elem());
let inner1 = B::add(c1, c2);
let inner2 = B::add(c3, c4);
let tanh = B::tanh(inner1);
let sech = B::powf(tanh.clone(), 2.0);
let sech = B::neg(sech);
let sech = B::add_scalar(sech, 1.elem());
let y1 = B::mul_scalar(tanh, 0.5.elem());
let y2 = B::mul(inner2, sech);
let y2 = B::add_scalar(y2, 0.5.elem());
let y = B::add(y1, y2);
B::mul(y, grad)
}
}