cubecl_core/frontend/
trigonometry.rs1use cubecl_ir::{ManagedVariable, Variable};
2
3use crate::prelude::*;
4use crate::{self as cubecl};
5
6define_scalar!(ElemA);
7define_size!(SizeA);
8
9#[cube]
14pub fn hypot<F: Float, N: Size>(lhs: Vector<F, N>, rhs: Vector<F, N>) -> Vector<F, N> {
15 let one = Vector::new(F::from_int(1));
16 let a = lhs.abs();
17 let b = rhs.abs();
18 let max_val = max(a, b);
19 let max_val_is_zero = max_val.equal(Vector::new(F::from_int(0)));
20 let max_val_safe = select_many(max_val_is_zero, one, max_val);
21 let min_val = min(a, b);
22 let t = min_val / max_val_safe;
23
24 max_val * fma(t, t, one).sqrt()
25}
26
27#[allow(missing_docs)]
28pub fn expand_hypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
29 scope.register_type::<ElemA>(lhs.ty.storage_type());
30 scope.register_size::<SizeA>(lhs.vector_size());
31 let res = hypot::expand::<ElemA, SizeA>(
32 scope,
33 ManagedVariable::Plain(lhs).into(),
34 ManagedVariable::Plain(rhs).into(),
35 );
36 assign::expand_no_check(scope, res, ManagedVariable::Plain(out).into());
37}
38
39#[cube]
44pub fn rhypot<F: Float, N: Size>(lhs: Vector<F, N>, rhs: Vector<F, N>) -> Vector<F, N> {
45 let one = Vector::new(F::from_int(1));
46 let a = lhs.abs();
47 let b = rhs.abs();
48 let max_val = max(a, b);
49 let max_val_is_zero = max_val.equal(Vector::new(F::from_int(0)));
50 let max_val_safe = select_many(max_val_is_zero, one, max_val);
51 let min_val = min(a, b);
52 let t = min_val / max_val_safe;
53
54 fma(t, t, one).inverse_sqrt() / max_val
55}
56
57#[allow(missing_docs)]
58pub fn expand_rhypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
59 scope.register_type::<ElemA>(lhs.ty.storage_type());
60 scope.register_size::<SizeA>(lhs.vector_size());
61 let res = rhypot::expand::<ElemA, SizeA>(
62 scope,
63 ManagedVariable::Plain(lhs).into(),
64 ManagedVariable::Plain(rhs).into(),
65 );
66 assign::expand_no_check(scope, res, ManagedVariable::Plain(out).into());
67}