burn_backend/backend/ops/
activation.rs1use crate::tensor::FloatTensor;
2use crate::{Backend, Scalar, TensorMetadata, get_device_settings};
3use core::f64::consts::SQRT_2;
4
5pub trait ActivationOps<B: Backend> {
9 fn leaky_relu(tensor: FloatTensor<B>, negative_slope: Scalar) -> FloatTensor<B> {
20 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
21 let mask = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
22 let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope);
23
24 B::float_mask_where(tensor, mask, scaled_tensor)
26 }
27
28 fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {
38 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
39 let mask = B::float_lower_equal_elem(tensor.clone(), 0f32.into(), bool_dtype);
40
41 B::float_mask_fill(tensor, mask, 0f32.into())
42 }
43
44 fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
54 let bool_dtype = get_device_settings::<B>(&B::float_device(&output)).bool_dtype;
55 let mask = B::float_lower_equal_elem(output, 0f32.into(), bool_dtype);
56
57 B::float_mask_fill(grad, mask, 0.into())
58 }
59
60 fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {
70 let x = B::float_div_scalar(tensor.clone(), SQRT_2.into());
71 let x = B::float_erf(x);
72 let x = B::float_add_scalar(x, 1f32.into());
73 let x = B::float_mul(tensor, x);
74
75 B::float_div_scalar(x, 2f32.into())
76 }
77 fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {
82 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
83 let mask = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
84 let scaled_tensor = B::float_mul(tensor.clone(), alpha);
85 B::float_mask_where(tensor, mask, scaled_tensor)
86 }
87
88 fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
99 let constant_1 = 0.0356774;
102 let constant_2 = 0.797885;
103 let constant_3 = 0.0535161;
104 let constant_4 = 0.398942;
105
106 let x3 = B::float_powi_scalar(x.clone(), 3.into());
107
108 let c1 = B::float_mul_scalar(x3.clone(), constant_1.into());
109 let c2 = B::float_mul_scalar(x.clone(), constant_2.into());
110 let c3 = B::float_mul_scalar(x3, constant_3.into());
111 let c4 = B::float_mul_scalar(x, constant_4.into());
112
113 let inner1 = B::float_add(c1, c2);
114 let inner2 = B::float_add(c3, c4);
115
116 let tanh = B::float_tanh(inner1);
117
118 let sech = B::float_powi_scalar(tanh.clone(), 2.into());
119 let sech = B::float_neg(sech);
120 let sech = B::float_add_scalar(sech, 1.into());
121
122 let y1 = B::float_mul_scalar(tanh, 0.5.into());
123 let y2 = B::float_mul(inner2, sech);
124 let y2 = B::float_add_scalar(y2, 0.5.into());
125 let y = B::float_add(y1, y2);
126
127 B::float_mul(y, grad)
128 }
129
130 fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
140 let dtype = tensor.dtype();
141 let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
142 let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(
143 B::float_exp(B::float_neg(tensor_full)),
144 1.0.into(),
145 ))));
146
147 B::float_cast(tensor_tmp, dtype.into())
148 }
149
150 fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
161 let value = B::float_mul(
162 output.clone(),
163 B::float_add_scalar(B::float_neg(output), 1.0.into()),
164 );
165 B::float_mul(value, grad)
166 }
167
168 fn hard_sigmoid(tensor: FloatTensor<B>, alpha: Scalar, beta: Scalar) -> FloatTensor<B> {
180 let dtype = tensor.dtype();
181 let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
182
183 let tensor_tmp = B::float_clamp(
184 B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha), beta),
185 0.0.into(),
186 1.0.into(),
187 );
188
189 B::float_cast(tensor_tmp, dtype.into())
190 }
191
192 fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
202 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
221 let tensor_neg = B::float_neg(tensor);
222 let mask = B::float_lower_elem(tensor_neg.clone(), 0f32.into(), bool_dtype);
223 let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0f32.into());
224 let max_elem_neg = B::float_neg(max_elem.clone());
225
226 let z = B::float_add(
228 B::float_exp(max_elem_neg.clone()),
229 B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
230 );
231
232 B::float_sub(max_elem_neg, B::float_log(z))
234 }
235
236 fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
247 let shape = x.shape();
256 let dtype = x.dtype();
257 let device = B::float_device(&x);
258 let bool_dtype = get_device_settings::<B>(&device).bool_dtype;
259
260 let x_neg = B::float_neg(x);
262 let mask = B::float_lower_elem(x_neg.clone(), 0f32.into(), bool_dtype); let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0f32.into());
264
265 let z = B::float_add(
267 B::float_exp(B::float_neg(max_elem.clone())),
268 B::float_exp(B::float_sub(x_neg, max_elem)),
269 );
270
271 let ones = B::float_ones(shape, &device, dtype.into());
273 let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0f32.into());
274 let sign = B::float_mask_fill(ones.clone(), mask, (-1f32).into());
275
276 B::float_mul(
278 grad,
279 B::float_sub(
280 max_derive,
281 B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
282 ),
283 )
284 }
285}