burn_ndarray/ops/
activations.rs

1use crate::{
2    NdArray, NdArrayTensor, SharedArray,
3    element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
4    execute_with_numeric_dtype,
5    ops::NdArrayMathOps,
6};
7use burn_tensor::{
8    ElementConversion, TensorMetadata,
9    ops::{ActivationOps, FloatTensor},
10};
11
12impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ActivationOps<Self>
13    for NdArray<E, I, Q>
14where
15    NdArrayTensor: From<SharedArray<E>>,
16    NdArrayTensor: From<SharedArray<I>>,
17{
18    fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
19        execute_with_numeric_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(tensor, 0.elem()))
20    }
21}