cubecl_core/frontend/operation/
clamp.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
use half::{bf16, f16};

use crate::{
    ir::{ClampOperator, Operator},
    prelude::{CubeContext, CubePrimitive, ExpandElement},
    unexpanded,
};

use super::unary_expand;

pub trait Clamp: CubePrimitive + Sized {
    /// Clamp the input value between the max and min values provided.
    #[allow(unused_variables)]
    fn clamp(input: Self, min_value: Self, max_value: Self) -> Self {
        unexpanded!()
    }
    fn __expand_clamp(
        context: &mut CubeContext,
        input: Self::ExpandType,
        min_value: Self::ExpandType,
        max_value: Self::ExpandType,
    ) -> Self::ExpandType {
        let input: ExpandElement = input.into();
        let min_value: ExpandElement = min_value.into();
        let max_value: ExpandElement = max_value.into();

        unary_expand(context, input, |op| {
            Operator::Clamp(ClampOperator {
                input: op.input,
                min_value: *min_value,
                max_value: *max_value,
                out: op.out,
            })
        })
        .into()
    }
}

impl Clamp for f16 {}
impl Clamp for bf16 {}
impl Clamp for f32 {}
impl Clamp for f64 {}
impl Clamp for i32 {}
impl Clamp for i64 {}
impl Clamp for u32 {}