burn_ndarray/ops/
activations.rs

1use crate::{
2    element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
3    execute_with_float_dtype,
4    tensor::NdArrayTensor,
5    NdArray,
6};
7use burn_tensor::{
8    ops::{ActivationOps, FloatTensor},
9    ElementConversion,
10};
11
12impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ActivationOps<Self>
13    for NdArray<E, I, Q>
14{
15    fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
16        execute_with_float_dtype!(tensor, |tensor: NdArrayTensor<_>| {
17            let zero = 0.elem();
18            let array = tensor
19                .array
20                .mapv_into(|elem| match elem < zero {
21                    true => zero,
22                    false => elem,
23                })
24                .into_shared();
25
26            NdArrayTensor::new(array)
27        })
28    }
29}