burn_ndarray/ops/
activations.rs1use crate::{
2 element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
3 execute_with_float_dtype,
4 tensor::NdArrayTensor,
5 NdArray,
6};
7use burn_tensor::{
8 ops::{ActivationOps, FloatTensor},
9 ElementConversion,
10};
11
12impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ActivationOps<Self>
13 for NdArray<E, I, Q>
14{
15 fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
16 execute_with_float_dtype!(tensor, |tensor: NdArrayTensor<_>| {
17 let zero = 0.elem();
18 let array = tensor
19 .array
20 .mapv_into(|elem| match elem < zero {
21 true => zero,
22 false => elem,
23 })
24 .into_shared();
25
26 NdArrayTensor::new(array)
27 })
28 }
29}