cubecl_core/frontend/
trigonometry.rs1use cubecl_ir::{ExpandElement, Variable};
2
3use crate::prelude::*;
4use crate::{self as cubecl};
5
6#[cube]
11pub fn hypot<F: Float>(lhs: Line<F>, rhs: Line<F>) -> Line<F> {
12 let one = Line::empty(lhs.size()).fill(F::from_int(1));
13 let a = lhs.abs();
14 let b = rhs.abs();
15 let max_val = max(a, b);
16 let max_val_is_zero = max_val.equal(Line::empty(lhs.size()).fill(F::from_int(0)));
17 let max_val_safe = select_many(max_val_is_zero, one, max_val);
18 let min_val = min(a, b);
19 let t = min_val / max_val_safe;
20
21 max_val * fma(t, t, one).sqrt()
22}
23
24#[allow(missing_docs)]
25pub fn expand_hypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
26 scope.register_type::<FloatExpand<0>>(lhs.ty.storage_type());
27 let res = hypot::expand::<FloatExpand<0>>(
28 scope,
29 ExpandElement::Plain(lhs).into(),
30 ExpandElement::Plain(rhs).into(),
31 );
32 assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
33}
34
35#[cube]
40pub fn rhypot<F: Float>(lhs: Line<F>, rhs: Line<F>) -> Line<F> {
41 let one = Line::empty(lhs.size()).fill(F::from_int(1));
42 let a = lhs.abs();
43 let b = rhs.abs();
44 let max_val = max(a, b);
45 let max_val_is_zero = max_val.equal(Line::empty(lhs.size()).fill(F::from_int(0)));
46 let max_val_safe = select_many(max_val_is_zero, one, max_val);
47 let min_val = min(a, b);
48 let t = min_val / max_val_safe;
49
50 fma(t, t, one).inverse_sqrt() / max_val
51}
52
53#[allow(missing_docs)]
54pub fn expand_rhypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
55 scope.register_type::<FloatExpand<0>>(lhs.ty.storage_type());
56 let res = rhypot::expand::<FloatExpand<0>>(
57 scope,
58 ExpandElement::Plain(lhs).into(),
59 ExpandElement::Plain(rhs).into(),
60 );
61 assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
62}