cubecl_core/frontend/operation/
clamp.rs

1use half::{bf16, f16};
2
3use crate::{
4    flex32,
5    ir::{Arithmetic, ClampOperator, ExpandElement, Scope},
6    prelude::CubePrimitive,
7    tf32, unexpanded,
8};
9
10use super::unary_expand;
11
12pub trait Clamp: CubePrimitive + Sized {
13    /// Clamp the input value between the max and min values provided.
14    #[allow(unused_variables)]
15    fn clamp(input: Self, min_value: Self, max_value: Self) -> Self {
16        unexpanded!()
17    }
18    fn __expand_clamp(
19        scope: &mut Scope,
20        input: Self::ExpandType,
21        min_value: Self::ExpandType,
22        max_value: Self::ExpandType,
23    ) -> Self::ExpandType {
24        let input: ExpandElement = input.into();
25        let min_value: ExpandElement = min_value.into();
26        let max_value: ExpandElement = max_value.into();
27
28        unary_expand(scope, input, |op| {
29            Arithmetic::Clamp(ClampOperator {
30                input: op.input,
31                min_value: *min_value,
32                max_value: *max_value,
33            })
34        })
35        .into()
36    }
37}
38
39impl Clamp for f16 {}
40impl Clamp for bf16 {}
41impl Clamp for flex32 {}
42impl Clamp for tf32 {}
43impl Clamp for f32 {}
44impl Clamp for f64 {}
45impl Clamp for i8 {}
46impl Clamp for i16 {}
47impl Clamp for i32 {}
48impl Clamp for i64 {}
49impl Clamp for u8 {}
50impl Clamp for u16 {}
51impl Clamp for u32 {}
52impl Clamp for u64 {}