cubecl_core/frontend/
trigonometry.rs

1use cubecl_ir::{ExpandElement, Variable};
2
3use crate::prelude::*;
4use crate::{self as cubecl};
5
6/// Computes the hypotenuse of a right triangle given the lengths of the other two sides.
7///
8/// This function computes `sqrt(x² + y²)` in a numerically stable way that avoids
9/// overflow and underflow issues.
10#[cube]
11pub fn hypot<F: Float>(lhs: Line<F>, rhs: Line<F>) -> Line<F> {
12    let one = Line::empty(lhs.size()).fill(F::from_int(1));
13    let a = lhs.abs();
14    let b = rhs.abs();
15    let max_val = max(a, b);
16    let max_val_is_zero = max_val.equal(Line::empty(lhs.size()).fill(F::from_int(0)));
17    let max_val_safe = select_many(max_val_is_zero, one, max_val);
18    let min_val = min(a, b);
19    let t = min_val / max_val_safe;
20
21    max_val * fma(t, t, one).sqrt()
22}
23
24#[allow(missing_docs)]
25pub fn expand_hypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
26    scope.register_type::<FloatExpand<0>>(lhs.ty.storage_type());
27    let res = hypot::expand::<FloatExpand<0>>(
28        scope,
29        ExpandElement::Plain(lhs).into(),
30        ExpandElement::Plain(rhs).into(),
31    );
32    assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
33}
34
35/// Computes the reciprocal of the hypotenuse of a right triangle given the lengths of the other two sides.
36///
37/// This function computes `1 / sqrt(x² + y²)` in a numerically stable way that avoids
38/// overflow and underflow issues.
39#[cube]
40pub fn rhypot<F: Float>(lhs: Line<F>, rhs: Line<F>) -> Line<F> {
41    let one = Line::empty(lhs.size()).fill(F::from_int(1));
42    let a = lhs.abs();
43    let b = rhs.abs();
44    let max_val = max(a, b);
45    let max_val_is_zero = max_val.equal(Line::empty(lhs.size()).fill(F::from_int(0)));
46    let max_val_safe = select_many(max_val_is_zero, one, max_val);
47    let min_val = min(a, b);
48    let t = min_val / max_val_safe;
49
50    fma(t, t, one).inverse_sqrt() / max_val
51}
52
53#[allow(missing_docs)]
54pub fn expand_rhypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
55    scope.register_type::<FloatExpand<0>>(lhs.ty.storage_type());
56    let res = rhypot::expand::<FloatExpand<0>>(
57        scope,
58        ExpandElement::Plain(lhs).into(),
59        ExpandElement::Plain(rhs).into(),
60    );
61    assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
62}