burn_ndarray/ops/
activations.rs1use crate::{
2 NdArray, NdArrayTensor, SharedArray,
3 element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
4 execute_with_numeric_dtype,
5 ops::NdArrayMathOps,
6};
7use burn_backend::{ElementConversion, TensorMetadata, ops::ActivationOps, tensor::FloatTensor};
8
9impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ActivationOps<Self>
10 for NdArray<E, I, Q>
11where
12 NdArrayTensor: From<SharedArray<E>>,
13 NdArrayTensor: From<SharedArray<I>>,
14{
15 fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
16 execute_with_numeric_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, 0.elem()))
17 }
18}