Skip to main content

cubecl_core/frontend/
trigonometry.rs

1use cubecl_ir::{ManagedVariable, Variable};
2
3use crate::prelude::*;
4use crate::{self as cubecl};
5
6define_scalar!(ElemA);
7define_size!(SizeA);
8
9/// Computes the hypotenuse of a right triangle given the lengths of the other two sides.
10///
11/// This function computes `sqrt(x² + y²)` in a numerically stable way that avoids
12/// overflow and underflow issues.
13#[cube]
14pub fn hypot<F: Float, N: Size>(lhs: Vector<F, N>, rhs: Vector<F, N>) -> Vector<F, N> {
15    let one = Vector::new(F::from_int(1));
16    let a = lhs.abs();
17    let b = rhs.abs();
18    let max_val = max(a, b);
19    let max_val_is_zero = max_val.equal(Vector::new(F::from_int(0)));
20    let max_val_safe = select_many(max_val_is_zero, one, max_val);
21    let min_val = min(a, b);
22    let t = min_val / max_val_safe;
23
24    max_val * fma(t, t, one).sqrt()
25}
26
27#[allow(missing_docs)]
28pub fn expand_hypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
29    scope.register_type::<ElemA>(lhs.ty.storage_type());
30    scope.register_size::<SizeA>(lhs.vector_size());
31    let res = hypot::expand::<ElemA, SizeA>(
32        scope,
33        ManagedVariable::Plain(lhs).into(),
34        ManagedVariable::Plain(rhs).into(),
35    );
36    assign::expand_no_check(scope, res, ManagedVariable::Plain(out).into());
37}
38
39/// Computes the reciprocal of the hypotenuse of a right triangle given the lengths of the other two sides.
40///
41/// This function computes `1 / sqrt(x² + y²)` in a numerically stable way that avoids
42/// overflow and underflow issues.
43#[cube]
44pub fn rhypot<F: Float, N: Size>(lhs: Vector<F, N>, rhs: Vector<F, N>) -> Vector<F, N> {
45    let one = Vector::new(F::from_int(1));
46    let a = lhs.abs();
47    let b = rhs.abs();
48    let max_val = max(a, b);
49    let max_val_is_zero = max_val.equal(Vector::new(F::from_int(0)));
50    let max_val_safe = select_many(max_val_is_zero, one, max_val);
51    let min_val = min(a, b);
52    let t = min_val / max_val_safe;
53
54    fma(t, t, one).inverse_sqrt() / max_val
55}
56
57#[allow(missing_docs)]
58pub fn expand_rhypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
59    scope.register_type::<ElemA>(lhs.ty.storage_type());
60    scope.register_size::<SizeA>(lhs.vector_size());
61    let res = rhypot::expand::<ElemA, SizeA>(
62        scope,
63        ManagedVariable::Plain(lhs).into(),
64        ManagedVariable::Plain(rhs).into(),
65    );
66    assign::expand_no_check(scope, res, ManagedVariable::Plain(out).into());
67}