cubecl_core/frontend/operation/
clamp.rs

1use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
2use half::{bf16, f16};
3
4use crate::{
5    flex32,
6    ir::{Arithmetic, ClampOperator, ExpandElement, Scope},
7    prelude::CubePrimitive,
8    tf32, unexpanded,
9};
10
11use super::unary_expand;
12
13pub trait Clamp: CubePrimitive + Sized {
14    /// Clamp the input value between the max and min values provided.
15    #[allow(unused_variables)]
16    fn clamp(input: Self, min_value: Self, max_value: Self) -> Self {
17        unexpanded!()
18    }
19    fn __expand_clamp(
20        scope: &mut Scope,
21        input: Self::ExpandType,
22        min_value: Self::ExpandType,
23        max_value: Self::ExpandType,
24    ) -> Self::ExpandType {
25        let input: ExpandElement = input.into();
26        let min_value: ExpandElement = min_value.into();
27        let max_value: ExpandElement = max_value.into();
28
29        unary_expand(scope, input, |op| {
30            Arithmetic::Clamp(ClampOperator {
31                input: op.input,
32                min_value: *min_value,
33                max_value: *max_value,
34            })
35        })
36        .into()
37    }
38}
39
40impl Clamp for e2m1 {}
41impl Clamp for e4m3 {}
42impl Clamp for e5m2 {}
43impl Clamp for ue8m0 {}
44impl Clamp for f16 {}
45impl Clamp for bf16 {}
46impl Clamp for flex32 {}
47impl Clamp for tf32 {}
48impl Clamp for f32 {}
49impl Clamp for f64 {}
50impl Clamp for i8 {}
51impl Clamp for i16 {}
52impl Clamp for i32 {}
53impl Clamp for i64 {}
54impl Clamp for u8 {}
55impl Clamp for u16 {}
56impl Clamp for u32 {}
57impl Clamp for u64 {}